From 87f63b125bd286bd5882568368fae1131db592d5 Mon Sep 17 00:00:00 2001
From: WeidiDeng <weidi_deng@icloud.com>
Date: Mon, 20 Nov 2023 20:31:36 +0800
Subject: [PATCH] httpredirectlistener: Only set read limit for when request is
 HTTP (#5917)

---
 caddytest/integration/listener_test.go    |  94 +++++++++++++++++++
 modules/caddyhttp/httpredirectlistener.go | 105 ++++++++++++----------
 2 files changed, 150 insertions(+), 49 deletions(-)
 create mode 100644 caddytest/integration/listener_test.go

diff --git a/caddytest/integration/listener_test.go b/caddytest/integration/listener_test.go
new file mode 100644
index 000000000..30642b1ae
--- /dev/null
+++ b/caddytest/integration/listener_test.go
@@ -0,0 +1,94 @@
+package integration
+
+import (
+	"bytes"
+	"fmt"
+	"math/rand"
+	"net"
+	"net/http"
+	"strings"
+	"testing"
+
+	"github.com/caddyserver/caddy/v2/caddytest"
+)
+
+func setupListenerWrapperTest(t *testing.T, handlerFunc http.HandlerFunc) *caddytest.Tester {
+	l, err := net.Listen("tcp", "127.0.0.1:0")
+	if err != nil {
+		t.Fatalf("failed to listen: %s", err)
+	}
+
+	mux := http.NewServeMux()
+	mux.Handle("/", handlerFunc)
+	srv := &http.Server{
+		Handler: mux,
+	}
+	go srv.Serve(l)
+	t.Cleanup(func() {
+		_ = srv.Close()
+		_ = l.Close()
+	})
+	tester := caddytest.NewTester(t)
+	tester.InitServer(fmt.Sprintf(`
+	{
+		skip_install_trust
+		admin localhost:2999
+		http_port     9080
+		https_port    9443
+		local_certs
+		servers :9443 {
+			listener_wrappers {
+				http_redirect
+				tls
+			}
+		}
+	}
+	localhost {
+		reverse_proxy %s
+	}
+  `, l.Addr().String()), "caddyfile")
+	return tester
+}
+
+func TestHTTPRedirectWrapperWithLargeUpload(t *testing.T) {
+	const uploadSize = (1024 * 1024) + 1 // 1 MB + 1 byte
+	// 1 more than an MB
+	body := make([]byte, uploadSize)
+	rand.New(rand.NewSource(0)).Read(body)
+
+	tester := setupListenerWrapperTest(t, func(writer http.ResponseWriter, request *http.Request) {
+		buf := new(bytes.Buffer)
+		_, err := buf.ReadFrom(request.Body)
+		if err != nil {
+			t.Fatalf("failed to read body: %s", err)
+		}
+
+		if !bytes.Equal(buf.Bytes(), body) {
+			t.Fatalf("body not the same")
+		}
+
+		writer.WriteHeader(http.StatusNoContent)
+	})
+	resp, err := tester.Client.Post("https://localhost:9443", "application/octet-stream", bytes.NewReader(body))
+	if err != nil {
+		t.Fatalf("failed to post: %s", err)
+	}
+
+	if resp.StatusCode != http.StatusNoContent {
+		t.Fatalf("unexpected status: %d != %d", resp.StatusCode, http.StatusNoContent)
+	}
+}
+
+func TestLargeHttpRequest(t *testing.T) {
+	tester := setupListenerWrapperTest(t, func(writer http.ResponseWriter, request *http.Request) {
+		t.Fatal("not supposed to handle a request")
+	})
+
+	// We never read the body in any way, set an extra long header instead.
+	req, _ := http.NewRequest("POST", "http://localhost:9443", nil)
+	req.Header.Set("Long-Header", strings.Repeat("X", 1024*1024))
+	_, err := tester.Client.Do(req)
+	if err == nil {
+		t.Fatal("not supposed to succeed")
+	}
+}
diff --git a/modules/caddyhttp/httpredirectlistener.go b/modules/caddyhttp/httpredirectlistener.go
index 082dc7ce8..ce9ac0308 100644
--- a/modules/caddyhttp/httpredirectlistener.go
+++ b/modules/caddyhttp/httpredirectlistener.go
@@ -16,11 +16,11 @@ package caddyhttp
 
 import (
 	"bufio"
+	"bytes"
 	"fmt"
 	"io"
 	"net"
 	"net/http"
-	"sync"
 
 	"github.com/caddyserver/caddy/v2"
 	"github.com/caddyserver/caddy/v2/caddyconfig/caddyfile"
@@ -86,15 +86,17 @@ func (l *httpRedirectListener) Accept() (net.Conn, error) {
 	}
 
 	return &httpRedirectConn{
-		Conn: c,
-		r:    bufio.NewReader(io.LimitReader(c, maxHeaderBytes)),
+		Conn:  c,
+		limit: maxHeaderBytes,
+		r:     bufio.NewReader(c),
 	}, nil
 }
 
 type httpRedirectConn struct {
 	net.Conn
-	once sync.Once
-	r    *bufio.Reader
+	once  bool
+	limit int64
+	r     *bufio.Reader
 }
 
 // Read tries to peek at the first few bytes of the request, and if we get
@@ -102,53 +104,58 @@ type httpRedirectConn struct {
 // like an HTTP request, then we perform a HTTP->HTTPS redirect on the same
 // port as the original connection.
 func (c *httpRedirectConn) Read(p []byte) (int, error) {
-	var errReturn error
-	c.once.Do(func() {
-		firstBytes, err := c.r.Peek(5)
-		if err != nil {
-			return
-		}
+	if c.once {
+		return c.r.Read(p)
+	}
+	// no need to use sync.Once - net.Conn is not read from concurrently.
+	c.once = true
 
-		// If the request doesn't look like HTTP, then it's probably
-		// TLS bytes and we don't need to do anything.
-		if !firstBytesLookLikeHTTP(firstBytes) {
-			return
-		}
-
-		// Parse the HTTP request, so we can get the Host and URL to redirect to.
-		req, err := http.ReadRequest(c.r)
-		if err != nil {
-			return
-		}
-
-		// Build the redirect response, using the same Host and URL,
-		// but replacing the scheme with https.
-		headers := make(http.Header)
-		headers.Add("Location", "https://"+req.Host+req.URL.String())
-		resp := &http.Response{
-			Proto:      "HTTP/1.0",
-			Status:     "308 Permanent Redirect",
-			StatusCode: 308,
-			ProtoMajor: 1,
-			ProtoMinor: 0,
-			Header:     headers,
-		}
-
-		err = resp.Write(c.Conn)
-		if err != nil {
-			errReturn = fmt.Errorf("couldn't write HTTP->HTTPS redirect")
-			return
-		}
-
-		errReturn = fmt.Errorf("redirected HTTP request on HTTPS port")
-		c.Conn.Close()
-	})
-
-	if errReturn != nil {
-		return 0, errReturn
+	firstBytes, err := c.r.Peek(5)
+	if err != nil {
+		return 0, err
 	}
 
-	return c.r.Read(p)
+	// If the request doesn't look like HTTP, then it's probably
+	// TLS bytes, and we don't need to do anything.
+	if !firstBytesLookLikeHTTP(firstBytes) {
+		return c.r.Read(p)
+	}
+
+	// From now on, we can be almost certain the request is HTTP.
+	// The returned error will be non nil and caller are expected to
+	// close the connection.
+
+	// Set the read limit, io.MultiReader is needed because
+	// when resetting, *bufio.Reader discards buffered data.
+	buffered, _ := c.r.Peek(c.r.Buffered())
+	mr := io.MultiReader(bytes.NewReader(buffered), c.Conn)
+	c.r.Reset(io.LimitReader(mr, c.limit))
+
+	// Parse the HTTP request, so we can get the Host and URL to redirect to.
+	req, err := http.ReadRequest(c.r)
+	if err != nil {
+		return 0, fmt.Errorf("couldn't read HTTP request")
+	}
+
+	// Build the redirect response, using the same Host and URL,
+	// but replacing the scheme with https.
+	headers := make(http.Header)
+	headers.Add("Location", "https://"+req.Host+req.URL.String())
+	resp := &http.Response{
+		Proto:      "HTTP/1.0",
+		Status:     "308 Permanent Redirect",
+		StatusCode: 308,
+		ProtoMajor: 1,
+		ProtoMinor: 0,
+		Header:     headers,
+	}
+
+	err = resp.Write(c.Conn)
+	if err != nil {
+		return 0, fmt.Errorf("couldn't write HTTP->HTTPS redirect")
+	}
+
+	return 0, fmt.Errorf("redirected HTTP request on HTTPS port")
 }
 
 // firstBytesLookLikeHTTP reports whether a TLS record header