Merge pull request #1656 from tw4452852/1587-limits

Introduce `limits` middleware
This commit is contained in:
Matt Holt 2017-05-08 11:39:10 -06:00 committed by GitHub
commit f06b825f44
11 changed files with 363 additions and 143 deletions

View file

@ -16,9 +16,9 @@ import (
_ "github.com/mholt/caddy/caddyhttp/header" _ "github.com/mholt/caddy/caddyhttp/header"
_ "github.com/mholt/caddy/caddyhttp/index" _ "github.com/mholt/caddy/caddyhttp/index"
_ "github.com/mholt/caddy/caddyhttp/internalsrv" _ "github.com/mholt/caddy/caddyhttp/internalsrv"
_ "github.com/mholt/caddy/caddyhttp/limits"
_ "github.com/mholt/caddy/caddyhttp/log" _ "github.com/mholt/caddy/caddyhttp/log"
_ "github.com/mholt/caddy/caddyhttp/markdown" _ "github.com/mholt/caddy/caddyhttp/markdown"
_ "github.com/mholt/caddy/caddyhttp/maxrequestbody"
_ "github.com/mholt/caddy/caddyhttp/mime" _ "github.com/mholt/caddy/caddyhttp/mime"
_ "github.com/mholt/caddy/caddyhttp/pprof" _ "github.com/mholt/caddy/caddyhttp/pprof"
_ "github.com/mholt/caddy/caddyhttp/proxy" _ "github.com/mholt/caddy/caddyhttp/proxy"

View file

@ -436,7 +436,7 @@ var directives = []string{
"root", "root",
"index", "index",
"bind", "bind",
"maxrequestbody", // TODO: 'limits' "limits",
"timeouts", "timeouts",
"tls", "tls",

View file

@ -302,7 +302,7 @@ func (r *replacer) getSubstitution(key string) string {
} }
_, err := ioutil.ReadAll(r.request.Body) _, err := ioutil.ReadAll(r.request.Body)
if err != nil { if err != nil {
if _, ok := err.(MaxBytesExceeded); ok { if err == MaxBytesExceededErr {
return r.emptyValue return r.emptyValue
} }
} }

View file

@ -4,8 +4,8 @@ package httpserver
import ( import (
"context" "context"
"crypto/tls" "crypto/tls"
"errors"
"fmt" "fmt"
"io"
"log" "log"
"net" "net"
"net/http" "net/http"
@ -66,6 +66,7 @@ func NewServer(addr string, group []*SiteConfig) (*Server, error) {
sites: group, sites: group,
connTimeout: GracefulTimeout, connTimeout: GracefulTimeout,
} }
s.Server = makeHTTPServerWithHeaderLimit(s.Server, group)
s.Server.Handler = s // this is weird, but whatever s.Server.Handler = s // this is weird, but whatever
// extract TLS settings from each site config to build // extract TLS settings from each site config to build
@ -127,6 +128,32 @@ func NewServer(addr string, group []*SiteConfig) (*Server, error) {
return s, nil return s, nil
} }
// makeHTTPServerWithHeaderLimit apply minimum header limit within a group to given http.Server
func makeHTTPServerWithHeaderLimit(s *http.Server, group []*SiteConfig) *http.Server {
var min int64
for _, cfg := range group {
limit := cfg.Limits.MaxRequestHeaderSize
if limit == 0 {
continue
}
// not set yet
if min == 0 {
min = limit
}
// find a better one
if limit < min {
min = limit
}
}
if min > 0 {
s.MaxHeaderBytes = int(min)
}
return s
}
// makeHTTPServerWithTimeouts makes an http.Server from the group of // makeHTTPServerWithTimeouts makes an http.Server from the group of
// configs in a way that configures timeouts (or, if not set, it uses // configs in a way that configures timeouts (or, if not set, it uses
// the default timeouts) by combining the configuration of each // the default timeouts) by combining the configuration of each
@ -359,20 +386,6 @@ func (s *Server) serveHTTP(w http.ResponseWriter, r *http.Request) (int, error)
} }
} }
// Apply the path-based request body size limit
// The error returned by MaxBytesReader is meant to be handled
// by whichever middleware/plugin that receives it when calling
// .Read() or a similar method on the request body
// TODO: Make this middleware instead?
if r.Body != nil {
for _, pathlimit := range vhost.MaxRequestBodySizes {
if Path(r.URL.Path).Matches(pathlimit.Path) {
r.Body = MaxBytesReader(w, r.Body, pathlimit.Limit)
break
}
}
}
return vhost.middlewareChain.ServeHTTP(w, r) return vhost.middlewareChain.ServeHTTP(w, r)
} }
@ -465,73 +478,9 @@ func (ln tcpKeepAliveListener) File() (*os.File, error) {
return ln.TCPListener.File() return ln.TCPListener.File()
} }
// MaxBytesExceeded is the error type returned by MaxBytesReader // MaxBytesExceeded is the error returned by MaxBytesReader
// when the request body exceeds the limit imposed // when the request body exceeds the limit imposed
type MaxBytesExceeded struct{} var MaxBytesExceededErr = errors.New("http: request body too large")
func (err MaxBytesExceeded) Error() string {
return "http: request body too large"
}
// MaxBytesReader and its associated methods are borrowed from the
// Go Standard library (comments intact). The only difference is that
// it returns a MaxBytesExceeded error instead of a generic error message
// when the request body has exceeded the requested limit
func MaxBytesReader(w http.ResponseWriter, r io.ReadCloser, n int64) io.ReadCloser {
return &maxBytesReader{w: w, r: r, n: n}
}
type maxBytesReader struct {
w http.ResponseWriter
r io.ReadCloser // underlying reader
n int64 // max bytes remaining
err error // sticky error
}
func (l *maxBytesReader) Read(p []byte) (n int, err error) {
if l.err != nil {
return 0, l.err
}
if len(p) == 0 {
return 0, nil
}
// If they asked for a 32KB byte read but only 5 bytes are
// remaining, no need to read 32KB. 6 bytes will answer the
// question of the whether we hit the limit or go past it.
if int64(len(p)) > l.n+1 {
p = p[:l.n+1]
}
n, err = l.r.Read(p)
if int64(n) <= l.n {
l.n -= int64(n)
l.err = err
return n, err
}
n = int(l.n)
l.n = 0
// The server code and client code both use
// maxBytesReader. This "requestTooLarge" check is
// only used by the server code. To prevent binaries
// which only using the HTTP Client code (such as
// cmd/go) from also linking in the HTTP server, don't
// use a static type assertion to the server
// "*response" type. Check this interface instead:
type requestTooLarger interface {
requestTooLarge()
}
if res, ok := l.w.(requestTooLarger); ok {
res.requestTooLarge()
}
l.err = MaxBytesExceeded{}
return n, l.err
}
func (l *maxBytesReader) Close() error {
return l.r.Close()
}
// DefaultErrorFunc responds to an HTTP request with a simple description // DefaultErrorFunc responds to an HTTP request with a simple description
// of the specified HTTP status code. // of the specified HTTP status code.

