diff --git a/server/mdm/scep/cryptoutil/cryptoutil.go b/server/mdm/scep/cryptoutil/cryptoutil.go index 6512c6154cc5..6b0dae05a6bd 100644 --- a/server/mdm/scep/cryptoutil/cryptoutil.go +++ b/server/mdm/scep/cryptoutil/cryptoutil.go @@ -1,13 +1,19 @@ package cryptoutil import ( + "context" "crypto" "crypto/ecdsa" + "crypto/ed25519" "crypto/elliptic" "crypto/rsa" "crypto/sha256" + "crypto/x509" "encoding/asn1" + "encoding/pem" "errors" + + "github.com/fleetdm/fleet/v4/server/contexts/ctxerr" ) // GenerateSubjectKeyID generates Subject Key Identifier (SKI) using SHA-256 @@ -34,3 +40,29 @@ func GenerateSubjectKeyID(pub crypto.PublicKey) ([]byte, error) { // subjectPublicKey (excluding the tag, length, and number of unused bits). return hash[:20], nil } + +func ParsePrivateKey(ctx context.Context, privKeyPEM []byte, keyName string) (crypto.PrivateKey, error) { + block, _ := pem.Decode(privKeyPEM) + if block == nil { + return nil, ctxerr.Errorf(ctx, "failed to decode %s", keyName) + } + + // The code below is based on tls.parsePrivateKey + // https://cs.opensource.google/go/go/+/release-branch.go1.23:src/crypto/tls/tls.go;l=355-372 + if key, err := x509.ParsePKCS1PrivateKey(block.Bytes); err == nil { + return key, nil + } + if key, err := x509.ParsePKCS8PrivateKey(block.Bytes); err == nil { + switch key := key.(type) { + case *rsa.PrivateKey, *ecdsa.PrivateKey, ed25519.PrivateKey: + return key, nil + default: + return nil, ctxerr.Errorf(ctx, "unmarshaled PKCS8 %s is not an RSA, ECDSA, or Ed25519 private key", keyName) + } + } + if key, err := x509.ParseECPrivateKey(block.Bytes); err == nil { + return key, nil + } + + return nil, ctxerr.Errorf(ctx, "failed to parse %s of type %s", keyName, block.Type) +} diff --git a/server/mdm/scep/cryptoutil/cryptoutil_test.go b/server/mdm/scep/cryptoutil/cryptoutil_test.go index 53a73ee9b36f..bd8859056524 100644 --- a/server/mdm/scep/cryptoutil/cryptoutil_test.go +++ b/server/mdm/scep/cryptoutil/cryptoutil_test.go @@ -1,13 +1,18 @@ package cryptoutil import ( + "context" "crypto" "crypto/ecdsa" "crypto/elliptic" "crypto/rand" "crypto/rsa" "math/big" + "os" "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestGenerateSubjectKeyID(t *testing.T) { @@ -56,3 +61,36 @@ func testSKIEq(a, b []byte) bool { return true } + +func TestParsePrivateKey(t *testing.T) { + t.Parallel() + // nil block not allowed + ctx := context.Background() + _, err := ParsePrivateKey(ctx, nil, "APNS private key") + assert.ErrorContains(t, err, "failed to decode") + + // encrypted pkcs8 not supported + pkcs8Encrypted, err := os.ReadFile("testdata/pkcs8-encrypted.key") + require.NoError(t, err) + _, err = ParsePrivateKey(ctx, pkcs8Encrypted, "APNS private key") + assert.ErrorContains(t, err, "failed to parse APNS private key of type ENCRYPTED PRIVATE KEY") + + // X25519 pkcs8 not supported + pkcs8Encrypted, err = os.ReadFile("testdata/pkcs8-x25519.key") + require.NoError(t, err) + _, err = ParsePrivateKey(ctx, pkcs8Encrypted, "APNS private key") + assert.ErrorContains(t, err, "unmarshaled PKCS8 APNS private key is not") + + // In this test, the pkcs1 key and pkcs8 keys are the same key, just different formats + pkcs1, err := os.ReadFile("testdata/pkcs1.key") + require.NoError(t, err) + pkcs1Key, err := ParsePrivateKey(ctx, pkcs1, "APNS private key") + require.NoError(t, err) + + pkcs8, err := os.ReadFile("testdata/pkcs8-rsa.key") + require.NoError(t, err) + pkcs8Key, err := ParsePrivateKey(ctx, pkcs8, "APNS private key") + require.NoError(t, err) + + assert.Equal(t, pkcs1Key, pkcs8Key) +} diff --git a/server/service/testdata/pkcs1.key b/server/mdm/scep/cryptoutil/testdata/pkcs1.key similarity index 100% rename from server/service/testdata/pkcs1.key rename to server/mdm/scep/cryptoutil/testdata/pkcs1.key diff --git a/server/service/testdata/pkcs8-encrypted.key b/server/mdm/scep/cryptoutil/testdata/pkcs8-encrypted.key similarity index 100% rename from server/service/testdata/pkcs8-encrypted.key rename to server/mdm/scep/cryptoutil/testdata/pkcs8-encrypted.key diff --git a/server/service/testdata/pkcs8-rsa.key b/server/mdm/scep/cryptoutil/testdata/pkcs8-rsa.key similarity index 100% rename from server/service/testdata/pkcs8-rsa.key rename to server/mdm/scep/cryptoutil/testdata/pkcs8-rsa.key diff --git a/server/service/testdata/pkcs8-x25519.key b/server/mdm/scep/cryptoutil/testdata/pkcs8-x25519.key similarity index 100% rename from server/service/testdata/pkcs8-x25519.key rename to server/mdm/scep/cryptoutil/testdata/pkcs8-x25519.key diff --git a/server/service/mdm.go b/server/service/mdm.go index 406c79441777..ceb7e6c3da23 100644 --- a/server/service/mdm.go +++ b/server/service/mdm.go @@ -35,6 +35,7 @@ import ( apple_mdm "github.com/fleetdm/fleet/v4/server/mdm/apple" "github.com/fleetdm/fleet/v4/server/mdm/assets" nanomdm "github.com/fleetdm/fleet/v4/server/mdm/nanomdm/mdm" + "github.com/fleetdm/fleet/v4/server/mdm/scep/cryptoutil" "github.com/fleetdm/fleet/v4/server/ptr" "github.com/go-kit/log/level" "github.com/go-sql-driver/mysql" @@ -2496,8 +2497,7 @@ func (svc *Service) GetMDMAppleCSR(ctx context.Context) ([]byte, error) { } } else { rawApnsKey := savedAssets[fleet.MDMAssetAPNSKey] - block, _ := pem.Decode(rawApnsKey.Value) - apnsKey, err = parseAPNSPrivateKey(ctx, block) + apnsKey, err = cryptoutil.ParsePrivateKey(ctx, rawApnsKey.Value, "APNS private key") if err != nil { return nil, err } diff --git a/server/service/mdm_test.go b/server/service/mdm_test.go index d7cf53fca5d3..aa978e37b68d 100644 --- a/server/service/mdm_test.go +++ b/server/service/mdm_test.go @@ -8,7 +8,6 @@ import ( "crypto/x509" "crypto/x509/pkix" "database/sql" - "encoding/pem" "errors" "math/big" "net/http" @@ -2185,44 +2184,3 @@ func TestBatchSetMDMProfilesLabels(t *testing.T) { assert.Equal(t, ProfileLabels{IncludeAny: true}, *profileLabels["DIncAny"]) assert.Equal(t, ProfileLabels{ExcludeAny: true}, *profileLabels["DExclAny"]) } - -func TestParseAPNSPrivateKey(t *testing.T) { - t.Parallel() - // nil block not allowed - ctx := context.Background() - _, err := parseAPNSPrivateKey(ctx, nil) - assert.ErrorContains(t, err, "failed to decode") - - // encrypted pkcs8 not supported - pkcs8Encrypted, err := os.ReadFile("testdata/pkcs8-encrypted.key") - require.NoError(t, err) - block, _ := pem.Decode(pkcs8Encrypted) - assert.NotNil(t, block) - _, err = parseAPNSPrivateKey(ctx, block) - assert.ErrorContains(t, err, "failed to parse APNS private key of type ENCRYPTED PRIVATE KEY") - - // X25519 pkcs8 not supported - pkcs8Encrypted, err = os.ReadFile("testdata/pkcs8-x25519.key") - require.NoError(t, err) - block, _ = pem.Decode(pkcs8Encrypted) - assert.NotNil(t, block) - _, err = parseAPNSPrivateKey(ctx, block) - assert.ErrorContains(t, err, "unmarshaled PKCS8 APNS key is not") - - // In this test, the pkcs1 key and pkcs8 keys are the same key, just different formats - pkcs1, err := os.ReadFile("testdata/pkcs1.key") - require.NoError(t, err) - block, _ = pem.Decode(pkcs1) - assert.NotNil(t, block) - pkcs1Key, err := parseAPNSPrivateKey(ctx, block) - require.NoError(t, err) - - pkcs8, err := os.ReadFile("testdata/pkcs8-rsa.key") - require.NoError(t, err) - block, _ = pem.Decode(pkcs8) - assert.NotNil(t, block) - pkcs8Key, err := parseAPNSPrivateKey(ctx, block) - require.NoError(t, err) - - assert.Equal(t, pkcs1Key, pkcs8Key) -}