From bc453fa6ae36287c90d2bf6941cb686490090df2 Mon Sep 17 00:00:00 2001
From: Mohammed Al Sahaf <msaa1990@gmail.com>
Date: Thu, 17 Sep 2020 19:25:34 +0300
Subject: [PATCH] reverseproxy: Correct alternate port for active health checks
 (#3693)

* reverseproxy: construct active health-check transport from scratch (Fixes #3691)

* reverseproxy: do upstream health-check on the correct alternative port

* reverseproxy: add integration test for health-check on alternative port

* reverseproxy: put back the custom transport for health-check http client

* reverseproxy: cleanup health-check integration test

* reverseproxy: fix health-check of unix socket upstreams

* reverseproxy: skip unix socket tests on Windows

* tabs > spaces

Co-authored-by: Francis Lavoie <lavofr@gmail.com>

* make the linter (and @francislavoie) happy

Co-authored-by: Francis Lavoie <lavofr@gmail.com>

* One more lint fix

Co-authored-by: Francis Lavoie <lavofr@gmail.com>

Co-authored-by: Francis Lavoie <lavofr@gmail.com>
---
 caddytest/integration/reverseproxy_test.go    | 97 +++++++++++++++++++
 .../caddyhttp/reverseproxy/healthchecks.go    | 33 +++----
 modules/caddyhttp/reverseproxy/hosts.go       |  6 +-
 .../caddyhttp/reverseproxy/httptransport.go   |  3 +
 .../caddyhttp/reverseproxy/reverseproxy.go    | 14 +++
 5 files changed, 132 insertions(+), 21 deletions(-)
 create mode 100644 caddytest/integration/reverseproxy_test.go

diff --git a/caddytest/integration/reverseproxy_test.go b/caddytest/integration/reverseproxy_test.go
new file mode 100644
index 000000000..8000546d6
--- /dev/null
+++ b/caddytest/integration/reverseproxy_test.go
@@ -0,0 +1,97 @@
+package integration
+
+import (
+	"fmt"
+	"io/ioutil"
+	"net"
+	"net/http"
+	"os"
+	"runtime"
+	"strings"
+	"testing"
+
+	"github.com/caddyserver/caddy/v2/caddytest"
+)
+
+func TestReverseProxyHealthCheck(t *testing.T) {
+	tester := caddytest.NewTester(t)
+	tester.InitServer(`
+	{
+		http_port     9080
+		https_port    9443
+	}
+	http://localhost:2020 {
+		respond "Hello, World!"
+	}
+	http://localhost:2021 {
+		respond "ok"
+	}
+	http://localhost:9080 {
+		reverse_proxy {
+			to localhost:2020
+	
+			health_path /health
+			health_port 2021
+			health_interval 2s
+			health_timeout 5s
+		}
+	}
+  `, "caddyfile")
+
+	tester.AssertGetResponse("http://localhost:9080/", 200, "Hello, World!")
+}
+
+func TestReverseProxyHealthCheckUnixSocket(t *testing.T) {
+	if runtime.GOOS == "windows" {
+		t.SkipNow()
+	}
+	tester := caddytest.NewTester(t)
+	f, err := ioutil.TempFile("", "*.sock")
+	if err != nil {
+		t.Errorf("failed to create TempFile: %s", err)
+		return
+	}
+	// a hack to get a file name within a valid path to use as socket
+	socketName := f.Name()
+	os.Remove(f.Name())
+
+	server := http.Server{
+		Handler: http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
+			if strings.HasPrefix(req.URL.Path, "/health") {
+				w.Write([]byte("ok"))
+				return
+			}
+			w.Write([]byte("Hello, World!"))
+		}),
+	}
+
+	unixListener, err := net.Listen("unix", socketName)
+	if err != nil {
+		t.Errorf("failed to listen on the socket: %s", err)
+		return
+	}
+	go server.Serve(unixListener)
+	t.Cleanup(func() {
+		server.Close()
+	})
+	runtime.Gosched() // Allow other goroutines to run
+
+	tester.InitServer(fmt.Sprintf(`
+	{
+		http_port     9080
+		https_port    9443
+	}
+	http://localhost:9080 {
+		reverse_proxy {
+			to unix/%s
+	
+			health_path /health
+			health_port 2021
+			health_interval 2s
+			health_timeout 5s
+		}
+	}
+	`, socketName), "caddyfile")
+
+	tester.AssertGetResponse("http://localhost:9080/", 200, "Hello, World!")
+}
diff --git a/modules/caddyhttp/reverseproxy/healthchecks.go b/modules/caddyhttp/reverseproxy/healthchecks.go
index 33cfd82b7..410b9d4a3 100644
--- a/modules/caddyhttp/reverseproxy/healthchecks.go
+++ b/modules/caddyhttp/reverseproxy/healthchecks.go
@@ -153,32 +153,27 @@ func (h *Handler) doActiveHealthCheckForAllHosts() {
 					log.Printf("[PANIC] active health check: %v\n%s", err, debug.Stack())
 				}
 			}()