View file

@ -15,7 +15,7 @@ func TestAddress(t *testing.T) {
} }
} }
func TestMakeHTTPServer(t *testing.T) { func TestMakeHTTPServerWithTimeouts(t *testing.T) {
for i, tc := range []struct { for i, tc := range []struct {
group []*SiteConfig group []*SiteConfig
expected Timeouts expected Timeouts
@ -111,3 +111,36 @@ func TestMakeHTTPServer(t *testing.T) {
} }
} }
} }
func TestMakeHTTPServerWithHeaderLimit(t *testing.T) {
for name, c := range map[string]struct {
group []*SiteConfig
expect int
}{
"disable": {
group: []*SiteConfig{{}},
expect: 0,
},
"oneSite": {
group: []*SiteConfig{{Limits: Limits{
MaxRequestHeaderSize: 100,
}}},
expect: 100,
},
"multiSites": {
group: []*SiteConfig{
{Limits: Limits{MaxRequestHeaderSize: 100}},
{Limits: Limits{MaxRequestHeaderSize: 50}},
},
expect: 50,
},
} {
c := c
t.Run(name, func(t *testing.T) {
actual := makeHTTPServerWithHeaderLimit(&http.Server{}, c.group)
if got := actual.MaxHeaderBytes; got != c.expect {
t.Errorf("Expect %d, but got %d", c.expect, got)
}
})
}
}

