diff --git a/caddy/setup/gzip.go b/caddy/setup/gzip.go index a81c5f170..40c81252f 100644 --- a/caddy/setup/gzip.go +++ b/caddy/setup/gzip.go @@ -27,9 +27,13 @@ func gzipParse(c *Controller) ([]gzip.Config, error) { for c.Next() { config := gzip.Config{} + // Request Filters pathFilter := gzip.PathFilter{IgnoredPaths: make(gzip.Set)} extFilter := gzip.ExtFilter{Exts: make(gzip.Set)} + // Response Filters + lengthFilter := gzip.LengthFilter(0) + // No extra args expected if len(c.RemainingArgs()) > 0 { return configs, c.ArgErr() @@ -68,24 +72,42 @@ func gzipParse(c *Controller) ([]gzip.Config, error) { } level, _ := strconv.Atoi(c.Val()) config.Level = level + case "min_length": + if !c.NextArg() { + return configs, c.ArgErr() + } + length, err := strconv.ParseInt(c.Val(), 10, 64) + if err != nil { + return configs, err + } else if length == 0 { + return configs, fmt.Errorf(`gzip: min_length must be greater than 0`) + } + lengthFilter = gzip.LengthFilter(length) default: return configs, c.ArgErr() } } - config.Filters = []gzip.Filter{} + // Request Filters + config.RequestFilters = []gzip.RequestFilter{} // If ignored paths are specified, put in front to filter with path first if len(pathFilter.IgnoredPaths) > 0 { - config.Filters = []gzip.Filter{pathFilter} + config.RequestFilters = []gzip.RequestFilter{pathFilter} } // Then, if extensions are specified, use those to filter. // Otherwise, use default extensions filter. if len(extFilter.Exts) > 0 { - config.Filters = append(config.Filters, extFilter) + config.RequestFilters = append(config.RequestFilters, extFilter) } else { - config.Filters = append(config.Filters, gzip.DefaultExtFilter()) + config.RequestFilters = append(config.RequestFilters, gzip.DefaultExtFilter()) + } + + // Response Filters + // If min_length is specified, use it. + if int64(lengthFilter) != 0 { + config.ResponseFilters = append(config.ResponseFilters, lengthFilter) } configs = append(configs, config) diff --git a/caddy/setup/gzip_test.go b/caddy/setup/gzip_test.go index 22d01d7a1..36eeb0aea 100644 --- a/caddy/setup/gzip_test.go +++ b/caddy/setup/gzip_test.go @@ -73,6 +73,18 @@ func TestGzip(t *testing.T) { level 1 } `, false}, + {`gzip { not /file + ext * + level 1 + min_length ab + } + `, true}, + {`gzip { not /file + ext * + level 1 + min_length 1000 + } + `, false}, } for i, test := range tests { c := NewTestController(test.input) diff --git a/middleware/gzip/gzip.go b/middleware/gzip/gzip.go index e2d447753..b5866f682 100644 --- a/middleware/gzip/gzip.go +++ b/middleware/gzip/gzip.go @@ -23,8 +23,9 @@ type Gzip struct { // Config holds the configuration for Gzip middleware type Config struct { - Filters []Filter // Filters to use - Level int // Compression level + RequestFilters []RequestFilter + ResponseFilters []ResponseFilter + Level int // Compression level } // ServeHTTP serves a gzipped response if the client supports it. @@ -36,8 +37,8 @@ func (g Gzip) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { outer: for _, c := range g.Configs { - // Check filters to determine if gzipping is permitted for this request - for _, filter := range c.Filters { + // Check request filters to determine if gzipping is permitted for this request + for _, filter := range c.RequestFilters { if !filter.ShouldCompress(r) { continue outer } @@ -56,8 +57,17 @@ outer: defer gzipWriter.Close() gz := gzipResponseWriter{Writer: gzipWriter, ResponseWriter: w} + var rw http.ResponseWriter + // if no response filter is used + if len(c.ResponseFilters) == 0 { + rw = gz + } else { + // wrap gzip writer with ResponseFilterWriter + rw = NewResponseFilterWriter(c.ResponseFilters, gz) + } + // Any response in forward middleware will now be compressed - status, err := g.Next.ServeHTTP(gz, r) + status, err := g.Next.ServeHTTP(rw, r) // If there was an error that remained unhandled, we need // to send something back before gzipWriter gets closed at diff --git a/middleware/gzip/gzip_test.go b/middleware/gzip/gzip_test.go index 3e7bed996..11ce6b209 100644 --- a/middleware/gzip/gzip_test.go +++ b/middleware/gzip/gzip_test.go @@ -21,7 +21,7 @@ func TestGzipHandler(t *testing.T) { extFilter.Exts.Add(e) } gz := Gzip{Configs: []Config{ - {Filters: []Filter{pathFilter, extFilter}}, + {RequestFilters: []RequestFilter{pathFilter, extFilter}}, }} w := httptest.NewRecorder() diff --git a/middleware/gzip/filter.go b/middleware/gzip/request_filter.go similarity index 91% rename from middleware/gzip/filter.go rename to middleware/gzip/request_filter.go index f4039da35..76a9ef83c 100644 --- a/middleware/gzip/filter.go +++ b/middleware/gzip/request_filter.go @@ -7,8 +7,8 @@ import ( "github.com/mholt/caddy/middleware" ) -// Filter determines if a request should be gzipped. -type Filter interface { +// RequestFilter determines if a request should be gzipped. +type RequestFilter interface { // ShouldCompress tells if gzip compression // should be done on the request. ShouldCompress(*http.Request) bool @@ -26,7 +26,7 @@ func DefaultExtFilter() ExtFilter { return m } -// ExtFilter is Filter for file name extensions. +// ExtFilter is RequestFilter for file name extensions. type ExtFilter struct { // Exts is the file name extensions to accept Exts Set @@ -43,7 +43,7 @@ func (e ExtFilter) ShouldCompress(r *http.Request) bool { return e.Exts.Contains(ExtWildCard) || e.Exts.Contains(ext) } -// PathFilter is Filter for request path. +// PathFilter is RequestFilter for request path. type PathFilter struct { // IgnoredPaths is the paths to ignore IgnoredPaths Set diff --git a/middleware/gzip/filter_test.go b/middleware/gzip/request_filter_test.go similarity index 95% rename from middleware/gzip/filter_test.go rename to middleware/gzip/request_filter_test.go index f6537b9e2..ce31d7faf 100644 --- a/middleware/gzip/filter_test.go +++ b/middleware/gzip/request_filter_test.go @@ -47,7 +47,7 @@ func TestSet(t *testing.T) { } func TestExtFilter(t *testing.T) { - var filter Filter = ExtFilter{make(Set)} + var filter RequestFilter = ExtFilter{make(Set)} for _, e := range []string{".txt", ".html", ".css", ".md"} { filter.(ExtFilter).Exts.Add(e) } @@ -86,7 +86,7 @@ func TestPathFilter(t *testing.T) { paths := []string{ "/a", "/b", "/c", "/de", } - var filter Filter = PathFilter{make(Set)} + var filter RequestFilter = PathFilter{make(Set)} for _, p := range paths { filter.(PathFilter).IgnoredPaths.Add(p) } diff --git a/middleware/gzip/response_filter.go b/middleware/gzip/response_filter.go new file mode 100644 index 000000000..8793dabe1 --- /dev/null +++ b/middleware/gzip/response_filter.go @@ -0,0 +1,62 @@ +package gzip + +import ( + "net/http" + "strconv" +) + +// ResponseFilter determines if the response should be gzipped. +type ResponseFilter interface { + ShouldCompress(http.ResponseWriter) bool +} + +// LengthFilter is ResponseFilter for minimum content length. +type LengthFilter int64 + +// ShouldCompress returns if content length is greater than or +// equals to minimum length. +func (l LengthFilter) ShouldCompress(w http.ResponseWriter) bool { + contentLength := (w.Header().Get("Content-Length")) + length, err := strconv.ParseInt(contentLength, 10, 64) + if err != nil || length == 0 { + return false + } + return l != 0 && int64(l) <= length +} + +// ResponseFilterWriter validates ResponseFilters. It writes +// gzip compressed data if ResponseFilters are satisfied or +// uncompressed data otherwise. +type ResponseFilterWriter struct { + filters []ResponseFilter + validated bool + shouldCompress bool + gzipResponseWriter +} + +// NewResponseFilterWriter creates and initializes a new ResponseFilterWriter. +func NewResponseFilterWriter(filters []ResponseFilter, gz gzipResponseWriter) *ResponseFilterWriter { + return &ResponseFilterWriter{filters: filters, gzipResponseWriter: gz} +} + +// Write wraps underlying Write method and compresses if filters +// are satisfied +func (r *ResponseFilterWriter) Write(b []byte) (int, error) { + // One time validation to determine if compression should + // be used or not. + if !r.validated { + r.shouldCompress = true + for _, filter := range r.filters { + if !filter.ShouldCompress(r) { + r.shouldCompress = false + break + } + } + r.validated = true + } + + if r.shouldCompress { + return r.gzipResponseWriter.Write(b) + } + return r.ResponseWriter.Write(b) +} diff --git a/middleware/gzip/response_filter_test.go b/middleware/gzip/response_filter_test.go new file mode 100644 index 000000000..1a5a1b4f3 --- /dev/null +++ b/middleware/gzip/response_filter_test.go @@ -0,0 +1,70 @@ +package gzip + +import ( + "compress/gzip" + "fmt" + "net/http/httptest" + "testing" +) + +func TestLengthFilter(t *testing.T) { + var filters []ResponseFilter = []ResponseFilter{ + LengthFilter(100), + LengthFilter(1000), + LengthFilter(0), + } + + var tests = []struct { + length int64 + shouldCompress [3]bool + }{ + {20, [3]bool{false, false, false}}, + {50, [3]bool{false, false, false}}, + {100, [3]bool{true, false, false}}, + {500, [3]bool{true, false, false}}, + {1000, [3]bool{true, true, false}}, + {1500, [3]bool{true, true, false}}, + } + + for i, ts := range tests { + for j, filter := range filters { + r := httptest.NewRecorder() + r.Header().Set("Content-Length", fmt.Sprint(ts.length)) + if filter.ShouldCompress(r) != ts.shouldCompress[j] { + t.Errorf("Test %v: Expected %v found %v", i, ts.shouldCompress[j], filter.ShouldCompress(r)) + } + } + } +} + +func TestResponseFilterWriter(t *testing.T) { + tests := []struct { + body string + shouldCompress bool + }{ + {"Hello\t\t\t\n", false}, + {"Hello the \t\t\t world is\n\n\n great", true}, + {"Hello \t\t\nfrom gzip", true}, + {"Hello gzip\n", false}, + } + filters := []ResponseFilter{ + LengthFilter(15), + } + for i, ts := range tests { + w := httptest.NewRecorder() + w.Header().Set("Content-Length", fmt.Sprint(len(ts.body))) + gz := gzipResponseWriter{gzip.NewWriter(w), w} + rw := NewResponseFilterWriter(filters, gz) + rw.Write([]byte(ts.body)) + resp := w.Body.String() + if !ts.shouldCompress { + if resp != ts.body { + t.Errorf("Test %v: No compression expected, found %v", i, resp) + } + } else { + if resp == ts.body { + t.Errorf("Test %v: Compression expected, found %v", i, resp) + } + } + } +}