diff --git a/pkg/jwtsigning/hasher.go b/pkg/jwtsigning/hasher.go new file mode 100644 index 0000000..af87afe --- /dev/null +++ b/pkg/jwtsigning/hasher.go @@ -0,0 +1,20 @@ +package jwtsigning + +import ( + "crypto/sha256" + "encoding/base64" +) + +const hashAlgorithm = "SHA256" + +type SHA256Hasher struct{} + +func (s *SHA256Hasher) HashMessage(body []byte) string { + sum := sha256.Sum256(body) + + return base64.RawURLEncoding.EncodeToString(sum[:]) +} + +func (s *SHA256Hasher) ToString() string { + return hashAlgorithm +} diff --git a/pkg/jwtsigning/helper_test.go b/pkg/jwtsigning/helper_test.go new file mode 100644 index 0000000..4ea7ff1 --- /dev/null +++ b/pkg/jwtsigning/helper_test.go @@ -0,0 +1,74 @@ +package jwtsigning_test + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/openkcm/common-sdk/pkg/jwtsigning" +) + +type stubPrivateKeyProvider struct { + key *rsa.PrivateKey + meta jwtsigning.KeyMetadata + err error +} + +func (s *stubPrivateKeyProvider) CurrentSigningKey(ctx context.Context) (*rsa.PrivateKey, jwtsigning.KeyMetadata, error) { + return s.key, s.meta, s.err +} + +type stubPublicKeyProvider struct { + key *rsa.PublicKey + err error + lastIss string + lastKid string +} + +func (s *stubPublicKeyProvider) VerificationKey(ctx context.Context, iss, kid string) (*rsa.PublicKey, error) { + s.lastIss = iss + + s.lastKid = kid + if s.err != nil { + return nil, s.err + } + + return s.key, nil +} + +type hasherStub struct { + hash string + alg string +} + +func (h *hasherStub) HashMessage(_ []byte) string { + return h.hash +} + +func (h *hasherStub) ToString() string { + return h.alg +} + +func generateRSAKey(tb testing.TB, bits int) *rsa.PrivateKey { + tb.Helper() + + key, err := rsa.GenerateKey(rand.Reader, bits) + assert.NoError(tb, err, "generateRSAKey", err) + + return key +} + +func signMessage(tb testing.TB, provider jwtsigning.PrivateKeyProvider, hasher jwtsigning.Hasher, body []byte) string { + tb.Helper() + + signer, err := jwtsigning.NewSigner(provider, hasher) + assert.NoError(tb, err, "NewSigner failed", err) + + token, err := signer.Sign(tb.Context(), body) + assert.NoError(tb, err, "Sign failed", err) + + return token +} diff --git a/pkg/jwtsigning/providers.go b/pkg/jwtsigning/providers.go new file mode 100644 index 0000000..060eb4e --- /dev/null +++ b/pkg/jwtsigning/providers.go @@ -0,0 +1,31 @@ +package jwtsigning + +import ( + "context" + "crypto/rsa" +) + +// KeyMetadata describes the logical identity of a signing key. +// It maps directly to the JWT "iss" (issuer) and "kid" (key ID) used by +// verifiers to look up the matching public key via the configured trust mechanism. +type KeyMetadata struct { + Iss string // typically the cluster URL that exposes .well-known/jwks.json + Kid string // uniquely identifies the signing key under that issuer +} + +// PrivateKeyProvider supplies the current RSA private key and its metadata for signing outgoing messages. +type PrivateKeyProvider interface { + // CurrentSigningKey returns the key+metadata to use for signing. + CurrentSigningKey(ctx context.Context) (*rsa.PrivateKey, KeyMetadata, error) +} + +// PublicKeyProvider resolves RSA public keys for verification of incoming message signatures. +type PublicKeyProvider interface { + // VerificationKey returns the public key for given issuer and kid. + VerificationKey(ctx context.Context, iss, kid string) (*rsa.PublicKey, error) +} + +type Hasher interface { + HashMessage(body []byte) string + ToString() string +} diff --git a/pkg/jwtsigning/signer.go b/pkg/jwtsigning/signer.go new file mode 100644 index 0000000..1511cbf --- /dev/null +++ b/pkg/jwtsigning/signer.go @@ -0,0 +1,91 @@ +package jwtsigning + +import ( + "context" + "errors" + "fmt" + + "github.com/golang-jwt/jwt/v5" +) + +var ( + ErrRSAKeyLength = errors.New("RSA key is too small") + ErrUndefinedHashingAlgorithm = errors.New("undefined hashing algorithm") + ErrUndefinedSigningAlgorithm = errors.New("undefined signing algorithm") + ErrNilKeyProvider = errors.New("keyProvider cannot be nil") +) + +const ( + jwtMapClaimIss = "iss" + jwtMapClaimKid = "kid" + jwtMapClaimHash = "hash" + jwtMapClaimHashAlgorithm = "hash-alg" + + tokenHeaderType = "typ" + tokenType = "JWT" + tokenHeaderAlgorithm = "alg" +) + +// Signer signs message bodies into JWS (JWT) tokens. +type Signer struct { + Hasher + + keys PrivateKeyProvider +} + +func NewSigner(keyProvider PrivateKeyProvider, hasher Hasher) (*Signer, error) { + if keyProvider == nil { + return nil, ErrNilKeyProvider + } + + if hasher == nil { + hasher = &SHA256Hasher{} + } + + return &Signer{ + Hasher: hasher, + keys: keyProvider, + }, nil +} + +// Sign creates a compact JWS (JWT) for the given message body using PS256. +// +// The returned string is suitable for use as a value of an HTTP or message +// header (e.g. "X-Message-Signature"). The token will contain the following +// claims and headers: +// - claims: +// iss: issuer from KeyMetadata +// kid: key ID from KeyMetadata +// hash: base64url(SHA-256(body)) +// hash-alg: "SHA256" +// - headers: +// typ: "JWT" +// alg: "PS256" +// +// Sign obtains the private key and metadata from the configured +// PrivateKeyProvider and enforces the minimum RSA key size before signing. +// It returns an error if the provider fails, if the key is too small, or if +// the token cannot be signed. +func (s *Signer) Sign(ctx context.Context, body []byte) (string, error) { + priv, meta, err := s.keys.CurrentSigningKey(ctx) + if err != nil { + return "", err + } + + if priv.N.BitLen() < 3072 { + return "", fmt.Errorf("%w: %d bits", ErrRSAKeyLength, priv.N.BitLen()) + } + + claims := jwt.MapClaims{ + jwtMapClaimIss: meta.Iss, + jwtMapClaimKid: meta.Kid, + jwtMapClaimHash: s.HashMessage(body), + jwtMapClaimHashAlgorithm: s.ToString(), + } + + token := jwt.NewWithClaims(jwt.SigningMethodPS256, claims) + token.Header[tokenHeaderType] = tokenType + token.Header[tokenHeaderAlgorithm] = jwt.SigningMethodPS256.Alg() + + return token.SignedString(priv) +} diff --git a/pkg/jwtsigning/signer_test.go b/pkg/jwtsigning/signer_test.go new file mode 100644 index 0000000..261e018 --- /dev/null +++ b/pkg/jwtsigning/signer_test.go @@ -0,0 +1,181 @@ +package jwtsigning_test + +import ( + "crypto/sha256" + "encoding/base64" + "errors" + "testing" + + "github.com/golang-jwt/jwt/v5" + "github.com/stretchr/testify/assert" + + "github.com/openkcm/common-sdk/pkg/jwtsigning" +) + +func TestNewSigner(t *testing.T) { + key := generateRSAKey(t, 3072) + + tests := []struct { + name string + provider jwtsigning.PrivateKeyProvider + hasher jwtsigning.Hasher + expectError error + }{ + { + name: "nil key provider returns ErrNilKeyProvider", + provider: nil, + hasher: &hasherStub{}, + expectError: jwtsigning.ErrNilKeyProvider, + }, + { + name: "nil hasher uses default SHA256 hasher", + provider: &stubPrivateKeyProvider{ + key: key, + meta: jwtsigning.KeyMetadata{ + Iss: "iss", + Kid: "kid", + }, + }, + hasher: nil, + expectError: nil, + }, + { + name: "custom hasher is accepted", + provider: &stubPrivateKeyProvider{ + key: key, + meta: jwtsigning.KeyMetadata{}, + }, + hasher: &hasherStub{hash: "ignored", alg: "IGNORED"}, + expectError: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + signer, err := jwtsigning.NewSigner(tt.provider, tt.hasher) + + if tt.expectError != nil { + assert.Error(t, err) + assert.Nil(t, signer) + assert.ErrorIs(t, err, tt.expectError) + } else { + assert.NoError(t, err) + assert.NotNil(t, signer) + } + }) + } +} + +func TestSignerSign_SuccessClaimsAndHeaders(t *testing.T) { + key := generateRSAKey(t, 3072) + meta := jwtsigning.KeyMetadata{ + Iss: "https://issuer.example", + Kid: "key-id-1", + } + + hasher := &hasherStub{ + hash: "fixed-hash-value", + alg: "TEST-ALG", + } + + provider := &stubPrivateKeyProvider{ + key: key, + meta: meta, + } + + signer, err := jwtsigning.NewSigner(provider, hasher) + assert.NoError(t, err) + + body := []byte("message-body") + tokenStr, err := signer.Sign(t.Context(), body) + assert.NoError(t, err) + assert.NotEmpty(t, tokenStr) + + parsed, err := jwt.Parse(tokenStr, func(tkn *jwt.Token) (any, error) { + return &key.PublicKey, nil + }) + assert.NoError(t, err) + assert.True(t, parsed.Valid) + + assert.Equal(t, "JWT", parsed.Header["typ"]) + assert.Equal(t, jwt.SigningMethodPS256.Alg(), parsed.Header["alg"]) + + claims, ok := parsed.Claims.(jwt.MapClaims) + if assert.True(t, ok) { + assert.Equal(t, meta.Iss, claims["iss"]) + assert.Equal(t, meta.Kid, claims["kid"]) + assert.Equal(t, hasher.hash, claims["hash"]) + assert.Equal(t, hasher.alg, claims["hash-alg"]) + } +} + +func TestSignerSign_KeyProviderErrorIsPropagated(t *testing.T) { + expectedErr := errors.New("provider failed") + + provider := &stubPrivateKeyProvider{ + err: expectedErr, + } + + signer, err := jwtsigning.NewSigner(provider, nil) + assert.NoError(t, err) + + tokenStr, err := signer.Sign(t.Context(), []byte("body")) + assert.Empty(t, tokenStr) + assert.Error(t, err) + assert.ErrorIs(t, err, expectedErr) +} + +func TestSignerSign_FailsForTooSmallKey(t *testing.T) { + smallKey := generateRSAKey(t, 1024) // < 3072 bits + + provider := &stubPrivateKeyProvider{ + key: smallKey, + meta: jwtsigning.KeyMetadata{ + Iss: "issuer", + Kid: "kid", + }, + } + + signer, err := jwtsigning.NewSigner(provider, nil) + assert.NoError(t, err) + + tokenStr, err := signer.Sign(t.Context(), []byte("body")) + assert.Empty(t, tokenStr) + assert.Error(t, err) + assert.ErrorIs(t, err, jwtsigning.ErrRSAKeyLength) +} + +func TestSignerSign_UsesDefaultSHA256Hasher(t *testing.T) { + key := generateRSAKey(t, 3072) + meta := jwtsigning.KeyMetadata{ + Iss: "default-iss", + Kid: "default-kid", + } + body := []byte("some-body-data") + + provider := &stubPrivateKeyProvider{ + key: key, + meta: meta, + } + + signer, err := jwtsigning.NewSigner(provider, nil) + assert.NoError(t, err) + + tokenStr, err := signer.Sign(t.Context(), body) + assert.NoError(t, err) + assert.NotEmpty(t, tokenStr) + + parsed, err := jwt.Parse(tokenStr, func(tkn *jwt.Token) (any, error) { + return &key.PublicKey, nil + }) + assert.NoError(t, err) + assert.True(t, parsed.Valid) + + claims, ok := parsed.Claims.(jwt.MapClaims) + if assert.True(t, ok) { + sum := sha256.Sum256(body) + expectedHash := base64.RawURLEncoding.EncodeToString(sum[:]) + assert.Equal(t, expectedHash, claims["hash"]) + assert.Equal(t, signer.ToString(), claims["hash-alg"]) + } +} diff --git a/pkg/jwtsigning/verifier.go b/pkg/jwtsigning/verifier.go new file mode 100644 index 0000000..5f89af3 --- /dev/null +++ b/pkg/jwtsigning/verifier.go @@ -0,0 +1,144 @@ +package jwtsigning + +import ( + "context" + "crypto/subtle" + "errors" + "fmt" + + "github.com/golang-jwt/jwt/v5" +) + +var ( + ErrUnexpectedClaimsType = errors.New("unexpected claims type") + ErrUnexpectedSigningMethod = errors.New("unexpected signing method") + ErrMissingIssOrKid = errors.New("missing iss or kid in token") + ErrUntrustedIssuer = errors.New("untrusted issuer") + ErrJWTParseFailed = errors.New("jwt parse failed") + ErrUnsupportedHashAlgorithm = errors.New("unsupported hash algorithm") + ErrSignatureInvalid = errors.New("jwt signature invalid") + ErrHashClaimMissing = errors.New("missing hash claim") + ErrMessageHashMismatch = errors.New("message hash mismatch") + ErrNilPublicKeyProvider = errors.New("publicKeyProvider cannot be nil") + ErrNoTrustedIssuers = errors.New("trusted issuers cannot be nil") +) + +// Verifier verifies signed messages represented as compact JWS (JWT) tokens. +type Verifier struct { + Hasher + + // keys resolves public keys for given (iss, kid) pairs. It must be non-nil. + keys PublicKeyProvider + + // trustedIssuers optionally restricts which issuers are accepted. If the + // map is non-empty, the verifier will only accept tokens whose "iss" + // claim appears as a key in this map. If it is nil or empty, all issuers + // resolved by Keys are accepted. + trustedIssuers map[string]struct{} +} + +func NewVerifier(publicKeyProvider PublicKeyProvider, hasher Hasher, trustedIssuers map[string]struct{}) (*Verifier, error) { + if publicKeyProvider == nil { + return nil, ErrNilPublicKeyProvider + } + + if hasher == nil { + hasher = &SHA256Hasher{} + } + + if len(trustedIssuers) == 0 { + return nil, ErrNoTrustedIssuers + } + + return &Verifier{ + Hasher: hasher, + keys: publicKeyProvider, + trustedIssuers: trustedIssuers, + }, nil +} + +// Verify checks that the given compact JWS (JWT) token is a valid signature +// for the provided message body. +// +// Verify performs the following steps: +// 1. Parses the token and enforces the PS256 signing method. +// 2. Extracts "iss" and "kid" claims and validates "iss" against +// TrustedIssuers if configured. +// 3. Resolves the corresponding RSA public key via the PublicKeyProvider. +// 4. Enforces the minimum RSA key size. +// 5. Verifies the JWS signature using the resolved public key. +// 6. Validates that the "hash-alg" claim is "SHA256". +// 7. Recomputes base64url(SHA-256(body)) and compares it to the "hash" +// claim in constant time. +// +// If any step fails, Verify returns a non-nil error and the caller should +// treat the message as untrusted. +func (v *Verifier) Verify(ctx context.Context, tokenStr string, body []byte) error { + keyFunc := func(t *jwt.Token) (any, error) { + if t.Method != jwt.SigningMethodPS256 { + return nil, fmt.Errorf("%w: %s", ErrUnexpectedSigningMethod, t.Method.Alg()) + } + + claims, ok := t.Claims.(jwt.MapClaims) + if !ok { + return nil, ErrUnexpectedClaimsType + } + + iss, _ := claims[jwtMapClaimIss].(string) + kid, _ := claims[jwtMapClaimKid].(string) + + if iss == "" || kid == "" { + return nil, ErrMissingIssOrKid + } + + if len(v.trustedIssuers) == 0 { + return nil, ErrNoTrustedIssuers + } + + if _, trusted := v.trustedIssuers[iss]; !trusted { + return nil, ErrUntrustedIssuer + } + + pub, err := v.keys.VerificationKey(ctx, iss, kid) + if err != nil { + return nil, err + } + + if pub.N.BitLen() < 3072 { + return nil, fmt.Errorf("%w: %d bits", ErrRSAKeyLength, pub.N.BitLen()) + } + + return pub, nil + } + + parsed, err := jwt.Parse(tokenStr, keyFunc) + if err != nil { + return fmt.Errorf("%w: %w", ErrJWTParseFailed, err) + } + + if !parsed.Valid { + return ErrSignatureInvalid + } + + claims, ok := parsed.Claims.(jwt.MapClaims) + if !ok { + return ErrUnexpectedClaimsType + } + + hashAlg, _ := claims[jwtMapClaimHashAlgorithm].(string) + if hashAlg != v.ToString() { + return fmt.Errorf("%w: %s", ErrUnsupportedHashAlgorithm, hashAlg) + } + + hashClaim, _ := claims[jwtMapClaimHash].(string) + if hashClaim == "" { + return ErrHashClaimMissing + } + + calc := v.HashMessage(body) + if subtle.ConstantTimeCompare([]byte(hashClaim), []byte(calc)) != 1 { + return ErrMessageHashMismatch + } + + return nil +} diff --git a/pkg/jwtsigning/verifier_test.go b/pkg/jwtsigning/verifier_test.go new file mode 100644 index 0000000..a39885c --- /dev/null +++ b/pkg/jwtsigning/verifier_test.go @@ -0,0 +1,370 @@ +package jwtsigning_test + +import ( + "errors" + "testing" + + "github.com/golang-jwt/jwt/v5" + "github.com/stretchr/testify/assert" + + "github.com/openkcm/common-sdk/pkg/jwtsigning" +) + +func TestVerifier(t *testing.T) { + defaultTrustMap := map[string]struct{}{ + "issuer": {}, + "https://issuer.example": {}, + "trusted-issuer": {}, + } + + t.Run("NewVerifier", func(t *testing.T) { + key := generateRSAKey(t, 3072) + + tests := []struct { + name string + pubProvider jwtsigning.PublicKeyProvider + hasher jwtsigning.Hasher + expectError error + }{ + { + name: "nil public key provider returns ErrNilPublicKeyProvider", + pubProvider: nil, + hasher: &hasherStub{}, + expectError: jwtsigning.ErrNilPublicKeyProvider, + }, + { + name: "nil hasher uses default SHA256 hasher", + pubProvider: &stubPublicKeyProvider{ + key: &key.PublicKey, + }, + hasher: nil, + expectError: nil, + }, + { + name: "custom hasher is accepted", + pubProvider: &stubPublicKeyProvider{ + key: &key.PublicKey, + }, + hasher: &hasherStub{hash: "ignored", alg: "IGNORED"}, + expectError: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + verifier, err := jwtsigning.NewVerifier(tt.pubProvider, tt.hasher, defaultTrustMap) + + if tt.expectError != nil { + assert.Error(t, err) + assert.Nil(t, verifier) + assert.ErrorIs(t, err, tt.expectError) + } else { + assert.NoError(t, err) + assert.NotNil(t, verifier) + } + }) + } + }) + + t.Run("Verify_Success_ResolvesKeyAndValidatesHash", func(t *testing.T) { + key := generateRSAKey(t, 3072) + meta := jwtsigning.KeyMetadata{ + Iss: "https://issuer.example", + Kid: "kid-123", + } + body := []byte("verified-message") + + privProvider := &stubPrivateKeyProvider{ + key: key, + meta: meta, + } + + token := signMessage(t, privProvider, nil, body) + + pubProvider := &stubPublicKeyProvider{ + key: &key.PublicKey, + } + + trustMap := map[string]struct{}{meta.Iss: {}} + verifier, err := jwtsigning.NewVerifier(pubProvider, nil, trustMap) + assert.NoError(t, err) + + err = verifier.Verify(t.Context(), token, body) + assert.NoError(t, err) + + assert.Equal(t, meta.Iss, pubProvider.lastIss) + assert.Equal(t, meta.Kid, pubProvider.lastKid) + }) + + t.Run("Verify_TrustedIssuersBehavior", func(t *testing.T) { + key := generateRSAKey(t, 3072) + meta := jwtsigning.KeyMetadata{ + Iss: "trusted-issuer", + Kid: "kid-1", + } + body := []byte("body") + + token := signMessage(t, &stubPrivateKeyProvider{ + key: key, + meta: meta, + }, nil, body) + + pubProvider := &stubPublicKeyProvider{ + key: &key.PublicKey, + } + + tests := []struct { + name string + trustedIssuers map[string]struct{} + expectErr error + }{ + { + name: "issuer in trust list is accepted", + trustedIssuers: map[string]struct{}{meta.Iss: {}}, + expectErr: nil, + }, + { + name: "issuer not in trust list is rejected", + trustedIssuers: map[string]struct{}{"other-issuer": {}}, + expectErr: jwtsigning.ErrUntrustedIssuer, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + verifier, err := jwtsigning.NewVerifier(pubProvider, nil, tt.trustedIssuers) + assert.NoError(t, err) + + err = verifier.Verify(t.Context(), token, body) + + if tt.expectErr == nil { + assert.NoError(t, err) + } else { + assert.Error(t, err) + assert.ErrorIs(t, err, jwtsigning.ErrJWTParseFailed) + assert.ErrorIs(t, err, tt.expectErr) + } + }) + } + }) + + t.Run("Verify_InvalidTokenString", func(t *testing.T) { + key := generateRSAKey(t, 3072) + + pubProvider := &stubPublicKeyProvider{ + key: &key.PublicKey, + } + + verifier, err := jwtsigning.NewVerifier(pubProvider, nil, defaultTrustMap) + assert.NoError(t, err) + + err = verifier.Verify(t.Context(), "not-a-jwt", []byte("body")) + assert.Error(t, err) + assert.ErrorIs(t, err, jwtsigning.ErrJWTParseFailed) + }) + + t.Run("Verify_MissingIssOrKid", func(t *testing.T) { + key := generateRSAKey(t, 3072) + body := []byte("body") + + tests := []struct { + name string + meta jwtsigning.KeyMetadata + }{ + { + name: "missing iss", + meta: jwtsigning.KeyMetadata{ + Iss: "", + Kid: "kid", + }, + }, + { + name: "missing kid", + meta: jwtsigning.KeyMetadata{ + Iss: "issuer", + Kid: "", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + token := signMessage(t, &stubPrivateKeyProvider{ + key: key, + meta: tt.meta, + }, nil, body) + + pubProvider := &stubPublicKeyProvider{ + key: &key.PublicKey, + } + + // FIX: Pass valid map + verifier, err := jwtsigning.NewVerifier(pubProvider, nil, defaultTrustMap) + assert.NoError(t, err) + + err = verifier.Verify(t.Context(), token, body) + assert.Error(t, err) + assert.ErrorIs(t, err, jwtsigning.ErrJWTParseFailed) + assert.ErrorIs(t, err, jwtsigning.ErrMissingIssOrKid) + }) + } + }) + + t.Run("Verify_UnexpectedSigningMethod", func(t *testing.T) { + claims := jwt.MapClaims{ + "iss": "issuer", + "kid": "kid", + "hash": "h", + "hash-alg": "SHA256", + } + + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + tokenStr, err := token.SignedString([]byte("secret")) + assert.NoError(t, err) + + verifier, err := jwtsigning.NewVerifier(&stubPublicKeyProvider{}, nil, defaultTrustMap) + assert.NoError(t, err) + + err = verifier.Verify(t.Context(), tokenStr, []byte("body")) + assert.Error(t, err) + assert.ErrorIs(t, err, jwtsigning.ErrJWTParseFailed) + assert.ErrorIs(t, err, jwtsigning.ErrUnexpectedSigningMethod) + }) + + t.Run("Verify_PublicKeyProviderErrorIsPropagated", func(t *testing.T) { + key := generateRSAKey(t, 3072) + meta := jwtsigning.KeyMetadata{ + Iss: "issuer", + Kid: "kid", + } + body := []byte("body") + + token := signMessage(t, &stubPrivateKeyProvider{ + key: key, + meta: meta, + }, nil, body) + + providerErr := errors.New("verification key lookup failed") + pubProvider := &stubPublicKeyProvider{ + key: nil, + err: providerErr, + } + + trustMap := map[string]struct{}{meta.Iss: {}} + verifier, err := jwtsigning.NewVerifier(pubProvider, nil, trustMap) + assert.NoError(t, err) + + err = verifier.Verify(t.Context(), token, body) + assert.Error(t, err) + assert.ErrorIs(t, err, jwtsigning.ErrJWTParseFailed) + assert.ErrorIs(t, err, providerErr) + }) + + t.Run("Verify_FailsForTooSmallPublicKey", func(t *testing.T) { + smallKey := generateRSAKey(t, 1024) + + claims := jwt.MapClaims{ + "iss": "issuer", + "kid": "kid", + "hash": "h", + "hash-alg": "SHA256", + } + + token := jwt.NewWithClaims(jwt.SigningMethodPS256, claims) + tokenStr, err := token.SignedString(smallKey) + assert.NoError(t, err) + + pubProvider := &stubPublicKeyProvider{ + key: &smallKey.PublicKey, + } + + verifier, err := jwtsigning.NewVerifier(pubProvider, nil, defaultTrustMap) + assert.NoError(t, err) + + err = verifier.Verify(t.Context(), tokenStr, []byte("body")) + assert.Error(t, err) + assert.ErrorIs(t, err, jwtsigning.ErrJWTParseFailed) + assert.ErrorIs(t, err, jwtsigning.ErrRSAKeyLength) + }) + + t.Run("Verify_HashAndAlgFailures", func(t *testing.T) { + key := generateRSAKey(t, 3072) + meta := jwtsigning.KeyMetadata{ + Iss: "issuer", + Kid: "kid", + } + body := []byte("body") + + pubProvider := &stubPublicKeyProvider{ + key: &key.PublicKey, + } + + tests := []struct { + name string + hasher *hasherStub + expectedErr error + }{ + { + name: "unsupported hash algorithm", + hasher: &hasherStub{ + hash: "some-hash", + alg: "OTHER-ALG", + }, + expectedErr: jwtsigning.ErrUnsupportedHashAlgorithm, + }, + { + name: "hash claim missing", + hasher: &hasherStub{ + hash: "", + alg: "SHA256", + }, + expectedErr: jwtsigning.ErrHashClaimMissing, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + token := signMessage(t, &stubPrivateKeyProvider{ + key: key, + meta: meta, + }, tt.hasher, body) + + trustMap := map[string]struct{}{meta.Iss: {}} + verifier, err := jwtsigning.NewVerifier(pubProvider, nil, trustMap) + assert.NoError(t, err) + + err = verifier.Verify(t.Context(), token, body) + assert.Error(t, err) + assert.ErrorIs(t, err, tt.expectedErr) + }) + } + }) + + t.Run("Verify_HashMismatch", func(t *testing.T) { + key := generateRSAKey(t, 3072) + meta := jwtsigning.KeyMetadata{ + Iss: "issuer", + Kid: "kid", + } + + originalBody := []byte("original-body") + tamperedBody := []byte("tampered-body") + + token := signMessage(t, &stubPrivateKeyProvider{ + key: key, + meta: meta, + }, nil, originalBody) + + pubProvider := &stubPublicKeyProvider{ + key: &key.PublicKey, + } + + trustMap := map[string]struct{}{meta.Iss: {}} + verifier, err := jwtsigning.NewVerifier(pubProvider, nil, trustMap) + assert.NoError(t, err) + + err = verifier.Verify(t.Context(), token, tamperedBody) + assert.Error(t, err) + assert.ErrorIs(t, err, jwtsigning.ErrMessageHashMismatch) + }) +}