diff --git a/caddyhttp/gzip/gzip.go b/caddyhttp/gzip/gzip.go index ed95156a3..ddaac4fb3 100644 --- a/caddyhttp/gzip/gzip.go +++ b/caddyhttp/gzip/gzip.go @@ -5,7 +5,6 @@ package gzip import ( "bufio" "compress/gzip" - "fmt" "io" "io/ioutil" "net" @@ -144,7 +143,7 @@ func (w *gzipResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { if hj, ok := w.ResponseWriter.(http.Hijacker); ok { return hj.Hijack() } - return nil, nil, fmt.Errorf("not a Hijacker") + return nil, nil, httpserver.NonHijackerError{Underlying: w.ResponseWriter} } // Flush implements http.Flusher. It simply wraps the underlying @@ -153,7 +152,7 @@ func (w *gzipResponseWriter) Flush() { if f, ok := w.ResponseWriter.(http.Flusher); ok { f.Flush() } else { - panic("not a Flusher") // should be recovered at the beginning of middleware stack + panic(httpserver.NonFlusherError{Underlying: w.ResponseWriter}) // should be recovered at the beginning of middleware stack } } @@ -163,5 +162,5 @@ func (w *gzipResponseWriter) CloseNotify() <-chan bool { if cn, ok := w.ResponseWriter.(http.CloseNotifier); ok { return cn.CloseNotify() } - panic("not a CloseNotifier") + panic(httpserver.NonCloseNotifierError{Underlying: w.ResponseWriter}) } diff --git a/caddyhttp/header/header.go b/caddyhttp/header/header.go index 121264d48..dff925d46 100644 --- a/caddyhttp/header/header.go +++ b/caddyhttp/header/header.go @@ -4,6 +4,8 @@ package header import ( + "bufio" + "net" "net/http" "strings" @@ -113,3 +115,12 @@ func (rww *responseWriterWrapper) setHeader(key, value string) { h.Set(key, value) }) } + +// Hijack implements http.Hijacker. It simply wraps the underlying +// ResponseWriter's Hijack method if there is one, or returns an error. +func (rww *responseWriterWrapper) Hijack() (net.Conn, *bufio.ReadWriter, error) { + if hj, ok := rww.w.(http.Hijacker); ok { + return hj.Hijack() + } + return nil, nil, httpserver.NonHijackerError{Underlying: rww.w} +} diff --git a/caddyhttp/httpserver/error.go b/caddyhttp/httpserver/error.go new file mode 100644 index 000000000..2cfa530e6 --- /dev/null +++ b/caddyhttp/httpserver/error.go @@ -0,0 +1,44 @@ +package httpserver + +import ( + "fmt" +) + +var ( + _ error = NonHijackerError{} + _ error = NonFlusherError{} + _ error = NonCloseNotifierError{} +) + +// NonHijackerError is more descriptive error caused by a non hijacker +type NonHijackerError struct { + // underlying type which doesn't implement Hijack + Underlying interface{} +} + +// Implement Error +func (h NonHijackerError) Error() string { + return fmt.Sprintf("%T is not a hijacker", h.Underlying) +} + +// NonFlusherError is more descriptive error caused by a non flusher +type NonFlusherError struct { + // underlying type which doesn't implement Flush + Underlying interface{} +} + +// Implement Error +func (f NonFlusherError) Error() string { + return fmt.Sprintf("%T is not a flusher", f.Underlying) +} + +// NonCloseNotifierError is more descriptive error caused by a non closeNotifier +type NonCloseNotifierError struct { + // underlying type which doesn't implement CloseNotify + Underlying interface{} +} + +// Implement Error +func (c NonCloseNotifierError) Error() string { + return fmt.Sprintf("%T is not a closeNotifier", c.Underlying) +} diff --git a/caddyhttp/httpserver/recorder.go b/caddyhttp/httpserver/recorder.go index 5788ab44b..c893f51fa 100644 --- a/caddyhttp/httpserver/recorder.go +++ b/caddyhttp/httpserver/recorder.go @@ -2,7 +2,6 @@ package httpserver import ( "bufio" - "errors" "net" "net/http" "time" @@ -75,7 +74,7 @@ func (r *ResponseRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) { if hj, ok := r.ResponseWriter.(http.Hijacker); ok { return hj.Hijack() } - return nil, nil, errors.New("not a Hijacker") + return nil, nil, NonHijackerError{Underlying: r.ResponseWriter} } // Flush implements http.Flusher. It simply wraps the underlying @@ -84,7 +83,7 @@ func (r *ResponseRecorder) Flush() { if f, ok := r.ResponseWriter.(http.Flusher); ok { f.Flush() } else { - panic("not a Flusher") // should be recovered at the beginning of middleware stack + panic(NonFlusherError{Underlying: r.ResponseWriter}) // should be recovered at the beginning of middleware stack } } @@ -94,5 +93,5 @@ func (r *ResponseRecorder) CloseNotify() <-chan bool { if cn, ok := r.ResponseWriter.(http.CloseNotifier); ok { return cn.CloseNotify() } - panic("not a CloseNotifier") + panic(NonCloseNotifierError{Underlying: r.ResponseWriter}) } diff --git a/caddyhttp/proxy/proxy_test.go b/caddyhttp/proxy/proxy_test.go index ccdb068bb..81c457352 100644 --- a/caddyhttp/proxy/proxy_test.go +++ b/caddyhttp/proxy/proxy_test.go @@ -97,6 +97,39 @@ func TestReverseProxyInsecureSkipVerify(t *testing.T) { } } +func TestWebSocketReverseProxyNonHijackerPanic(t *testing.T) { + // Capture the expected panic + defer func() { + r := recover() + if _, ok := r.(httpserver.NonHijackerError); !ok { + t.Error("not get the expected panic") + } + }() + + var connCount int32 + wsNop := httptest.NewServer(websocket.Handler(func(ws *websocket.Conn) { atomic.AddInt32(&connCount, 1) })) + defer wsNop.Close() + + // Get proxy to use for the test + p := newWebSocketTestProxy(wsNop.URL) + + // Create client request + r, err := http.NewRequest("GET", "/", nil) + if err != nil { + t.Fatalf("Failed to create request: %v", err) + } + r.Header = http.Header{ + "Connection": {"Upgrade"}, + "Upgrade": {"websocket"}, + "Origin": {wsNop.URL}, + "Sec-WebSocket-Key": {"x3JJHMbDL1EzLkh9GBhXDw=="}, + "Sec-WebSocket-Version": {"13"}, + } + + nonHijacker := httptest.NewRecorder() + p.ServeHTTP(nonHijacker, r) +} + func TestWebSocketReverseProxyServeHTTPHandler(t *testing.T) { // No-op websocket backend simply allows the WS connection to be // accepted then it will be immediately closed. Perfect for testing. diff --git a/caddyhttp/proxy/reverseproxy.go b/caddyhttp/proxy/reverseproxy.go index c537f7c98..03b637866 100644 --- a/caddyhttp/proxy/reverseproxy.go +++ b/caddyhttp/proxy/reverseproxy.go @@ -21,6 +21,8 @@ import ( "strings" "sync" "time" + + "github.com/mholt/caddy/caddyhttp/httpserver" ) var bufferPool = sync.Pool{New: createBuffer} @@ -195,7 +197,7 @@ func (rp *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request, res.Body.Close() hj, ok := rw.(http.Hijacker) if !ok { - return nil + panic(httpserver.NonHijackerError{Underlying: rw}) } conn, _, err := hj.Hijack()