Bug fixes and other improvements to TLS functions

Now attempt to staple OCSP even for certs that don't have an existing staple (issue #605). "tls off" short-circuits tls setup function. Now we call getEmail() when setting up an acme.Client that does renewals, rather than making a new account with empty email address. Check certificate expiry every 12 hours, and OCSP every hour.
This commit is contained in:
Matthew Holt 2016-02-15 23:39:04 -07:00
parent 2dba44327a
commit 1cfd960f3c
7 changed files with 168 additions and 138 deletions

View file

@ -24,7 +24,7 @@ var certCacheMu sync.RWMutex
// we can be more efficient by extracting the metadata once so it's // we can be more efficient by extracting the metadata once so it's
// just there, ready to use. // just there, ready to use.
type Certificate struct { type Certificate struct {
*tls.Certificate tls.Certificate
// Names is the list of names this certificate is written for. // Names is the list of names this certificate is written for.
// The first is the CommonName (if any), the rest are SAN. // The first is the CommonName (if any), the rest are SAN.
@ -170,7 +170,6 @@ func makeCertificate(certPEMBlock, keyPEMBlock []byte) (Certificate, error) {
if len(tlsCert.Certificate) == 0 { if len(tlsCert.Certificate) == 0 {
return cert, errors.New("certificate is empty") return cert, errors.New("certificate is empty")
} }
cert.Certificate = &tlsCert
// Parse leaf certificate and extract relevant metadata // Parse leaf certificate and extract relevant metadata
leaf, err := x509.ParseCertificate(tlsCert.Certificate[0]) leaf, err := x509.ParseCertificate(tlsCert.Certificate[0])
@ -198,6 +197,7 @@ func makeCertificate(certPEMBlock, keyPEMBlock []byte) (Certificate, error) {
cert.OCSP = ocspResp cert.OCSP = ocspResp
} }
cert.Certificate = tlsCert
return cert, nil return cert, nil
} }
@ -213,7 +213,9 @@ func makeCertificate(certPEMBlock, keyPEMBlock []byte) (Certificate, error) {
func cacheCertificate(cert Certificate) { func cacheCertificate(cert Certificate) {
certCacheMu.Lock() certCacheMu.Lock()
if _, ok := certCache[""]; !ok { if _, ok := certCache[""]; !ok {
certCache[""] = cert // use as default // use as default
certCache[""] = cert
cert.Names = append(cert.Names, "")
} }
for len(certCache)+len(cert.Names) > 10000 { for len(certCache)+len(cert.Names) > 10000 {
// for simplicity, just remove random elements // for simplicity, just remove random elements

View file

@ -23,7 +23,7 @@ import (
// This function is safe for use as a tls.Config.GetCertificate callback. // This function is safe for use as a tls.Config.GetCertificate callback.
func GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) { func GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
cert, err := getCertDuringHandshake(clientHello.ServerName, false, false) cert, err := getCertDuringHandshake(clientHello.ServerName, false, false)
return cert.Certificate, err return &cert.Certificate, err
} }
// GetOrObtainCertificate will get a certificate to satisfy clientHello, even // GetOrObtainCertificate will get a certificate to satisfy clientHello, even
@ -35,7 +35,7 @@ func GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error)
// This function is safe for use as a tls.Config.GetCertificate callback. // This function is safe for use as a tls.Config.GetCertificate callback.
func GetOrObtainCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) { func GetOrObtainCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
cert, err := getCertDuringHandshake(clientHello.ServerName, true, true) cert, err := getCertDuringHandshake(clientHello.ServerName, true, true)
return cert.Certificate, err return &cert.Certificate, err
} }
// getCertDuringHandshake will get a certificate for name. It first tries // getCertDuringHandshake will get a certificate for name. It first tries
@ -122,8 +122,8 @@ func checkLimitsForObtainingNewCerts(name string) error {
} }
// obtainOnDemandCertificate obtains a certificate for name for the given // obtainOnDemandCertificate obtains a certificate for name for the given
// clientHello. If another goroutine has already started obtaining a cert // name. If another goroutine has already started obtaining a cert for
// for name, it will wait and use what the other goroutine obtained. // name, it will wait and use what the other goroutine obtained.
// //
// This function is safe for use by multiple concurrent goroutines. // This function is safe for use by multiple concurrent goroutines.
func obtainOnDemandCertificate(name string) (Certificate, error) { func obtainOnDemandCertificate(name string) (Certificate, error) {
@ -248,7 +248,7 @@ func renewDynamicCertificate(name string) (Certificate, error) {
log.Printf("[INFO] Renewing certificate for %s", name) log.Printf("[INFO] Renewing certificate for %s", name)
client, err := NewACMEClient("", false) // renewals don't use email client, err := NewACMEClientGetEmail(server.Config{}, false)
if err != nil { if err != nil {
return Certificate{}, err return Certificate{}, err
} }
@ -295,7 +295,7 @@ var obtainCertWaitChansMu sync.Mutex
// OnDemandIssuedCount is the number of certificates that have been issued // OnDemandIssuedCount is the number of certificates that have been issued
// on-demand by this process. It is only safe to modify this count atomically. // on-demand by this process. It is only safe to modify this count atomically.
// If it reaches max_certs, on-demand issuances will fail. // If it reaches onDemandMaxIssue, on-demand issuances will fail.
var OnDemandIssuedCount = new(int32) var OnDemandIssuedCount = new(int32)
// onDemandMaxIssue is set based on max_certs in tls config. It specifies the // onDemandMaxIssue is set based on max_certs in tls config. It specifies the

View file

@ -12,7 +12,6 @@ import (
"net/http" "net/http"
"os" "os"
"strings" "strings"
"time"
"github.com/mholt/caddy/middleware" "github.com/mholt/caddy/middleware"
"github.com/mholt/caddy/middleware/redirect" "github.com/mholt/caddy/middleware/redirect"
@ -215,7 +214,7 @@ func hostHasOtherPort(allConfigs []server.Config, thisConfigIdx int, otherPort s
// all configs. // all configs.
func MakePlaintextRedirects(allConfigs []server.Config) []server.Config { func MakePlaintextRedirects(allConfigs []server.Config) []server.Config {
for i, cfg := range allConfigs { for i, cfg := range allConfigs {
if (cfg.TLS.Managed || cfg.TLS.OnDemand) && if cfg.TLS.Managed &&
!hostHasOtherPort(allConfigs, i, "80") && !hostHasOtherPort(allConfigs, i, "80") &&
(cfg.Port == "443" || !hostHasOtherPort(allConfigs, i, "443")) { (cfg.Port == "443" || !hostHasOtherPort(allConfigs, i, "443")) {
allConfigs = append(allConfigs, redirPlaintextHost(cfg)) allConfigs = append(allConfigs, redirPlaintextHost(cfg))
@ -233,15 +232,16 @@ func MakePlaintextRedirects(allConfigs []server.Config) []server.Config {
// setting up the config may make it look like it // setting up the config may make it look like it
// doesn't qualify even though it originally did. // doesn't qualify even though it originally did.
func ConfigQualifies(cfg server.Config) bool { func ConfigQualifies(cfg server.Config) bool {
return !cfg.TLS.Manual && // user can provide own cert and key return (!cfg.TLS.Manual || cfg.TLS.OnDemand) && // user might provide own cert and key
// user can force-disable automatic HTTPS for this host // user can force-disable automatic HTTPS for this host
cfg.Scheme != "http" && cfg.Scheme != "http" &&
cfg.Port != "80" && cfg.Port != "80" &&
cfg.TLS.LetsEncryptEmail != "off" && cfg.TLS.LetsEncryptEmail != "off" &&
// we get can't certs for some kinds of hostnames // we get can't certs for some kinds of hostnames, but
HostQualifies(cfg.Host) // on-demand TLS allows empty hostnames at startup
(HostQualifies(cfg.Host) || cfg.TLS.OnDemand)
} }
// HostQualifies returns true if the hostname alone // HostQualifies returns true if the hostname alone
@ -387,20 +387,11 @@ var (
CAUrl string CAUrl string
) )
// Some essential values related to the Let's Encrypt process // AlternatePort is the port on which the acme client will open a
const ( // listener and solve the CA's challenges. If this alternate port
// AlternatePort is the port on which the acme client will open a // is used instead of the default port (80 or 443), then the
// listener and solve the CA's challenges. If this alternate port // default port for the challenge must be forwarded to this one.
// is used instead of the default port (80 or 443), then the const AlternatePort = "5033"
// default port for the challenge must be forwarded to this one.
AlternatePort = "5033"
// RenewInterval is how often to check certificates for renewal.
RenewInterval = 6 * time.Hour
// OCSPInterval is how often to check if OCSP stapling needs updating.
OCSPInterval = 1 * time.Hour
)
// KeySize represents the length of a key in bits. // KeySize represents the length of a key in bits.
type KeySize int type KeySize int

View file

@ -4,9 +4,19 @@ import (
"log" "log"
"time" "time"
"github.com/mholt/caddy/server"
"golang.org/x/crypto/ocsp" "golang.org/x/crypto/ocsp"
) )
const (
// RenewInterval is how often to check certificates for renewal.
RenewInterval = 12 * time.Hour
// OCSPInterval is how often to check if OCSP stapling needs updating.
OCSPInterval = 1 * time.Hour
)
// maintainAssets is a permanently-blocking function // maintainAssets is a permanently-blocking function
// that loops indefinitely and, on a regular schedule, checks // that loops indefinitely and, on a regular schedule, checks
// certificates for expiration and initiates a renewal of certs // certificates for expiration and initiates a renewal of certs
@ -28,7 +38,7 @@ func maintainAssets(stopChan chan struct{}) {
log.Println("[INFO] Done checking certificates") log.Println("[INFO] Done checking certificates")
case <-ocspTicker.C: case <-ocspTicker.C:
log.Println("[INFO] Scanning for stale OCSP staples") log.Println("[INFO] Scanning for stale OCSP staples")
updatePreloadedOCSPStaples() updateOCSPStaples()
log.Println("[INFO] Done checking OCSP staples") log.Println("[INFO] Done checking OCSP staples")
case <-stopChan: case <-stopChan:
renewalTicker.Stop() renewalTicker.Stop()
@ -70,7 +80,7 @@ func renewManagedCertificates(allowPrompts bool) (err error) {
log.Printf("[INFO] Certificate for %v expires in %v; attempting renewal", cert.Names, timeLeft) log.Printf("[INFO] Certificate for %v expires in %v; attempting renewal", cert.Names, timeLeft)
if client == nil { if client == nil {
client, err = NewACMEClient("", allowPrompts) // renewals don't use email client, err = NewACMEClientGetEmail(server.Config{}, allowPrompts)
if err != nil { if err != nil {
return err return err
} }
@ -116,42 +126,66 @@ func renewManagedCertificates(allowPrompts bool) (err error) {
return nil return nil
} }
func updatePreloadedOCSPStaples() { func updateOCSPStaples() {
// Create a temporary place to store updates // Create a temporary place to store updates
// until we release the potentially slow read // until we release the potentially long-lived
// lock so we can use a quick write lock. // read lock and use a short-lived write lock.
type ocspUpdate struct { type ocspUpdate struct {
rawBytes []byte rawBytes []byte
parsedResponse *ocsp.Response parsed *ocsp.Response
} }
updated := make(map[string]ocspUpdate) updated := make(map[string]ocspUpdate)
// A single SAN certificate maps to multiple names, so we use this
// set to make sure we don't waste cycles checking OCSP for the same
// certificate multiple times.
visited := make(map[string]struct{})
certCacheMu.RLock() certCacheMu.RLock()
for name, cert := range certCache { for name, cert := range certCache {
// we update OCSP for managed and un-managed certs here, but only // skip this certificate if we've already visited it,
// if it has OCSP stapled and only for pre-loaded certificates // and if not, mark all the names as visited
if cert.OnDemand || cert.OCSP == nil { if _, ok := visited[name]; ok {
continue
}
for _, n := range cert.Names {
visited[n] = struct{}{}
}
// no point in updating OCSP for expired certificates
if time.Now().After(cert.NotAfter) {
continue continue
} }
// start checking OCSP staple about halfway through validity period for good measure var lastNextUpdate time.Time
oldNextUpdate := cert.OCSP.NextUpdate if cert.OCSP != nil {
refreshTime := cert.OCSP.ThisUpdate.Add(oldNextUpdate.Sub(cert.OCSP.ThisUpdate) / 2) // start checking OCSP staple about halfway through validity period for good measure
lastNextUpdate = cert.OCSP.NextUpdate
refreshTime := cert.OCSP.ThisUpdate.Add(lastNextUpdate.Sub(cert.OCSP.ThisUpdate) / 2)
// only check for updated OCSP validity window if the refresh time is // since OCSP is already stapled, we need only check if we're in that "refresh window"
// in the past and the certificate is not expired if time.Now().Before(refreshTime) {
if time.Now().After(refreshTime) && time.Now().Before(cert.NotAfter) {
err := stapleOCSP(&cert, nil)
if err != nil {
log.Printf("[ERROR] Checking OCSP for %s: %v", name, err)
continue continue
} }
}
// if the OCSP response has been updated, we use it err := stapleOCSP(&cert, nil)
if oldNextUpdate != cert.OCSP.NextUpdate { if err != nil {
log.Printf("[INFO] Moving validity period of OCSP staple for %s from %v to %v", if cert.OCSP != nil {
name, oldNextUpdate, cert.OCSP.NextUpdate) // if it was no staple before, that's fine, otherwise we should log the error
updated[name] = ocspUpdate{rawBytes: cert.Certificate.OCSPStaple, parsedResponse: cert.OCSP} log.Printf("[ERROR] Checking OCSP for %s: %v", name, err)
}
continue
}
// By this point, we've obtained the latest OCSP response.
// If there was no staple before, or if the response is updated, make
// sure we apply the update to all names on the certificate.
if lastNextUpdate.IsZero() || lastNextUpdate != cert.OCSP.NextUpdate {
log.Printf("[INFO] Advancing OCSP staple for %v from %s to %s",
cert.Names, lastNextUpdate, cert.OCSP.NextUpdate)
for _, n := range cert.Names {
updated[n] = ocspUpdate{rawBytes: cert.Certificate.OCSPStaple, parsed: cert.OCSP}
} }
} }
} }
@ -161,7 +195,7 @@ func updatePreloadedOCSPStaples() {
certCacheMu.Lock() certCacheMu.Lock()
for name, update := range updated { for name, update := range updated {
cert := certCache[name] cert := certCache[name]
cert.OCSP = update.parsedResponse cert.OCSP = update.parsed
cert.Certificate.OCSPStaple = update.rawBytes cert.Certificate.OCSPStaple = update.rawBytes
certCache[name] = cert certCache[name] = cert
} }

View file

@ -20,12 +20,12 @@ import (
// are specified by the user in the config file. All the automatic HTTPS // are specified by the user in the config file. All the automatic HTTPS
// stuff comes later outside of this function. // stuff comes later outside of this function.
func Setup(c *setup.Controller) (middleware.Middleware, error) { func Setup(c *setup.Controller) (middleware.Middleware, error) {
if c.Scheme == "http" { if c.Port == "80" || c.Scheme == "http" {
c.TLS.Enabled = false c.TLS.Enabled = false
log.Printf("[WARNING] TLS disabled for %s://%s.", c.Scheme, c.Address()) log.Printf("[WARNING] TLS disabled for %s://%s.", c.Scheme, c.Address())
} else { return nil, nil
c.TLS.Enabled = true
} }
c.TLS.Enabled = true
for c.Next() { for c.Next() {
var certificateFile, keyFile, loadDir, maxCerts string var certificateFile, keyFile, loadDir, maxCerts string
@ -38,6 +38,7 @@ func Setup(c *setup.Controller) (middleware.Middleware, error) {
// user can force-disable managed TLS this way // user can force-disable managed TLS this way
if c.TLS.LetsEncryptEmail == "off" { if c.TLS.LetsEncryptEmail == "off" {
c.TLS.Enabled = false c.TLS.Enabled = false
return nil, nil
} }
case 2: case 2:
certificateFile = args[0] certificateFile = args[0]
@ -120,78 +121,8 @@ func Setup(c *setup.Controller) (middleware.Middleware, error) {
} }
// load a directory of certificates, if specified // load a directory of certificates, if specified
// modeled after haproxy: https://cbonte.github.io/haproxy-dconv/configuration-1.5.html#5.1-crt
if loadDir != "" { if loadDir != "" {
err := filepath.Walk(loadDir, func(path string, info os.FileInfo, err error) error { err := loadCertsInDir(c, loadDir)
if err != nil {
log.Printf("[WARNING] Unable to traverse into %s; skipping", path)
return nil
}
if info.IsDir() {
return nil
}
if strings.HasSuffix(strings.ToLower(info.Name()), ".pem") {
certBuilder, keyBuilder := new(bytes.Buffer), new(bytes.Buffer)
var foundKey bool
bundle, err := ioutil.ReadFile(path)
if err != nil {
return err
}
for {
// Decode next block so we can see what type it is
var derBlock *pem.Block
derBlock, bundle = pem.Decode(bundle)
if derBlock == nil {
break
}
if derBlock.Type == "CERTIFICATE" {
// Re-encode certificate as PEM, appending to certificate chain
pem.Encode(certBuilder, derBlock)
} else if derBlock.Type == "EC PARAMETERS" {
// EC keys are composed of two blocks: parameters and key
// (parameter block should come first)
if !foundKey {
// Encode parameters
pem.Encode(keyBuilder, derBlock)
// Key must immediately follow
derBlock, bundle = pem.Decode(bundle)
if derBlock == nil || derBlock.Type != "EC PRIVATE KEY" {
return c.Errf("%s: expected elliptic private key to immediately follow EC parameters", path)
}
pem.Encode(keyBuilder, derBlock)
foundKey = true
}
} else if derBlock.Type == "PRIVATE KEY" || strings.HasSuffix(derBlock.Type, " PRIVATE KEY") {
// RSA key
if !foundKey {
pem.Encode(keyBuilder, derBlock)
foundKey = true
}
} else {
return c.Errf("%s: unrecognized PEM block type: %s", path, derBlock.Type)
}
}
certPEMBytes, keyPEMBytes := certBuilder.Bytes(), keyBuilder.Bytes()
if len(certPEMBytes) == 0 {
return c.Errf("%s: failed to parse PEM data", path)
}
if len(keyPEMBytes) == 0 {
return c.Errf("%s: no private key block found", path)
}
err = cacheUnmanagedCertificatePEMBytes(certPEMBytes, keyPEMBytes)
if err != nil {
return c.Errf("%s: failed to load cert and key for %s: %v", path, c.Host, err)
}
log.Printf("[INFO] Successfully loaded TLS assets from %s", path)
}
return nil
})
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -203,6 +134,86 @@ func Setup(c *setup.Controller) (middleware.Middleware, error) {
return nil, nil return nil, nil
} }
// loadCertsInDir loads all the certificates/keys in dir, as long as
// the file ends with .pem. This method of loading certificates is
// modeled after haproxy, which expects the certificate and key to
// be bundled into the same file:
// https://cbonte.github.io/haproxy-dconv/configuration-1.5.html#5.1-crt
//
// This function may write to the log as it walks the directory tree.
func loadCertsInDir(c *setup.Controller, dir string) error {
return filepath.Walk(dir, func(path string, info os.FileInfo, err error) error {
if err != nil {
log.Printf("[WARNING] Unable to traverse into %s; skipping", path)
return nil
}
if info.IsDir() {
return nil
}
if strings.HasSuffix(strings.ToLower(info.Name()), ".pem") {
certBuilder, keyBuilder := new(bytes.Buffer), new(bytes.Buffer)
var foundKey bool // use only the first key in the file
bundle, err := ioutil.ReadFile(path)
if err != nil {
return err
}
for {
// Decode next block so we can see what type it is
var derBlock *pem.Block
derBlock, bundle = pem.Decode(bundle)
if derBlock == nil {
break
}
if derBlock.Type == "CERTIFICATE" {
// Re-encode certificate as PEM, appending to certificate chain
pem.Encode(certBuilder, derBlock)
} else if derBlock.Type == "EC PARAMETERS" {
// EC keys generated from openssl can be composed of two blocks:
// parameters and key (parameter block should come first)
if !foundKey {
// Encode parameters
pem.Encode(keyBuilder, derBlock)
// Key must immediately follow
derBlock, bundle = pem.Decode(bundle)
if derBlock == nil || derBlock.Type != "EC PRIVATE KEY" {
return c.Errf("%s: expected elliptic private key to immediately follow EC parameters", path)
}
pem.Encode(keyBuilder, derBlock)
foundKey = true
}
} else if derBlock.Type == "PRIVATE KEY" || strings.HasSuffix(derBlock.Type, " PRIVATE KEY") {
// RSA key
if !foundKey {
pem.Encode(keyBuilder, derBlock)
foundKey = true
}
} else {
return c.Errf("%s: unrecognized PEM block type: %s", path, derBlock.Type)
}
}
certPEMBytes, keyPEMBytes := certBuilder.Bytes(), keyBuilder.Bytes()
if len(certPEMBytes) == 0 {
return c.Errf("%s: failed to parse PEM data", path)
}
if len(keyPEMBytes) == 0 {
return c.Errf("%s: no private key block found", path)
}
err = cacheUnmanagedCertificatePEMBytes(certPEMBytes, keyPEMBytes)
if err != nil {
return c.Errf("%s: failed to load cert and key for %s: %v", path, c.Host, err)
}
log.Printf("[INFO] Successfully loaded TLS assets from %s", path)
}
return nil
})
}
// setDefaultTLSParams sets the default TLS cipher suites, protocol versions, // setDefaultTLSParams sets the default TLS cipher suites, protocol versions,
// and server preferences of a server.Config if they were not previously set // and server preferences of a server.Config if they were not previously set
// (it does not overwrite; only fills in missing values). It will also set the // (it does not overwrite; only fills in missing values). It will also set the
@ -231,7 +242,7 @@ func setDefaultTLSParams(c *server.Config) {
// Default TLS port is 443; only use if port is not manually specified, // Default TLS port is 443; only use if port is not manually specified,
// TLS is enabled, and the host is not localhost // TLS is enabled, and the host is not localhost
if c.Port == "" && c.TLS.Enabled && !c.TLS.Manual && c.Host != "localhost" { if c.Port == "" && c.TLS.Enabled && (!c.TLS.Manual || c.TLS.OnDemand) && c.Host != "localhost" {
c.Port = "443" c.Port = "443"
} }
} }

View file

@ -68,7 +68,7 @@ type TLSConfig struct {
Enabled bool // will be set to true if TLS is enabled Enabled bool // will be set to true if TLS is enabled
LetsEncryptEmail string LetsEncryptEmail string
Manual bool // will be set to true if user provides own certs and keys Manual bool // will be set to true if user provides own certs and keys
Managed bool // will be set to true if config qualifies for automatic/managed HTTPS Managed bool // will be set to true if config qualifies for implicit automatic/managed HTTPS
OnDemand bool // will be set to true if user enables on-demand TLS (obtain certs during handshakes) OnDemand bool // will be set to true if user enables on-demand TLS (obtain certs during handshakes)
Ciphers []uint16 Ciphers []uint16
ProtocolMinVersion uint16 ProtocolMinVersion uint16

View file

@ -63,15 +63,7 @@ func New(addr string, configs []Config, gracefulTimeout time.Duration) (*Server,
var useTLS, useOnDemandTLS bool var useTLS, useOnDemandTLS bool
if len(configs) > 0 { if len(configs) > 0 {
useTLS = configs[0].TLS.Enabled useTLS = configs[0].TLS.Enabled
if useTLS { useOnDemandTLS = configs[0].TLS.OnDemand
host, _, err := net.SplitHostPort(addr)
if err != nil {
host = addr
}
if host == "" && configs[0].TLS.OnDemand {
useOnDemandTLS = true
}
}
} }
s := &Server{ s := &Server{