Skip to content

Commit

Permalink
Refactoring crypto code for future reuse. (#25148)
Browse files Browse the repository at this point in the history
Refactoring crypto code for future reuse for #24869. No functional
changes.
  • Loading branch information
getvictor authored Jan 7, 2025
1 parent 721b732 commit cbe44ee
Show file tree
Hide file tree
Showing 18 changed files with 170 additions and 173 deletions.
2 changes: 1 addition & 1 deletion pkg/mdm/mdmtest/apple.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ import (
"github.com/fleetdm/fleet/v4/server/fleet"
apple_mdm "github.com/fleetdm/fleet/v4/server/mdm/apple"
"github.com/fleetdm/fleet/v4/server/mdm/nanomdm/mdm"
"github.com/fleetdm/fleet/v4/server/mdm/scep/cryptoutil/x509util"
scepserver "github.com/fleetdm/fleet/v4/server/mdm/scep/server"
"github.com/fleetdm/fleet/v4/server/mdm/scep/x509util"
httptransport "github.com/go-kit/kit/transport/http"
"github.com/go-kit/log"
kitlog "github.com/go-kit/log"
Expand Down
68 changes: 68 additions & 0 deletions server/mdm/cryptoutil/cryptoutil.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package cryptoutil

import (
"crypto"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/elliptic"
"crypto/rsa"
"crypto/sha256"
"crypto/x509"
"encoding/asn1"
"encoding/pem"
"errors"
"fmt"
)

// GenerateSubjectKeyID generates Subject Key Identifier (SKI) using SHA-256
// hash of the public key bytes according to RFC 7093 section 2.
func GenerateSubjectKeyID(pub crypto.PublicKey) ([]byte, error) {
var pubBytes []byte
var err error
switch pub := pub.(type) {
case *rsa.PublicKey:
pubBytes, err = asn1.Marshal(*pub)
if err != nil {
return nil, err
}
case *ecdsa.PublicKey:
pubBytes = elliptic.Marshal(pub.Curve, pub.X, pub.Y)
default:
return nil, errors.New("only ECDSA and RSA public keys are supported")
}

hash := sha256.Sum256(pubBytes)

// According to RFC 7093, The keyIdentifier is composed of the leftmost
// 160-bits of the SHA-256 hash of the value of the BIT STRING
// subjectPublicKey (excluding the tag, length, and number of unused bits).
return hash[:20], nil
}

// ParsePrivateKey parses a PEM encoded private key and returns a crypto.PrivateKey.
// It can be used for private keys passed in from environment variables or command line or files.
func ParsePrivateKey(privKeyPEM []byte, keyName string) (crypto.PrivateKey, error) {
block, _ := pem.Decode(privKeyPEM)
if block == nil {
return nil, fmt.Errorf("failed to decode %s", keyName)
}

// The code below is based on tls.parsePrivateKey
// https://cs.opensource.google/go/go/+/release-branch.go1.23:src/crypto/tls/tls.go;l=355-372
if key, err := x509.ParsePKCS1PrivateKey(block.Bytes); err == nil {
return key, nil
}
if key, err := x509.ParsePKCS8PrivateKey(block.Bytes); err == nil {
switch key := key.(type) {
case *rsa.PrivateKey, *ecdsa.PrivateKey, ed25519.PrivateKey:
return key, nil
default:
return nil, fmt.Errorf("unmarshaled PKCS8 %s is not an RSA, ECDSA, or Ed25519 private key", keyName)
}
}
if key, err := x509.ParseECPrivateKey(block.Bytes); err == nil {
return key, nil
}

return nil, fmt.Errorf("failed to parse %s of type %s", keyName, block.Type)
}
94 changes: 94 additions & 0 deletions server/mdm/cryptoutil/cryptoutil_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
package cryptoutil

import (
"crypto"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
"math/big"
"os"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestGenerateSubjectKeyID(t *testing.T) {
ecKey, err := ecdsa.GenerateKey(elliptic.P224(), rand.Reader)
if err != nil {
t.Fatal(err)
}
for _, test := range []struct {
testName string
pub crypto.PublicKey
}{
{"RSA", &rsa.PublicKey{N: big.NewInt(123), E: 65537}},
{"ECDSA", ecKey.Public()},
} {
test := test
t.Run(test.testName, func(t *testing.T) {
t.Parallel()
ski, err := GenerateSubjectKeyID(test.pub)
if err != nil {
t.Fatal(err)
}
if len(ski) != 20 {
t.Fatalf("unexpected subject public key identifier length: %d", len(ski))
}
ski2, err := GenerateSubjectKeyID(test.pub)
if err != nil {
t.Fatal(err)
}
if !testSKIEq(ski, ski2) {
t.Fatal("subject key identifier generation is not deterministic")
}
})
}
}

func testSKIEq(a, b []byte) bool {
if len(a) != len(b) {
return false
}

for i := range a {
if a[i] != b[i] {
return false
}
}

return true
}

func TestParsePrivateKey(t *testing.T) {
t.Parallel()
// nil block not allowed
_, err := ParsePrivateKey(nil, "APNS private key")
assert.ErrorContains(t, err, "failed to decode")

// encrypted pkcs8 not supported
pkcs8Encrypted, err := os.ReadFile("testdata/pkcs8-encrypted.key")
require.NoError(t, err)
_, err = ParsePrivateKey(pkcs8Encrypted, "APNS private key")
assert.ErrorContains(t, err, "failed to parse APNS private key of type ENCRYPTED PRIVATE KEY")

// X25519 pkcs8 not supported
pkcs8Encrypted, err = os.ReadFile("testdata/pkcs8-x25519.key")
require.NoError(t, err)
_, err = ParsePrivateKey(pkcs8Encrypted, "APNS private key")
assert.ErrorContains(t, err, "unmarshaled PKCS8 APNS private key is not")

// In this test, the pkcs1 key and pkcs8 keys are the same key, just different formats
pkcs1, err := os.ReadFile("testdata/pkcs1.key")
require.NoError(t, err)
pkcs1Key, err := ParsePrivateKey(pkcs1, "APNS private key")
require.NoError(t, err)

pkcs8, err := os.ReadFile("testdata/pkcs8-rsa.key")
require.NoError(t, err)
pkcs8Key, err := ParsePrivateKey(pkcs8, "APNS private key")
require.NoError(t, err)

assert.Equal(t, pkcs1Key, pkcs8Key)
}
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
2 changes: 1 addition & 1 deletion server/mdm/scep/cmd/scepclient/csr.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (
"io/ioutil"
"os"

"github.com/fleetdm/fleet/v4/server/mdm/scep/cryptoutil/x509util"
"github.com/fleetdm/fleet/v4/server/mdm/scep/x509util"
)

const (
Expand Down
36 changes: 0 additions & 36 deletions server/mdm/scep/cryptoutil/cryptoutil.go

This file was deleted.

58 changes: 0 additions & 58 deletions server/mdm/scep/cryptoutil/cryptoutil_test.go

This file was deleted.

2 changes: 1 addition & 1 deletion server/mdm/scep/depot/cacert.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (
"math/big"
"time"

"github.com/fleetdm/fleet/v4/server/mdm/scep/cryptoutil"
"github.com/fleetdm/fleet/v4/server/mdm/cryptoutil"
)

// CACert represents a new self-signed CA certificate
Expand Down
2 changes: 1 addition & 1 deletion server/mdm/scep/depot/signer.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import (
"crypto/x509"
"time"

"github.com/fleetdm/fleet/v4/server/mdm/scep/cryptoutil"
"github.com/fleetdm/fleet/v4/server/mdm/cryptoutil"
"github.com/smallstep/scep"
)

Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
35 changes: 3 additions & 32 deletions server/service/mdm.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,7 @@ import (
"bytes"
"context"
"crypto"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"encoding/json"
"encoding/pem"
"errors"
Expand All @@ -34,6 +30,7 @@ import (
"github.com/fleetdm/fleet/v4/server/mdm"
apple_mdm "github.com/fleetdm/fleet/v4/server/mdm/apple"
"github.com/fleetdm/fleet/v4/server/mdm/assets"
"github.com/fleetdm/fleet/v4/server/mdm/cryptoutil"
nanomdm "github.com/fleetdm/fleet/v4/server/mdm/nanomdm/mdm"
"github.com/fleetdm/fleet/v4/server/ptr"
"github.com/go-kit/log/level"
Expand Down Expand Up @@ -2496,10 +2493,9 @@ func (svc *Service) GetMDMAppleCSR(ctx context.Context) ([]byte, error) {
}
} else {
rawApnsKey := savedAssets[fleet.MDMAssetAPNSKey]
block, _ := pem.Decode(rawApnsKey.Value)
apnsKey, err = parseAPNSPrivateKey(ctx, block)
apnsKey, err = cryptoutil.ParsePrivateKey(rawApnsKey.Value, "APNS private key")
if err != nil {
return nil, err
return nil, ctxerr.Wrap(ctx, err, "parse APNS private key")
}
}

Expand Down Expand Up @@ -2546,31 +2542,6 @@ func (svc *Service) GetMDMAppleCSR(ctx context.Context) ([]byte, error) {
return signedCSRB64, nil
}

func parseAPNSPrivateKey(ctx context.Context, block *pem.Block) (crypto.PrivateKey, error) {
if block == nil {
return nil, ctxerr.New(ctx, "failed to decode saved APNS key")
}

// The code below is based on tls.parsePrivateKey
// https://cs.opensource.google/go/go/+/release-branch.go1.23:src/crypto/tls/tls.go;l=355-372
if key, err := x509.ParsePKCS1PrivateKey(block.Bytes); err == nil {
return key, nil
}
if key, err := x509.ParsePKCS8PrivateKey(block.Bytes); err == nil {
switch key := key.(type) {
case *rsa.PrivateKey, *ecdsa.PrivateKey, ed25519.PrivateKey:
return key, nil
default:
return nil, errors.New("unmarshaled PKCS8 APNS key is not an RSA, ECDSA, or Ed25519 private key")
}
}
if key, err := x509.ParseECPrivateKey(block.Bytes); err == nil {
return key, nil
}

return nil, ctxerr.New(ctx, fmt.Sprintf("failed to parse APNS private key of type %s", block.Type))
}

////////////////////////////////////////////////////////////////////////////////
// POST /mdm/apple/apns_certificate
////////////////////////////////////////////////////////////////////////////////
Expand Down
Loading

0 comments on commit cbe44ee

Please sign in to comment.