diff --git a/caddyhttp/proxy/reverseproxy.go b/caddyhttp/proxy/reverseproxy.go index eadd7e3ad..9d7087c24 100644 --- a/caddyhttp/proxy/reverseproxy.go +++ b/caddyhttp/proxy/reverseproxy.go @@ -183,9 +183,80 @@ var hopHeaders = []string{ type respUpdateFn func(resp *http.Response) +type hijackedConn struct { + net.Conn + hj *connHijackerTransport +} + +func (c *hijackedConn) Read(b []byte) (n int, err error) { + n, err = c.Conn.Read(b) + c.hj.Replay = append(c.hj.Replay, b[:n]...) + return +} + +func (c *hijackedConn) Close() error { + return nil +} + +type connHijackerTransport struct { + *http.Transport + Conn net.Conn + Replay []byte +} + +func newConnHijackerTransport(base http.RoundTripper) *connHijackerTransport { + transport := &http.Transport{ + Proxy: http.ProxyFromEnvironment, + Dial: (&net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + }).Dial, + TLSHandshakeTimeout: 10 * time.Second, + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + } + if base != nil { + if baseTransport, ok := base.(*http.Transport); ok { + transport.Proxy = baseTransport.Proxy + transport.TLSClientConfig = baseTransport.TLSClientConfig + transport.TLSHandshakeTimeout = baseTransport.TLSHandshakeTimeout + transport.Dial = baseTransport.Dial + transport.DialTLS = baseTransport.DialTLS + transport.DisableKeepAlives = true + } + } + hjTransport := &connHijackerTransport{transport, nil, bufferPool.Get().([]byte)[:0]} + oldDial := transport.Dial + oldDialTLS := transport.DialTLS + if oldDial == nil { + oldDial = (&net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + }).Dial + } + hjTransport.Dial = func(network, addr string) (net.Conn, error) { + c, err := oldDial(network, addr) + hjTransport.Conn = c + return &hijackedConn{c, hjTransport}, err + } + if oldDialTLS != nil { + hjTransport.DialTLS = func(network, addr string) (net.Conn, error) { + c, err := oldDialTLS(network, addr) + hjTransport.Conn = c + return &hijackedConn{c, hjTransport}, err + } + } + return hjTransport +} + +func requestIsWebsocket(req *http.Request) bool { + return !(strings.ToLower(req.Header.Get("Upgrade")) != "websocket" || !strings.Contains(strings.ToLower(req.Header.Get("Connection")), "upgrade")) +} + func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request, respUpdateFn respUpdateFn) error { transport := p.Transport - if transport == nil { + if requestIsWebsocket(outreq) { + transport = newConnHijackerTransport(transport) + } else if transport == nil { transport = http.DefaultTransport } @@ -216,13 +287,22 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request, r } defer conn.Close() - backendConn, err := net.Dial("tcp", outreq.URL.Host) - if err != nil { - return err - } - defer backendConn.Close() + var backendConn net.Conn + if hj, ok := transport.(*connHijackerTransport); ok { + backendConn = hj.Conn + if _, err := conn.Write(hj.Replay); err != nil { + return err + } + bufferPool.Put(hj.Replay) + } else { + backendConn, err = net.Dial("tcp", outreq.URL.Host) + if err != nil { + return err + } + defer backendConn.Close() - outreq.Write(backendConn) + outreq.Write(backendConn) + } go func() { io.Copy(backendConn, conn) // write tcp stream to backend.