diff --git a/oauth2/fosite_store_helpers_test.go b/oauth2/fosite_store_helpers_test.go index 8c6ed0d6881..107df0e0051 100644 --- a/oauth2/fosite_store_helpers_test.go +++ b/oauth2/fosite_store_helpers_test.go @@ -1250,6 +1250,31 @@ func testFositeJWTBearerGrantStorage(x *driver.RegistrySQL) func(t *testing.T) { require.NotEmpty(t, jwks.Keys) }) + t.Run("case=does not found expired grant", func(t *testing.T) { + keySet, err := jwk.GenerateJWK(jose.RS256, uuid.Must(uuid.NewV4()).String(), "sig") + require.NoError(t, err) + + publicKey := keySet.Keys[0].Public() + issuer := uuid.Must(uuid.NewV4()).String() + subject := uuid.Must(uuid.NewV4()).String() + grant := trust.Grant{ + ID: uuid.Must(uuid.NewV4()), + Issuer: issuer, + Subject: subject, + AllowAnySubject: true, + Scope: []string{"openid", "offline"}, + PublicKey: trust.PublicKey{Set: issuer, KeyID: publicKey.KeyID}, + CreatedAt: time.Now().UTC().Round(time.Second), + ExpiresAt: time.Now().UTC().Round(time.Second).AddDate(-1, 0, 0), + } + + require.NoError(t, grantManager.CreateGrant(ctx, grant, publicKey)) + + key, err := grantStorage.GetPublicKey(ctx, issuer, subject, publicKey.KeyID) + require.Error(t, err) + assert.Nil(t, key) + }) + t.Run("case=does not return expired values", func(t *testing.T) { keySet, err := jwk.GenerateJWK(jose.RS256, uuid.Must(uuid.NewV4()).String(), "sig") require.NoError(t, err) diff --git a/persistence/sql/persister_grant_jwk.go b/persistence/sql/persister_grant_jwk.go index 001167ad0b7..29dd47f1074 100644 --- a/persistence/sql/persister_grant_jwk.go +++ b/persistence/sql/persister_grant_jwk.go @@ -175,7 +175,12 @@ func (p *Persister) GetPublicKey(ctx context.Context, issuer string, subject str tableName += "@hydra_oauth2_trusted_jwt_bearer_issuer_nid_uq_idx" } - sql := fmt.Sprintf(`SELECT key_set FROM %s WHERE key_id = ? AND nid = ? AND issuer = ? AND (subject = ? OR allow_any_subject IS TRUE) LIMIT 1`, tableName) + expiresAt := "expires_at > NOW()" + if p.Connection(ctx).Dialect.Name() == "sqlite3" { + expiresAt = "expires_at > datetime('now')" + } + + sql := fmt.Sprintf(`SELECT key_set FROM %s WHERE key_id = ? AND nid = ? AND issuer = ? AND (subject = ? OR allow_any_subject IS TRUE) AND %s LIMIT 1`, tableName, expiresAt) query := p.Connection(ctx).RawQuery(sql, keyId, p.NetworkID(ctx), issuer, subject, )