diff --git a/certificates.go b/certificates.go index e430ecc..6029a86 100644 --- a/certificates.go +++ b/certificates.go @@ -22,6 +22,7 @@ package crypto11 import ( + "bytes" "crypto/tls" "crypto/x509" "encoding/asn1" @@ -31,26 +32,7 @@ import ( "github.com/pkg/errors" ) -// FindCertificate retrieves a previously imported certificate. Any combination of id, label -// and serial can be provided. An error is return if all are nil. func findCertificate(session *pkcs11Session, id []byte, label []byte, serial *big.Int) (cert *x509.Certificate, err error) { - - rawCertificate, err := findRawCertificate(session, id, label, serial) - if err != nil { - return nil, err - } - - if rawCertificate != nil { - cert, err = x509.ParseCertificate(rawCertificate) - if err != nil { - return nil, err - } - } - - return cert, err -} - -func findRawCertificate(session *pkcs11Session, id []byte, label []byte, serial *big.Int) (rawCertificate []byte, err error) { if id == nil && label == nil && serial == nil { return nil, errors.New("id, label and serial cannot all be nil") } @@ -72,6 +54,36 @@ func findRawCertificate(session *pkcs11Session, id []byte, label []byte, serial template = append(template, pkcs11.NewAttribute(pkcs11.CKA_SERIAL_NUMBER, derSerial)) } + handles, err := findCertificatesWithAttributes(session, template) + if err != nil { + return nil, err + } + + if len(handles) == 0 { + return nil, nil + } + + return getX509Certificate(session, handles[0]) +} + +func getX509Certificate(session *pkcs11Session, handle pkcs11.ObjectHandle) (cert *x509.Certificate, err error) { + attributes := []*pkcs11.Attribute{ + pkcs11.NewAttribute(pkcs11.CKA_VALUE, 0), + } + + if attributes, err = session.ctx.GetAttributeValue(session.handle, handle, attributes); err != nil { + return nil, err + } + + cert, err = x509.ParseCertificate(attributes[0].Value) + if err != nil { + return nil, err + } + + return cert, nil +} + +func findCertificatesWithAttributes(session *pkcs11Session, template []*pkcs11.Attribute) (handles []pkcs11.ObjectHandle, err error) { template = append(template, pkcs11.NewAttribute(pkcs11.CKA_CLASS, pkcs11.CKO_CERTIFICATE)) if err = session.ctx.FindObjectsInit(session.handle, template); err != nil { @@ -84,25 +96,81 @@ func findRawCertificate(session *pkcs11Session, id []byte, label []byte, serial } }() - handles, _, err := session.ctx.FindObjects(session.handle, 1) + for { + newhandles, _, err := session.ctx.FindObjects(session.handle, maxHandlePerFind) + if err != nil { + return nil, err + } + + if len(newhandles) == 0 { + break + } + + handles = append(handles, newhandles...) + } + + return handles, nil +} + +func findCertificateByKeyID(session *pkcs11Session, keyID []byte) (cert *x509.Certificate, err error) { + handles, err := findCertificatesWithAttributes(session, nil) if err != nil { return nil, err } - if len(handles) == 0 { + + for _, handle := range handles { + if cert, err = getX509Certificate(session, handle); err != nil { + return nil, err + } + + if bytes.Equal(cert.SubjectKeyId, keyID) { + return cert, nil + } + } + + return nil, errors.New("no certificate with required subject key ID found") +} + +func findCertificateChain(session *pkcs11Session, cert *x509.Certificate) (certs []*x509.Certificate, err error) { + if len(cert.RawIssuer) == 0 || bytes.Equal(cert.RawIssuer, cert.RawSubject) { return nil, nil } - attributes := []*pkcs11.Attribute{ - pkcs11.NewAttribute(pkcs11.CKA_VALUE, 0), + template := []*pkcs11.Attribute{pkcs11.NewAttribute(pkcs11.CKA_SUBJECT, cert.RawIssuer)} + + handles, err := findCertificatesWithAttributes(session, template) + if err != nil { + return nil, err + } + + if len(handles) == 0 { + if cert, err = findCertificateByKeyID(session, cert.AuthorityKeyId); err != nil { + return nil, err + } + } else { + if cert, err = getX509Certificate(session, handles[0]); err != nil { + return nil, err + } } - if attributes, err = session.ctx.GetAttributeValue(session.handle, handles[0], attributes); err != nil { + for _, foundCert := range certs { + if bytes.Equal(cert.RawSubject, foundCert.RawSubject) { + return certs, nil + } + } + + certs = append(certs, cert) + + certChain, err := findCertificateChain(session, cert) + if err != nil { return nil, err } - rawCertificate = attributes[0].Value + if len(certChain) != 0 { + certs = append(certs, certChain...) + } - return + return certs, nil } // FindCertificate retrieves a previously imported certificate. Any combination of id, label @@ -122,6 +190,69 @@ func (c *Context) FindCertificate(id []byte, label []byte, serial *big.Int) (*x5 return cert, err } +// FindCertificateWithAttributes retrieves a previously imported certificate with selected attributes. +func (c *Context) FindCertificateWithAttributes(template AttributeSet) (*x509.Certificate, error) { + if c.closed.Get() { + return nil, errClosed + } + + var cert *x509.Certificate + err := c.withSession(func(session *pkcs11Session) (err error) { + handles, err := findCertificatesWithAttributes(session, template.ToSlice()) + if err != nil { + return err + } + + if len(handles) == 0 { + return nil + } + + if cert, err = getX509Certificate(session, handles[0]); err != nil { + return err + } + + return nil + }) + + return cert, err +} + +// FindCertificateChain retrieves a previously imported certificate chain. Any combination of id, label +// and serial can be provided. An error is return if all are nil. +func (c *Context) FindCertificateChain(id []byte, label []byte, serial *big.Int) (certs []*x509.Certificate, err error) { + if c.closed.Get() { + return nil, errClosed + } + + err = c.withSession(func(session *pkcs11Session) (err error) { + cert, err := findCertificate(session, id, label, serial) + if err != nil { + return err + } + + if cert == nil { + return nil + } + + certs = append(certs, cert) + + certChain, err := findCertificateChain(session, cert) + if err != nil { + return err + } + + if len(certChain) == 0 { + return nil + } + + certs = append(certs, certChain...) + + return nil + }) + + return certs, err +} + func (c *Context) FindAllPairedCertificates() (certificates []tls.Certificate, err error) { if c.closed.Get() { return nil, errClosed @@ -258,40 +389,48 @@ func (c *Context) DeleteCertificate(id []byte, label []byte, serial *big.Int) er return errors.New("id, label and serial cannot all be nil") } - template := []*pkcs11.Attribute{ - pkcs11.NewAttribute(pkcs11.CKA_CLASS, pkcs11.CKO_CERTIFICATE), - } + template := NewAttributeSet() if id != nil { - template = append(template, pkcs11.NewAttribute(pkcs11.CKA_ID, id)) + if err := template.Set(pkcs11.CKA_ID, id); err != nil { + return err + } } if label != nil { - template = append(template, pkcs11.NewAttribute(pkcs11.CKA_LABEL, label)) + if err := template.Set(pkcs11.CKA_LABEL, label); err != nil { + return err + } } if serial != nil { asn1Serial, err := asn1.Marshal(serial) if err != nil { return err } - template = append(template, pkcs11.NewAttribute(pkcs11.CKA_SERIAL_NUMBER, asn1Serial)) - } - - err := c.withSession(func(session *pkcs11Session) error { - err := session.ctx.FindObjectsInit(session.handle, template) - if err != nil { + if err := template.Set(pkcs11.CKA_SERIAL_NUMBER, asn1Serial); err != nil { return err } - handles, _, err := session.ctx.FindObjects(session.handle, 1) - finalErr := session.ctx.FindObjectsFinal(session.handle) + } + + return c.DeleteCertificateWithAttributes(template) +} + +// DeleteCertificateWithAttributes destroys a previously imported certificate by selected attributes. +// It will return nil if succeeds or if the certificate does not exist. +func (c *Context) DeleteCertificateWithAttributes(template AttributeSet) error { + if c.closed.Get() { + return errClosed + } + + err := c.withSession(func(session *pkcs11Session) (err error) { + handles, err := findCertificatesWithAttributes(session, template.ToSlice()) if err != nil { return err } - if finalErr != nil { - return finalErr - } + if len(handles) == 0 { return nil } + return session.ctx.DestroyObject(session.handle, handles[0]) }) diff --git a/certificates_test.go b/certificates_test.go index 31aa856..69c43dc 100644 --- a/certificates_test.go +++ b/certificates_test.go @@ -48,7 +48,7 @@ func TestCertificate(t *testing.T) { id := randomBytes() label := randomBytes() - cert := generateRandomCert(t) + cert := generateRandomCert(t, nil, "Foo", nil, nil) err = ctx.ImportCertificateWithLabel(id, label, cert) require.NoError(t, err) @@ -81,7 +81,7 @@ func TestCertificateAttributes(t *testing.T) { require.NoError(t, ctx.Close()) }() - cert := generateRandomCert(t) + cert := generateRandomCert(t, nil, "Foo", nil, nil) // We import this with a different serial number, to test this is obeyed ourSerial := new(big.Int) @@ -116,7 +116,7 @@ func TestCertificateRequiredArgs(t *testing.T) { require.NoError(t, ctx.Close()) }() - cert := generateRandomCert(t) + cert := generateRandomCert(t, nil, "Foo", nil, nil) val := randomBytes() @@ -143,7 +143,7 @@ func TestDeleteCertificate(t *testing.T) { randomCert := func() ([]byte, []byte, *x509.Certificate) { id := randomBytes() label := randomBytes() - cert := generateRandomCert(t) + cert := generateRandomCert(t, nil, "Foo", nil, nil) return id, label, cert } importCertificate := func() ([]byte, []byte, *big.Int) { @@ -207,14 +207,98 @@ func TestDeleteCertificate(t *testing.T) { require.Nil(t, cert) } -func generateRandomCert(t *testing.T) *x509.Certificate { +func TestCertificateChain(t *testing.T) { + skipTest(t, skipTestCert) + + ctx, err := ConfigureFromFile("config") + require.NoError(t, err) + + defer func() { + require.NoError(t, ctx.Close()) + }() + + certNames := []string{"Cert0", "Cert1", "Cert2"} + + var ( + parent *x509.Certificate + originCertChain []*x509.Certificate + authorityKeyId, subjectKeyID []byte + ids [][]byte + ) + + for _, name := range certNames { + subjectKeyID = randomBytes() + + cert := generateRandomCert(t, parent, name, authorityKeyId, subjectKeyID) + + id := randomBytes() + ids = append([][]byte{id}, ids...) + + err = ctx.ImportCertificate(id, cert) + require.NoError(t, err) + + originCertChain = append([]*x509.Certificate{cert}, originCertChain...) + + parent = cert + authorityKeyId = subjectKeyID + } + + foundCertChain, err := ctx.FindCertificateChain(ids[0], nil, nil) + require.NoError(t, err) + require.NotNil(t, foundCertChain) + + assert.Equal(t, len(foundCertChain), len(originCertChain)) + + for i := 0; i < len(foundCertChain); i++ { + assert.Equal(t, foundCertChain[i].Signature, originCertChain[i].Signature) + } + + err = ctx.DeleteCertificate(ids[len(ids)-1], nil, nil) + require.NoError(t, err) + + oldCert := originCertChain[len(originCertChain)-1] + newCert := generateRandomCert(t, nil, "NewCert", oldCert.AuthorityKeyId, oldCert.SubjectKeyId) + + originCertChain[len(originCertChain)-1] = newCert + + id := randomBytes() + + err = ctx.ImportCertificate(id, newCert) + require.NoError(t, err) + + ids[len(ids)-1] = id + + foundCertChain, err = ctx.FindCertificateChain(ids[0], nil, nil) + require.NoError(t, err) + require.NotNil(t, foundCertChain) + + assert.Equal(t, len(foundCertChain), len(originCertChain)) + + for i := 0; i < len(foundCertChain); i++ { + assert.Equal(t, foundCertChain[i].Signature, originCertChain[i].Signature) + } + + for _, id := range ids { + err = ctx.DeleteCertificate(id, nil, nil) + require.NoError(t, err) + } + + foundCertChain, err = ctx.FindCertificateChain([]byte("test2"), nil, nil) + require.NoError(t, err) + assert.Nil(t, foundCertChain) +} + +func generateRandomCert(t *testing.T, parent *x509.Certificate, commonName string, + authorityKeyId, subjectKeyID []byte) *x509.Certificate { serial, err := rand.Int(rand.Reader, big.NewInt(20000)) require.NoError(t, err) - ca := &x509.Certificate{ + template := &x509.Certificate{ Subject: pkix.Name{ - CommonName: "Foo", + CommonName: commonName, }, + AuthorityKeyId: authorityKeyId, + SubjectKeyId: subjectKeyID, SerialNumber: serial, NotAfter: time.Now().Add(365 * 24 * time.Hour), IsCA: true, @@ -223,11 +307,15 @@ func generateRandomCert(t *testing.T) *x509.Certificate { BasicConstraintsValid: true, } + if parent == nil { + parent = template + } + key, err := rsa.GenerateKey(rand.Reader, 4096) require.NoError(t, err) csr := &key.PublicKey - certBytes, err := x509.CreateCertificate(rand.Reader, ca, ca, csr, key) + certBytes, err := x509.CreateCertificate(rand.Reader, template, parent, csr, key) require.NoError(t, err) cert, err := x509.ParseCertificate(certBytes) diff --git a/close_test.go b/close_test.go index c1acfc9..27761b8 100644 --- a/close_test.go +++ b/close_test.go @@ -85,7 +85,7 @@ func TestErrorAfterClosed(t *testing.T) { _, err = ctx.NewRandomReader() assert.Equal(t, errClosed, err) - cert := generateRandomCert(t) + cert := generateRandomCert(t, nil, "Foo", nil, nil) err = ctx.ImportCertificate(bytes, cert) assert.Equal(t, errClosed, err)