diff --git a/caddy/caddy.go b/caddy/caddy.go index 5d8ceddd8..da6975496 100644 --- a/caddy/caddy.go +++ b/caddy/caddy.go @@ -26,6 +26,7 @@ import ( "path" "strings" "sync" + "sync/atomic" "time" "github.com/mholt/caddy/caddy/https" @@ -317,6 +318,7 @@ func LoadCaddyfile(loader func() (Input, error)) (cdyfile Input, err error) { return nil, err } cdyfile = loadedGob.Caddyfile + atomic.StoreInt32(https.OnDemandIssuedCount, loadedGob.OnDemandTLSCertsIssued) } // Try user's loader diff --git a/caddy/helpers.go b/caddy/helpers.go index 0165573ac..0a2299dfc 100644 --- a/caddy/helpers.go +++ b/caddy/helpers.go @@ -63,10 +63,12 @@ var signalParentOnce sync.Once // caddyfileGob maps bind address to index of the file descriptor // in the Files array passed to the child process. It also contains -// the caddyfile contents. Used only during graceful restarts. +// the caddyfile contents and other state needed by the new process. +// Used only during graceful restarts where a new process is spawned. type caddyfileGob struct { - ListenerFds map[string]uintptr - Caddyfile Input + ListenerFds map[string]uintptr + Caddyfile Input + OnDemandTLSCertsIssued int32 } // IsRestart returns whether this process is, according diff --git a/caddy/https/handler.go b/caddy/https/handler.go index 446539296..f3139f54e 100644 --- a/caddy/https/handler.go +++ b/caddy/https/handler.go @@ -3,7 +3,6 @@ package https import ( "crypto/tls" "log" - "net" "net/http" "net/http/httputil" "net/url" @@ -23,21 +22,16 @@ func RequestCallback(w http.ResponseWriter, r *http.Request) bool { scheme = "https" } - hostname, _, err := net.SplitHostPort(r.Host) - if err != nil { - hostname = r.Host - } - - upstream, err := url.Parse(scheme + "://" + hostname + ":" + AlternatePort) + upstream, err := url.Parse(scheme + "://localhost:" + AlternatePort) if err != nil { w.WriteHeader(http.StatusInternalServerError) - log.Printf("[ERROR] letsencrypt handler: %v", err) + log.Printf("[ERROR] ACME proxy handler: %v", err) return true } proxy := httputil.NewSingleHostReverseProxy(upstream) proxy.Transport = &http.Transport{ - TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, // client would use self-signed cert + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, // solver uses self-signed certs } proxy.ServeHTTP(w, r) diff --git a/caddy/https/handshake.go b/caddy/https/handshake.go index e06e7d0da..e535cf6f4 100644 --- a/caddy/https/handshake.go +++ b/caddy/https/handshake.go @@ -7,7 +7,9 @@ import ( "errors" "fmt" "log" + "strings" "sync" + "sync/atomic" "time" "github.com/mholt/caddy/server" @@ -15,11 +17,12 @@ import ( ) // GetCertificate gets a certificate to satisfy clientHello as long as -// the certificate is already cached in memory. +// the certificate is already cached in memory. It will not be loaded +// from disk or obtained from the CA during the handshake. // // This function is safe for use as a tls.Config.GetCertificate callback. func GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) { - cert, err := getCertDuringHandshake(clientHello.ServerName, false) + cert, err := getCertDuringHandshake(clientHello.ServerName, false, false) return cert.Certificate, err } @@ -31,45 +34,60 @@ func GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) // // This function is safe for use as a tls.Config.GetCertificate callback. func GetOrObtainCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) { - cert, err := getCertDuringHandshake(clientHello.ServerName, true) + cert, err := getCertDuringHandshake(clientHello.ServerName, true, true) return cert.Certificate, err } // getCertDuringHandshake will get a certificate for name. It first tries -// the in-memory cache, then, if obtainIfNecessary is true, it goes to disk, -// then asks the CA for a certificate if necessary. +// the in-memory cache. If no certificate for name is in the cach and if +// loadIfNecessary == true, it goes to disk to load it into the cache and +// serve it. If it's not on disk and if obtainIfNecessary == true, the +// certificate will be obtained from the CA, cached, and served. If +// obtainIfNecessary is true, then loadIfNecessary must also be set to true. // // This function is safe for concurrent use. -func getCertDuringHandshake(name string, obtainIfNecessary bool) (Certificate, error) { +func getCertDuringHandshake(name string, loadIfNecessary, obtainIfNecessary bool) (Certificate, error) { // First check our in-memory cache to see if we've already loaded it cert, ok := getCertificate(name) if ok { return cert, nil } - if obtainIfNecessary { - // TODO: Mitigate abuse! + if loadIfNecessary { var err error // Then check to see if we have one on disk - cert, err := cacheManagedCertificate(name, true) - if err != nil { - return cert, err - } else if cert.Certificate != nil { - cert, err := handshakeMaintenance(name, cert) + cert, err = cacheManagedCertificate(name, true) + if err == nil { + cert, err = handshakeMaintenance(name, cert) if err != nil { log.Printf("[ERROR] Maintaining newly-loaded certificate for %s: %v", name, err) } - return cert, err + return cert, nil } - // Only option left is to get one from LE, but the name has to qualify first - if !HostQualifies(name) { - return cert, errors.New("hostname '" + name + "' does not qualify for certificate") - } + if obtainIfNecessary { + name = strings.ToLower(name) - // By this point, we need to obtain one from the CA. - return obtainOnDemandCertificate(name) + // Make sure aren't over any applicable limits + if onDemandMaxIssue > 0 && atomic.LoadInt32(OnDemandIssuedCount) >= onDemandMaxIssue { + return Certificate{}, fmt.Errorf("%s: maximum certificates issued (%d)", name, onDemandMaxIssue) + } + failedIssuanceMu.RLock() + when, ok := failedIssuance[name] + failedIssuanceMu.RUnlock() + if ok { + return Certificate{}, fmt.Errorf("%s: throttled; refusing to issue cert since last attempt on %s failed", name, when.String()) + } + + // Only option left is to get one from LE, but the name has to qualify first + if !HostQualifies(name) { + return cert, errors.New("hostname '" + name + "' does not qualify for certificate") + } + + // By this point, we need to obtain one from the CA. + return obtainOnDemandCertificate(name) + } } return Certificate{}, nil @@ -89,7 +107,7 @@ func obtainOnDemandCertificate(name string) (Certificate, error) { // wait for it to finish obtaining the cert and then we'll use it. obtainCertWaitChansMu.Unlock() <-wait - return getCertDuringHandshake(name, false) // passing in true might result in infinite loop if obtain failed + return getCertDuringHandshake(name, true, false) } // looks like it's up to us to do all the work and obtain the cert @@ -115,11 +133,24 @@ func obtainOnDemandCertificate(name string) (Certificate, error) { client.Configure("") // TODO: which BindHost? err = client.Obtain([]string{name}) if err != nil { + // Failed to solve challenge, so don't allow another on-demand + // issue for this name to be attempted for a little while. + failedIssuanceMu.Lock() + failedIssuance[name] = time.Now() + go func(name string) { + time.Sleep(5 * time.Minute) + failedIssuanceMu.Lock() + delete(failedIssuance, name) + failedIssuanceMu.Unlock() + }(name) + failedIssuanceMu.Unlock() return Certificate{}, err } + atomic.AddInt32(OnDemandIssuedCount, 1) + // The certificate is on disk; now just start over to load it and serve it - return getCertDuringHandshake(name, false) // pass in false as a fail-safe from infinite-looping + return getCertDuringHandshake(name, true, false) } // handshakeMaintenance performs a check on cert for expiration and OCSP @@ -127,12 +158,6 @@ func obtainOnDemandCertificate(name string) (Certificate, error) { // // This function is safe for use by multiple concurrent goroutines. func handshakeMaintenance(name string, cert Certificate) (Certificate, error) { - // fmt.Println("ON-DEMAND CERT?", cert.OnDemand) - // if !cert.OnDemand { - // return cert, nil - // } - fmt.Println("Checking expiration of cert; on-demand:", cert.OnDemand) - // Check cert expiration timeLeft := cert.NotAfter.Sub(time.Now().UTC()) if timeLeft < renewDurationBefore { @@ -173,7 +198,7 @@ func renewDynamicCertificate(name string) (Certificate, error) { // wait for it to finish, then we'll use the new one. obtainCertWaitChansMu.Unlock() <-wait - return getCertDuringHandshake(name, false) + return getCertDuringHandshake(name, true, false) } // looks like it's up to us to do all the work and renew the cert @@ -201,7 +226,7 @@ func renewDynamicCertificate(name string) (Certificate, error) { return Certificate{}, err } - return getCertDuringHandshake(name, false) + return getCertDuringHandshake(name, true, false) } // stapleOCSP staples OCSP information to cert for hostname name. @@ -235,3 +260,20 @@ func stapleOCSP(cert *Certificate, pemBundle []byte) error { // obtainCertWaitChans is used to coordinate obtaining certs for each hostname. var obtainCertWaitChans = make(map[string]chan struct{}) var obtainCertWaitChansMu sync.Mutex + +// OnDemandIssuedCount is the number of certificates that have been issued +// on-demand by this process. It is only safe to modify this count atomically. +// If it reaches max_certs, on-demand issuances will fail. +var OnDemandIssuedCount = new(int32) + +// onDemandMaxIssue is set based on max_certs in tls config. It specifies the +// maximum number of certificates that can be issued. +// TODO: This applies globally, but we should probably make a server-specific +// way to keep track of these limits and counts... +var onDemandMaxIssue int32 + +// failedIssuance is a set of names that we recently failed to get a +// certificate for from the ACME CA. They are removed after some time. +// When a name is in this map, do not issue a certificate for it. +var failedIssuance = make(map[string]time.Time) +var failedIssuanceMu sync.RWMutex diff --git a/caddy/https/setup.go b/caddy/https/setup.go index 592dfee59..2500bce07 100644 --- a/caddy/https/setup.go +++ b/caddy/https/setup.go @@ -8,6 +8,7 @@ import ( "log" "os" "path/filepath" + "strconv" "strings" "github.com/mholt/caddy/caddy/setup" @@ -27,7 +28,7 @@ func Setup(c *setup.Controller) (middleware.Middleware, error) { } for c.Next() { - var certificateFile, keyFile, loadDir string + var certificateFile, keyFile, loadDir, maxCerts string args := c.RemainingArgs() switch len(args) { @@ -80,6 +81,8 @@ func Setup(c *setup.Controller) (middleware.Middleware, error) { case "load": c.Args(&loadDir) c.TLS.Manual = true + case "max_certs": + c.Args(&maxCerts) default: return nil, c.Errf("Unknown keyword '%s'", c.Val()) } @@ -90,6 +93,20 @@ func Setup(c *setup.Controller) (middleware.Middleware, error) { return nil, c.ArgErr() } + if c.TLS.Manual && maxCerts != "" { + return nil, c.Err("Cannot limit certificate count (max_certs) for manual TLS configurations") + } + + if maxCerts != "" { + maxCertsNum, err := strconv.Atoi(maxCerts) + if err != nil || maxCertsNum < 0 { + return nil, c.Err("max_certs must be a positive integer") + } + if onDemandMaxIssue == 0 || int32(maxCertsNum) < onDemandMaxIssue { // keep the minimum; TODO: This is global; should be per-server or per-vhost... + onDemandMaxIssue = int32(maxCertsNum) + } + } + // don't load certificates unless we're supposed to if !c.TLS.Enabled || !c.TLS.Manual { continue diff --git a/caddy/restart.go b/caddy/restart.go index c8dc8c7e2..255f9cd7c 100644 --- a/caddy/restart.go +++ b/caddy/restart.go @@ -11,6 +11,7 @@ import ( "os" "os/exec" "path" + "sync/atomic" "github.com/mholt/caddy/caddy/https" ) @@ -55,8 +56,9 @@ func Restart(newCaddyfile Input) error { // Prepare our payload to the child process cdyfileGob := caddyfileGob{ - ListenerFds: make(map[string]uintptr), - Caddyfile: newCaddyfile, + ListenerFds: make(map[string]uintptr), + Caddyfile: newCaddyfile, + OnDemandTLSCertsIssued: atomic.LoadInt32(https.OnDemandIssuedCount), } // Prepare a pipe to the fork's stdin so it can get the Caddyfile diff --git a/caddy/setup/basicauth_test.go b/caddy/setup/basicauth_test.go index a94d6e695..186a3e97e 100644 --- a/caddy/setup/basicauth_test.go +++ b/caddy/setup/basicauth_test.go @@ -118,7 +118,7 @@ md5:$apr1$l42y8rex$pOA2VJ0x/0TwaFeAF9nX61` } if !actualRule.Password(pwd) || actualRule.Password(test.password+"!") { t.Errorf("Test %d, rule %d: Expected password '%v', got '%v'", - i, j, test.password, actualRule.Password) + i, j, test.password, actualRule.Password("")) } expectedRes := fmt.Sprintf("%v", expectedRule.Resources)