From feec7c5b407713df76a1d0320b92e193398d666c Mon Sep 17 00:00:00 2001 From: Matthew Holt Date: Wed, 15 Apr 2015 14:11:32 -0600 Subject: [PATCH] Virtual hosts and SNI support --- config/config.go | 8 +- main.go | 58 ++++++++++- server/server.go | 218 ++++++++++++++++++++++++------------------ server/virtualhost.go | 42 ++++++++ 4 files changed, 226 insertions(+), 100 deletions(-) create mode 100644 server/virtualhost.go diff --git a/config/config.go b/config/config.go index b36871c52..0222b000f 100644 --- a/config/config.go +++ b/config/config.go @@ -3,6 +3,7 @@ package config import ( + "net" "os" "github.com/mholt/caddy/middleware" @@ -12,13 +13,16 @@ const ( defaultHost = "localhost" defaultPort = "8080" defaultRoot = "." + + // The default configuration file to load if none is specified + DefaultConfigFile = "Caddyfile" ) // config represents a server configuration. It // is populated by parsing a config file (via the // Load function). type Config struct { - // The hostname or IP to which to bind the server + // The hostname or IP on which to serve Host string // The port to listen on @@ -51,7 +55,7 @@ type Config struct { // Address returns the host:port of c as a string. func (c Config) Address() string { - return c.Host + ":" + c.Port + return net.JoinHostPort(c.Host, c.Port) } // TLSConfig describes how TLS should be configured and used, diff --git a/main.go b/main.go index e5a953c21..8ecf316e8 100644 --- a/main.go +++ b/main.go @@ -2,7 +2,9 @@ package main import ( "flag" + "fmt" "log" + "net" "sync" "github.com/mholt/caddy/config" @@ -15,7 +17,7 @@ var ( ) func init() { - flag.StringVar(&conf, "conf", server.DefaultConfigFile, "the configuration file to use") + flag.StringVar(&conf, "conf", config.DefaultConfigFile, "the configuration file to use") flag.BoolVar(&http2, "http2", true, "enable HTTP/2 support") // temporary flag until http2 merged into std lib } @@ -24,17 +26,25 @@ func main() { flag.Parse() - vhosts, err := config.Load(conf) + // Load config from file + allConfigs, err := config.Load(conf) if err != nil { if config.IsNotFound(err) { - vhosts = config.Default() + allConfigs = config.Default() } else { log.Fatal(err) } } - for _, conf := range vhosts { - s, err := server.New(conf) + // Group by address (virtual hosting) + addresses, err := arrangeBindings(allConfigs) + if err != nil { + log.Fatal(err) + } + + // Start each server with its one or more configurations + for addr, configs := range addresses { + s, err := server.New(addr, configs, configs[0].TLS.Enabled) if err != nil { log.Fatal(err) } @@ -51,3 +61,41 @@ func main() { wg.Wait() } + +// arrangeBindings groups configurations by their bind address. For example, +// a server that should listen on localhost and another on 127.0.0.1 will +// be grouped into the same address: 127.0.0.1. It will return an error +// if the address lookup fails or if a TLS listener is configured on the +// same address as a plaintext HTTP listener. +func arrangeBindings(allConfigs []config.Config) (map[string][]config.Config, error) { + addresses := make(map[string][]config.Config) + + // Group configs by bind address + for _, conf := range allConfigs { + addr, err := net.ResolveTCPAddr("tcp", conf.Address()) + if err != nil { + return addresses, err + } + addresses[addr.String()] = append(addresses[addr.String()], conf) + } + + // Don't allow HTTP and HTTPS to be served on the same address + for _, configs := range addresses { + isTLS := configs[0].TLS.Enabled + for _, config := range configs { + if config.TLS.Enabled != isTLS { + thisConfigProto, otherConfigProto := "HTTP", "HTTP" + if config.TLS.Enabled { + thisConfigProto = "HTTPS" + } + if configs[0].TLS.Enabled { + otherConfigProto = "HTTPS" + } + return addresses, fmt.Errorf("Configuration error: Cannot multiplex %s (%s) and %s (%s) on same address", + configs[0].Address(), otherConfigProto, config.Address(), thisConfigProto) + } + } + } + + return addresses, nil +} diff --git a/server/server.go b/server/server.go index d44bc56d8..cde3d1821 100644 --- a/server/server.go +++ b/server/server.go @@ -4,9 +4,10 @@ package server import ( - "errors" + "crypto/tls" "fmt" "log" + "net" "net/http" "os" "os/signal" @@ -14,73 +15,55 @@ import ( "github.com/bradfitz/http2" "github.com/mholt/caddy/config" - "github.com/mholt/caddy/middleware" ) -// The default configuration file to load if none is specified -const DefaultConfigFile = "Caddyfile" - -// servers maintains a registry of running servers, keyed by address. -var servers = make(map[string]*Server) - // Server represents an instance of a server, which serves // static content at a particular address (host and port). type Server struct { - HTTP2 bool // temporary while http2 is not in std lib (TODO: remove flag when part of std lib) - config config.Config - fileServer middleware.Handler - stack middleware.Handler + HTTP2 bool // temporary while http2 is not in std lib (TODO: remove flag when part of std lib) + address string + tls bool + vhosts map[string]virtualHost } -// New creates a new Server and registers it with the list -// of servers created. Each server must have a unique host:port -// combination. This function does not start serving. -func New(conf config.Config) (*Server, error) { - addr := conf.Address() - - // Unique address check - if _, exists := servers[addr]; exists { - return nil, errors.New("Address " + addr + " is already in use") +// New creates a new Server which will bind to addr and serve +// the sites/hosts configured in configs. This function does +// not start serving. +func New(addr string, configs []config.Config, tls bool) (*Server, error) { + s := &Server{ + address: addr, + tls: tls, + vhosts: make(map[string]virtualHost), } - // Use all CPUs (if needed) by default - if conf.MaxCPU == 0 { - conf.MaxCPU = runtime.NumCPU() + for _, conf := range configs { + if _, exists := s.vhosts[conf.Host]; exists { + return nil, fmt.Errorf("Cannot serve %s - host already defined for address %s", conf.Address(), s.address) + } + + // Use all CPUs (if needed) by default + if conf.MaxCPU == 0 { + conf.MaxCPU = runtime.NumCPU() + } + + vh := virtualHost{config: conf} + + // Build middleware stack + err := vh.buildStack() + if err != nil { + return nil, err + } + + s.vhosts[conf.Host] = vh } - // Initialize - s := new(Server) - s.config = conf - - // Register the server - servers[addr] = s - return s, nil } // Serve starts the server. It blocks until the server quits. func (s *Server) Serve() error { - // Execute startup functions - for _, start := range s.config.Startup { - err := start() - if err != nil { - return err - } - } - - // Build middleware stack - err := s.buildStack() - if err != nil { - return err - } - - // Use highest procs value across all configurations - if s.config.MaxCPU > 0 && s.config.MaxCPU > runtime.GOMAXPROCS(0) { - runtime.GOMAXPROCS(s.config.MaxCPU) - } - server := &http.Server{ - Addr: s.config.Address(), + Addr: s.address, Handler: s, } @@ -89,28 +72,91 @@ func (s *Server) Serve() error { http2.ConfigureServer(server, nil) } - // Execute shutdown commands on exit - go func() { - interrupt := make(chan os.Signal, 1) - signal.Notify(interrupt, os.Interrupt, os.Kill) // TODO: syscall.SIGQUIT? (Ctrl+\, Unix-only) - <-interrupt - for _, shutdownFunc := range s.config.Shutdown { - err := shutdownFunc() + for _, vh := range s.vhosts { + // Execute startup functions + for _, start := range vh.config.Startup { + err := start() if err != nil { - log.Fatal(err) + return err } } - os.Exit(0) - }() - if s.config.TLS.Enabled { - return server.ListenAndServeTLS(s.config.TLS.Certificate, s.config.TLS.Key) + // Use highest procs value across all configurations + if vh.config.MaxCPU > 0 && vh.config.MaxCPU > runtime.GOMAXPROCS(0) { + runtime.GOMAXPROCS(vh.config.MaxCPU) + } + + if len(vh.config.Shutdown) > 0 { + // Execute shutdown commands on exit + go func() { + interrupt := make(chan os.Signal, 1) + signal.Notify(interrupt, os.Interrupt, os.Kill) // TODO: syscall.SIGQUIT? (Ctrl+\, Unix-only) + <-interrupt + for _, shutdownFunc := range vh.config.Shutdown { + err := shutdownFunc() + if err != nil { + log.Fatal(err) + } + } + os.Exit(0) + }() + } + } + + if s.tls { + var tlsConfigs []config.TLSConfig + for _, vh := range s.vhosts { + tlsConfigs = append(tlsConfigs, vh.config.TLS) + } + return ListenAndServeTLSWithSNI(server, tlsConfigs) } else { return server.ListenAndServe() } } -// ServeHTTP is the entry point for every request to s. +// ListenAndServeTLSWithSNI serves TLS with Server Name Indication (SNI) support, which allows +// multiple sites (different hostnames) to be served from the same address. This method is +// adapted directly from the std lib's net/http ListenAndServeTLS function, which was +// written by the Go Authors. It has been modified to support multiple certificate/key pairs. +func ListenAndServeTLSWithSNI(srv *http.Server, tlsConfigs []config.TLSConfig) error { + addr := srv.Addr + if addr == "" { + addr = ":https" + } + + config := new(tls.Config) + if srv.TLSConfig != nil { + *config = *srv.TLSConfig + } + if config.NextProtos == nil { + config.NextProtos = []string{"http/1.1"} + } + + // Here we diverge from the stdlib a bit by loading multiple certs/key pairs + // then we map the server names to their certs + var err error + config.Certificates = make([]tls.Certificate, len(tlsConfigs)) + for i, tlsConfig := range tlsConfigs { + config.Certificates[i], err = tls.LoadX509KeyPair(tlsConfig.Certificate, tlsConfig.Key) + if err != nil { + return err + } + } + config.BuildNameToCertificate() + + conn, err := net.Listen("tcp", addr) + if err != nil { + return err + } + + tlsListener := tls.NewListener(conn, config) + return srv.Serve(tlsListener) +} + +// ServeHTTP is the entry point for every request to the address that s +// is bound to. It acts as a multiplexer for the requests hostname as +// defined in the Host header so that the correct virtualhost +// (configuration and middleware stack) will handle the request. func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { defer func() { // In case the user doesn't enable error middleware, we still @@ -121,35 +167,21 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { } }() - status, _ := s.stack.ServeHTTP(w, r) + host, _, err := net.SplitHostPort(r.Host) + if err != nil { + host = r.Host // oh well + } - // Fallback error response in case error handling wasn't chained in - if status >= 400 { - w.WriteHeader(status) - fmt.Fprintf(w, "%d %s", status, http.StatusText(status)) - } -} - -// buildStack builds the server's middleware stack based -// on its config. This method should be called last before -// ListenAndServe begins. -func (s *Server) buildStack() error { - s.fileServer = FileServer(http.Dir(s.config.Root), []string{s.config.ConfigFile}) - - // TODO: We only compile middleware for the "/" scope. - // Partial support for multiple location contexts already - // exists at the parser and config levels, but until full - // support is implemented, this is all we do right here. - s.compile(s.config.Middleware["/"]) - - return nil -} - -// compile is an elegant alternative to nesting middleware function -// calls like handler1(handler2(handler3(finalHandler))). -func (s *Server) compile(layers []middleware.Middleware) { - s.stack = s.fileServer // core app layer - for i := len(layers) - 1; i >= 0; i-- { - s.stack = layers[i](s.stack) + if vh, ok := s.vhosts[host]; ok { + status, _ := vh.stack.ServeHTTP(w, r) + + // Fallback error response in case error handling wasn't chained in + if status >= 400 { + w.WriteHeader(status) + fmt.Fprintf(w, "%d %s", status, http.StatusText(status)) + } + } else { + w.WriteHeader(http.StatusNotFound) + fmt.Fprintf(w, "No such host at %s", s.address) } } diff --git a/server/virtualhost.go b/server/virtualhost.go new file mode 100644 index 000000000..57c5651cb --- /dev/null +++ b/server/virtualhost.go @@ -0,0 +1,42 @@ +package server + +import ( + "net/http" + + "github.com/mholt/caddy/config" + "github.com/mholt/caddy/middleware" +) + +// virtualHost represents a virtual host/server. While a Server +// is what actually binds to the address, a user may want to serve +// multiple sites on a single address, and what is what a +// virtualHost allows us to do. +type virtualHost struct { + config config.Config + fileServer middleware.Handler + stack middleware.Handler +} + +// buildStack builds the server's middleware stack based +// on its config. This method should be called last before +// ListenAndServe begins. +func (vh *virtualHost) buildStack() error { + vh.fileServer = FileServer(http.Dir(vh.config.Root), []string{vh.config.ConfigFile}) + + // TODO: We only compile middleware for the "/" scope. + // Partial support for multiple location contexts already + // exists at the parser and config levels, but until full + // support is implemented, this is all we do right here. + vh.compile(vh.config.Middleware["/"]) + + return nil +} + +// compile is an elegant alternative to nesting middleware function +// calls like handler1(handler2(handler3(finalHandler))). +func (vh *virtualHost) compile(layers []middleware.Middleware) { + vh.stack = vh.fileServer // core app layer + for i := len(layers) - 1; i >= 0; i-- { + vh.stack = layers[i](vh.stack) + } +}