View file

@ -38,8 +38,8 @@ type SiteConfig struct {
// for a request. // for a request.
HiddenFiles []string HiddenFiles []string
// Max amount of bytes a request can send on a given path // Max request's header/body size
MaxRequestBodySizes []PathLimit Limits Limits
// The path to the Caddyfile used to generate this site config // The path to the Caddyfile used to generate this site config
originCaddyfile string originCaddyfile string
@ -71,6 +71,12 @@ type Timeouts struct {
IdleTimeoutSet bool IdleTimeoutSet bool
} }
// Limits specify size limit of request's header and body.
type Limits struct {
MaxRequestHeaderSize int64
MaxRequestBodySizes []PathLimit
}
// PathLimit is a mapping from a site's path to its corresponding // PathLimit is a mapping from a site's path to its corresponding
// maximum request body size (in bytes) // maximum request body size (in bytes)
type PathLimit struct { type PathLimit struct {

View file

@ -0,0 +1,90 @@
package limits
import (
"io"
"net/http"
"github.com/mholt/caddy/caddyhttp/httpserver"
)
// Limit is a middleware to control request body size
type Limit struct {
Next httpserver.Handler
BodyLimits []httpserver.PathLimit
}
func (l Limit) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
if r.Body == nil {
return l.Next.ServeHTTP(w, r)
}
// apply the path-based request body size limit.
for _, bl := range l.BodyLimits {
if httpserver.Path(r.URL.Path).Matches(bl.Path) {
r.Body = MaxBytesReader(w, r.Body, bl.Limit)
break
}
}
return l.Next.ServeHTTP(w, r)
}
// MaxBytesReader and its associated methods are borrowed from the
// Go Standard library (comments intact). The only difference is that
// it returns a MaxBytesExceeded error instead of a generic error message
// when the request body has exceeded the requested limit
func MaxBytesReader(w http.ResponseWriter, r io.ReadCloser, n int64) io.ReadCloser {
return &maxBytesReader{w: w, r: r, n: n}
}
type maxBytesReader struct {
w http.ResponseWriter
r io.ReadCloser // underlying reader
n int64 // max bytes remaining
err error // sticky error
}
func (l *maxBytesReader) Read(p []byte) (n int, err error) {
if l.err != nil {
return 0, l.err
}
if len(p) == 0 {
return 0, nil
}
// If they asked for a 32KB byte read but only 5 bytes are
// remaining, no need to read 32KB. 6 bytes will answer the
// question of the whether we hit the limit or go past it.
if int64(len(p)) > l.n+1 {
p = p[:l.n+1]
}
n, err = l.r.Read(p)
if int64(n) <= l.n {
l.n -= int64(n)
l.err = err
return n, err
}
n = int(l.n)
l.n = 0
// The server code and client code both use
// maxBytesReader. This "requestTooLarge" check is
// only used by the server code. To prevent binaries
// which only using the HTTP Client code (such as
// cmd/go) from also linking in the HTTP server, don't
// use a static type assertion to the server
// "*response" type. Check this interface instead:
type requestTooLarger interface {
requestTooLarge()
}
if res, ok := l.w.(requestTooLarger); ok {
res.requestTooLarge()
}
l.err = httpserver.MaxBytesExceededErr
return n, l.err
}
func (l *maxBytesReader) Close() error {
return l.r.Close()
}

View file

@ -0,0 +1,35 @@
package limits
import (
"io/ioutil"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/mholt/caddy/caddyhttp/httpserver"
)
func TestBodySizeLimit(t *testing.T) {
var (
gotContent []byte
gotError error
expectContent = "hello"
)
l := Limit{
Next: httpserver.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) {
gotContent, gotError = ioutil.ReadAll(r.Body)
return 0, nil
}),
BodyLimits: []httpserver.PathLimit{{Path: "/", Limit: int64(len(expectContent))}},
}
r := httptest.NewRequest("GET", "/", strings.NewReader(expectContent+expectContent))
l.ServeHTTP(httptest.NewRecorder(), r)
if got := string(gotContent); got != expectContent {
t.Errorf("expected content[%s], got[%s]", expectContent, got)
}
if gotError != httpserver.MaxBytesExceededErr {
t.Errorf("expect error %v, got %v", httpserver.MaxBytesExceededErr, gotError)
}
}

