From 2019eec5a573b5939918795af22908c0d7bfe758 Mon Sep 17 00:00:00 2001 From: Matthew Holt Date: Sat, 6 Aug 2016 14:46:52 -0600 Subject: [PATCH] Fix lint warnings; group methods for same type together --- caddyhttp/proxy/reverseproxy.go | 178 ++++++++++++++++---------------- 1 file changed, 90 insertions(+), 88 deletions(-) diff --git a/caddyhttp/proxy/reverseproxy.go b/caddyhttp/proxy/reverseproxy.go index 007bcb1ac..b86b35b2c 100644 --- a/caddyhttp/proxy/reverseproxy.go +++ b/caddyhttp/proxy/reverseproxy.go @@ -144,7 +144,7 @@ func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int) * return rp } -// InsecureTransport is used to facilitate HTTPS proxying +// UseInsecureTransport is used to facilitate HTTPS proxying // when it is OK for upstream to be using a bad certificate, // since this transport skips verification. func (rp *ReverseProxy) UseInsecureTransport() { @@ -163,6 +163,95 @@ func (rp *ReverseProxy) UseInsecureTransport() { } } +// ServeHTTP serves the proxied request to the upstream by performing a roundtrip. +// It is designed to handle websocket connection upgrades as well. +func (rp *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request, respUpdateFn respUpdateFn) error { + transport := rp.Transport + if requestIsWebsocket(outreq) { + transport = newConnHijackerTransport(transport) + } else if transport == nil { + transport = http.DefaultTransport + } + + rp.Director(outreq) + outreq.Proto = "HTTP/1.1" + outreq.ProtoMajor = 1 + outreq.ProtoMinor = 1 + outreq.Close = false + + res, err := transport.RoundTrip(outreq) + if err != nil { + return err + } + + if respUpdateFn != nil { + respUpdateFn(res) + } + if res.StatusCode == http.StatusSwitchingProtocols && strings.ToLower(res.Header.Get("Upgrade")) == "websocket" { + res.Body.Close() + hj, ok := rw.(http.Hijacker) + if !ok { + return nil + } + + conn, _, err := hj.Hijack() + if err != nil { + return err + } + defer conn.Close() + + var backendConn net.Conn + if hj, ok := transport.(*connHijackerTransport); ok { + backendConn = hj.Conn + if _, err := conn.Write(hj.Replay); err != nil { + return err + } + bufferPool.Put(hj.Replay) + } else { + backendConn, err = net.Dial("tcp", outreq.URL.Host) + if err != nil { + return err + } + outreq.Write(backendConn) + } + defer backendConn.Close() + + go func() { + io.Copy(backendConn, conn) // write tcp stream to backend. + }() + io.Copy(conn, backendConn) // read tcp stream from backend. + } else { + defer res.Body.Close() + for _, h := range hopHeaders { + res.Header.Del(h) + } + copyHeader(rw.Header(), res.Header) + rw.WriteHeader(res.StatusCode) + rp.copyResponse(rw, res.Body) + } + + return nil +} + +func (rp *ReverseProxy) copyResponse(dst io.Writer, src io.Reader) { + buf := bufferPool.Get() + defer bufferPool.Put(buf) + + if rp.FlushInterval != 0 { + if wf, ok := dst.(writeFlusher); ok { + mlw := &maxLatencyWriter{ + dst: wf, + latency: rp.FlushInterval, + done: make(chan bool), + } + go mlw.flushLoop() + defer mlw.stop() + dst = mlw + } + } + io.CopyBuffer(dst, src, buf.([]byte)) +} + func copyHeader(dst, src http.Header) { for k, vv := range src { for _, v := range vv { @@ -255,93 +344,6 @@ func requestIsWebsocket(req *http.Request) bool { return !(strings.ToLower(req.Header.Get("Upgrade")) != "websocket" || !strings.Contains(strings.ToLower(req.Header.Get("Connection")), "upgrade")) } -func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request, respUpdateFn respUpdateFn) error { - transport := p.Transport - if requestIsWebsocket(outreq) { - transport = newConnHijackerTransport(transport) - } else if transport == nil { - transport = http.DefaultTransport - } - - p.Director(outreq) - outreq.Proto = "HTTP/1.1" - outreq.ProtoMajor = 1 - outreq.ProtoMinor = 1 - outreq.Close = false - - res, err := transport.RoundTrip(outreq) - if err != nil { - return err - } - - if respUpdateFn != nil { - respUpdateFn(res) - } - if res.StatusCode == http.StatusSwitchingProtocols && strings.ToLower(res.Header.Get("Upgrade")) == "websocket" { - res.Body.Close() - hj, ok := rw.(http.Hijacker) - if !ok { - return nil - } - - conn, _, err := hj.Hijack() - if err != nil { - return err - } - defer conn.Close() - - var backendConn net.Conn - if hj, ok := transport.(*connHijackerTransport); ok { - backendConn = hj.Conn - if _, err := conn.Write(hj.Replay); err != nil { - return err - } - bufferPool.Put(hj.Replay) - } else { - backendConn, err = net.Dial("tcp", outreq.URL.Host) - if err != nil { - return err - } - outreq.Write(backendConn) - } - defer backendConn.Close() - - go func() { - io.Copy(backendConn, conn) // write tcp stream to backend. - }() - io.Copy(conn, backendConn) // read tcp stream from backend. - } else { - defer res.Body.Close() - for _, h := range hopHeaders { - res.Header.Del(h) - } - copyHeader(rw.Header(), res.Header) - rw.WriteHeader(res.StatusCode) - p.copyResponse(rw, res.Body) - } - - return nil -} - -func (p *ReverseProxy) copyResponse(dst io.Writer, src io.Reader) { - buf := bufferPool.Get() - defer bufferPool.Put(buf) - - if p.FlushInterval != 0 { - if wf, ok := dst.(writeFlusher); ok { - mlw := &maxLatencyWriter{ - dst: wf, - latency: p.FlushInterval, - done: make(chan bool), - } - go mlw.flushLoop() - defer mlw.stop() - dst = mlw - } - } - io.CopyBuffer(dst, src, buf.([]byte)) -} - type writeFlusher interface { io.Writer http.Flusher