diff --git a/oidc/verify.go b/oidc/verify.go index 0bca49a8..d8e80794 100644 --- a/oidc/verify.go +++ b/oidc/verify.go @@ -12,7 +12,7 @@ import ( "strings" "time" - jose "github.com/go-jose/go-jose/v3" + "github.com/go-jose/go-jose/v3" "golang.org/x/oauth2" ) @@ -33,6 +33,19 @@ func (e *TokenExpiredError) Error() string { return fmt.Sprintf("oidc: token is expired (Token Expiry: %v)", e.Expiry) } +// UnexpectedAudienceError indicates that the audience claim of the token did not match +// any of the expected audiences. +type UnexpectedAudienceError struct { + // ClientID is the client ID that was used to initialize the verifier. + ClientID string + // Audience is the audiences specified in the token. + Audience []string +} + +func (e *UnexpectedAudienceError) Error() string { + return fmt.Sprintf("oidc: expected audience %q got %q", e.ClientID, e.Audience) +} + // KeySet is a set of publc JSON Web Keys that can be used to validate the signature // of JSON web tokens. This is expected to be backed by a remote key set through // provider metadata discovery or an in-memory set of keys delivered out-of-band. @@ -274,7 +287,7 @@ func (v *IDTokenVerifier) Verify(ctx context.Context, rawIDToken string) (*IDTok if !v.config.SkipClientIDCheck { if v.config.ClientID != "" { if !contains(t.Audience, v.config.ClientID) { - return nil, fmt.Errorf("oidc: expected audience %q got %q", v.config.ClientID, t.Audience) + return nil, &UnexpectedAudienceError{v.config.ClientID, t.Audience} } } else { return nil, fmt.Errorf("oidc: invalid configuration, clientID must be provided or SkipClientIDCheck must be set") diff --git a/oidc/verify_test.go b/oidc/verify_test.go index f2e2433b..14650878 100644 --- a/oidc/verify_test.go +++ b/oidc/verify_test.go @@ -175,8 +175,8 @@ func TestVerifyAudience(t *testing.T) { ClientID: "client1", SkipExpiryCheck: true, }, - signKey: newRSAKey(t), - wantErr: true, + signKey: newRSAKey(t), + wantErrAud: true, }, { name: "multiple audiences, one matches", @@ -573,6 +573,7 @@ type verificationTest struct { config Config wantErr bool wantErrExpiry bool + wantErrAud bool } func (v verificationTest) runGetToken(t *testing.T) (*IDToken, error) { @@ -605,10 +606,10 @@ func (v verificationTest) runGetToken(t *testing.T) (*IDToken, error) { func (v verificationTest) run(t *testing.T) { _, err := v.runGetToken(t) - if err != nil && !v.wantErr && !v.wantErrExpiry { + if err != nil && !v.wantErr && !v.wantErrExpiry && !v.wantErrAud { t.Errorf("%v", err) } - if err == nil && (v.wantErr || v.wantErrExpiry) { + if err == nil && (v.wantErr || v.wantErrExpiry || v.wantErrAud) { t.Errorf("expected error") } if v.wantErrExpiry { @@ -617,4 +618,10 @@ func (v verificationTest) run(t *testing.T) { t.Errorf("expected *TokenExpiryError but got %q", err) } } + if v.wantErrAud { + var errAud *UnexpectedAudienceError + if !errors.As(err, &errAud) { + t.Errorf("expected *AudienceError but got %q", err) + } + } }