reverseproxy: refactor HTTP transport layer (#5369)

Co-authored-by: Francis Lavoie <lavofr@gmail.com>
Co-authored-by: Weidi Deng <weidi_deng@icloud.com>
This commit is contained in:
Mohammed Al Sahaf 2023-02-24 22:54:04 +03:00 committed by GitHub
parent be53e432fc
commit e3909cc385
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 148 additions and 140 deletions

View file

@ -33,7 +33,7 @@ func TestSRVReverseProxy(t *testing.T) {
"servers": { "servers": {
"srv0": { "srv0": {
"listen": [ "listen": [
":8080" ":18080"
], ],
"routes": [ "routes": [
{ {
@ -73,7 +73,7 @@ func TestSRVWithDial(t *testing.T) {
"servers": { "servers": {
"srv0": { "srv0": {
"listen": [ "listen": [
":8080" ":18080"
], ],
"routes": [ "routes": [
{ {
@ -149,7 +149,7 @@ func TestDialWithPlaceholderUnix(t *testing.T) {
"servers": { "servers": {
"srv0": { "srv0": {
"listen": [ "listen": [
":8080" ":18080"
], ],
"routes": [ "routes": [
{ {
@ -172,7 +172,7 @@ func TestDialWithPlaceholderUnix(t *testing.T) {
} }
`, "json") `, "json")
req, err := http.NewRequest(http.MethodGet, "http://localhost:8080", nil) req, err := http.NewRequest(http.MethodGet, "http://localhost:18080", nil)
if err != nil { if err != nil {
t.Fail() t.Fail()
return return
@ -201,7 +201,7 @@ func TestReverseProxyWithPlaceholderDialAddress(t *testing.T) {
"servers": { "servers": {
"srv0": { "srv0": {
"listen": [ "listen": [
":8080" ":18080"
], ],
"routes": [ "routes": [
{ {
@ -271,7 +271,7 @@ func TestReverseProxyWithPlaceholderDialAddress(t *testing.T) {
t.Fail() t.Fail()
return return
} }
req.Header.Set("X-Caddy-Upstream-Dial", "localhost:8080") req.Header.Set("X-Caddy-Upstream-Dial", "localhost:18080")
tester.AssertResponse(req, 200, "Hello, World!") tester.AssertResponse(req, 200, "Hello, World!")
} }
@ -295,7 +295,7 @@ func TestReverseProxyWithPlaceholderTCPDialAddress(t *testing.T) {
"servers": { "servers": {
"srv0": { "srv0": {
"listen": [ "listen": [
":8080" ":18080"
], ],
"routes": [ "routes": [
{ {
@ -340,7 +340,7 @@ func TestReverseProxyWithPlaceholderTCPDialAddress(t *testing.T) {
"handler": "reverse_proxy", "handler": "reverse_proxy",
"upstreams": [ "upstreams": [
{ {
"dial": "tcp/{http.request.header.X-Caddy-Upstream-Dial}:8080" "dial": "tcp/{http.request.header.X-Caddy-Upstream-Dial}:18080"
} }
] ]
} }
@ -385,7 +385,7 @@ func TestSRVWithActiveHealthcheck(t *testing.T) {
"servers": { "servers": {
"srv0": { "srv0": {
"listen": [ "listen": [
":8080" ":18080"
], ],
"routes": [ "routes": [
{ {

View file

@ -172,13 +172,20 @@ func (h *HTTPTransport) NewTransport(caddyCtx caddy.Context) (*http.Transport, e
} }
} }
// Set up the dialer to pull the correct information from the context
dialContext := func(ctx context.Context, network, address string) (net.Conn, error) { dialContext := func(ctx context.Context, network, address string) (net.Conn, error) {
// the proper dialing information should be embedded into the request's context // For unix socket upstreams, we need to recover the dial info from
// the request's context, because the Host on the request's URL
// will have been modified by directing the request, overwriting
// the unix socket filename.
// Also, we need to avoid overwriting the address at this point
// when not necessary, because http.ProxyFromEnvironment may have
// modified the address according to the user's env proxy config.
if dialInfo, ok := GetDialInfo(ctx); ok { if dialInfo, ok := GetDialInfo(ctx); ok {
if strings.HasPrefix(dialInfo.Network, "unix") {
network = dialInfo.Network network = dialInfo.Network
address = dialInfo.Address address = dialInfo.Address
} }
}
conn, err := dialer.DialContext(ctx, network, address) conn, err := dialer.DialContext(ctx, network, address)
if err != nil { if err != nil {
@ -188,8 +195,8 @@ func (h *HTTPTransport) NewTransport(caddyCtx caddy.Context) (*http.Transport, e
return nil, DialError{err} return nil, DialError{err}
} }
// if read/write timeouts are configured and this is a TCP connection, enforce the timeouts // if read/write timeouts are configured and this is a TCP connection,
// by wrapping the connection with our own type // enforce the timeouts by wrapping the connection with our own type
if tcpConn, ok := conn.(*net.TCPConn); ok && (h.ReadTimeout > 0 || h.WriteTimeout > 0) { if tcpConn, ok := conn.(*net.TCPConn); ok && (h.ReadTimeout > 0 || h.WriteTimeout > 0) {
conn = &tcpRWTimeoutConn{ conn = &tcpRWTimeoutConn{
TCPConn: tcpConn, TCPConn: tcpConn,
@ -203,6 +210,7 @@ func (h *HTTPTransport) NewTransport(caddyCtx caddy.Context) (*http.Transport, e
} }
rt := &http.Transport{ rt := &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: dialContext, DialContext: dialContext,
MaxConnsPerHost: h.MaxConnsPerHost, MaxConnsPerHost: h.MaxConnsPerHost,
ResponseHeaderTimeout: time.Duration(h.ResponseHeaderTimeout), ResponseHeaderTimeout: time.Duration(h.ResponseHeaderTimeout),