From 5ef4a2e5652a1d23f009d0bc6d7822437b2c4d9e Mon Sep 17 00:00:00 2001 From: Oleksandr Grytsov Date: Thu, 3 Jun 2021 18:14:02 +0300 Subject: [PATCH 1/3] Make private certificate functions generic Current private certificate functions implement limited functionality to find x509 certificate by id and/or label and/or serial. They can't be reused to implement more generic find certificates API. This commit contains following changes: * implement generic findCertificatesWithAttributes function which allows to find certificates by defined templates. This function returns slice of pkcs11.ObjectHandle. Handles can be used to retrieve any information from pkcs11 object; * reimplement findCertificate function to use generic certificate functions. Signed-off-by: Oleksandr Grytsov --- certificates.go | 90 ++++++++++++++++++++++++------------------------- 1 file changed, 44 insertions(+), 46 deletions(-) diff --git a/certificates.go b/certificates.go index e430ecc..b9c8225 100644 --- a/certificates.go +++ b/certificates.go @@ -31,26 +31,7 @@ import ( "github.com/pkg/errors" ) -// FindCertificate retrieves a previously imported certificate. Any combination of id, label -// and serial can be provided. An error is return if all are nil. func findCertificate(session *pkcs11Session, id []byte, label []byte, serial *big.Int) (cert *x509.Certificate, err error) { - - rawCertificate, err := findRawCertificate(session, id, label, serial) - if err != nil { - return nil, err - } - - if rawCertificate != nil { - cert, err = x509.ParseCertificate(rawCertificate) - if err != nil { - return nil, err - } - } - - return cert, err -} - -func findRawCertificate(session *pkcs11Session, id []byte, label []byte, serial *big.Int) (rawCertificate []byte, err error) { if id == nil && label == nil && serial == nil { return nil, errors.New("id, label and serial cannot all be nil") } @@ -72,6 +53,36 @@ func findRawCertificate(session *pkcs11Session, id []byte, label []byte, serial template = append(template, pkcs11.NewAttribute(pkcs11.CKA_SERIAL_NUMBER, derSerial)) } + handles, err := findCertificatesWithAttributes(session, template) + if err != nil { + return nil, err + } + + if len(handles) == 0 { + return nil, nil + } + + return getX509Certificate(session, handles[0]) +} + +func getX509Certificate(session *pkcs11Session, handle pkcs11.ObjectHandle) (cert *x509.Certificate, err error) { + attributes := []*pkcs11.Attribute{ + pkcs11.NewAttribute(pkcs11.CKA_VALUE, 0), + } + + if attributes, err = session.ctx.GetAttributeValue(session.handle, handle, attributes); err != nil { + return nil, err + } + + cert, err = x509.ParseCertificate(attributes[0].Value) + if err != nil { + return nil, err + } + + return cert, nil +} + +func findCertificatesWithAttributes(session *pkcs11Session, template []*pkcs11.Attribute) (handles []pkcs11.ObjectHandle, err error) { template = append(template, pkcs11.NewAttribute(pkcs11.CKA_CLASS, pkcs11.CKO_CERTIFICATE)) if err = session.ctx.FindObjectsInit(session.handle, template); err != nil { @@ -84,25 +95,20 @@ func findRawCertificate(session *pkcs11Session, id []byte, label []byte, serial } }() - handles, _, err := session.ctx.FindObjects(session.handle, 1) - if err != nil { - return nil, err - } - if len(handles) == 0 { - return nil, nil - } + for { + newhandles, _, err := session.ctx.FindObjects(session.handle, maxHandlePerFind) + if err != nil { + return nil, err + } - attributes := []*pkcs11.Attribute{ - pkcs11.NewAttribute(pkcs11.CKA_VALUE, 0), - } + if len(newhandles) == 0 { + break + } - if attributes, err = session.ctx.GetAttributeValue(session.handle, handles[0], attributes); err != nil { - return nil, err + handles = append(handles, newhandles...) } - rawCertificate = attributes[0].Value - - return + return handles, nil } // FindCertificate retrieves a previously imported certificate. Any combination of id, label @@ -258,9 +264,7 @@ func (c *Context) DeleteCertificate(id []byte, label []byte, serial *big.Int) er return errors.New("id, label and serial cannot all be nil") } - template := []*pkcs11.Attribute{ - pkcs11.NewAttribute(pkcs11.CKA_CLASS, pkcs11.CKO_CERTIFICATE), - } + var template []*pkcs11.Attribute if id != nil { template = append(template, pkcs11.NewAttribute(pkcs11.CKA_ID, id)) @@ -277,21 +281,15 @@ func (c *Context) DeleteCertificate(id []byte, label []byte, serial *big.Int) er } err := c.withSession(func(session *pkcs11Session) error { - err := session.ctx.FindObjectsInit(session.handle, template) - if err != nil { - return err - } - handles, _, err := session.ctx.FindObjects(session.handle, 1) - finalErr := session.ctx.FindObjectsFinal(session.handle) + handles, err := findCertificatesWithAttributes(session, template) if err != nil { return err } - if finalErr != nil { - return finalErr - } + if len(handles) == 0 { return nil } + return session.ctx.DestroyObject(session.handle, handles[0]) }) From 1bd9b578e5e76f09f2c10d68a2c69f4be1fe2e84 Mon Sep 17 00:00:00 2001 From: Oleksandr Grytsov Date: Fri, 4 Jun 2021 19:37:35 +0300 Subject: [PATCH 2/3] Implement FindCertificateChain function to find certificate chain Certificate chain is found by following algorithm: * find first certificate either by id or/and label or/and serial (same as existing FindCertificate does); * if issuer is not nil, find next certificate by CKA_SUBJECT (issuer should be equal subject); * if certificate with required subject not found then read all certificates and try to find next certificate by AuthorityKeyId (AuthorityKeyId should be equal to SubjectKeyId); * finding stops if last found certificate is selfsigned (issuer is nil or equals to subject). Signed-off-by: Oleksandr Grytsov --- certificates.go | 98 ++++++++++++++++++++++++++++++++++++++++ certificates_test.go | 104 +++++++++++++++++++++++++++++++++++++++---- close_test.go | 2 +- 3 files changed, 195 insertions(+), 9 deletions(-) diff --git a/certificates.go b/certificates.go index b9c8225..54a1cb0 100644 --- a/certificates.go +++ b/certificates.go @@ -22,6 +22,7 @@ package crypto11 import ( + "bytes" "crypto/tls" "crypto/x509" "encoding/asn1" @@ -111,6 +112,67 @@ func findCertificatesWithAttributes(session *pkcs11Session, template []*pkcs11.A return handles, nil } +func findCertificateByKeyID(session *pkcs11Session, keyID []byte) (cert *x509.Certificate, err error) { + handles, err := findCertificatesWithAttributes(session, nil) + if err != nil { + return nil, err + } + + for _, handle := range handles { + if cert, err = getX509Certificate(session, handle); err != nil { + return nil, err + } + + if bytes.Equal(cert.SubjectKeyId, keyID) { + return cert, nil + } + } + + return nil, errors.New("no certificate with required subject key ID found") +} + +func findCertificateChain(session *pkcs11Session, cert *x509.Certificate) (certs []*x509.Certificate, err error) { + if len(cert.RawIssuer) == 0 || bytes.Equal(cert.RawIssuer, cert.RawSubject) { + return nil, nil + } + + template := []*pkcs11.Attribute{pkcs11.NewAttribute(pkcs11.CKA_SUBJECT, cert.RawIssuer)} + + handles, err := findCertificatesWithAttributes(session, template) + if err != nil { + return nil, err + } + + if len(handles) == 0 { + if cert, err = findCertificateByKeyID(session, cert.AuthorityKeyId); err != nil { + return nil, err + } + } else { + if cert, err = getX509Certificate(session, handles[0]); err != nil { + return nil, err + } + } + + for _, foundCert := range certs { + if bytes.Equal(cert.RawSubject, foundCert.RawSubject) { + return certs, nil + } + } + + certs = append(certs, cert) + + certChain, err := findCertificateChain(session, cert) + if err != nil { + return nil, err + } + + if len(certChain) != 0 { + certs = append(certs, certChain...) + } + + return certs, nil +} + // FindCertificate retrieves a previously imported certificate. Any combination of id, label // and serial can be provided. An error is return if all are nil. func (c *Context) FindCertificate(id []byte, label []byte, serial *big.Int) (*x509.Certificate, error) { @@ -128,6 +190,42 @@ func (c *Context) FindCertificate(id []byte, label []byte, serial *big.Int) (*x5 return cert, err } +// FindCertificateChain retrieves a previously imported certificate chain. Any combination of id, label +// and serial can be provided. An error is return if all are nil. +func (c *Context) FindCertificateChain(id []byte, label []byte, serial *big.Int) (certs []*x509.Certificate, err error) { + if c.closed.Get() { + return nil, errClosed + } + + err = c.withSession(func(session *pkcs11Session) (err error) { + cert, err := findCertificate(session, id, label, serial) + if err != nil { + return err + } + + if cert == nil { + return nil + } + + certs = append(certs, cert) + + certChain, err := findCertificateChain(session, cert) + if err != nil { + return err + } + + if len(certChain) == 0 { + return nil + } + + certs = append(certs, certChain...) + + return nil + }) + + return certs, err +} + func (c *Context) FindAllPairedCertificates() (certificates []tls.Certificate, err error) { if c.closed.Get() { return nil, errClosed diff --git a/certificates_test.go b/certificates_test.go index 31aa856..69c43dc 100644 --- a/certificates_test.go +++ b/certificates_test.go @@ -48,7 +48,7 @@ func TestCertificate(t *testing.T) { id := randomBytes() label := randomBytes() - cert := generateRandomCert(t) + cert := generateRandomCert(t, nil, "Foo", nil, nil) err = ctx.ImportCertificateWithLabel(id, label, cert) require.NoError(t, err) @@ -81,7 +81,7 @@ func TestCertificateAttributes(t *testing.T) { require.NoError(t, ctx.Close()) }() - cert := generateRandomCert(t) + cert := generateRandomCert(t, nil, "Foo", nil, nil) // We import this with a different serial number, to test this is obeyed ourSerial := new(big.Int) @@ -116,7 +116,7 @@ func TestCertificateRequiredArgs(t *testing.T) { require.NoError(t, ctx.Close()) }() - cert := generateRandomCert(t) + cert := generateRandomCert(t, nil, "Foo", nil, nil) val := randomBytes() @@ -143,7 +143,7 @@ func TestDeleteCertificate(t *testing.T) { randomCert := func() ([]byte, []byte, *x509.Certificate) { id := randomBytes() label := randomBytes() - cert := generateRandomCert(t) + cert := generateRandomCert(t, nil, "Foo", nil, nil) return id, label, cert } importCertificate := func() ([]byte, []byte, *big.Int) { @@ -207,14 +207,98 @@ func TestDeleteCertificate(t *testing.T) { require.Nil(t, cert) } -func generateRandomCert(t *testing.T) *x509.Certificate { +func TestCertificateChain(t *testing.T) { + skipTest(t, skipTestCert) + + ctx, err := ConfigureFromFile("config") + require.NoError(t, err) + + defer func() { + require.NoError(t, ctx.Close()) + }() + + certNames := []string{"Cert0", "Cert1", "Cert2"} + + var ( + parent *x509.Certificate + originCertChain []*x509.Certificate + authorityKeyId, subjectKeyID []byte + ids [][]byte + ) + + for _, name := range certNames { + subjectKeyID = randomBytes() + + cert := generateRandomCert(t, parent, name, authorityKeyId, subjectKeyID) + + id := randomBytes() + ids = append([][]byte{id}, ids...) + + err = ctx.ImportCertificate(id, cert) + require.NoError(t, err) + + originCertChain = append([]*x509.Certificate{cert}, originCertChain...) + + parent = cert + authorityKeyId = subjectKeyID + } + + foundCertChain, err := ctx.FindCertificateChain(ids[0], nil, nil) + require.NoError(t, err) + require.NotNil(t, foundCertChain) + + assert.Equal(t, len(foundCertChain), len(originCertChain)) + + for i := 0; i < len(foundCertChain); i++ { + assert.Equal(t, foundCertChain[i].Signature, originCertChain[i].Signature) + } + + err = ctx.DeleteCertificate(ids[len(ids)-1], nil, nil) + require.NoError(t, err) + + oldCert := originCertChain[len(originCertChain)-1] + newCert := generateRandomCert(t, nil, "NewCert", oldCert.AuthorityKeyId, oldCert.SubjectKeyId) + + originCertChain[len(originCertChain)-1] = newCert + + id := randomBytes() + + err = ctx.ImportCertificate(id, newCert) + require.NoError(t, err) + + ids[len(ids)-1] = id + + foundCertChain, err = ctx.FindCertificateChain(ids[0], nil, nil) + require.NoError(t, err) + require.NotNil(t, foundCertChain) + + assert.Equal(t, len(foundCertChain), len(originCertChain)) + + for i := 0; i < len(foundCertChain); i++ { + assert.Equal(t, foundCertChain[i].Signature, originCertChain[i].Signature) + } + + for _, id := range ids { + err = ctx.DeleteCertificate(id, nil, nil) + require.NoError(t, err) + } + + foundCertChain, err = ctx.FindCertificateChain([]byte("test2"), nil, nil) + require.NoError(t, err) + assert.Nil(t, foundCertChain) +} + +func generateRandomCert(t *testing.T, parent *x509.Certificate, commonName string, + authorityKeyId, subjectKeyID []byte) *x509.Certificate { serial, err := rand.Int(rand.Reader, big.NewInt(20000)) require.NoError(t, err) - ca := &x509.Certificate{ + template := &x509.Certificate{ Subject: pkix.Name{ - CommonName: "Foo", + CommonName: commonName, }, + AuthorityKeyId: authorityKeyId, + SubjectKeyId: subjectKeyID, SerialNumber: serial, NotAfter: time.Now().Add(365 * 24 * time.Hour), IsCA: true, @@ -223,11 +307,15 @@ func generateRandomCert(t *testing.T) *x509.Certificate { BasicConstraintsValid: true, } + if parent == nil { + parent = template + } + key, err := rsa.GenerateKey(rand.Reader, 4096) require.NoError(t, err) csr := &key.PublicKey - certBytes, err := x509.CreateCertificate(rand.Reader, ca, ca, csr, key) + certBytes, err := x509.CreateCertificate(rand.Reader, template, parent, csr, key) require.NoError(t, err) cert, err := x509.ParseCertificate(certBytes) diff --git a/close_test.go b/close_test.go index c1acfc9..27761b8 100644 --- a/close_test.go +++ b/close_test.go @@ -85,7 +85,7 @@ func TestErrorAfterClosed(t *testing.T) { _, err = ctx.NewRandomReader() assert.Equal(t, errClosed, err) - cert := generateRandomCert(t) + cert := generateRandomCert(t, nil, "Foo", nil, nil) err = ctx.ImportCertificate(bytes, cert) assert.Equal(t, errClosed, err) From 7564f425b7042afcbbc3fceed0f3f94c0e7fca59 Mon Sep 17 00:00:00 2001 From: Oleksandr Grytsov Date: Thu, 17 Feb 2022 18:19:19 +0200 Subject: [PATCH 3/3] Add FindCertificateWithAttributes, DeleteCertificateWithAttributes Add methods to find and delete certificates with custom attributes. Signed-off-by: Oleksandr Grytsov --- certificates.go | 55 +++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 49 insertions(+), 6 deletions(-) diff --git a/certificates.go b/certificates.go index 54a1cb0..6029a86 100644 --- a/certificates.go +++ b/certificates.go @@ -190,6 +190,33 @@ func (c *Context) FindCertificate(id []byte, label []byte, serial *big.Int) (*x5 return cert, err } +// FindCertificateWithAttributes retrieves a previously imported certificate with selected attributes. +func (c *Context) FindCertificateWithAttributes(template AttributeSet) (*x509.Certificate, error) { + if c.closed.Get() { + return nil, errClosed + } + + var cert *x509.Certificate + err := c.withSession(func(session *pkcs11Session) (err error) { + handles, err := findCertificatesWithAttributes(session, template.ToSlice()) + if err != nil { + return err + } + + if len(handles) == 0 { + return nil + } + + if cert, err = getX509Certificate(session, handles[0]); err != nil { + return err + } + + return nil + }) + + return cert, err +} + // FindCertificateChain retrieves a previously imported certificate chain. Any combination of id, label // and serial can be provided. An error is return if all are nil. func (c *Context) FindCertificateChain(id []byte, label []byte, serial *big.Int) (certs []*x509.Certificate, err error) { @@ -362,24 +389,40 @@ func (c *Context) DeleteCertificate(id []byte, label []byte, serial *big.Int) er return errors.New("id, label and serial cannot all be nil") } - var template []*pkcs11.Attribute + template := NewAttributeSet() if id != nil { - template = append(template, pkcs11.NewAttribute(pkcs11.CKA_ID, id)) + if err := template.Set(pkcs11.CKA_ID, id); err != nil { + return err + } } if label != nil { - template = append(template, pkcs11.NewAttribute(pkcs11.CKA_LABEL, label)) + if err := template.Set(pkcs11.CKA_LABEL, label); err != nil { + return err + } } if serial != nil { asn1Serial, err := asn1.Marshal(serial) if err != nil { return err } - template = append(template, pkcs11.NewAttribute(pkcs11.CKA_SERIAL_NUMBER, asn1Serial)) + if err := template.Set(pkcs11.CKA_SERIAL_NUMBER, asn1Serial); err != nil { + return err + } + } + + return c.DeleteCertificateWithAttributes(template) +} + +// DeleteCertificateWithAttributes destroys a previously imported certificate by selected attributes. +// It will return nil if succeeds or if the certificate does not exist. +func (c *Context) DeleteCertificateWithAttributes(template AttributeSet) error { + if c.closed.Get() { + return errClosed } - err := c.withSession(func(session *pkcs11Session) error { - handles, err := findCertificatesWithAttributes(session, template) + err := c.withSession(func(session *pkcs11Session) (err error) { + handles, err := findCertificatesWithAttributes(session, template.ToSlice()) if err != nil { return err }