diff --git a/caddyhttp/proxy/reverseproxy.go b/caddyhttp/proxy/reverseproxy.go index d32be7cdb..021025a05 100644 --- a/caddyhttp/proxy/reverseproxy.go +++ b/caddyhttp/proxy/reverseproxy.go @@ -349,9 +349,14 @@ func newConnHijackerTransport(base http.RoundTripper) *connHijackerTransport { MaxIdleConnsPerHost: -1, } if b, _ := base.(*http.Transport); b != nil { + tlsClientConfig := b.TLSClientConfig + if tlsClientConfig.NextProtos != nil { + tlsClientConfig = cloneTLSClientConfig(tlsClientConfig) + tlsClientConfig.NextProtos = nil + } + t.Proxy = b.Proxy - t.TLSClientConfig = cloneTLSClientConfig(b.TLSClientConfig) - t.TLSClientConfig.NextProtos = nil + t.TLSClientConfig = tlsClientConfig t.TLSHandshakeTimeout = b.TLSHandshakeTimeout t.Dial = b.Dial t.DialTLS = b.DialTLS @@ -363,19 +368,15 @@ func newConnHijackerTransport(base http.RoundTripper) *connHijackerTransport { dial := getTransportDial(t) dialTLS := getTransportDialTLS(t) - t.Dial = func(network, addr string) (net.Conn, error) { c, err := dial(network, addr) hj.Conn = c return &hijackedConn{c, hj}, err } - - if dialTLS != nil { - t.DialTLS = func(network, addr string) (net.Conn, error) { - c, err := dialTLS(network, addr) - hj.Conn = c - return &hijackedConn{c, hj}, err - } + t.DialTLS = func(network, addr string) (net.Conn, error) { + c, err := dialTLS(network, addr) + hj.Conn = c + return &hijackedConn{c, hj}, err } return hj @@ -390,27 +391,35 @@ func getTransportDial(t *http.Transport) func(network, addr string) (net.Conn, e return defaultDialer.Dial } -// getTransportDial returns a TLS Dialer if TLSClientConfig is non-nil +// getTransportDial always returns a TLS Dialer // and defaults to the existing t.DialTLS. func getTransportDialTLS(t *http.Transport) func(network, addr string) (net.Conn, error) { if t.DialTLS != nil { return t.DialTLS } - if t.TLSClientConfig == nil { - return nil - } // newConnHijackerTransport will modify t.Dial after calling this method // => Create a backup reference. plainDial := getTransportDial(t) + // The following DialTLS implementation stems from the Go stdlib and + // is identical to what happens if DialTLS is not provided. + // Source: https://github.com/golang/go/blob/230a376b5a67f0e9341e1fa47e670ff762213c83/src/net/http/transport.go#L1018-L1051 return func(network, addr string) (net.Conn, error) { plainConn, err := plainDial(network, addr) if err != nil { return nil, err } - tlsConn := tls.Client(plainConn, t.TLSClientConfig) + tlsClientConfig := t.TLSClientConfig + if tlsClientConfig == nil { + tlsClientConfig = &tls.Config{} + } + if !tlsClientConfig.InsecureSkipVerify && tlsClientConfig.ServerName == "" { + tlsClientConfig.ServerName = stripPort(addr) + } + + tlsConn := tls.Client(plainConn, tlsClientConfig) errc := make(chan error, 2) var timer *time.Timer if d := t.TLSHandshakeTimeout; d != 0 { @@ -429,16 +438,12 @@ func getTransportDialTLS(t *http.Transport) func(network, addr string) (net.Conn plainConn.Close() return nil, err } - if !t.TLSClientConfig.InsecureSkipVerify { - serverName := t.TLSClientConfig.ServerName - if serverName == "" { - serverName = addr - idx := strings.LastIndex(serverName, ":") - if idx != -1 { - serverName = serverName[:idx] - } + if !tlsClientConfig.InsecureSkipVerify { + hostname := tlsClientConfig.ServerName + if hostname == "" { + hostname = stripPort(addr) } - if err := tlsConn.VerifyHostname(serverName); err != nil { + if err := tlsConn.VerifyHostname(hostname); err != nil { plainConn.Close() return nil, err } @@ -448,6 +453,22 @@ func getTransportDialTLS(t *http.Transport) func(network, addr string) (net.Conn } } +// stripPort returns address without its port if it has one and +// works with IP addresses as well as hostnames formatted as host:port. +// +// IPv6 addresses (excluding the port) must be enclosed in +// square brackets similar to the requirements of Go's stdlib. +func stripPort(address string) string { + // Keep in mind that the address might be a IPv6 address + // and thus contain a colon, but not have a port. + portIdx := strings.LastIndex(address, ":") + ipv6Idx := strings.LastIndex(address, "]") + if portIdx > ipv6Idx { + address = address[:portIdx] + } + return address +} + type tlsHandshakeTimeoutError struct{} func (tlsHandshakeTimeoutError) Timeout() bool { return true }