mirror of
https://github.com/caddyserver/caddy.git
synced 2025-01-24 01:26:47 +01:00
Merge pull request #1656 from tw4452852/1587-limits
Introduce `limits` middleware
This commit is contained in:
commit
f06b825f44
11 changed files with 363 additions and 143 deletions
|
@ -16,9 +16,9 @@ import (
|
||||||
_ "github.com/mholt/caddy/caddyhttp/header"
|
_ "github.com/mholt/caddy/caddyhttp/header"
|
||||||
_ "github.com/mholt/caddy/caddyhttp/index"
|
_ "github.com/mholt/caddy/caddyhttp/index"
|
||||||
_ "github.com/mholt/caddy/caddyhttp/internalsrv"
|
_ "github.com/mholt/caddy/caddyhttp/internalsrv"
|
||||||
|
_ "github.com/mholt/caddy/caddyhttp/limits"
|
||||||
_ "github.com/mholt/caddy/caddyhttp/log"
|
_ "github.com/mholt/caddy/caddyhttp/log"
|
||||||
_ "github.com/mholt/caddy/caddyhttp/markdown"
|
_ "github.com/mholt/caddy/caddyhttp/markdown"
|
||||||
_ "github.com/mholt/caddy/caddyhttp/maxrequestbody"
|
|
||||||
_ "github.com/mholt/caddy/caddyhttp/mime"
|
_ "github.com/mholt/caddy/caddyhttp/mime"
|
||||||
_ "github.com/mholt/caddy/caddyhttp/pprof"
|
_ "github.com/mholt/caddy/caddyhttp/pprof"
|
||||||
_ "github.com/mholt/caddy/caddyhttp/proxy"
|
_ "github.com/mholt/caddy/caddyhttp/proxy"
|
||||||
|
|
|
@ -436,7 +436,7 @@ var directives = []string{
|
||||||
"root",
|
"root",
|
||||||
"index",
|
"index",
|
||||||
"bind",
|
"bind",
|
||||||
"maxrequestbody", // TODO: 'limits'
|
"limits",
|
||||||
"timeouts",
|
"timeouts",
|
||||||
"tls",
|
"tls",
|
||||||
|
|
||||||
|
|
|
@ -302,7 +302,7 @@ func (r *replacer) getSubstitution(key string) string {
|
||||||
}
|
}
|
||||||
_, err := ioutil.ReadAll(r.request.Body)
|
_, err := ioutil.ReadAll(r.request.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if _, ok := err.(MaxBytesExceeded); ok {
|
if err == MaxBytesExceededErr {
|
||||||
return r.emptyValue
|
return r.emptyValue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,8 +4,8 @@ package httpserver
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"log"
|
"log"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
@ -66,6 +66,7 @@ func NewServer(addr string, group []*SiteConfig) (*Server, error) {
|
||||||
sites: group,
|
sites: group,
|
||||||
connTimeout: GracefulTimeout,
|
connTimeout: GracefulTimeout,
|
||||||
}
|
}
|
||||||
|
s.Server = makeHTTPServerWithHeaderLimit(s.Server, group)
|
||||||
s.Server.Handler = s // this is weird, but whatever
|
s.Server.Handler = s // this is weird, but whatever
|
||||||
|
|
||||||
// extract TLS settings from each site config to build
|
// extract TLS settings from each site config to build
|
||||||
|
@ -127,6 +128,32 @@ func NewServer(addr string, group []*SiteConfig) (*Server, error) {
|
||||||
return s, nil
|
return s, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// makeHTTPServerWithHeaderLimit apply minimum header limit within a group to given http.Server
|
||||||
|
func makeHTTPServerWithHeaderLimit(s *http.Server, group []*SiteConfig) *http.Server {
|
||||||
|
var min int64
|
||||||
|
for _, cfg := range group {
|
||||||
|
limit := cfg.Limits.MaxRequestHeaderSize
|
||||||
|
if limit == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// not set yet
|
||||||
|
if min == 0 {
|
||||||
|
min = limit
|
||||||
|
}
|
||||||
|
|
||||||
|
// find a better one
|
||||||
|
if limit < min {
|
||||||
|
min = limit
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if min > 0 {
|
||||||
|
s.MaxHeaderBytes = int(min)
|
||||||
|
}
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
// makeHTTPServerWithTimeouts makes an http.Server from the group of
|
// makeHTTPServerWithTimeouts makes an http.Server from the group of
|
||||||
// configs in a way that configures timeouts (or, if not set, it uses
|
// configs in a way that configures timeouts (or, if not set, it uses
|
||||||
// the default timeouts) by combining the configuration of each
|
// the default timeouts) by combining the configuration of each
|
||||||
|
@ -359,20 +386,6 @@ func (s *Server) serveHTTP(w http.ResponseWriter, r *http.Request) (int, error)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Apply the path-based request body size limit
|
|
||||||
// The error returned by MaxBytesReader is meant to be handled
|
|
||||||
// by whichever middleware/plugin that receives it when calling
|
|
||||||
// .Read() or a similar method on the request body
|
|
||||||
// TODO: Make this middleware instead?
|
|
||||||
if r.Body != nil {
|
|
||||||
for _, pathlimit := range vhost.MaxRequestBodySizes {
|
|
||||||
if Path(r.URL.Path).Matches(pathlimit.Path) {
|
|
||||||
r.Body = MaxBytesReader(w, r.Body, pathlimit.Limit)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return vhost.middlewareChain.ServeHTTP(w, r)
|
return vhost.middlewareChain.ServeHTTP(w, r)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -465,73 +478,9 @@ func (ln tcpKeepAliveListener) File() (*os.File, error) {
|
||||||
return ln.TCPListener.File()
|
return ln.TCPListener.File()
|
||||||
}
|
}
|
||||||
|
|
||||||
// MaxBytesExceeded is the error type returned by MaxBytesReader
|
// MaxBytesExceeded is the error returned by MaxBytesReader
|
||||||
// when the request body exceeds the limit imposed
|
// when the request body exceeds the limit imposed
|
||||||
type MaxBytesExceeded struct{}
|
var MaxBytesExceededErr = errors.New("http: request body too large")
|
||||||
|
|
||||||
func (err MaxBytesExceeded) Error() string {
|
|
||||||
return "http: request body too large"
|
|
||||||
}
|
|
||||||
|
|
||||||
// MaxBytesReader and its associated methods are borrowed from the
|
|
||||||
// Go Standard library (comments intact). The only difference is that
|
|
||||||
// it returns a MaxBytesExceeded error instead of a generic error message
|
|
||||||
// when the request body has exceeded the requested limit
|
|
||||||
func MaxBytesReader(w http.ResponseWriter, r io.ReadCloser, n int64) io.ReadCloser {
|
|
||||||
return &maxBytesReader{w: w, r: r, n: n}
|
|
||||||
}
|
|
||||||
|
|
||||||
type maxBytesReader struct {
|
|
||||||
w http.ResponseWriter
|
|
||||||
r io.ReadCloser // underlying reader
|
|
||||||
n int64 // max bytes remaining
|
|
||||||
err error // sticky error
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *maxBytesReader) Read(p []byte) (n int, err error) {
|
|
||||||
if l.err != nil {
|
|
||||||
return 0, l.err
|
|
||||||
}
|
|
||||||
if len(p) == 0 {
|
|
||||||
return 0, nil
|
|
||||||
}
|
|
||||||
// If they asked for a 32KB byte read but only 5 bytes are
|
|
||||||
// remaining, no need to read 32KB. 6 bytes will answer the
|
|
||||||
// question of the whether we hit the limit or go past it.
|
|
||||||
if int64(len(p)) > l.n+1 {
|
|
||||||
p = p[:l.n+1]
|
|
||||||
}
|
|
||||||
n, err = l.r.Read(p)
|
|
||||||
|
|
||||||
if int64(n) <= l.n {
|
|
||||||
l.n -= int64(n)
|
|
||||||
l.err = err
|
|
||||||
return n, err
|
|
||||||
}
|
|
||||||
|
|
||||||
n = int(l.n)
|
|
||||||
l.n = 0
|
|
||||||
|
|
||||||
// The server code and client code both use
|
|
||||||
// maxBytesReader. This "requestTooLarge" check is
|
|
||||||
// only used by the server code. To prevent binaries
|
|
||||||
// which only using the HTTP Client code (such as
|
|
||||||
// cmd/go) from also linking in the HTTP server, don't
|
|
||||||
// use a static type assertion to the server
|
|
||||||
// "*response" type. Check this interface instead:
|
|
||||||
type requestTooLarger interface {
|
|
||||||
requestTooLarge()
|
|
||||||
}
|
|
||||||
if res, ok := l.w.(requestTooLarger); ok {
|
|
||||||
res.requestTooLarge()
|
|
||||||
}
|
|
||||||
l.err = MaxBytesExceeded{}
|
|
||||||
return n, l.err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *maxBytesReader) Close() error {
|
|
||||||
return l.r.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
// DefaultErrorFunc responds to an HTTP request with a simple description
|
// DefaultErrorFunc responds to an HTTP request with a simple description
|
||||||
// of the specified HTTP status code.
|
// of the specified HTTP status code.
|
||||||
|
|
|
@ -15,7 +15,7 @@ func TestAddress(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestMakeHTTPServer(t *testing.T) {
|
func TestMakeHTTPServerWithTimeouts(t *testing.T) {
|
||||||
for i, tc := range []struct {
|
for i, tc := range []struct {
|
||||||
group []*SiteConfig
|
group []*SiteConfig
|
||||||
expected Timeouts
|
expected Timeouts
|
||||||
|
@ -111,3 +111,36 @@ func TestMakeHTTPServer(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestMakeHTTPServerWithHeaderLimit(t *testing.T) {
|
||||||
|
for name, c := range map[string]struct {
|
||||||
|
group []*SiteConfig
|
||||||
|
expect int
|
||||||
|
}{
|
||||||
|
"disable": {
|
||||||
|
group: []*SiteConfig{{}},
|
||||||
|
expect: 0,
|
||||||
|
},
|
||||||
|
"oneSite": {
|
||||||
|
group: []*SiteConfig{{Limits: Limits{
|
||||||
|
MaxRequestHeaderSize: 100,
|
||||||
|
}}},
|
||||||
|
expect: 100,
|
||||||
|
},
|
||||||
|
"multiSites": {
|
||||||
|
group: []*SiteConfig{
|
||||||
|
{Limits: Limits{MaxRequestHeaderSize: 100}},
|
||||||
|
{Limits: Limits{MaxRequestHeaderSize: 50}},
|
||||||
|
},
|
||||||
|
expect: 50,
|
||||||
|
},
|
||||||
|
} {
|
||||||
|
c := c
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
actual := makeHTTPServerWithHeaderLimit(&http.Server{}, c.group)
|
||||||
|
if got := actual.MaxHeaderBytes; got != c.expect {
|
||||||
|
t.Errorf("Expect %d, but got %d", c.expect, got)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -38,8 +38,8 @@ type SiteConfig struct {
|
||||||
// for a request.
|
// for a request.
|
||||||
HiddenFiles []string
|
HiddenFiles []string
|
||||||
|
|
||||||
// Max amount of bytes a request can send on a given path
|
// Max request's header/body size
|
||||||
MaxRequestBodySizes []PathLimit
|
Limits Limits
|
||||||
|
|
||||||
// The path to the Caddyfile used to generate this site config
|
// The path to the Caddyfile used to generate this site config
|
||||||
originCaddyfile string
|
originCaddyfile string
|
||||||
|
@ -71,6 +71,12 @@ type Timeouts struct {
|
||||||
IdleTimeoutSet bool
|
IdleTimeoutSet bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Limits specify size limit of request's header and body.
|
||||||
|
type Limits struct {
|
||||||
|
MaxRequestHeaderSize int64
|
||||||
|
MaxRequestBodySizes []PathLimit
|
||||||
|
}
|
||||||
|
|
||||||
// PathLimit is a mapping from a site's path to its corresponding
|
// PathLimit is a mapping from a site's path to its corresponding
|
||||||
// maximum request body size (in bytes)
|
// maximum request body size (in bytes)
|
||||||
type PathLimit struct {
|
type PathLimit struct {
|
||||||
|
|
90
caddyhttp/limits/handler.go
Normal file
90
caddyhttp/limits/handler.go
Normal file
|
@ -0,0 +1,90 @@
|
||||||
|
package limits
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/mholt/caddy/caddyhttp/httpserver"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Limit is a middleware to control request body size
|
||||||
|
type Limit struct {
|
||||||
|
Next httpserver.Handler
|
||||||
|
BodyLimits []httpserver.PathLimit
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l Limit) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
|
||||||
|
if r.Body == nil {
|
||||||
|
return l.Next.ServeHTTP(w, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
// apply the path-based request body size limit.
|
||||||
|
for _, bl := range l.BodyLimits {
|
||||||
|
if httpserver.Path(r.URL.Path).Matches(bl.Path) {
|
||||||
|
r.Body = MaxBytesReader(w, r.Body, bl.Limit)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return l.Next.ServeHTTP(w, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MaxBytesReader and its associated methods are borrowed from the
|
||||||
|
// Go Standard library (comments intact). The only difference is that
|
||||||
|
// it returns a MaxBytesExceeded error instead of a generic error message
|
||||||
|
// when the request body has exceeded the requested limit
|
||||||
|
func MaxBytesReader(w http.ResponseWriter, r io.ReadCloser, n int64) io.ReadCloser {
|
||||||
|
return &maxBytesReader{w: w, r: r, n: n}
|
||||||
|
}
|
||||||
|
|
||||||
|
type maxBytesReader struct {
|
||||||
|
w http.ResponseWriter
|
||||||
|
r io.ReadCloser // underlying reader
|
||||||
|
n int64 // max bytes remaining
|
||||||
|
err error // sticky error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *maxBytesReader) Read(p []byte) (n int, err error) {
|
||||||
|
if l.err != nil {
|
||||||
|
return 0, l.err
|
||||||
|
}
|
||||||
|
if len(p) == 0 {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
// If they asked for a 32KB byte read but only 5 bytes are
|
||||||
|
// remaining, no need to read 32KB. 6 bytes will answer the
|
||||||
|
// question of the whether we hit the limit or go past it.
|
||||||
|
if int64(len(p)) > l.n+1 {
|
||||||
|
p = p[:l.n+1]
|
||||||
|
}
|
||||||
|
n, err = l.r.Read(p)
|
||||||
|
|
||||||
|
if int64(n) <= l.n {
|
||||||
|
l.n -= int64(n)
|
||||||
|
l.err = err
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
|
||||||
|
n = int(l.n)
|
||||||
|
l.n = 0
|
||||||
|
|
||||||
|
// The server code and client code both use
|
||||||
|
// maxBytesReader. This "requestTooLarge" check is
|
||||||
|
// only used by the server code. To prevent binaries
|
||||||
|
// which only using the HTTP Client code (such as
|
||||||
|
// cmd/go) from also linking in the HTTP server, don't
|
||||||
|
// use a static type assertion to the server
|
||||||
|
// "*response" type. Check this interface instead:
|
||||||
|
type requestTooLarger interface {
|
||||||
|
requestTooLarge()
|
||||||
|
}
|
||||||
|
if res, ok := l.w.(requestTooLarger); ok {
|
||||||
|
res.requestTooLarge()
|
||||||
|
}
|
||||||
|
l.err = httpserver.MaxBytesExceededErr
|
||||||
|
return n, l.err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *maxBytesReader) Close() error {
|
||||||
|
return l.r.Close()
|
||||||
|
}
|
35
caddyhttp/limits/handler_test.go
Normal file
35
caddyhttp/limits/handler_test.go
Normal file
|
@ -0,0 +1,35 @@
|
||||||
|
package limits
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io/ioutil"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/mholt/caddy/caddyhttp/httpserver"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestBodySizeLimit(t *testing.T) {
|
||||||
|
var (
|
||||||
|
gotContent []byte
|
||||||
|
gotError error
|
||||||
|
expectContent = "hello"
|
||||||
|
)
|
||||||
|
l := Limit{
|
||||||
|
Next: httpserver.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) {
|
||||||
|
gotContent, gotError = ioutil.ReadAll(r.Body)
|
||||||
|
return 0, nil
|
||||||
|
}),
|
||||||
|
BodyLimits: []httpserver.PathLimit{{Path: "/", Limit: int64(len(expectContent))}},
|
||||||
|
}
|
||||||
|
|
||||||
|
r := httptest.NewRequest("GET", "/", strings.NewReader(expectContent+expectContent))
|
||||||
|
l.ServeHTTP(httptest.NewRecorder(), r)
|
||||||
|
if got := string(gotContent); got != expectContent {
|
||||||
|
t.Errorf("expected content[%s], got[%s]", expectContent, got)
|
||||||
|
}
|
||||||
|
if gotError != httpserver.MaxBytesExceededErr {
|
||||||
|
t.Errorf("expect error %v, got %v", httpserver.MaxBytesExceededErr, gotError)
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,4 +1,4 @@
|
||||||
package maxrequestbody
|
package limits
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
|
@ -12,13 +12,13 @@ import (
|
||||||
|
|
||||||
const (
|
const (
|
||||||
serverType = "http"
|
serverType = "http"
|
||||||
pluginName = "maxrequestbody"
|
pluginName = "limits"
|
||||||
)
|
)
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
caddy.RegisterPlugin(pluginName, caddy.Plugin{
|
caddy.RegisterPlugin(pluginName, caddy.Plugin{
|
||||||
ServerType: serverType,
|
ServerType: serverType,
|
||||||
Action: setupMaxRequestBody,
|
Action: setupLimits,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -28,56 +28,97 @@ type pathLimitUnparsed struct {
|
||||||
Limit string
|
Limit string
|
||||||
}
|
}
|
||||||
|
|
||||||
func setupMaxRequestBody(c *caddy.Controller) error {
|
func setupLimits(c *caddy.Controller) error {
|
||||||
|
bls, err := parseLimits(c)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
httpserver.GetConfig(c).AddMiddleware(func(next httpserver.Handler) httpserver.Handler {
|
||||||
|
return Limit{Next: next, BodyLimits: bls}
|
||||||
|
})
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseLimits(c *caddy.Controller) ([]httpserver.PathLimit, error) {
|
||||||
config := httpserver.GetConfig(c)
|
config := httpserver.GetConfig(c)
|
||||||
|
|
||||||
if !c.Next() {
|
if !c.Next() {
|
||||||
return c.ArgErr()
|
return nil, c.ArgErr()
|
||||||
}
|
}
|
||||||
|
|
||||||
args := c.RemainingArgs()
|
args := c.RemainingArgs()
|
||||||
argList := []pathLimitUnparsed{}
|
argList := []pathLimitUnparsed{}
|
||||||
|
headerLimit := ""
|
||||||
|
|
||||||
switch len(args) {
|
switch len(args) {
|
||||||
case 0:
|
case 0:
|
||||||
// Format: { <path> <limit> ... }
|
// Format: limits {
|
||||||
|
// header <limit>
|
||||||
|
// body <path> <limit>
|
||||||
|
// body <limit>
|
||||||
|
// ...
|
||||||
|
// }
|
||||||
for c.NextBlock() {
|
for c.NextBlock() {
|
||||||
path := c.Val()
|
kind := c.Val()
|
||||||
if !c.NextArg() {
|
pathOrLimit := c.RemainingArgs()
|
||||||
// Uneven pairing of path/limit
|
switch kind {
|
||||||
return c.ArgErr()
|
case "header":
|
||||||
|
if len(pathOrLimit) != 1 {
|
||||||
|
return nil, c.ArgErr()
|
||||||
}
|
}
|
||||||
|
headerLimit = pathOrLimit[0]
|
||||||
|
case "body":
|
||||||
|
if len(pathOrLimit) == 1 {
|
||||||
argList = append(argList, pathLimitUnparsed{
|
argList = append(argList, pathLimitUnparsed{
|
||||||
Path: path,
|
Path: "/",
|
||||||
Limit: c.Val(),
|
Limit: pathOrLimit[0],
|
||||||
})
|
})
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(pathOrLimit) == 2 {
|
||||||
|
argList = append(argList, pathLimitUnparsed{
|
||||||
|
Path: pathOrLimit[0],
|
||||||
|
Limit: pathOrLimit[1],
|
||||||
|
})
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
fallthrough
|
||||||
|
default:
|
||||||
|
return nil, c.ArgErr()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
case 1:
|
case 1:
|
||||||
// Format: <limit>
|
// Format: limits <limit>
|
||||||
|
headerLimit = args[0]
|
||||||
argList = []pathLimitUnparsed{{
|
argList = []pathLimitUnparsed{{
|
||||||
Path: "/",
|
Path: "/",
|
||||||
Limit: args[0],
|
Limit: args[0],
|
||||||
}}
|
}}
|
||||||
case 2:
|
|
||||||
// Format: <path> <limit>
|
|
||||||
argList = []pathLimitUnparsed{{
|
|
||||||
Path: args[0],
|
|
||||||
Limit: args[1],
|
|
||||||
}}
|
|
||||||
default:
|
default:
|
||||||
return c.ArgErr()
|
return nil, c.ArgErr()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if headerLimit != "" {
|
||||||
|
size := parseSize(headerLimit)
|
||||||
|
if size < 1 { // also disallow size = 0
|
||||||
|
return nil, c.ArgErr()
|
||||||
|
}
|
||||||
|
config.Limits.MaxRequestHeaderSize = size
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(argList) > 0 {
|
||||||
pathLimit, err := parseArguments(argList)
|
pathLimit, err := parseArguments(argList)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return c.ArgErr()
|
return nil, c.ArgErr()
|
||||||
|
}
|
||||||
|
SortPathLimits(pathLimit)
|
||||||
|
config.Limits.MaxRequestBodySizes = pathLimit
|
||||||
}
|
}
|
||||||
|
|
||||||
SortPathLimits(pathLimit)
|
return config.Limits.MaxRequestBodySizes, nil
|
||||||
|
|
||||||
config.MaxRequestBodySizes = pathLimit
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseArguments(args []pathLimitUnparsed) ([]httpserver.PathLimit, error) {
|
func parseArguments(args []pathLimitUnparsed) ([]httpserver.PathLimit, error) {
|
|
@ -1,4 +1,4 @@
|
||||||
package maxrequestbody
|
package limits
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"reflect"
|
"reflect"
|
||||||
|
@ -14,32 +14,98 @@ const (
|
||||||
GB = 1024 * 1024 * 1024
|
GB = 1024 * 1024 * 1024
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestSetupMaxRequestBody(t *testing.T) {
|
func TestParseLimits(t *testing.T) {
|
||||||
cases := []struct {
|
for name, c := range map[string]struct {
|
||||||
input string
|
input string
|
||||||
hasError bool
|
shouldErr bool
|
||||||
|
expect httpserver.Limits
|
||||||
}{
|
}{
|
||||||
// Format: { <path> <limit> ... }
|
"catchAll": {
|
||||||
{input: "maxrequestbody / 20MB", hasError: false},
|
input: `limits 2kb`,
|
||||||
// Format: <limit>
|
expect: httpserver.Limits{
|
||||||
{input: "maxrequestbody 999KB", hasError: false},
|
MaxRequestHeaderSize: 2 * KB,
|
||||||
// Format: { <path> <limit> ... }
|
MaxRequestBodySizes: []httpserver.PathLimit{{Path: "/", Limit: 2 * KB}},
|
||||||
{input: "maxrequestbody { /images 50MB /upload 10MB\n/test 10KB }", hasError: false},
|
},
|
||||||
|
},
|
||||||
// Wrong formats
|
"onlyHeader": {
|
||||||
{input: "maxrequestbody typo { /images 50MB }", hasError: true},
|
input: `limits {
|
||||||
{input: "maxrequestbody 999MB /home 20KB", hasError: true},
|
header 2kb
|
||||||
}
|
}`,
|
||||||
for caseNum, c := range cases {
|
expect: httpserver.Limits{
|
||||||
|
MaxRequestHeaderSize: 2 * KB,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"onlyBody": {
|
||||||
|
input: `limits {
|
||||||
|
body 2kb
|
||||||
|
}`,
|
||||||
|
expect: httpserver.Limits{
|
||||||
|
MaxRequestBodySizes: []httpserver.PathLimit{{Path: "/", Limit: 2 * KB}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"onlyBodyWithPath": {
|
||||||
|
input: `limits {
|
||||||
|
body /test 2kb
|
||||||
|
}`,
|
||||||
|
expect: httpserver.Limits{
|
||||||
|
MaxRequestBodySizes: []httpserver.PathLimit{{Path: "/test", Limit: 2 * KB}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"mixture": {
|
||||||
|
input: `limits {
|
||||||
|
header 1kb
|
||||||
|
body 2kb
|
||||||
|
body /bar 3kb
|
||||||
|
}`,
|
||||||
|
expect: httpserver.Limits{
|
||||||
|
MaxRequestHeaderSize: 1 * KB,
|
||||||
|
MaxRequestBodySizes: []httpserver.PathLimit{
|
||||||
|
{Path: "/bar", Limit: 3 * KB},
|
||||||
|
{Path: "/", Limit: 2 * KB},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"invalidFormat": {
|
||||||
|
input: `limits a b`,
|
||||||
|
shouldErr: true,
|
||||||
|
},
|
||||||
|
"invalidHeaderFormat": {
|
||||||
|
input: `limits {
|
||||||
|
header / 100
|
||||||
|
}`,
|
||||||
|
shouldErr: true,
|
||||||
|
},
|
||||||
|
"invalidBodyFormat": {
|
||||||
|
input: `limits {
|
||||||
|
body / 100 200
|
||||||
|
}`,
|
||||||
|
shouldErr: true,
|
||||||
|
},
|
||||||
|
"invalidKind": {
|
||||||
|
input: `limits {
|
||||||
|
head 100
|
||||||
|
}`,
|
||||||
|
shouldErr: true,
|
||||||
|
},
|
||||||
|
"invalidLimitSize": {
|
||||||
|
input: `limits 10bk`,
|
||||||
|
shouldErr: true,
|
||||||
|
},
|
||||||
|
} {
|
||||||
|
c := c
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
controller := caddy.NewTestController("", c.input)
|
controller := caddy.NewTestController("", c.input)
|
||||||
err := setupMaxRequestBody(controller)
|
_, err := parseLimits(controller)
|
||||||
|
if c.shouldErr && err == nil {
|
||||||
if c.hasError && (err == nil) {
|
t.Error("failed to get expected error")
|
||||||
t.Errorf("Expecting error for case %v but none encountered", caseNum)
|
|
||||||
}
|
}
|
||||||
if !c.hasError && (err != nil) {
|
if !c.shouldErr && err != nil {
|
||||||
t.Errorf("Expecting no error for case %v but encountered %v", caseNum, err)
|
t.Errorf("got unexpected error: %v", err)
|
||||||
}
|
}
|
||||||
|
if got := httpserver.GetConfig(controller).Limits; !reflect.DeepEqual(got, c.expect) {
|
||||||
|
t.Errorf("expect %#v, but got %#v", c.expect, got)
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -228,7 +228,7 @@ func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
|
||||||
return 0, nil
|
return 0, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, ok := backendErr.(httpserver.MaxBytesExceeded); ok {
|
if backendErr == httpserver.MaxBytesExceededErr {
|
||||||
return http.StatusRequestEntityTooLarge, backendErr
|
return http.StatusRequestEntityTooLarge, backendErr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue