diff --git a/middleware/proxy/proxy_test.go b/middleware/proxy/proxy_test.go new file mode 100644 index 000000000..409df4630 --- /dev/null +++ b/middleware/proxy/proxy_test.go @@ -0,0 +1,151 @@ +package proxy + +import ( + "bufio" + "bytes" + "io" + "net" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + "time" + + "golang.org/x/net/websocket" +) + +func TestWebSocketReverseProxyServeHTTPHandler(t *testing.T) { + // No-op websocket backend simply allows the WS connection to be + // accepted then it will be immediately closed. Perfect for testing. + wsNop := httptest.NewServer(websocket.Handler(func(ws *websocket.Conn) {})) + defer wsNop.Close() + + // Get proxy to use for the test + p := newWebSocketTestProxy(wsNop.URL) + + // Create client request + r, err := http.NewRequest("GET", "/", nil) + if err != nil { + t.Fatalf("Failed to create request: %v", err) + } + r.Header = http.Header{ + "Connection": {"Upgrade"}, + "Upgrade": {"websocket"}, + "Origin": {wsNop.URL}, + "Sec-WebSocket-Key": {"x3JJHMbDL1EzLkh9GBhXDw=="}, + "Sec-WebSocket-Version": {"13"}, + } + + // Capture the request + w := &recorderHijacker{httptest.NewRecorder(), new(fakeConn)} + + // Booya! Do the test. + p.ServeHTTP(w, r) + + // Make sure the backend accepted the WS connection. + // Mostly interested in the Upgrade and Connection response headers + // and the 101 status code. + expected := []byte("HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: HSmrc0sMlYUkAGmm5OPpG2HaGWk=\r\n\r\n") + actual := w.fakeConn.writeBuf.Bytes() + if !bytes.Equal(actual, expected) { + t.Errorf("Expected backend to accept response:\n'%s'\nActually got:\n'%s'", expected, actual) + } +} + +func TestWebSocketReverseProxyFromWSClient(t *testing.T) { + // Echo server allows us to test that socket bytes are properly + // being proxied. + wsEcho := httptest.NewServer(websocket.Handler(func(ws *websocket.Conn) { + io.Copy(ws, ws) + })) + defer wsEcho.Close() + + // Get proxy to use for the test + p := newWebSocketTestProxy(wsEcho.URL) + + // This is a full end-end test, so the proxy handler + // has to be part of a server listening on a port. Our + // WS client will connect to this test server, not + // the echo client directly. + echoProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + p.ServeHTTP(w, r) + })) + defer echoProxy.Close() + + // Set up WebSocket client + url := strings.Replace(echoProxy.URL, "http://", "ws://", 1) + ws, err := websocket.Dial(url, "", echoProxy.URL) + if err != nil { + t.Fatal(err) + } + defer ws.Close() + + // Send test message + trialMsg := "Is it working?" + websocket.Message.Send(ws, trialMsg) + + // It should be echoed back to us + var actualMsg string + websocket.Message.Receive(ws, &actualMsg) + if actualMsg != trialMsg { + t.Errorf("Expected '%s' but got '%s' instead", trialMsg, actualMsg) + } +} + +// newWebSocketTestProxy returns a test proxy that will +// redirect to the specified backendAddr. The function +// also sets up the rules/environment for testing WebSocket +// proxy. +func newWebSocketTestProxy(backendAddr string) *Proxy { + proxyHeaders = http.Header{ + "Connection": {"{>Connection}"}, + "Upgrade": {"{>Upgrade}"}, + } + + return &Proxy{ + Upstreams: []Upstream{&fakeUpstream{name: backendAddr}}, + } +} + +type fakeUpstream struct { + name string +} + +func (u *fakeUpstream) From() string { + return "/" +} + +func (u *fakeUpstream) Select() *UpstreamHost { + uri, _ := url.Parse(u.name) + return &UpstreamHost{ + Name: u.name, + ReverseProxy: NewSingleHostReverseProxy(uri), + ExtraHeaders: proxyHeaders, + } +} + +// recorderHijacker is a ResponseRecorder that can +// be hijacked. +type recorderHijacker struct { + *httptest.ResponseRecorder + fakeConn *fakeConn +} + +func (rh *recorderHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) { + return rh.fakeConn, nil, nil +} + +type fakeConn struct { + readBuf bytes.Buffer + writeBuf bytes.Buffer +} + +func (c *fakeConn) LocalAddr() net.Addr { return nil } +func (c *fakeConn) RemoteAddr() net.Addr { return nil } +func (c *fakeConn) SetDeadline(t time.Time) error { return nil } +func (c *fakeConn) SetReadDeadline(t time.Time) error { return nil } +func (c *fakeConn) SetWriteDeadline(t time.Time) error { return nil } +func (c *fakeConn) Close() error { return nil } +func (c *fakeConn) Read(b []byte) (int, error) { return c.readBuf.Read(b) } +func (c *fakeConn) Write(b []byte) (int, error) { return c.writeBuf.Write(b) } diff --git a/middleware/proxy/reverseproxy.go b/middleware/proxy/reverseproxy.go index 15350a993..0ffe54a11 100644 --- a/middleware/proxy/reverseproxy.go +++ b/middleware/proxy/reverseproxy.go @@ -1,8 +1,13 @@ -// Copyright 2011 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// HTTP reverse proxy handler +// This file is adapted from code in the net/http/httputil +// package of the Go standard library, which is by the +// Go Authors, and bears this copyright and license info: +// +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. +// +// This file has been modified from the standard lib to +// meet the needs of the application. package proxy @@ -157,22 +162,20 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request, extr if err != nil { return err } + defer conn.Close() backendConn, err := net.Dial("tcp", outreq.Host) if err != nil { - conn.Close() return err } + defer backendConn.Close() outreq.Write(backendConn) go func() { io.Copy(backendConn, conn) // write tcp stream to backend. - backendConn.Close() }() - io.Copy(conn, backendConn) // read tcp stream from backend. - conn.Close() } else { for _, h := range hopHeaders { res.Header.Del(h)