From ccd3e55b328bddf85f3b33f870d64d2df676d3cf Mon Sep 17 00:00:00 2001 From: Austin Date: Mon, 1 Jun 2015 10:23:57 -0700 Subject: [PATCH] changes as noted in PR --- middleware/proxy/reverseproxy.go | 16 +++++++--------- middleware/proxy/upstream.go | 16 ++++------------ 2 files changed, 11 insertions(+), 21 deletions(-) diff --git a/middleware/proxy/reverseproxy.go b/middleware/proxy/reverseproxy.go index f3a0390b5..a49d1080c 100644 --- a/middleware/proxy/reverseproxy.go +++ b/middleware/proxy/reverseproxy.go @@ -16,8 +16,6 @@ import ( "time" ) -const HTTPSwitchingProtocols = 101 - // onExitFlushLoop is a callback set by tests to detect the state of the // flushLoop() goroutine. var onExitFlushLoop func() @@ -149,13 +147,7 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request, extr } defer res.Body.Close() - for _, h := range hopHeaders { - res.Header.Del(h) - } - - copyHeader(rw.Header(), res.Header) - - if res.StatusCode == HTTPSwitchingProtocols && outreq.Header.Get("Upgrade") == "websocket" { + if res.StatusCode == http.StatusSwitchingProtocols && outreq.Header.Get("Upgrade") == "websocket" { hj, ok := rw.(http.Hijacker) if !ok { return nil @@ -182,6 +174,12 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request, extr io.Copy(conn, backendConn) // read tcp stream from backend. conn.Close() } else { + for _, h := range hopHeaders { + res.Header.Del(h) + } + + copyHeader(rw.Header(), res.Header) + rw.WriteHeader(res.StatusCode) p.copyResponse(rw, res.Body) } diff --git a/middleware/proxy/upstream.go b/middleware/proxy/upstream.go index 4c1b9fff7..011a58b86 100644 --- a/middleware/proxy/upstream.go +++ b/middleware/proxy/upstream.go @@ -14,7 +14,7 @@ import ( var ( supportedPolicies map[string]func() Policy = make(map[string]func() Policy) - proxyHeaders http.Header + proxyHeaders http.Header = make(http.Header) ) type staticUpstream struct { @@ -100,10 +100,10 @@ func NewStaticUpstreams(c parse.Dispenser) ([]Upstream, error) { if !c.Args(&header, &value) { return upstreams, c.ArgErr() } - addProxyHeader(header, value) + proxyHeaders.Add(header, value) case "websocket": - addProxyHeader("Connection", "{>Connection}") - addProxyHeader("Upgrade", "{>Upgrade}") + proxyHeaders.Add("Connection", "{>Connection}") + proxyHeaders.Add("Upgrade", "{>Upgrade}") } } @@ -153,14 +153,6 @@ func RegisterPolicy(name string, policy func() Policy) { supportedPolicies[name] = policy } -// AddProxyHeader adds a proxy header. -func addProxyHeader(header, value string) { - if proxyHeaders == nil { - proxyHeaders = make(map[string][]string) - } - proxyHeaders.Add(header, value) -} - func (u *staticUpstream) From() string { return u.from }