From 86060ef9b4bddc6f21ce0e58feeb0af95ce9934b Mon Sep 17 00:00:00 2001
From: Eugen Kleiner <eklein@ext.uber.com>
Date: Mon, 29 Oct 2018 17:00:44 -0700
Subject: [PATCH] caddy: Add OnRestartFailed callback (#2262)

* Add callback OnRestartFailed to caddy.Controller

* markdown: Fix 500 error (#2266)

* Addressed the comments

* Update paths for filebrowser plugins

* httpserver: update minify ordering (#2273)

* Bump required version of golang to 1.10 in README.md (#2267)

Adding TLS client cert placeholders #2217 uses features of go
v1.10.  Update README requirements accordingly.

* Update CI to use Go 1.11

* caddytls: gofmt (Go 1.11) (#2241)

* Ensure assets path exists before writing UUID file

* Adding {when_unix_ms} requests placeholder (unix timestamp with a milliseconds precision) (#2260)

* update to quic-go v0.10.0 (#2288)

quic-go now vendors all of its dependencies, so we don't need to vendor
them here.

Created by running:
gvt delete github.com/lucas-clemente/quic-go
gvt delete github.com/bifurcation/mint
gvt delete github.com/lucas-clemente/aes12
gvt delete github.com/lucas-clemente/fnv128a
gvt delete github.com/lucas-clemente/quic-go-certificates
gvt delete github.com/aead/chacha20
gvt delete github.com/hashicorp/golang-lru
gvt fetch -tag v0.10.0-no-integrationtests github.com/lucas-clemente/quic-go

* fastcgi: Add default timeouts (#2265)

Default fastcgi timeout is 60 seconds
Add tests

* Fix AppVeyor builds (#2289)

* Attempting to fix AppVeyor builds

* Trying again, 2015 image this time

* Use Appveyor's Go 1.11 stack

* Restore GOPATH\bin to PATH and delete old image config

* Add gcc to path manually

* Addressed the comments

* Fix broken link to sourcegraph in README (#2285)

* Fix deadlock, ensure instances mutex unlocked (#2296)

it's a stupid mistake

* proxy: Use DualStack=true in defaultDialer (#2305)

* ci: get golint tool from `golang.org/x/lint/golint` (#2324)

* templates: TLSVersion (#2323)

* new template action: TLS protocol version

* new template action: use caddytls.GetSupportedProtocolName

Avoids code duplication by reusing existing method to get TLS protocol
version used on connection. Also adds tests

* Don't return error on onRestartFail. Only log it.
---
 caddy.go      | 22 ++++++++++++++--
 caddy_test.go | 69 +++++++++++++++++++++++++++++++++++++++++++++++++++
 controller.go |  6 +++++
 3 files changed, 95 insertions(+), 2 deletions(-)

diff --git a/caddy.go b/caddy.go
index d2f60001d..1adf01134 100644
--- a/caddy.go
+++ b/caddy.go
@@ -111,6 +111,7 @@ type Instance struct {
 	onFirstStartup  []func() error // starting, not as part of a restart
 	onStartup       []func() error // starting, even as part of a restart
 	onRestart       []func() error // before restart commences
+	onRestartFailed []func() error // if restart failed
 	onShutdown      []func() error // stopping, even as part of a restart
 	onFinalShutdown []func() error // stopping, not as part of a restart
 
@@ -186,9 +187,26 @@ func (i *Instance) Restart(newCaddyfile Input) (*Instance, error) {
 	i.wg.Add(1)
 	defer i.wg.Done()
 
+	var err error
+	// if something went wrong on restart then run onRestartFailed callbacks
+	defer func() {
+		r := recover()
+		if err != nil || r != nil {
+			for _, fn := range i.onRestartFailed {
+				err = fn()
+				if err != nil {
+					log.Printf("[ERROR] restart failed: %v", err)
+				}
+			}
+			if r != nil {
+				panic(r)
+			}
+		}
+	}()
+
 	// run restart callbacks
 	for _, fn := range i.onRestart {
-		err := fn()
+		err = fn()
 		if err != nil {
 			return i, err
 		}
@@ -224,7 +242,7 @@ func (i *Instance) Restart(newCaddyfile Input) (*Instance, error) {
 	newInst := &Instance{serverType: newCaddyfile.ServerType(), wg: i.wg, Storage: make(map[interface{}]interface{})}
 
 	// attempt to start new instance
-	err := startWithListenerFds(newCaddyfile, newInst, restartFds)
+	err = startWithListenerFds(newCaddyfile, newInst, restartFds)
 	if err != nil {
 		return i, err
 	}
diff --git a/caddy_test.go b/caddy_test.go
index 76cd2ec25..0a18100f6 100644
--- a/caddy_test.go
+++ b/caddy_test.go
@@ -15,9 +15,14 @@
 package caddy
 
 import (
+	"fmt"
 	"net"
+	"reflect"
 	"strconv"
+	"sync"
 	"testing"
+
+	"github.com/mholt/caddy/caddyfile"
 )
 
 /*
@@ -48,6 +53,70 @@ func TestCaddyStartStop(t *testing.T) {
 }
 */
 
+// CallbackTestContext implements Context interface
+type CallbackTestContext struct {
+	// If MakeServersFail is set to true then MakeServers returns an error
+	MakeServersFail bool
+}
+
+func (h *CallbackTestContext) InspectServerBlocks(name string, sblock []caddyfile.ServerBlock) ([]caddyfile.ServerBlock, error) {
+	return sblock, nil
+}
+func (h *CallbackTestContext) MakeServers() ([]Server, error) {
+	if h.MakeServersFail {
+		return make([]Server, 0), fmt.Errorf("MakeServers failed")
+	}
+	return make([]Server, 0), nil
+}
+
+func TestCaddyRestartCallbacks(t *testing.T) {
+	for i, test := range []struct {
+		restartFail   bool
+		expectedCalls []string
+	}{
+		{false, []string{"OnRestart", "OnShutdown"}},
+		{true, []string{"OnRestart", "OnRestartFailed"}},
+	} {
+		serverName := fmt.Sprintf("%v", i)
+		// RegisterServerType to make successful restart possible
+		RegisterServerType(serverName, ServerType{
+			Directives: func() []string { return []string{} },
+			// If MakeServersFail is true then the restart will fail due to context failure
+			NewContext: func(inst *Instance) Context { return &CallbackTestContext{MakeServersFail: test.restartFail} },
+		})
+		c := NewTestController(serverName, "")
+		c.instance = &Instance{
+			serverType: serverName,
+			wg:         new(sync.WaitGroup),
+		}
+
+		// Register callbacks which save the calls order
+		calls := make([]string, 0)
+		c.OnRestart(func() error {
+			calls = append(calls, "OnRestart")
+			return nil
+		})
+		c.OnRestartFailed(func() error {
+			calls = append(calls, "OnRestartFailed")
+			return nil
+		})
+		c.OnShutdown(func() error {
+			calls = append(calls, "OnShutdown")
+			return nil
+		})
+
+		c.instance.Restart(CaddyfileInput{Contents: []byte(""), ServerTypeName: serverName})
+
+		if !reflect.DeepEqual(calls, test.expectedCalls) {
+			t.Errorf("Test %d: Callbacks expected: %v, got: %v", i, test.expectedCalls, calls)
+		}
+
+		c.instance.Stop()
+		c.instance.Wait()
+	}
+
+}
+
 func TestIsLoopback(t *testing.T) {
 	for i, test := range []struct {
 		input  string
diff --git a/controller.go b/controller.go
index 6015d210f..f63cebe00 100644
--- a/controller.go
+++ b/controller.go
@@ -86,6 +86,12 @@ func (c *Controller) OnRestart(fn func() error) {
 	c.instance.onRestart = append(c.instance.onRestart, fn)
 }
 
+// OnRestartFailed adds fn to the list of callback functions to execute
+// if the server failed to restart.
+func (c *Controller) OnRestartFailed(fn func() error) {
+	c.instance.onRestartFailed = append(c.instance.onRestartFailed, fn)
+}
+
 // OnShutdown adds fn to the list of callback functions to execute
 // when the server is about to be shut down (including restarts).
 func (c *Controller) OnShutdown(fn func() error) {