View file

@ -1,4 +1,4 @@
package maxrequestbody package limits
import ( import (
"errors" "errors"
@ -12,13 +12,13 @@ import (
const ( const (
serverType = "http" serverType = "http"
pluginName = "maxrequestbody" pluginName = "limits"
) )
func init() { func init() {
caddy.RegisterPlugin(pluginName, caddy.Plugin{ caddy.RegisterPlugin(pluginName, caddy.Plugin{
ServerType: serverType, ServerType: serverType,
Action: setupMaxRequestBody, Action: setupLimits,
}) })
} }
@ -28,56 +28,97 @@ type pathLimitUnparsed struct {
Limit string Limit string
} }
func setupMaxRequestBody(c *caddy.Controller) error { func setupLimits(c *caddy.Controller) error {
bls, err := parseLimits(c)
if err != nil {
return err
}
httpserver.GetConfig(c).AddMiddleware(func(next httpserver.Handler) httpserver.Handler {
return Limit{Next: next, BodyLimits: bls}
})
return nil
}
func parseLimits(c *caddy.Controller) ([]httpserver.PathLimit, error) {
config := httpserver.GetConfig(c) config := httpserver.GetConfig(c)
if !c.Next() { if !c.Next() {
return c.ArgErr() return nil, c.ArgErr()
} }
args := c.RemainingArgs() args := c.RemainingArgs()
argList := []pathLimitUnparsed{} argList := []pathLimitUnparsed{}
headerLimit := ""
switch len(args) { switch len(args) {
case 0: case 0:
// Format: { <path> <limit> ... } // Format: limits {
// header <limit>
// body <path> <limit>
// body <limit>
// ...
// }
for c.NextBlock() { for c.NextBlock() {
path := c.Val() kind := c.Val()
if !c.NextArg() { pathOrLimit := c.RemainingArgs()
// Uneven pairing of path/limit switch kind {
return c.ArgErr() case "header":
if len(pathOrLimit) != 1 {
return nil, c.ArgErr()
}
headerLimit = pathOrLimit[0]
case "body":
if len(pathOrLimit) == 1 {
argList = append(argList, pathLimitUnparsed{
Path: "/",
Limit: pathOrLimit[0],
})
break
}
if len(pathOrLimit) == 2 {
argList = append(argList, pathLimitUnparsed{
Path: pathOrLimit[0],
Limit: pathOrLimit[1],
})
break
}
fallthrough
default:
return nil, c.ArgErr()
} }
argList = append(argList, pathLimitUnparsed{
Path: path,
Limit: c.Val(),
})
} }
case 1: case 1:
// Format: <limit> // Format: limits <limit>
headerLimit = args[0]
argList = []pathLimitUnparsed{{ argList = []pathLimitUnparsed{{
Path: "/", Path: "/",
Limit: args[0], Limit: args[0],
}} }}
case 2:
// Format: <path> <limit>
argList = []pathLimitUnparsed{{
Path: args[0],
Limit: args[1],
}}
default: default:
return c.ArgErr() return nil, c.ArgErr()
} }
pathLimit, err := parseArguments(argList) if headerLimit != "" {
if err != nil { size := parseSize(headerLimit)
return c.ArgErr() if size < 1 { // also disallow size = 0
return nil, c.ArgErr()
}
config.Limits.MaxRequestHeaderSize = size
} }
SortPathLimits(pathLimit) if len(argList) > 0 {
pathLimit, err := parseArguments(argList)
if err != nil {
return nil, c.ArgErr()
}
SortPathLimits(pathLimit)
config.Limits.MaxRequestBodySizes = pathLimit
}
config.MaxRequestBodySizes = pathLimit return config.Limits.MaxRequestBodySizes, nil
return nil
} }
func parseArguments(args []pathLimitUnparsed) ([]httpserver.PathLimit, error) { func parseArguments(args []pathLimitUnparsed) ([]httpserver.PathLimit, error) {

View file

@ -1,4 +1,4 @@
package maxrequestbody package limits
import ( import (
"reflect" "reflect"
@ -14,32 +14,98 @@ const (
GB = 1024 * 1024 * 1024 GB = 1024 * 1024 * 1024
) )
func TestSetupMaxRequestBody(t *testing.T) { func TestParseLimits(t *testing.T) {
cases := []struct { for name, c := range map[string]struct {
input string input string
hasError bool shouldErr bool
expect httpserver.Limits
}{ }{
// Format: { <path> <limit> ... } "catchAll": {
{input: "maxrequestbody / 20MB", hasError: false}, input: `limits 2kb`,
// Format: <limit> expect: httpserver.Limits{
{input: "maxrequestbody 999KB", hasError: false}, MaxRequestHeaderSize: 2 * KB,
// Format: { <path> <limit> ... } MaxRequestBodySizes: []httpserver.PathLimit{{Path: "/", Limit: 2 * KB}},
{input: "maxrequestbody { /images 50MB /upload 10MB\n/test 10KB }", hasError: false}, },
},
// Wrong formats "onlyHeader": {
{input: "maxrequestbody typo { /images 50MB }", hasError: true}, input: `limits {
{input: "maxrequestbody 999MB /home 20KB", hasError: true}, header 2kb
} }`,
for caseNum, c := range cases { expect: httpserver.Limits{
controller := caddy.NewTestController("", c.input) MaxRequestHeaderSize: 2 * KB,
err := setupMaxRequestBody(controller) },
},
if c.hasError && (err == nil) { "onlyBody": {
t.Errorf("Expecting error for case %v but none encountered", caseNum) input: `limits {
} body 2kb
if !c.hasError && (err != nil) { }`,
t.Errorf("Expecting no error for case %v but encountered %v", caseNum, err) expect: httpserver.Limits{
} MaxRequestBodySizes: []httpserver.PathLimit{{Path: "/", Limit: 2 * KB}},
},
},
"onlyBodyWithPath": {
input: `limits {
body /test 2kb
}`,
expect: httpserver.Limits{
MaxRequestBodySizes: []httpserver.PathLimit{{Path: "/test", Limit: 2 * KB}},
},
},
"mixture": {
input: `limits {
header 1kb
body 2kb
body /bar 3kb
}`,
expect: httpserver.Limits{
MaxRequestHeaderSize: 1 * KB,
MaxRequestBodySizes: []httpserver.PathLimit{
{Path: "/bar", Limit: 3 * KB},
{Path: "/", Limit: 2 * KB},
},
},
},
"invalidFormat": {
input: `limits a b`,
shouldErr: true,
},
"invalidHeaderFormat": {
input: `limits {
header / 100
}`,
shouldErr: true,
},
"invalidBodyFormat": {
input: `limits {
body / 100 200
}`,
shouldErr: true,
},
"invalidKind": {
input: `limits {
head 100
}`,
shouldErr: true,
},
"invalidLimitSize": {
input: `limits 10bk`,
shouldErr: true,
},
} {
c := c
t.Run(name, func(t *testing.T) {
controller := caddy.NewTestController("", c.input)
_, err := parseLimits(controller)
if c.shouldErr && err == nil {
t.Error("failed to get expected error")
}
if !c.shouldErr && err != nil {
t.Errorf("got unexpected error: %v", err)
}
if got := httpserver.GetConfig(controller).Limits; !reflect.DeepEqual(got, c.expect) {
t.Errorf("expect %#v, but got %#v", c.expect, got)
}
})
} }
} }

View file

@ -228,7 +228,7 @@ func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
return 0, nil return 0, nil
} }
if _, ok := backendErr.(httpserver.MaxBytesExceeded); ok { if backendErr == httpserver.MaxBytesExceededErr {
return http.StatusRequestEntityTooLarge, backendErr return http.StatusRequestEntityTooLarge, backendErr
} }