mirror of
https://github.com/caddyserver/caddy.git
synced 2025-01-22 16:46:53 +01:00
caddyhttp: Serve http2 when listener wrapper doesn't return *tls.Conn (#4929)
* Serve http2 when listener wrapper doesn't return *tls.Conn * close conn when h2server serveConn returns * merge from upstream * rebase from latest * run New and Closed ConnState hook for h2 conns * go fmt * fix lint * Add comments * reorder import
This commit is contained in:
parent
f8b59e77f8
commit
d8d87a378f
3 changed files with 153 additions and 5 deletions
|
@ -357,6 +357,14 @@ func (app *App) Start() error {
|
||||||
MaxHeaderBytes: srv.MaxHeaderBytes,
|
MaxHeaderBytes: srv.MaxHeaderBytes,
|
||||||
Handler: srv,
|
Handler: srv,
|
||||||
ErrorLog: serverLogger,
|
ErrorLog: serverLogger,
|
||||||
|
ConnContext: func(ctx context.Context, c net.Conn) context.Context {
|
||||||
|
return context.WithValue(ctx, ConnCtxKey, c)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
h2server := &http2.Server{
|
||||||
|
NewWriteScheduler: func() http2.WriteScheduler {
|
||||||
|
return http2.NewPriorityWriteScheduler(nil)
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
// disable HTTP/2, which we enabled by default during provisioning
|
// disable HTTP/2, which we enabled by default during provisioning
|
||||||
|
@ -378,6 +386,9 @@ func (app *App) Start() error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
//nolint:errcheck
|
||||||
|
http2.ConfigureServer(srv.server, h2server)
|
||||||
}
|
}
|
||||||
|
|
||||||
// this TLS config is used by the std lib to choose the actual TLS config for connections
|
// this TLS config is used by the std lib to choose the actual TLS config for connections
|
||||||
|
@ -387,9 +398,6 @@ func (app *App) Start() error {
|
||||||
|
|
||||||
// enable H2C if configured
|
// enable H2C if configured
|
||||||
if srv.protocol("h2c") {
|
if srv.protocol("h2c") {
|
||||||
h2server := &http2.Server{
|
|
||||||
IdleTimeout: time.Duration(srv.IdleTimeout),
|
|
||||||
}
|
|
||||||
srv.server.Handler = h2c.NewHandler(srv, h2server)
|
srv.server.Handler = h2c.NewHandler(srv, h2server)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -456,6 +464,17 @@ func (app *App) Start() error {
|
||||||
ln = srv.listenerWrappers[i].WrapListener(ln)
|
ln = srv.listenerWrappers[i].WrapListener(ln)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// handle http2 if use tls listener wrapper
|
||||||
|
if useTLS {
|
||||||
|
http2lnWrapper := &http2Listener{
|
||||||
|
Listener: ln,
|
||||||
|
server: srv.server,
|
||||||
|
h2server: h2server,
|
||||||
|
}
|
||||||
|
srv.h2listeners = append(srv.h2listeners, http2lnWrapper)
|
||||||
|
ln = http2lnWrapper
|
||||||
|
}
|
||||||
|
|
||||||
// if binding to port 0, the OS chooses a port for us;
|
// if binding to port 0, the OS chooses a port for us;
|
||||||
// but the user won't know the port unless we print it
|
// but the user won't know the port unless we print it
|
||||||
if !listenAddr.IsUnixNetwork() && listenAddr.StartPort == 0 && listenAddr.EndPort == 0 {
|
if !listenAddr.IsUnixNetwork() && listenAddr.StartPort == 0 && listenAddr.EndPort == 0 {
|
||||||
|
@ -585,12 +604,25 @@ func (app *App) Stop() error {
|
||||||
zap.Strings("addresses", server.Listen))
|
zap.Strings("addresses", server.Listen))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
stopH2Listener := func(server *Server) {
|
||||||
|
defer finishedShutdown.Done()
|
||||||
|
startedShutdown.Done()
|
||||||
|
|
||||||
|
for i, s := range server.h2listeners {
|
||||||
|
if err := s.Shutdown(ctx); err != nil {
|
||||||
|
app.logger.Error("http2 listener shutdown",
|
||||||
|
zap.Error(err),
|
||||||
|
zap.Int("index", i))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
for _, server := range app.Servers {
|
for _, server := range app.Servers {
|
||||||
startedShutdown.Add(2)
|
startedShutdown.Add(3)
|
||||||
finishedShutdown.Add(2)
|
finishedShutdown.Add(3)
|
||||||
go stopServer(server)
|
go stopServer(server)
|
||||||
go stopH3Server(server)
|
go stopH3Server(server)
|
||||||
|
go stopH2Listener(server)
|
||||||
}
|
}
|
||||||
|
|
||||||
// block until all the goroutines have been run by the scheduler;
|
// block until all the goroutines have been run by the scheduler;
|
||||||
|
|
102
modules/caddyhttp/http2listener.go
Normal file
102
modules/caddyhttp/http2listener.go
Normal file
|
@ -0,0 +1,102 @@
|
||||||
|
package caddyhttp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
weakrand "math/rand"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/net/http2"
|
||||||
|
)
|
||||||
|
|
||||||
|
// http2Listener wraps the listener to solve the following problems:
|
||||||
|
// 1. server h2 natively without using h2c hack when listener handles tls connection but
|
||||||
|
// don't return *tls.Conn
|
||||||
|
// 2. graceful shutdown. the shutdown logic is copied from stdlib http.Server, it's an extra maintenance burden but
|
||||||
|
// whatever, the shutdown logic maybe extracted to be used with h2c graceful shutdown. http2.Server supports graceful shutdown
|
||||||
|
// sending GO_AWAY frame to connected clients, but doesn't track connection status. It requires explicit call of http2.ConfigureServer
|
||||||
|
type http2Listener struct {
|
||||||
|
cnt uint64
|
||||||
|
net.Listener
|
||||||
|
server *http.Server
|
||||||
|
h2server *http2.Server
|
||||||
|
}
|
||||||
|
|
||||||
|
type connectionStateConn interface {
|
||||||
|
net.Conn
|
||||||
|
ConnectionState() tls.ConnectionState
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *http2Listener) Accept() (net.Conn, error) {
|
||||||
|
for {
|
||||||
|
conn, err := h.Listener.Accept()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if csc, ok := conn.(connectionStateConn); ok {
|
||||||
|
// *tls.Conn will return empty string because it's only populated after handshake is complete
|
||||||
|
if csc.ConnectionState().NegotiatedProtocol == http2.NextProtoTLS {
|
||||||
|
go h.serveHttp2(csc)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return conn, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *http2Listener) serveHttp2(csc connectionStateConn) {
|
||||||
|
atomic.AddUint64(&h.cnt, 1)
|
||||||
|
h.runHook(csc, http.StateNew)
|
||||||
|
defer func() {
|
||||||
|
csc.Close()
|
||||||
|
atomic.AddUint64(&h.cnt, ^uint64(0))
|
||||||
|
h.runHook(csc, http.StateClosed)
|
||||||
|
}()
|
||||||
|
h.h2server.ServeConn(csc, &http2.ServeConnOpts{
|
||||||
|
Context: h.server.ConnContext(context.Background(), csc),
|
||||||
|
BaseConfig: h.server,
|
||||||
|
Handler: h.server.Handler,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
const shutdownPollIntervalMax = 500 * time.Millisecond
|
||||||
|
|
||||||
|
func (h *http2Listener) Shutdown(ctx context.Context) error {
|
||||||
|
pollIntervalBase := time.Millisecond
|
||||||
|
nextPollInterval := func() time.Duration {
|
||||||
|
// Add 10% jitter.
|
||||||
|
//nolint:gosec
|
||||||
|
interval := pollIntervalBase + time.Duration(weakrand.Intn(int(pollIntervalBase/10)))
|
||||||
|
// Double and clamp for next time.
|
||||||
|
pollIntervalBase *= 2
|
||||||
|
if pollIntervalBase > shutdownPollIntervalMax {
|
||||||
|
pollIntervalBase = shutdownPollIntervalMax
|
||||||
|
}
|
||||||
|
return interval
|
||||||
|
}
|
||||||
|
|
||||||
|
timer := time.NewTimer(nextPollInterval())
|
||||||
|
defer timer.Stop()
|
||||||
|
for {
|
||||||
|
if atomic.LoadUint64(&h.cnt) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
case <-timer.C:
|
||||||
|
timer.Reset(nextPollInterval())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *http2Listener) runHook(conn net.Conn, state http.ConnState) {
|
||||||
|
if h.server.ConnState != nil {
|
||||||
|
h.server.ConnState(conn, state)
|
||||||
|
}
|
||||||
|
}
|
|
@ -198,6 +198,7 @@ type Server struct {
|
||||||
server *http.Server
|
server *http.Server
|
||||||
h3server *http3.Server
|
h3server *http3.Server
|
||||||
h3listeners []net.PacketConn // TODO: we have to hold these because quic-go won't close listeners it didn't create
|
h3listeners []net.PacketConn // TODO: we have to hold these because quic-go won't close listeners it didn't create
|
||||||
|
h2listeners []*http2Listener
|
||||||
addresses []caddy.NetworkAddress
|
addresses []caddy.NetworkAddress
|
||||||
|
|
||||||
trustedProxies IPRangeSource
|
trustedProxies IPRangeSource
|
||||||
|
@ -213,6 +214,16 @@ type Server struct {
|
||||||
|
|
||||||
// ServeHTTP is the entry point for all HTTP requests.
|
// ServeHTTP is the entry point for all HTTP requests.
|
||||||
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// If there are listener wrappers that process tls connections but don't return a *tls.Conn, this field will be nil.
|
||||||
|
// Can be removed if https://github.com/golang/go/pull/56110 is ever merged.
|
||||||
|
if r.TLS == nil {
|
||||||
|
conn := r.Context().Value(ConnCtxKey).(net.Conn)
|
||||||
|
if csc, ok := conn.(connectionStateConn); ok {
|
||||||
|
r.TLS = new(tls.ConnectionState)
|
||||||
|
*r.TLS = csc.ConnectionState()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
w.Header().Set("Server", "Caddy")
|
w.Header().Set("Server", "Caddy")
|
||||||
|
|
||||||
// advertise HTTP/3, if enabled
|
// advertise HTTP/3, if enabled
|
||||||
|
@ -870,6 +881,9 @@ const (
|
||||||
// originally came into the server's entry handler
|
// originally came into the server's entry handler
|
||||||
OriginalRequestCtxKey caddy.CtxKey = "original_request"
|
OriginalRequestCtxKey caddy.CtxKey = "original_request"
|
||||||
|
|
||||||
|
// For referencing underlying net.Conn
|
||||||
|
ConnCtxKey caddy.CtxKey = "conn"
|
||||||
|
|
||||||
// For tracking whether the client is a trusted proxy
|
// For tracking whether the client is a trusted proxy
|
||||||
TrustedProxyVarKey string = "trusted_proxy"
|
TrustedProxyVarKey string = "trusted_proxy"
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue