diff --git a/caddyhttp/proxy/proxy.go b/caddyhttp/proxy/proxy.go index 0f48a61f4..c0c2bb4b0 100644 --- a/caddyhttp/proxy/proxy.go +++ b/caddyhttp/proxy/proxy.go @@ -39,6 +39,9 @@ type Upstream interface { // Gets how long to wait between selecting upstream // hosts in the case of cascading failures. GetTryInterval() time.Duration + + // Gets the number of upstream hosts. + GetHostCount() int } // UpstreamHostDownFunc can be used to customize how Down behaves. @@ -94,13 +97,26 @@ func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { // outreq is the request that makes a roundtrip to the backend outreq := createUpstreamRequest(r) - // record and replace outreq body - body, err := newBufferedBody(outreq.Body) - if err != nil { - return http.StatusBadRequest, errors.New("failed to read downstream request body") - } - if body != nil { - outreq.Body = body + // If we have more than one upstream host defined and if retrying is enabled + // by setting try_duration to a non-zero value, caddy will try to + // retry the request at a different host if the first one failed. + // + // This requires us to possibly rewind and replay the request body though, + // which in turn requires us to buffer the request body first. + // + // An unbuffered request is usually preferrable, because it reduces latency + // as well as memory usage. Furthermore it enables different kinds of + // HTTP streaming applications like gRPC for instance. + requiresBuffering := upstream.GetHostCount() > 1 && upstream.GetTryDuration() != 0 + + if requiresBuffering { + body, err := newBufferedBody(outreq.Body) + if err != nil { + return http.StatusBadRequest, errors.New("failed to read downstream request body") + } + if body != nil { + outreq.Body = body + } } // The keepRetrying function will return true if we should @@ -173,15 +189,25 @@ func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { downHeaderUpdateFn = createRespHeaderUpdateFn(host.DownstreamHeaders, replacer) } - // rewind request body to its beginning - if err := body.rewind(); err != nil { - return http.StatusInternalServerError, errors.New("unable to rewind downstream request body") + // Before we retry the request we have to make sure + // that the body is rewound to it's beginning. + if bb, ok := outreq.Body.(*bufferedBody); ok { + if err := bb.rewind(); err != nil { + return http.StatusInternalServerError, errors.New("unable to rewind downstream request body") + } } // tell the proxy to serve the request - atomic.AddInt64(&host.Conns, 1) - backendErr = proxy.ServeHTTP(w, outreq, downHeaderUpdateFn) - atomic.AddInt64(&host.Conns, -1) + // + // NOTE: + // The call to proxy.ServeHTTP can theoretically panic. + // To prevent host.Conns from getting out-of-sync we thus have to + // make sure that it's _always_ correctly decremented afterwards. + func() { + atomic.AddInt64(&host.Conns, 1) + defer atomic.AddInt64(&host.Conns, -1) + backendErr = proxy.ServeHTTP(w, outreq, downHeaderUpdateFn) + }() // if no errors, we're done here if backendErr == nil { diff --git a/caddyhttp/proxy/proxy_test.go b/caddyhttp/proxy/proxy_test.go index 1503eccd6..7ffdb7756 100644 --- a/caddyhttp/proxy/proxy_test.go +++ b/caddyhttp/proxy/proxy_test.go @@ -954,6 +954,90 @@ func TestReverseProxyRetry(t *testing.T) { } } +func TestReverseProxyLargeBody(t *testing.T) { + log.SetOutput(ioutil.Discard) + defer log.SetOutput(os.Stderr) + + // set up proxy + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + io.Copy(ioutil.Discard, r.Body) + r.Body.Close() + })) + defer backend.Close() + + su, err := NewStaticUpstreams(caddyfile.NewDispenser("Testfile", strings.NewReader(`proxy / `+backend.URL))) + if err != nil { + t.Fatal(err) + } + + p := &Proxy{ + Next: httpserver.EmptyNext, // prevents panic in some cases when test fails + Upstreams: su, + } + + // middle is required to simulate closable downstream request body + middle := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, err = p.ServeHTTP(w, r) + if err != nil { + t.Error(err) + } + })) + defer middle.Close() + + // Our request body will be 100MB + bodySize := uint64(100 * 1000 * 1000) + + // We want to see how much memory the proxy module requires for this request. + // So lets record the mem stats before we start it. + begMemstats := &runtime.MemStats{} + runtime.ReadMemStats(begMemstats) + + r, err := http.NewRequest("POST", middle.URL, &noopReader{len: bodySize}) + if err != nil { + t.Fatal(err) + } + resp, err := http.DefaultTransport.RoundTrip(r) + if err != nil { + t.Fatal(err) + } + resp.Body.Close() + + // Finally we need the mem stats after the request is done... + endMemstats := &runtime.MemStats{} + runtime.ReadMemStats(endMemstats) + + // ...to calculate the total amount of allocated memory during the request. + totalAlloc := endMemstats.TotalAlloc - begMemstats.TotalAlloc + + // If that's as much as the size of the body itself it's a serious sign that the + // request was not "streamed" to the upstream without buffering it first. + if totalAlloc >= bodySize { + t.Fatalf("proxy allocated too much memory: %d bytes", totalAlloc) + } +} + +type noopReader struct { + len uint64 + pos uint64 +} + +var _ io.Reader = &noopReader{} + +func (r *noopReader) Read(b []byte) (int, error) { + if r.pos >= r.len { + return 0, io.EOF + } + n := int(r.len - r.pos) + if n > len(b) { + n = len(b) + } + for i := range b[:n] { + b[i] = 0 + } + r.pos += uint64(n) + return n, nil +} + func newFakeUpstream(name string, insecure bool) *fakeUpstream { uri, _ := url.Parse(name) u := &fakeUpstream{ @@ -998,6 +1082,7 @@ func (u *fakeUpstream) Select(r *http.Request) *UpstreamHost { func (u *fakeUpstream) AllowedPath(requestPath string) bool { return true } func (u *fakeUpstream) GetTryDuration() time.Duration { return 1 * time.Second } func (u *fakeUpstream) GetTryInterval() time.Duration { return 250 * time.Millisecond } +func (u *fakeUpstream) GetHostCount() int { return 1 } // newWebSocketTestProxy returns a test proxy that will // redirect to the specified backendAddr. The function @@ -1049,6 +1134,7 @@ func (u *fakeWsUpstream) Select(r *http.Request) *UpstreamHost { func (u *fakeWsUpstream) AllowedPath(requestPath string) bool { return true } func (u *fakeWsUpstream) GetTryDuration() time.Duration { return 1 * time.Second } func (u *fakeWsUpstream) GetTryInterval() time.Duration { return 250 * time.Millisecond } +func (u *fakeWsUpstream) GetHostCount() int { return 1 } // recorderHijacker is a ResponseRecorder that can // be hijacked. diff --git a/caddyhttp/proxy/upstream.go b/caddyhttp/proxy/upstream.go index 0309e2420..5742eff03 100644 --- a/caddyhttp/proxy/upstream.go +++ b/caddyhttp/proxy/upstream.go @@ -423,6 +423,10 @@ func (u *staticUpstream) GetTryInterval() time.Duration { return u.TryInterval } +func (u *staticUpstream) GetHostCount() int { + return len(u.Hosts) +} + // RegisterPolicy adds a custom policy to the proxy. func RegisterPolicy(name string, policy func() Policy) { supportedPolicies[name] = policy