diff --git a/modules/caddypki/ca.go b/modules/caddypki/ca.go index 6d25b8f76..e57f0d5fc 100644 --- a/modules/caddypki/ca.go +++ b/modules/caddypki/ca.go @@ -16,7 +16,9 @@ package caddypki import ( "crypto" + "crypto/rand" "crypto/x509" + "crypto/x509/pkix" "encoding/json" "errors" "fmt" @@ -432,40 +434,61 @@ func (ca CA) generateCSR(csrReq csrRequest) (csr *x509.CertificateRequest, err e csrKeyPEM, err := ca.storage.Load(ca.ctx, ca.storageKeyCSRKey(csrReq.ID)) if err != nil { if !errors.Is(err, fs.ErrNotExist) { - return nil, fmt.Errorf("loading csr key '%s': %v", csrReq.ID, err) + return csr, fmt.Errorf("loading csr key '%s': %v", csrReq.ID, err) } if csrReq.Key == nil { signer, err = keyutil.GenerateDefaultSigner() if err != nil { - return nil, err + return csr, err } } else { signer, err = keyutil.GenerateSigner(csrReq.Key.Type.String(), csrReq.Key.Curve.String(), csrReq.Key.Size) if err != nil { - return nil, err + return csr, err } } csrKeyPEM, err = certmagic.PEMEncodePrivateKey(signer) if err != nil { - return nil, fmt.Errorf("encoding csr key: %v", err) + return csr, fmt.Errorf("encoding csr key: %v", err) } if err := ca.storage.Store(ca.ctx, ca.storageKeyCSRKey(csrReq.ID), csrKeyPEM); err != nil { - return nil, fmt.Errorf("saving csr key: %v", err) + return csr, fmt.Errorf("saving csr key: %v", err) } } if signer == nil { signer, err = certmagic.PEMDecodePrivateKey(csrKeyPEM) if err != nil { - return nil, fmt.Errorf("decoding csr key: %v", err) + return csr, fmt.Errorf("decoding csr key: %v", err) } } - csr, err = x509util.CreateCertificateRequest("", csrReq.SANs, signer) - if err != nil { - return nil, err + var subject pkix.Name + if csrReq.Request != nil && csrReq.Request.Subject != nil { + subject = pkix.Name{ + Country: csrReq.Request.Subject.Country, + Organization: csrReq.Request.Subject.Organization, + OrganizationalUnit: csrReq.Request.Subject.OrganizationalUnit, + Locality: csrReq.Request.Subject.Locality, + Province: csrReq.Request.Subject.Province, + StreetAddress: csrReq.Request.Subject.StreetAddress, + PostalCode: csrReq.Request.Subject.PostalCode, + CommonName: csrReq.Request.Subject.CommonName, + } } - return csr, nil + dnsNames, ips, emails, uris := x509util.SplitSANs(csrReq.Request.SANs) + + csrBytes, err := x509.CreateCertificateRequest(rand.Reader, &x509.CertificateRequest{ + Subject: subject, + DNSNames: dnsNames, + IPAddresses: ips, + EmailAddresses: emails, + URIs: uris, + }, signer) + if err != nil { + return csr, err + } + return x509.ParseCertificateRequest(csrBytes) } // AuthorityConfig is used to help a CA configure diff --git a/modules/caddypki/csr.go b/modules/caddypki/csr.go index 2f379ef9d..d00c001ed 100644 --- a/modules/caddypki/csr.go +++ b/modules/caddypki/csr.go @@ -2,7 +2,9 @@ package caddypki import ( "encoding/json" + "errors" "fmt" + "strings" ) // The key type to be used for signing the CSR. The possible types are: @@ -150,10 +152,44 @@ type csrRequest struct { // The values are case-sensitive. Key *keyParameters `json:"key,omitempty"` - // SANs is a list of subject alternative names for the certificate. - SANs []string `json:"sans"` + Request *requestParameters `json:"request,omitempty"` } func (c csrRequest) validate() error { + if !c.Request.valid() { + return errors.New("the 'request' field is not valid") + } return c.Key.validate() } + +type requestParameters struct { + Subject *subject `json:"subject,omitempty"` + + // SANs is a list of subject alternative names for the certificate. + SANs []string `json:"sans,omitempty"` +} + +type subject struct { + CommonName string `json:"cn,omitempty"` + Country []string `json:"c,omitempty"` + Organization []string `json:"o,omitempty"` + OrganizationalUnit []string `json:"ou,omitempty"` + Locality []string `json:"l,omitempty"` + Province []string `json:"s,omitempty"` + StreetAddress []string `json:"street_address,omitempty"` + PostalCode []string `json:"postal_code,omitempty"` +} + +func (rp *requestParameters) valid() bool { + if rp == nil || (len(rp.SANs) == 0 && rp.Subject == nil) { + return false + } + if len(rp.SANs) > 0 { + for _, san := range rp.SANs { + if strings.TrimSpace(san) == "" { + return false + } + } + } + return rp.Subject == nil || (rp.Subject != nil && len(strings.TrimSpace(rp.Subject.CommonName)) > 0) +} diff --git a/modules/caddypki/csr_test.go b/modules/caddypki/csr_test.go index 19a56d65c..09870e658 100644 --- a/modules/caddypki/csr_test.go +++ b/modules/caddypki/csr_test.go @@ -2,7 +2,6 @@ package caddypki import ( "encoding/json" - "reflect" "testing" ) @@ -19,12 +18,12 @@ func TestParseKeyType(t *testing.T) { expected: keyTypeEC, }, { - name: "lowercase EC is recognized", + name: "lowercase EC is rejected", input: `"ec"`, err: "unknown key type: ec", }, { - name: "mixed case EC is recognized", + name: "mixed case EC is rejected", input: `"eC"`, err: "unknown key type: eC", }, @@ -34,12 +33,12 @@ func TestParseKeyType(t *testing.T) { expected: keyTypeRSA, }, { - name: "lowercase rsa is not accepted", + name: "lowercase rsa is rejected", input: `"rsa"`, err: "unknown key type: rsa", }, { - name: "mixed case RSA is not accepted", + name: "mixed case RSA is rejected", input: `"RsA"`, err: "unknown key type: RsA", }, @@ -49,17 +48,17 @@ func TestParseKeyType(t *testing.T) { expected: keyTypeOKP, }, { - name: "lowercase OKP is not accepted", + name: "lowercase OKP is rejected", input: `"okp"`, err: "unknown key type: okp", }, { - name: "mixed case OKP is not accepted", + name: "mixed case OKP is rejected", input: `"OkP"`, err: "unknown key type: OkP", }, { - name: "unknown key type is an error", + name: "unknown key type is rejected", input: `"foo"`, err: "unknown key type: foo", }, @@ -89,7 +88,7 @@ func TestParseKeyType(t *testing.T) { } } -func TestCSRRequestValidate(t *testing.T) { +func TestCSRKeyParameterValidate(t *testing.T) { tests := []struct { name string key *keyParameters @@ -221,71 +220,154 @@ func TestCSRRequestValidate(t *testing.T) { wantErr: true, }, } - for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - c := csrRequest{ - Key: tt.key, - } - if err := c.validate(); (err != nil) != tt.wantErr { - t.Errorf("csrRequest.validate() error = %v, wantErr %v", err, tt.wantErr) + if err := tt.key.validate(); (err != nil) != tt.wantErr { + t.Errorf("keyParameter.validate() error = %v, wantErr %v", err, tt.wantErr) } }) } } -func TestCSRRequestUnmarshalJSON(t *testing.T) { +func TestParseCurve(t *testing.T) { tests := []struct { - name string - request string - want csrRequest - err string + name string + input string + expected curve + err string }{ { - name: "empty request is valid", - request: "{}", - want: csrRequest{ - Key: nil, - }, + name: "Ed25519 is recognized", + input: `"Ed25519"`, + expected: curveEd25519, }, { - name: "RSA with size 2048 is valid", - request: `{"key":{"type":"RSA","size":2048}}`, - want: csrRequest{ - Key: &keyParameters{ - Type: keyTypeRSA, - Size: 2048, - }, - }, + name: "ed25519 is rejected", + input: `"ed25519"`, + err: "unknown curve: ed25519", }, { - name: "EC key with curve P-256 is valid", - request: `{"key":{"type":"EC","curve":"P-256"}}`, - want: csrRequest{ - Key: &keyParameters{ - Type: keyTypeEC, - Curve: "P-256", - }, - }, + name: "eD25519 is rejected", + input: `"eD25519"`, + err: "unknown curve: eD25519", + }, + { + name: "X25519 is recognized", + input: `"X25519"`, + expected: curveX25519, + }, + { + name: "x25519 is rejected", + input: `"x25519"`, + err: "unknown curve: x25519", + }, + { + name: "P-256 is recognized", + input: `"P-256"`, + expected: curveP256, + }, + { + name: "p-256 is rejected", + input: `"p-256"`, + err: "unknown curve: p-256", + }, + + { + name: "P-384 is recognized", + input: `"P-384"`, + expected: curveP384, + }, + { + name: "p-384 is rejected", + input: `"p-384"`, + err: "unknown curve: p-384", + }, + + { + name: "P-521 is recognized", + input: `"P-521"`, + expected: curveP521, + }, + { + name: "p-521 is rejected", + input: `"p-521"`, + err: "unknown curve: p-521", }, } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - var c csrRequest - err := json.Unmarshal([]byte(tt.request), &c) - if tt.err != "" { + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + var kt curve + + err := json.Unmarshal([]byte(test.input), &kt) + if test.err != "" { if err == nil { - t.Errorf("expected error %q, but got nil", tt.err) + t.Errorf("expected error %q, but got nil", test.err) } - if err.Error() != tt.err { - t.Errorf("expected error %q, but got %q", tt.err, err.Error()) + if err.Error() != test.err { + t.Errorf("expected error %q, but got %q", test.err, err.Error()) } + return } if err != nil { t.Errorf("expected no error, but got %q", err.Error()) + return } - if !reflect.DeepEqual(c, tt.want) { - t.Errorf("csrRequest.unmarshalJSON() = %v, want %v", c, tt.want) + if kt != test.expected { + t.Errorf("expected %v, but got %v", test.expected, kt) + } + }) + } +} + +func TestRequestParametersValidation(t *testing.T) { + tests := []struct { + name string + req *requestParameters + want bool + }{ + { + name: "nil request is invalid", + req: nil, + want: false, + }, + { + name: "empty request is invalid", + req: &requestParameters{}, + want: false, + }, + { + name: "request containing empty SAN value is invalid", + req: &requestParameters{ + SANs: []string{"example.com", "", "foo.com"}, + }, + want: false, + }, + { + name: "request with SANs is valid", + req: &requestParameters{ + SANs: []string{"example.com"}, + }, + want: true, + }, + { + name: "request with non-empty CommonName is valid", + req: &requestParameters{ + Subject: &subject{CommonName: "example.com"}, + }, + want: true, + }, + { + name: "request with empty-space CommonName is invalid", + req: &requestParameters{ + Subject: &subject{CommonName: " "}, + }, + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.req.valid(); got != tt.want { + t.Errorf("requestParameters.valid() = %v, want %v", got, tt.want) } }) }