mirror of
https://github.com/caddyserver/caddy.git
synced 2025-01-23 17:16:40 +01:00
Merge pull request #1314 from mholt/unbuffered_proxy
proxy: Unbuffered request optimization
This commit is contained in:
commit
696792781a
3 changed files with 129 additions and 13 deletions
|
@ -39,6 +39,9 @@ type Upstream interface {
|
||||||
// Gets how long to wait between selecting upstream
|
// Gets how long to wait between selecting upstream
|
||||||
// hosts in the case of cascading failures.
|
// hosts in the case of cascading failures.
|
||||||
GetTryInterval() time.Duration
|
GetTryInterval() time.Duration
|
||||||
|
|
||||||
|
// Gets the number of upstream hosts.
|
||||||
|
GetHostCount() int
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpstreamHostDownFunc can be used to customize how Down behaves.
|
// UpstreamHostDownFunc can be used to customize how Down behaves.
|
||||||
|
@ -94,7 +97,19 @@ func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
|
||||||
// outreq is the request that makes a roundtrip to the backend
|
// outreq is the request that makes a roundtrip to the backend
|
||||||
outreq := createUpstreamRequest(r)
|
outreq := createUpstreamRequest(r)
|
||||||
|
|
||||||
// record and replace outreq body
|
// If we have more than one upstream host defined and if retrying is enabled
|
||||||
|
// by setting try_duration to a non-zero value, caddy will try to
|
||||||
|
// retry the request at a different host if the first one failed.
|
||||||
|
//
|
||||||
|
// This requires us to possibly rewind and replay the request body though,
|
||||||
|
// which in turn requires us to buffer the request body first.
|
||||||
|
//
|
||||||
|
// An unbuffered request is usually preferrable, because it reduces latency
|
||||||
|
// as well as memory usage. Furthermore it enables different kinds of
|
||||||
|
// HTTP streaming applications like gRPC for instance.
|
||||||
|
requiresBuffering := upstream.GetHostCount() > 1 && upstream.GetTryDuration() != 0
|
||||||
|
|
||||||
|
if requiresBuffering {
|
||||||
body, err := newBufferedBody(outreq.Body)
|
body, err := newBufferedBody(outreq.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return http.StatusBadRequest, errors.New("failed to read downstream request body")
|
return http.StatusBadRequest, errors.New("failed to read downstream request body")
|
||||||
|
@ -102,6 +117,7 @@ func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
|
||||||
if body != nil {
|
if body != nil {
|
||||||
outreq.Body = body
|
outreq.Body = body
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// The keepRetrying function will return true if we should
|
// The keepRetrying function will return true if we should
|
||||||
// loop and try to select another host, or false if we
|
// loop and try to select another host, or false if we
|
||||||
|
@ -173,15 +189,25 @@ func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
|
||||||
downHeaderUpdateFn = createRespHeaderUpdateFn(host.DownstreamHeaders, replacer)
|
downHeaderUpdateFn = createRespHeaderUpdateFn(host.DownstreamHeaders, replacer)
|
||||||
}
|
}
|
||||||
|
|
||||||
// rewind request body to its beginning
|
// Before we retry the request we have to make sure
|
||||||
if err := body.rewind(); err != nil {
|
// that the body is rewound to it's beginning.
|
||||||
|
if bb, ok := outreq.Body.(*bufferedBody); ok {
|
||||||
|
if err := bb.rewind(); err != nil {
|
||||||
return http.StatusInternalServerError, errors.New("unable to rewind downstream request body")
|
return http.StatusInternalServerError, errors.New("unable to rewind downstream request body")
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// tell the proxy to serve the request
|
// tell the proxy to serve the request
|
||||||
|
//
|
||||||
|
// NOTE:
|
||||||
|
// The call to proxy.ServeHTTP can theoretically panic.
|
||||||
|
// To prevent host.Conns from getting out-of-sync we thus have to
|
||||||
|
// make sure that it's _always_ correctly decremented afterwards.
|
||||||
|
func() {
|
||||||
atomic.AddInt64(&host.Conns, 1)
|
atomic.AddInt64(&host.Conns, 1)
|
||||||
|
defer atomic.AddInt64(&host.Conns, -1)
|
||||||
backendErr = proxy.ServeHTTP(w, outreq, downHeaderUpdateFn)
|
backendErr = proxy.ServeHTTP(w, outreq, downHeaderUpdateFn)
|
||||||
atomic.AddInt64(&host.Conns, -1)
|
}()
|
||||||
|
|
||||||
// if no errors, we're done here
|
// if no errors, we're done here
|
||||||
if backendErr == nil {
|
if backendErr == nil {
|
||||||
|
|
|
@ -954,6 +954,90 @@ func TestReverseProxyRetry(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestReverseProxyLargeBody(t *testing.T) {
|
||||||
|
log.SetOutput(ioutil.Discard)
|
||||||
|
defer log.SetOutput(os.Stderr)
|
||||||
|
|
||||||
|
// set up proxy
|
||||||
|
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
io.Copy(ioutil.Discard, r.Body)
|
||||||
|
r.Body.Close()
|
||||||
|
}))
|
||||||
|
defer backend.Close()
|
||||||
|
|
||||||
|
su, err := NewStaticUpstreams(caddyfile.NewDispenser("Testfile", strings.NewReader(`proxy / `+backend.URL)))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
p := &Proxy{
|
||||||
|
Next: httpserver.EmptyNext, // prevents panic in some cases when test fails
|
||||||
|
Upstreams: su,
|
||||||
|
}
|
||||||
|
|
||||||
|
// middle is required to simulate closable downstream request body
|
||||||
|
middle := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
_, err = p.ServeHTTP(w, r)
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer middle.Close()
|
||||||
|
|
||||||
|
// Our request body will be 100MB
|
||||||
|
bodySize := uint64(100 * 1000 * 1000)
|
||||||
|
|
||||||
|
// We want to see how much memory the proxy module requires for this request.
|
||||||
|
// So lets record the mem stats before we start it.
|
||||||
|
begMemstats := &runtime.MemStats{}
|
||||||
|
runtime.ReadMemStats(begMemstats)
|
||||||
|
|
||||||
|
r, err := http.NewRequest("POST", middle.URL, &noopReader{len: bodySize})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
resp, err := http.DefaultTransport.RoundTrip(r)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
resp.Body.Close()
|
||||||
|
|
||||||
|
// Finally we need the mem stats after the request is done...
|
||||||
|
endMemstats := &runtime.MemStats{}
|
||||||
|
runtime.ReadMemStats(endMemstats)
|
||||||
|
|
||||||
|
// ...to calculate the total amount of allocated memory during the request.
|
||||||
|
totalAlloc := endMemstats.TotalAlloc - begMemstats.TotalAlloc
|
||||||
|
|
||||||
|
// If that's as much as the size of the body itself it's a serious sign that the
|
||||||
|
// request was not "streamed" to the upstream without buffering it first.
|
||||||
|
if totalAlloc >= bodySize {
|
||||||
|
t.Fatalf("proxy allocated too much memory: %d bytes", totalAlloc)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type noopReader struct {
|
||||||
|
len uint64
|
||||||
|
pos uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ io.Reader = &noopReader{}
|
||||||
|
|
||||||
|
func (r *noopReader) Read(b []byte) (int, error) {
|
||||||
|
if r.pos >= r.len {
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
n := int(r.len - r.pos)
|
||||||
|
if n > len(b) {
|
||||||
|
n = len(b)
|
||||||
|
}
|
||||||
|
for i := range b[:n] {
|
||||||
|
b[i] = 0
|
||||||
|
}
|
||||||
|
r.pos += uint64(n)
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
|
||||||
func newFakeUpstream(name string, insecure bool) *fakeUpstream {
|
func newFakeUpstream(name string, insecure bool) *fakeUpstream {
|
||||||
uri, _ := url.Parse(name)
|
uri, _ := url.Parse(name)
|
||||||
u := &fakeUpstream{
|
u := &fakeUpstream{
|
||||||
|
@ -998,6 +1082,7 @@ func (u *fakeUpstream) Select(r *http.Request) *UpstreamHost {
|
||||||
func (u *fakeUpstream) AllowedPath(requestPath string) bool { return true }
|
func (u *fakeUpstream) AllowedPath(requestPath string) bool { return true }
|
||||||
func (u *fakeUpstream) GetTryDuration() time.Duration { return 1 * time.Second }
|
func (u *fakeUpstream) GetTryDuration() time.Duration { return 1 * time.Second }
|
||||||
func (u *fakeUpstream) GetTryInterval() time.Duration { return 250 * time.Millisecond }
|
func (u *fakeUpstream) GetTryInterval() time.Duration { return 250 * time.Millisecond }
|
||||||
|
func (u *fakeUpstream) GetHostCount() int { return 1 }
|
||||||
|
|
||||||
// newWebSocketTestProxy returns a test proxy that will
|
// newWebSocketTestProxy returns a test proxy that will
|
||||||
// redirect to the specified backendAddr. The function
|
// redirect to the specified backendAddr. The function
|
||||||
|
@ -1049,6 +1134,7 @@ func (u *fakeWsUpstream) Select(r *http.Request) *UpstreamHost {
|
||||||
func (u *fakeWsUpstream) AllowedPath(requestPath string) bool { return true }
|
func (u *fakeWsUpstream) AllowedPath(requestPath string) bool { return true }
|
||||||
func (u *fakeWsUpstream) GetTryDuration() time.Duration { return 1 * time.Second }
|
func (u *fakeWsUpstream) GetTryDuration() time.Duration { return 1 * time.Second }
|
||||||
func (u *fakeWsUpstream) GetTryInterval() time.Duration { return 250 * time.Millisecond }
|
func (u *fakeWsUpstream) GetTryInterval() time.Duration { return 250 * time.Millisecond }
|
||||||
|
func (u *fakeWsUpstream) GetHostCount() int { return 1 }
|
||||||
|
|
||||||
// recorderHijacker is a ResponseRecorder that can
|
// recorderHijacker is a ResponseRecorder that can
|
||||||
// be hijacked.
|
// be hijacked.
|
||||||
|
|
|
@ -423,6 +423,10 @@ func (u *staticUpstream) GetTryInterval() time.Duration {
|
||||||
return u.TryInterval
|
return u.TryInterval
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (u *staticUpstream) GetHostCount() int {
|
||||||
|
return len(u.Hosts)
|
||||||
|
}
|
||||||
|
|
||||||
// RegisterPolicy adds a custom policy to the proxy.
|
// RegisterPolicy adds a custom policy to the proxy.
|
||||||
func RegisterPolicy(name string, policy func() Policy) {
|
func RegisterPolicy(name string, policy func() Policy) {
|
||||||
supportedPolicies[name] = policy
|
supportedPolicies[name] = policy
|
||||||
|
|
Loading…
Reference in a new issue