-			networkAddr := upstream.Dial
-			addr, err := caddy.ParseNetworkAddress(networkAddr)
-			if err != nil {
-				h.HealthChecks.Active.logger.Error("bad network address",
-					zap.String("address", networkAddr),
-					zap.Error(err),
-				)
-				return
-			}
-			if addr.PortRangeSize() != 1 {
-				h.HealthChecks.Active.logger.Error("multiple addresses (upstream must map to only one address)",
-					zap.String("address", networkAddr),
-				)
-				return
-			}
-			hostAddr := addr.JoinHostPort(0)
-			if addr.IsUnixNetwork() {
+
+			portStr := strconv.Itoa(upstream.activeHealthCheckPort)
+			hostAddr := net.JoinHostPort(upstream.networkAddress.Host, portStr)
+			if upstream.networkAddress.IsUnixNetwork() {
 				// this will be used as the Host portion of a http.Request URL, and
 				// paths to socket files would produce an error when creating URL,
 				// so use a fake Host value instead; unix sockets are usually local
 				hostAddr = "localhost"
 			}
-			err = h.doActiveHealthCheck(DialInfo{Network: addr.Network, Address: hostAddr}, hostAddr, upstream.Host)
+
+			dialInfo := DialInfo{
+				Upstream: upstream,
+				Network:  upstream.networkAddress.Network,
+				Host:     upstream.networkAddress.Host,
+				Port:     portStr,
+				Address:  hostAddr,
+			}
+			err := h.doActiveHealthCheck(dialInfo, hostAddr, upstream.Host)
 			if err != nil {
 				h.HealthChecks.Active.logger.Error("active health check failed",
-					zap.String("address", networkAddr),
+					zap.String("address", hostAddr),
 					zap.Error(err),
 				)
 			}
diff --git a/modules/caddyhttp/reverseproxy/hosts.go b/modules/caddyhttp/reverseproxy/hosts.go
index 5870b75fd..b7b8c9b59 100644
--- a/modules/caddyhttp/reverseproxy/hosts.go
+++ b/modules/caddyhttp/reverseproxy/hosts.go
@@ -92,8 +92,10 @@ type Upstream struct {
 	// HeaderAffinity string
 	// IPAffinity     string
 
-	healthCheckPolicy *PassiveHealthChecks
-	cb                CircuitBreaker
+	networkAddress        caddy.NetworkAddress
+	activeHealthCheckPort int
+	healthCheckPolicy     *PassiveHealthChecks
+	cb                    CircuitBreaker
 }
 
 func (u Upstream) String() string {
diff --git a/modules/caddyhttp/reverseproxy/httptransport.go b/modules/caddyhttp/reverseproxy/httptransport.go
index dce7b9e87..7e3bb69c0 100644
--- a/modules/caddyhttp/reverseproxy/httptransport.go
+++ b/modules/caddyhttp/reverseproxy/httptransport.go
@@ -182,6 +182,9 @@ func (h *HTTPTransport) NewTransport(ctx caddy.Context) (*http.Transport, error)
 			if dialInfo, ok := GetDialInfo(ctx); ok {
 				network = dialInfo.Network
 				address = dialInfo.Address
+				if dialInfo.Upstream.networkAddress.IsUnixNetwork() {
+					address = dialInfo.Host
+				}
 			}
 			conn, err := dialer.DialContext(ctx, network, address)
 			if err != nil {
diff --git a/modules/caddyhttp/reverseproxy/reverseproxy.go b/modules/caddyhttp/reverseproxy/reverseproxy.go
index 910fbfc7b..138a3fc84 100644
--- a/modules/caddyhttp/reverseproxy/reverseproxy.go
+++ b/modules/caddyhttp/reverseproxy/reverseproxy.go
@@ -208,9 +208,13 @@ func (h *Handler) Provision(ctx caddy.Context) error {
 		if err != nil {
 			return err
 		}
+
 		if addr.PortRangeSize() != 1 {
 			return fmt.Errorf("multiple addresses (upstream must map to only one address): %v", addr)
 		}
+
+		upstream.networkAddress = addr
+
 		// create or get the host representation for this upstream
 		var host Host = new(upstreamHost)
 		existingHost, loaded := hosts.LoadOrStore(upstream.String(), host)
@@ -267,6 +271,16 @@ func (h *Handler) Provision(ctx caddy.Context) error {
 				Transport: h.Transport,
 			}
 
+			for _, upstream := range h.Upstreams {
+				// if there's an alternative port for health-check provided in the config,
+				// then use it, otherwise use the port of upstream.
+				if h.HealthChecks.Active.Port != 0 {
+					upstream.activeHealthCheckPort = h.HealthChecks.Active.Port
+				} else {
+					upstream.activeHealthCheckPort = int(upstream.networkAddress.StartPort)
+				}
+			}
+
 			if h.HealthChecks.Active.Interval == 0 {
 				h.HealthChecks.Active.Interval = caddy.Duration(30 * time.Second)
 			}