Skip to content

Commit

Permalink
Refactoring crypto code for future reuse.
Browse files Browse the repository at this point in the history
  • Loading branch information
getvictor committed Jan 3, 2025
1 parent 4c463b6 commit 666bea7
Show file tree
Hide file tree
Showing 8 changed files with 72 additions and 44 deletions.
32 changes: 32 additions & 0 deletions server/mdm/scep/cryptoutil/cryptoutil.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
package cryptoutil

import (
"context"
"crypto"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/elliptic"
"crypto/rsa"
"crypto/sha256"
"crypto/x509"
"encoding/asn1"
"encoding/pem"
"errors"

"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
)

// GenerateSubjectKeyID generates Subject Key Identifier (SKI) using SHA-256
Expand All @@ -34,3 +40,29 @@ func GenerateSubjectKeyID(pub crypto.PublicKey) ([]byte, error) {
// subjectPublicKey (excluding the tag, length, and number of unused bits).
return hash[:20], nil
}

func ParsePrivateKey(ctx context.Context, privKeyPEM []byte, keyName string) (crypto.PrivateKey, error) {
block, _ := pem.Decode(privKeyPEM)
if block == nil {
return nil, ctxerr.Errorf(ctx, "failed to decode %s", keyName)
}

// The code below is based on tls.parsePrivateKey
// https://cs.opensource.google/go/go/+/release-branch.go1.23:src/crypto/tls/tls.go;l=355-372
if key, err := x509.ParsePKCS1PrivateKey(block.Bytes); err == nil {
return key, nil
}
if key, err := x509.ParsePKCS8PrivateKey(block.Bytes); err == nil {
switch key := key.(type) {
case *rsa.PrivateKey, *ecdsa.PrivateKey, ed25519.PrivateKey:
return key, nil
default:
return nil, ctxerr.Errorf(ctx, "unmarshaled PKCS8 %s is not an RSA, ECDSA, or Ed25519 private key", keyName)
}
}
if key, err := x509.ParseECPrivateKey(block.Bytes); err == nil {
return key, nil
}

return nil, ctxerr.Errorf(ctx, "failed to parse %s of type %s", keyName, block.Type)
}
38 changes: 38 additions & 0 deletions server/mdm/scep/cryptoutil/cryptoutil_test.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
package cryptoutil

import (
"context"
"crypto"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
"math/big"
"os"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestGenerateSubjectKeyID(t *testing.T) {
Expand Down Expand Up @@ -56,3 +61,36 @@ func testSKIEq(a, b []byte) bool {

return true
}

func TestParsePrivateKey(t *testing.T) {
t.Parallel()
// nil block not allowed
ctx := context.Background()
_, err := ParsePrivateKey(ctx, nil, "APNS private key")
assert.ErrorContains(t, err, "failed to decode")

// encrypted pkcs8 not supported
pkcs8Encrypted, err := os.ReadFile("testdata/pkcs8-encrypted.key")
require.NoError(t, err)
_, err = ParsePrivateKey(ctx, pkcs8Encrypted, "APNS private key")
assert.ErrorContains(t, err, "failed to parse APNS private key of type ENCRYPTED PRIVATE KEY")

// X25519 pkcs8 not supported
pkcs8Encrypted, err = os.ReadFile("testdata/pkcs8-x25519.key")
require.NoError(t, err)
_, err = ParsePrivateKey(ctx, pkcs8Encrypted, "APNS private key")
assert.ErrorContains(t, err, "unmarshaled PKCS8 APNS private key is not")

// In this test, the pkcs1 key and pkcs8 keys are the same key, just different formats
pkcs1, err := os.ReadFile("testdata/pkcs1.key")
require.NoError(t, err)
pkcs1Key, err := ParsePrivateKey(ctx, pkcs1, "APNS private key")
require.NoError(t, err)

pkcs8, err := os.ReadFile("testdata/pkcs8-rsa.key")
require.NoError(t, err)
pkcs8Key, err := ParsePrivateKey(ctx, pkcs8, "APNS private key")
require.NoError(t, err)

assert.Equal(t, pkcs1Key, pkcs8Key)
}
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
4 changes: 2 additions & 2 deletions server/service/mdm.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import (
apple_mdm "github.com/fleetdm/fleet/v4/server/mdm/apple"
"github.com/fleetdm/fleet/v4/server/mdm/assets"
nanomdm "github.com/fleetdm/fleet/v4/server/mdm/nanomdm/mdm"
"github.com/fleetdm/fleet/v4/server/mdm/scep/cryptoutil"
"github.com/fleetdm/fleet/v4/server/ptr"
"github.com/go-kit/log/level"
"github.com/go-sql-driver/mysql"
Expand Down Expand Up @@ -2496,8 +2497,7 @@ func (svc *Service) GetMDMAppleCSR(ctx context.Context) ([]byte, error) {
}
} else {
rawApnsKey := savedAssets[fleet.MDMAssetAPNSKey]
block, _ := pem.Decode(rawApnsKey.Value)
apnsKey, err = parseAPNSPrivateKey(ctx, block)
apnsKey, err = cryptoutil.ParsePrivateKey(ctx, rawApnsKey.Value, "APNS private key")
if err != nil {
return nil, err
}
Expand Down
42 changes: 0 additions & 42 deletions server/service/mdm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"crypto/x509"
"crypto/x509/pkix"
"database/sql"
"encoding/pem"
"errors"
"math/big"
"net/http"
Expand Down Expand Up @@ -2185,44 +2184,3 @@ func TestBatchSetMDMProfilesLabels(t *testing.T) {
assert.Equal(t, ProfileLabels{IncludeAny: true}, *profileLabels["DIncAny"])
assert.Equal(t, ProfileLabels{ExcludeAny: true}, *profileLabels["DExclAny"])
}

func TestParseAPNSPrivateKey(t *testing.T) {
t.Parallel()
// nil block not allowed
ctx := context.Background()
_, err := parseAPNSPrivateKey(ctx, nil)
assert.ErrorContains(t, err, "failed to decode")

// encrypted pkcs8 not supported
pkcs8Encrypted, err := os.ReadFile("testdata/pkcs8-encrypted.key")
require.NoError(t, err)
block, _ := pem.Decode(pkcs8Encrypted)
assert.NotNil(t, block)
_, err = parseAPNSPrivateKey(ctx, block)
assert.ErrorContains(t, err, "failed to parse APNS private key of type ENCRYPTED PRIVATE KEY")

// X25519 pkcs8 not supported
pkcs8Encrypted, err = os.ReadFile("testdata/pkcs8-x25519.key")
require.NoError(t, err)
block, _ = pem.Decode(pkcs8Encrypted)
assert.NotNil(t, block)
_, err = parseAPNSPrivateKey(ctx, block)
assert.ErrorContains(t, err, "unmarshaled PKCS8 APNS key is not")

// In this test, the pkcs1 key and pkcs8 keys are the same key, just different formats
pkcs1, err := os.ReadFile("testdata/pkcs1.key")
require.NoError(t, err)
block, _ = pem.Decode(pkcs1)
assert.NotNil(t, block)
pkcs1Key, err := parseAPNSPrivateKey(ctx, block)
require.NoError(t, err)

pkcs8, err := os.ReadFile("testdata/pkcs8-rsa.key")
require.NoError(t, err)
block, _ = pem.Decode(pkcs8)
assert.NotNil(t, block)
pkcs8Key, err := parseAPNSPrivateKey(ctx, block)
require.NoError(t, err)

assert.Equal(t, pkcs1Key, pkcs8Key)
}

0 comments on commit 666bea7

Please sign in to comment.