diff --git a/caddyhttp/httpserver/replacer.go b/caddyhttp/httpserver/replacer.go index 2844cab7b..b31011a92 100644 --- a/caddyhttp/httpserver/replacer.go +++ b/caddyhttp/httpserver/replacer.go @@ -40,10 +40,40 @@ type Replacer interface { // they will be used to overwrite other replacements // if there is a name conflict. type replacer struct { - replacements map[string]func() string - customReplacements map[string]func() string + customReplacements map[string]string emptyValue string responseRecorder *ResponseRecorder + request *http.Request + requestBody *limitWriter +} + +type limitWriter struct { + w bytes.Buffer + remain int +} + +func newLimitWriter(max int) *limitWriter { + return &limitWriter{ + w: bytes.Buffer{}, + remain: max, + } +} + +func (lw *limitWriter) Write(p []byte) (int, error) { + // skip if we are full + if lw.remain <= 0 { + return len(p), nil + } + if n := len(p); n > lw.remain { + p = p[:lw.remain] + } + n, err := lw.w.Write(p) + lw.remain -= n + return n, err +} + +func (lw *limitWriter) String() string { + return lw.w.String() } // NewReplacer makes a new replacer based on r and rr which @@ -54,91 +84,24 @@ type replacer struct { // emptyValue should be the string that is used in place // of empty string (can still be empty string). func NewReplacer(r *http.Request, rr *ResponseRecorder, emptyValue string) Replacer { + rb := newLimitWriter(MaxLogBodySize) + if r.Body != nil { + r.Body = struct { + io.Reader + io.Closer + }{io.TeeReader(r.Body, rb), io.Closer(r.Body)} + } rep := &replacer{ + request: r, + requestBody: rb, responseRecorder: rr, - customReplacements: make(map[string]func() string), - replacements: map[string]func() string{ - "{method}": func() string { return r.Method }, - "{scheme}": func() string { - if r.TLS != nil { - return "https" - } - return "http" - }, - "{hostname}": func() string { - name, err := os.Hostname() - if err != nil { - return "" - } - return name - }, - "{host}": func() string { return r.Host }, - "{hostonly}": func() string { - host, _, err := net.SplitHostPort(r.Host) - if err != nil { - return r.Host - } - return host - }, - "{path}": func() string { return r.URL.Path }, - "{path_escaped}": func() string { return url.QueryEscape(r.URL.Path) }, - "{query}": func() string { return r.URL.RawQuery }, - "{query_escaped}": func() string { return url.QueryEscape(r.URL.RawQuery) }, - "{fragment}": func() string { return r.URL.Fragment }, - "{proto}": func() string { return r.Proto }, - "{remote}": func() string { - host, _, err := net.SplitHostPort(r.RemoteAddr) - if err != nil { - return r.RemoteAddr - } - return host - }, - "{port}": func() string { - _, port, err := net.SplitHostPort(r.RemoteAddr) - if err != nil { - return "" - } - return port - }, - "{uri}": func() string { return r.URL.RequestURI() }, - "{uri_escaped}": func() string { return url.QueryEscape(r.URL.RequestURI()) }, - "{when}": func() string { return time.Now().Format(timeFormat) }, - "{file}": func() string { - _, file := path.Split(r.URL.Path) - return file - }, - "{dir}": func() string { - dir, _ := path.Split(r.URL.Path) - return dir - }, - "{request}": func() string { - dump, err := httputil.DumpRequest(r, false) - if err != nil { - return "" - } - - return requestReplacer.Replace(string(dump)) - }, - "{request_body}": func() string { - if !canLogRequest(r) { - return "" - } - - body, err := readRequestBody(r, maxLogBodySize) - if err != nil { - return "" - } - - return requestReplacer.Replace(string(body)) - }, - }, - emptyValue: emptyValue, + customReplacements: make(map[string]string), + emptyValue: emptyValue, } // Header placeholders (case-insensitive) for header, values := range r.Header { - values := values - rep.replacements[headerReplacer+strings.ToLower(header)+"}"] = func() string { return strings.Join(values, ",") } + rep.customReplacements["{>"+strings.ToLower(header)+"}"] = strings.Join(values, ",") } return rep @@ -156,27 +119,6 @@ func canLogRequest(r *http.Request) bool { return false } -// readRequestBody reads the request body and sets a -// new io.ReadCloser that has not yet been read. -func readRequestBody(r *http.Request, n int64) ([]byte, error) { - defer r.Body.Close() - - body, err := ioutil.ReadAll(io.LimitReader(r.Body, n)) - if err != nil { - return nil, err - } - - // Read the remaining bytes - remaining, err := ioutil.ReadAll(r.Body) - if err != nil { - return nil, err - } - - buf := bytes.NewBuffer(append(body, remaining...)) - r.Body = ioutil.NopCloser(buf) - return body, nil -} - // Replace performs a replacement of values on s and returns // the string with the replaced values. func (r *replacer) Replace(s string) string { @@ -185,54 +127,37 @@ func (r *replacer) Replace(s string) string { return s } - // Make response placeholders now - if r.responseRecorder != nil { - r.replacements["{status}"] = func() string { return strconv.Itoa(r.responseRecorder.status) } - r.replacements["{size}"] = func() string { return strconv.Itoa(r.responseRecorder.size) } - r.replacements["{latency}"] = func() string { - dur := time.Since(r.responseRecorder.start) - return roundDuration(dur).String() - } - } - - // Include custom placeholders, overwriting existing ones if necessary - for key, val := range r.customReplacements { - r.replacements[key] = val - } - - // Header replacements - these are case-insensitive, so we can't just use strings.Replace() - for strings.Contains(s, headerReplacer) { - idxStart := strings.Index(s, headerReplacer) - endOffset := idxStart + len(headerReplacer) - idxEnd := strings.Index(s[endOffset:], "}") - if idxEnd > -1 { - placeholder := strings.ToLower(s[idxStart : endOffset+idxEnd+1]) - replacement := "" - if getReplacement, ok := r.replacements[placeholder]; ok { - replacement = getReplacement() - } - if replacement == "" { - replacement = r.emptyValue - } - s = s[:idxStart] + replacement + s[endOffset+idxEnd+1:] - } else { + result := "" + for { + idxStart := strings.Index(s, "{") + if idxStart == -1 { + // no placeholder anymore break } + idxEnd := strings.Index(s[idxStart:], "}") + if idxEnd == -1 { + // unpaired placeholder + break + } + idxEnd += idxStart + + // get a replacement + placeholder := s[idxStart : idxEnd+1] + // Header replacements - they are case-insensitive + if placeholder[1] == '>' { + placeholder = strings.ToLower(placeholder) + } + replacement := r.getSubstitution(placeholder) + + // append prefix + replacement + result += s[:idxStart] + replacement + + // strip out scanned parts + s = s[idxEnd+1:] } - // Regular replacements - these are easier because they're case-sensitive - for placeholder, getReplacement := range r.replacements { - if !strings.Contains(s, placeholder) { - continue - } - replacement := getReplacement() - if replacement == "" { - replacement = r.emptyValue - } - s = strings.Replace(s, placeholder, replacement, -1) - } - - return s + // append unscanned parts + return result + s } func roundDuration(d time.Duration) time.Duration { @@ -265,16 +190,117 @@ func round(d, r time.Duration) time.Duration { return d } +// getSubstitution retrieves value from corresponding key +func (r *replacer) getSubstitution(key string) string { + // search custom replacements first + if value, ok := r.customReplacements[key]; ok { + return value + } + + // search default replacements then + switch key { + case "{method}": + return r.request.Method + case "{scheme}": + if r.request.TLS != nil { + return "https" + } + return "http" + case "{hostname}": + name, err := os.Hostname() + if err != nil { + return r.emptyValue + } + return name + case "{host}": + return r.request.Host + case "{hostonly}": + host, _, err := net.SplitHostPort(r.request.Host) + if err != nil { + return r.request.Host + } + return host + case "{path}": + return r.request.URL.Path + case "{path_escaped}": + return url.QueryEscape(r.request.URL.Path) + case "{query}": + return r.request.URL.RawQuery + case "{query_escaped}": + return url.QueryEscape(r.request.URL.RawQuery) + case "{fragment}": + return r.request.URL.Fragment + case "{proto}": + return r.request.Proto + case "{remote}": + host, _, err := net.SplitHostPort(r.request.RemoteAddr) + if err != nil { + return r.request.RemoteAddr + } + return host + case "{port}": + _, port, err := net.SplitHostPort(r.request.RemoteAddr) + if err != nil { + return r.emptyValue + } + return port + case "{uri}": + return r.request.URL.RequestURI() + case "{uri_escaped}": + return url.QueryEscape(r.request.URL.RequestURI()) + case "{when}": + return time.Now().Format(timeFormat) + case "{file}": + _, file := path.Split(r.request.URL.Path) + return file + case "{dir}": + dir, _ := path.Split(r.request.URL.Path) + return dir + case "{request}": + dump, err := httputil.DumpRequest(r.request, false) + if err != nil { + return r.emptyValue + } + return requestReplacer.Replace(string(dump)) + case "{request_body}": + if !canLogRequest(r.request) { + return r.emptyValue + } + _, err := ioutil.ReadAll(r.request.Body) + if err != nil { + return r.emptyValue + } + return requestReplacer.Replace(r.requestBody.String()) + case "{status}": + if r.responseRecorder == nil { + return r.emptyValue + } + return strconv.Itoa(r.responseRecorder.status) + case "{size}": + if r.responseRecorder == nil { + return r.emptyValue + } + return strconv.Itoa(r.responseRecorder.size) + case "{latency}": + if r.responseRecorder == nil { + return r.emptyValue + } + return roundDuration(time.Since(r.responseRecorder.start)).String() + } + + return r.emptyValue +} + // Set sets key to value in the r.customReplacements map. func (r *replacer) Set(key, value string) { - r.customReplacements["{"+key+"}"] = func() string { return value } + r.customReplacements["{"+key+"}"] = value } const ( timeFormat = "02/Jan/2006:15:04:05 -0700" - headerReplacer = "{>" headerContentType = "Content-Type" contentTypeJSON = "application/json" contentTypeXML = "application/xml" - maxLogBodySize = 100 * 1024 + // MaxLogBodySize limits the size of logged request's body + MaxLogBodySize = 100 * 1024 ) diff --git a/caddyhttp/httpserver/replacer_test.go b/caddyhttp/httpserver/replacer_test.go index cfd52fb2e..9fcff8c90 100644 --- a/caddyhttp/httpserver/replacer_test.go +++ b/caddyhttp/httpserver/replacer_test.go @@ -1,8 +1,6 @@ package httpserver import ( - "bytes" - "io/ioutil" "net/http" "net/http/httptest" "os" @@ -24,28 +22,12 @@ func TestNewReplacer(t *testing.T) { switch v := rep.(type) { case *replacer: - if v.replacements["{host}"]() != "localhost" { + if v.getSubstitution("{host}") != "localhost" { t.Error("Expected host to be localhost") } - if v.replacements["{method}"]() != "POST" { + if v.getSubstitution("{method}") != "POST" { t.Error("Expected request method to be POST") } - - // Response placeholders should only be set after call to Replace() - got, want := "", "" - if getReplacement, ok := v.replacements["{status}"]; ok { - got = getReplacement() - } - if want := ""; got != want { - t.Errorf("Expected status to NOT be set before Replace() is called; was: %s", got) - } - rep.Replace("{foobar}") - if getReplacement, ok := v.replacements["{status}"]; ok { - got = getReplacement() - } - if want = "200"; got != want { - t.Errorf("Expected status to be %s, was: %s", want, got) - } default: t.Fatalf("Expected *replacer underlying Replacer type, got: %#v", rep) } @@ -94,19 +76,21 @@ func TestReplace(t *testing.T) { complexCases := []struct { template string - replacements map[string]func() string + replacements map[string]string expect string }{ - {"/a{1}/{2}", - map[string]func() string{ - "{1}": func() string { return "12" }, - "{2}": func() string { return "" }}, + { + "/a{1}/{2}", + map[string]string{ + "{1}": "12", + "{2}": "", + }, "/a12/"}, } for _, c := range complexCases { repl := &replacer{ - replacements: c.replacements, + customReplacements: c.replacements, } if expected, actual := c.expect, repl.Replace(c.template); expected != actual { t.Errorf("for template '%s', expected '%s', got '%s'", c.template, expected, actual) @@ -163,28 +147,3 @@ func TestRound(t *testing.T) { } } } - -func TestReadRequestBody(t *testing.T) { - payload := []byte(`{ "foo": "bar" }`) - var readSize int64 = 5 - r, err := http.NewRequest("POST", "/", bytes.NewReader(payload)) - if err != nil { - t.Error(err) - } - defer r.Body.Close() - - logBody, err := readRequestBody(r, readSize) - if err != nil { - t.Error("readRequestBody failed", err) - } else if !bytes.EqualFold(payload[0:readSize], logBody) { - t.Error("Expected log comparison failed") - } - - // Ensure the Request body is the same as the original. - reqBody, err := ioutil.ReadAll(r.Body) - if err != nil { - t.Error("Unable to read request body", err) - } else if !bytes.EqualFold(payload, reqBody) { - t.Error("Expected request body comparison failed") - } -} diff --git a/caddyhttp/log/log_test.go b/caddyhttp/log/log_test.go index af48f4424..eeb7cf2ea 100644 --- a/caddyhttp/log/log_test.go +++ b/caddyhttp/log/log_test.go @@ -2,6 +2,7 @@ package log import ( "bytes" + "io/ioutil" "log" "net/http" "net/http/httptest" @@ -60,3 +61,59 @@ func TestLoggedStatus(t *testing.T) { t.Errorf("Expected the log entry to contain 'foobar' (custom placeholder), but it didn't: %s", logged) } } + +func TestLogRequestBody(t *testing.T) { + var got bytes.Buffer + logger := Logger{ + Rules: []Rule{{ + PathScope: "/", + Format: "{request_body}", + Log: log.New(&got, "", 0), + }}, + Next: httpserver.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) { + // drain up body + ioutil.ReadAll(r.Body) + return 0, nil + }), + } + + for i, c := range []struct { + body string + expect string + }{ + {"", "\n"}, + {"{hello} world!", "{hello} world!\n"}, + {func() string { + length := httpserver.MaxLogBodySize + 100 + b := make([]byte, length) + for i := 0; i < length; i++ { + b[i] = 0xab + } + return string(b) + }(), func() string { + b := make([]byte, httpserver.MaxLogBodySize) + for i := 0; i < httpserver.MaxLogBodySize; i++ { + b[i] = 0xab + } + return string(b) + "\n" + }(), + }, + } { + got.Reset() + r, err := http.NewRequest("POST", "/", bytes.NewBufferString(c.body)) + if err != nil { + t.Fatal(err) + } + r.Header.Set("Content-Type", "application/json") + status, err := logger.ServeHTTP(httptest.NewRecorder(), r) + if status != 0 { + t.Errorf("case %d: Expected status to be 0, but was %d", i, status) + } + if err != nil { + t.Errorf("case %d: Expected error to be nil, instead got: %v", i, err) + } + if got.String() != c.expect { + t.Errorf("case %d: Expected body %q, but got %q", i, c.expect, got.String()) + } + } +}