diff --git a/caddyhttp/proxy/proxy.go b/caddyhttp/proxy/proxy.go index 4bac8976f..a7922d4a9 100644 --- a/caddyhttp/proxy/proxy.go +++ b/caddyhttp/proxy/proxy.go @@ -42,6 +42,9 @@ type Upstream interface { // Gets the number of upstream hosts. GetHostCount() int + + // Stops the upstream from proxying requests to shutdown goroutines cleanly. + Stop() error } // UpstreamHostDownFunc can be used to customize how Down behaves. diff --git a/caddyhttp/proxy/proxy_test.go b/caddyhttp/proxy/proxy_test.go index 386a16b50..f8c96dd83 100644 --- a/caddyhttp/proxy/proxy_test.go +++ b/caddyhttp/proxy/proxy_test.go @@ -1216,6 +1216,7 @@ func (u *fakeUpstream) AllowedPath(requestPath string) bool { return true } func (u *fakeUpstream) GetTryDuration() time.Duration { return 1 * time.Second } func (u *fakeUpstream) GetTryInterval() time.Duration { return 250 * time.Millisecond } func (u *fakeUpstream) GetHostCount() int { return 1 } +func (u *fakeUpstream) Stop() error { return nil } // newWebSocketTestProxy returns a test proxy that will // redirect to the specified backendAddr. The function @@ -1268,6 +1269,7 @@ func (u *fakeWsUpstream) AllowedPath(requestPath string) bool { return true } func (u *fakeWsUpstream) GetTryDuration() time.Duration { return 1 * time.Second } func (u *fakeWsUpstream) GetTryInterval() time.Duration { return 250 * time.Millisecond } func (u *fakeWsUpstream) GetHostCount() int { return 1 } +func (u *fakeWsUpstream) Stop() error { return nil } // recorderHijacker is a ResponseRecorder that can // be hijacked. diff --git a/caddyhttp/proxy/setup.go b/caddyhttp/proxy/setup.go index c25b041a8..5daffabc3 100644 --- a/caddyhttp/proxy/setup.go +++ b/caddyhttp/proxy/setup.go @@ -21,5 +21,11 @@ func setup(c *caddy.Controller) error { httpserver.GetConfig(c).AddMiddleware(func(next httpserver.Handler) httpserver.Handler { return Proxy{Next: next, Upstreams: upstreams} }) + + // Register shutdown handlers. + for _, upstream := range upstreams { + c.OnShutdown(upstream.Stop) + } + return nil } diff --git a/caddyhttp/proxy/upstream.go b/caddyhttp/proxy/upstream.go index 0e8311248..303f986c4 100644 --- a/caddyhttp/proxy/upstream.go +++ b/caddyhttp/proxy/upstream.go @@ -9,6 +9,7 @@ import ( "path" "strconv" "strings" + "sync" "sync/atomic" "time" @@ -24,6 +25,8 @@ type staticUpstream struct { from string upstreamHeaders http.Header downstreamHeaders http.Header + stop chan struct{} // Signals running goroutines to stop. + wg sync.WaitGroup // Used to wait for running goroutines to stop. Hosts HostPool Policy Policy KeepAlive int @@ -48,8 +51,10 @@ type staticUpstream struct { func NewStaticUpstreams(c caddyfile.Dispenser) ([]Upstream, error) { var upstreams []Upstream for c.Next() { + upstream := &staticUpstream{ from: "", + stop: make(chan struct{}), upstreamHeaders: make(http.Header), downstreamHeaders: make(http.Header), Hosts: nil, @@ -108,7 +113,11 @@ func NewStaticUpstreams(c caddyfile.Dispenser) ([]Upstream, error) { upstream.HealthCheck.Client = http.Client{ Timeout: upstream.HealthCheck.Timeout, } - go upstream.HealthCheckWorker(nil) + upstream.wg.Add(1) + go func() { + defer upstream.wg.Done() + upstream.HealthCheckWorker(upstream.stop) + }() } upstreams = append(upstreams, upstream) } @@ -380,9 +389,8 @@ func (u *staticUpstream) HealthCheckWorker(stop chan struct{}) { case <-ticker.C: u.healthCheck() case <-stop: - // TODO: the library should provide a stop channel and global - // waitgroup to allow goroutines started by plugins a chance - // to clean themselves up. + ticker.Stop() + return } } } @@ -434,6 +442,14 @@ func (u *staticUpstream) GetHostCount() int { return len(u.Hosts) } +// Stop sends a signal to all goroutines started by this staticUpstream to exit +// and waits for them to finish before returning. +func (u *staticUpstream) Stop() error { + close(u.stop) + u.wg.Wait() + return nil +} + // RegisterPolicy adds a custom policy to the proxy. func RegisterPolicy(name string, policy func() Policy) { supportedPolicies[name] = policy diff --git a/caddyhttp/proxy/upstream_test.go b/caddyhttp/proxy/upstream_test.go index 1163fffe1..d84c366e5 100644 --- a/caddyhttp/proxy/upstream_test.go +++ b/caddyhttp/proxy/upstream_test.go @@ -1,8 +1,11 @@ package proxy import ( + "fmt" "net/http" + "net/http/httptest" "strings" + "sync/atomic" "testing" "time" @@ -189,6 +192,75 @@ func TestParseBlockHealthCheck(t *testing.T) { } } +func TestStop(t *testing.T) { + config := "proxy / %s {\n health_check /healthcheck \nhealth_check_interval %dms \n}" + tests := []struct { + name string + intervalInMilliseconds int + numHealthcheckIntervals int + }{ + { + "No Healthchecks After Stop - 5ms, 1 intervals", + 5, + 1, + }, + { + "No Healthchecks After Stop - 5ms, 2 intervals", + 5, + 2, + }, + { + "No Healthchecks After Stop - 5ms, 3 intervals", + 5, + 3, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + + // Set up proxy. + var counter int64 + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + r.Body.Close() + atomic.AddInt64(&counter, 1) + })) + + defer backend.Close() + + upstreams, err := NewStaticUpstreams(caddyfile.NewDispenser("Testfile", strings.NewReader(fmt.Sprintf(config, backend.URL, test.intervalInMilliseconds)))) + if err != nil { + t.Error("Expected no error. Got:", err.Error()) + } + + // Give some time for healthchecks to hit the server. + time.Sleep(time.Duration(test.intervalInMilliseconds*test.numHealthcheckIntervals) * time.Millisecond) + + for _, upstream := range upstreams { + if err := upstream.Stop(); err != nil { + t.Error("Expected no error stopping upstream. Got: ", err.Error()) + } + } + + counterValueAfterShutdown := atomic.LoadInt64(&counter) + + // Give some time to see if healthchecks are still hitting the server. + time.Sleep(time.Duration(test.intervalInMilliseconds*test.numHealthcheckIntervals) * time.Millisecond) + + if counterValueAfterShutdown == 0 { + t.Error("Expected healthchecks to hit test server. Got no healthchecks.") + } + + counterValueAfterWaiting := atomic.LoadInt64(&counter) + if counterValueAfterWaiting != counterValueAfterShutdown { + t.Errorf("Expected no more healthchecks after shutdown. Got: %d healthchecks after shutdown", counterValueAfterWaiting-counterValueAfterShutdown) + } + + }) + + } +} + func TestParseBlock(t *testing.T) { r, _ := http.NewRequest("GET", "/", nil) tests := []struct {