Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions pkg/jwtsigning/hasher.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package jwtsigning

import (
"crypto/sha256"
"encoding/base64"
)

const hashAlgorithm = "SHA256"

type SHA256Hasher struct{}

func (s *SHA256Hasher) HashMessage(body []byte) string {
sum := sha256.Sum256(body)

return base64.RawURLEncoding.EncodeToString(sum[:])
}

func (s *SHA256Hasher) ToString() string {
return hashAlgorithm
}
74 changes: 74 additions & 0 deletions pkg/jwtsigning/helper_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package jwtsigning_test

import (
"context"
"crypto/rand"
"crypto/rsa"
"testing"

"github.com/stretchr/testify/assert"

"github.com/openkcm/common-sdk/pkg/jwtsigning"
)

type stubPrivateKeyProvider struct {
key *rsa.PrivateKey
meta jwtsigning.KeyMetadata
err error
}

func (s *stubPrivateKeyProvider) CurrentSigningKey(ctx context.Context) (*rsa.PrivateKey, jwtsigning.KeyMetadata, error) {
return s.key, s.meta, s.err
}

type stubPublicKeyProvider struct {
key *rsa.PublicKey
err error
lastIss string
lastKid string
}

func (s *stubPublicKeyProvider) VerificationKey(ctx context.Context, iss, kid string) (*rsa.PublicKey, error) {
s.lastIss = iss

s.lastKid = kid
if s.err != nil {
return nil, s.err
}

return s.key, nil
}

type hasherStub struct {
hash string
alg string
}

func (h *hasherStub) HashMessage(_ []byte) string {
return h.hash
}

func (h *hasherStub) ToString() string {
return h.alg
}

func generateRSAKey(tb testing.TB, bits int) *rsa.PrivateKey {
tb.Helper()

key, err := rsa.GenerateKey(rand.Reader, bits)
assert.NoError(tb, err, "generateRSAKey", err)

return key
}

func signMessage(tb testing.TB, provider jwtsigning.PrivateKeyProvider, hasher jwtsigning.Hasher, body []byte) string {
tb.Helper()

signer, err := jwtsigning.NewSigner(provider, hasher)
assert.NoError(tb, err, "NewSigner failed", err)

token, err := signer.Sign(tb.Context(), body)
assert.NoError(tb, err, "Sign failed", err)

return token
}
31 changes: 31 additions & 0 deletions pkg/jwtsigning/providers.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package jwtsigning

import (
"context"
"crypto/rsa"
)

// KeyMetadata describes the logical identity of a signing key.
// It maps directly to the JWT "iss" (issuer) and "kid" (key ID) used by
// verifiers to look up the matching public key via the configured trust mechanism.
type KeyMetadata struct {
Iss string // typically the cluster URL that exposes .well-known/jwks.json
Kid string // uniquely identifies the signing key under that issuer
}

// PrivateKeyProvider supplies the current RSA private key and its metadata for signing outgoing messages.
type PrivateKeyProvider interface {
// CurrentSigningKey returns the key+metadata to use for signing.
CurrentSigningKey(ctx context.Context) (*rsa.PrivateKey, KeyMetadata, error)
}

// PublicKeyProvider resolves RSA public keys for verification of incoming message signatures.
type PublicKeyProvider interface {
// VerificationKey returns the public key for given issuer and kid.
VerificationKey(ctx context.Context, iss, kid string) (*rsa.PublicKey, error)
}

type Hasher interface {
HashMessage(body []byte) string
ToString() string
}
91 changes: 91 additions & 0 deletions pkg/jwtsigning/signer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
package jwtsigning

import (
"context"
"errors"
"fmt"

"github.com/golang-jwt/jwt/v5"
)

var (
ErrRSAKeyLength = errors.New("RSA key is too small")
ErrUndefinedHashingAlgorithm = errors.New("undefined hashing algorithm")
ErrUndefinedSigningAlgorithm = errors.New("undefined signing algorithm")
ErrNilKeyProvider = errors.New("keyProvider cannot be nil")
)

const (
jwtMapClaimIss = "iss"
jwtMapClaimKid = "kid"
jwtMapClaimHash = "hash"
jwtMapClaimHashAlgorithm = "hash-alg"

tokenHeaderType = "typ"
tokenType = "JWT"
tokenHeaderAlgorithm = "alg"
)

// Signer signs message bodies into JWS (JWT) tokens.
type Signer struct {
Hasher

keys PrivateKeyProvider
}

func NewSigner(keyProvider PrivateKeyProvider, hasher Hasher) (*Signer, error) {
if keyProvider == nil {
return nil, ErrNilKeyProvider
}

if hasher == nil {
hasher = &SHA256Hasher{}
}

return &Signer{
Hasher: hasher,
keys: keyProvider,
}, nil
}

// Sign creates a compact JWS (JWT) for the given message body using PS256.
//
// The returned string is suitable for use as a value of an HTTP or message
// header (e.g. "X-Message-Signature"). The token will contain the following
// claims and headers:
// - claims:
// iss: issuer from KeyMetadata
// kid: key ID from KeyMetadata
// hash: base64url(SHA-256(body))
// hash-alg: "SHA256"
// - headers:
// typ: "JWT"
// alg: "PS256"
//
// Sign obtains the private key and metadata from the configured
// PrivateKeyProvider and enforces the minimum RSA key size before signing.
// It returns an error if the provider fails, if the key is too small, or if
// the token cannot be signed.
func (s *Signer) Sign(ctx context.Context, body []byte) (string, error) {
priv, meta, err := s.keys.CurrentSigningKey(ctx)
if err != nil {
return "", err
}

if priv.N.BitLen() < 3072 {
return "", fmt.Errorf("%w: %d bits", ErrRSAKeyLength, priv.N.BitLen())
}

claims := jwt.MapClaims{
jwtMapClaimIss: meta.Iss,
jwtMapClaimKid: meta.Kid,
jwtMapClaimHash: s.HashMessage(body),
jwtMapClaimHashAlgorithm: s.ToString(),
}

token := jwt.NewWithClaims(jwt.SigningMethodPS256, claims)
token.Header[tokenHeaderType] = tokenType
token.Header[tokenHeaderAlgorithm] = jwt.SigningMethodPS256.Alg()

return token.SignedString(priv)
}
Loading