Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add finding certificate chain #83

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
223 changes: 181 additions & 42 deletions certificates.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
package crypto11

import (
"bytes"
"crypto/tls"
"crypto/x509"
"encoding/asn1"
Expand All @@ -31,26 +32,7 @@ import (
"github.com/pkg/errors"
)

// FindCertificate retrieves a previously imported certificate. Any combination of id, label
// and serial can be provided. An error is return if all are nil.
func findCertificate(session *pkcs11Session, id []byte, label []byte, serial *big.Int) (cert *x509.Certificate, err error) {

rawCertificate, err := findRawCertificate(session, id, label, serial)
if err != nil {
return nil, err
}

if rawCertificate != nil {
cert, err = x509.ParseCertificate(rawCertificate)
if err != nil {
return nil, err
}
}

return cert, err
}

func findRawCertificate(session *pkcs11Session, id []byte, label []byte, serial *big.Int) (rawCertificate []byte, err error) {
if id == nil && label == nil && serial == nil {
return nil, errors.New("id, label and serial cannot all be nil")
}
Expand All @@ -72,6 +54,36 @@ func findRawCertificate(session *pkcs11Session, id []byte, label []byte, serial
template = append(template, pkcs11.NewAttribute(pkcs11.CKA_SERIAL_NUMBER, derSerial))
}

handles, err := findCertificatesWithAttributes(session, template)
if err != nil {
return nil, err
}

if len(handles) == 0 {
return nil, nil
}

return getX509Certificate(session, handles[0])
}

func getX509Certificate(session *pkcs11Session, handle pkcs11.ObjectHandle) (cert *x509.Certificate, err error) {
attributes := []*pkcs11.Attribute{
pkcs11.NewAttribute(pkcs11.CKA_VALUE, 0),
}

if attributes, err = session.ctx.GetAttributeValue(session.handle, handle, attributes); err != nil {
return nil, err
}

cert, err = x509.ParseCertificate(attributes[0].Value)
if err != nil {
return nil, err
}

return cert, nil
}

func findCertificatesWithAttributes(session *pkcs11Session, template []*pkcs11.Attribute) (handles []pkcs11.ObjectHandle, err error) {
template = append(template, pkcs11.NewAttribute(pkcs11.CKA_CLASS, pkcs11.CKO_CERTIFICATE))

if err = session.ctx.FindObjectsInit(session.handle, template); err != nil {
Expand All @@ -84,25 +96,81 @@ func findRawCertificate(session *pkcs11Session, id []byte, label []byte, serial
}
}()

handles, _, err := session.ctx.FindObjects(session.handle, 1)
for {
newhandles, _, err := session.ctx.FindObjects(session.handle, maxHandlePerFind)
if err != nil {
return nil, err
}

if len(newhandles) == 0 {
break
}

handles = append(handles, newhandles...)
}

return handles, nil
}

func findCertificateByKeyID(session *pkcs11Session, keyID []byte) (cert *x509.Certificate, err error) {
handles, err := findCertificatesWithAttributes(session, nil)
if err != nil {
return nil, err
}
if len(handles) == 0 {

for _, handle := range handles {
if cert, err = getX509Certificate(session, handle); err != nil {
return nil, err
}

if bytes.Equal(cert.SubjectKeyId, keyID) {
return cert, nil
}
}

return nil, errors.New("no certificate with required subject key ID found")
}

func findCertificateChain(session *pkcs11Session, cert *x509.Certificate) (certs []*x509.Certificate, err error) {
if len(cert.RawIssuer) == 0 || bytes.Equal(cert.RawIssuer, cert.RawSubject) {
return nil, nil
}

attributes := []*pkcs11.Attribute{
pkcs11.NewAttribute(pkcs11.CKA_VALUE, 0),
template := []*pkcs11.Attribute{pkcs11.NewAttribute(pkcs11.CKA_SUBJECT, cert.RawIssuer)}

handles, err := findCertificatesWithAttributes(session, template)
if err != nil {
return nil, err
}

if len(handles) == 0 {
if cert, err = findCertificateByKeyID(session, cert.AuthorityKeyId); err != nil {
return nil, err
}
} else {
if cert, err = getX509Certificate(session, handles[0]); err != nil {
return nil, err
}
}

if attributes, err = session.ctx.GetAttributeValue(session.handle, handles[0], attributes); err != nil {
for _, foundCert := range certs {
if bytes.Equal(cert.RawSubject, foundCert.RawSubject) {
return certs, nil
}
}

certs = append(certs, cert)

certChain, err := findCertificateChain(session, cert)
if err != nil {
return nil, err
}

rawCertificate = attributes[0].Value
if len(certChain) != 0 {
certs = append(certs, certChain...)
}

return
return certs, nil
}

// FindCertificate retrieves a previously imported certificate. Any combination of id, label
Expand All @@ -122,6 +190,69 @@ func (c *Context) FindCertificate(id []byte, label []byte, serial *big.Int) (*x5
return cert, err
}

// FindCertificateWithAttributes retrieves a previously imported certificate with selected attributes.
func (c *Context) FindCertificateWithAttributes(template AttributeSet) (*x509.Certificate, error) {
if c.closed.Get() {
return nil, errClosed
}

var cert *x509.Certificate
err := c.withSession(func(session *pkcs11Session) (err error) {
handles, err := findCertificatesWithAttributes(session, template.ToSlice())
if err != nil {
return err
}

if len(handles) == 0 {
return nil
}

if cert, err = getX509Certificate(session, handles[0]); err != nil {
return err
}

return nil
})

return cert, err
}

// FindCertificateChain retrieves a previously imported certificate chain. Any combination of id, label
// and serial can be provided. An error is return if all are nil.
func (c *Context) FindCertificateChain(id []byte, label []byte, serial *big.Int) (certs []*x509.Certificate, err error) {
if c.closed.Get() {
return nil, errClosed
}

err = c.withSession(func(session *pkcs11Session) (err error) {
cert, err := findCertificate(session, id, label, serial)
if err != nil {
return err
}

if cert == nil {
return nil
}

certs = append(certs, cert)

certChain, err := findCertificateChain(session, cert)
if err != nil {
return err
}

if len(certChain) == 0 {
return nil
}

certs = append(certs, certChain...)

return nil
})

return certs, err
}

func (c *Context) FindAllPairedCertificates() (certificates []tls.Certificate, err error) {
if c.closed.Get() {
return nil, errClosed
Expand Down Expand Up @@ -258,40 +389,48 @@ func (c *Context) DeleteCertificate(id []byte, label []byte, serial *big.Int) er
return errors.New("id, label and serial cannot all be nil")
}

template := []*pkcs11.Attribute{
pkcs11.NewAttribute(pkcs11.CKA_CLASS, pkcs11.CKO_CERTIFICATE),
}
template := NewAttributeSet()

if id != nil {
template = append(template, pkcs11.NewAttribute(pkcs11.CKA_ID, id))
if err := template.Set(pkcs11.CKA_ID, id); err != nil {
return err
}
}
if label != nil {
template = append(template, pkcs11.NewAttribute(pkcs11.CKA_LABEL, label))
if err := template.Set(pkcs11.CKA_LABEL, label); err != nil {
return err
}
}
if serial != nil {
asn1Serial, err := asn1.Marshal(serial)
if err != nil {
return err
}
template = append(template, pkcs11.NewAttribute(pkcs11.CKA_SERIAL_NUMBER, asn1Serial))
}

err := c.withSession(func(session *pkcs11Session) error {
err := session.ctx.FindObjectsInit(session.handle, template)
if err != nil {
if err := template.Set(pkcs11.CKA_SERIAL_NUMBER, asn1Serial); err != nil {
return err
}
handles, _, err := session.ctx.FindObjects(session.handle, 1)
finalErr := session.ctx.FindObjectsFinal(session.handle)
}

return c.DeleteCertificateWithAttributes(template)
}

// DeleteCertificateWithAttributes destroys a previously imported certificate by selected attributes.
// It will return nil if succeeds or if the certificate does not exist.
func (c *Context) DeleteCertificateWithAttributes(template AttributeSet) error {
if c.closed.Get() {
return errClosed
}

err := c.withSession(func(session *pkcs11Session) (err error) {
handles, err := findCertificatesWithAttributes(session, template.ToSlice())
if err != nil {
return err
}
if finalErr != nil {
return finalErr
}

if len(handles) == 0 {
return nil
}

return session.ctx.DestroyObject(session.handle, handles[0])
})

Expand Down
Loading