diff --git a/cypress/helpers/index.js b/cypress/helpers/index.js index e8f11a9251f..4c37fb32e38 100644 --- a/cypress/helpers/index.js +++ b/cypress/helpers/index.js @@ -98,3 +98,20 @@ const deleteGrant = (id) => "DELETE", Cypress.env("admin_url") + "/trust/grants/jwt-bearer/issuers/" + id, ) + +export const validateJwt = (jwt) => + cy + .request({ + method: "POST", + url: `${Cypress.env("client_url")}/oauth2/validate-jwt`, + form: true, + body: { jwt }, + }) + .then(({ body }) => body) + +export const rotateJwks = (set) => + cy + .request("POST", `${Cypress.env("admin_url")}/keys/${set}`, { + alg: "RS256", + }) + .then(({ body }) => body) diff --git a/cypress/integration/oauth2/jwt.js b/cypress/integration/oauth2/jwt.js index 8b9d0fe78f8..5057d78f376 100644 --- a/cypress/integration/oauth2/jwt.js +++ b/cypress/integration/oauth2/jwt.js @@ -1,7 +1,7 @@ // Copyright © 2022 Ory Corp // SPDX-License-Identifier: Apache-2.0 -import { createClient, prng } from "../../helpers" +import { createClient, prng, validateJwt } from "../../helpers" const accessTokenStrategies = ["opaque", "jwt"] @@ -44,15 +44,12 @@ describe("OAuth 2.0 JSON Web Token Access Tokens", () => { expect(token.refresh_token).to.not.be.empty expect(token.access_token.split(".").length).to.equal(3) expect(token.refresh_token.split(".").length).to.equal(2) - }) - cy.request(`${Cypress.env("client_url")}/oauth2/validate-jwt`) - .its("body") - .then((body) => { - console.log(body) - expect(body.sub).to.eq("foo@bar.com") - expect(body.client_id).to.eq(client.client_id) - expect(body.jti).to.not.be.empty + validateJwt(token.access_token).then(({ payload }) => { + expect(payload.sub).to.eq("foo@bar.com") + expect(payload.client_id).to.eq(client.client_id) + expect(payload.jti).to.not.be.empty + }) }) }) }) diff --git a/cypress/integration/oauth2/refresh_token.js b/cypress/integration/oauth2/refresh_token.js index 2ddf7d30f19..dcd9cacbf19 100644 --- a/cypress/integration/oauth2/refresh_token.js +++ b/cypress/integration/oauth2/refresh_token.js @@ -1,7 +1,7 @@ // Copyright © 2022 Ory Corp // SPDX-License-Identifier: Apache-2.0 -import { createClient, prng } from "../../helpers" +import { createClient, prng, rotateJwks, validateJwt } from "../../helpers" const accessTokenStrategies = ["opaque", "jwt"] @@ -100,6 +100,61 @@ describe("The OAuth 2.0 Refresh Token Grant", function () { }) }) }) + + const validateJwtAndGetKid = (token) => + validateJwt(token).then(({ header }) => header.kid) + + it("should refresh the Access and ID Token with newly rotated keys", function () { + if ( + accessTokenStrategy === "opaque" || + (Cypress.env("jwt_enabled") !== "true" && + !Boolean(Cypress.env("jwt_enabled"))) + ) { + this.skip() + } + + const referrer = `${Cypress.env("client_url")}/empty` + cy.visit(referrer, { + failOnStatusCode: false, + }) + + createClient({ + scope: "offline_access openid", + redirect_uris: [referrer], + grant_types: ["authorization_code", "refresh_token"], + response_types: ["code"], + token_endpoint_auth_method: "none", + }).then((client) => { + cy.authCodeFlowBrowser(client, { + consent: { + scope: ["offline_access", "openid"], + }, + createClient: false, + }).then(({ body: tokensBefore }) => { + const kidsBefore = { + accessToken: validateJwtAndGetKid(tokensBefore.access_token), + idToken: validateJwtAndGetKid(tokensBefore.id_token), + } + + rotateJwks("hydra.jwt.access-token") + rotateJwks("hydra.openid.id-token") + + cy.refreshTokenBrowser(client, tokensBefore.refresh_token).then( + ({ body: tokensAfter }) => { + const kidsAfter = { + accessToken: validateJwtAndGetKid(tokensAfter.access_token), + idToken: validateJwtAndGetKid(tokensAfter.id_token), + } + + expect(kidsAfter.accessToken).to.not.equal( + kidsBefore.accessToken, + ) + expect(kidsAfter.idToken).to.not.equal(kidsBefore.idToken) + }, + ) + }) + }) + }) }) }) }) diff --git a/cypress/support/commands.js b/cypress/support/commands.js index 2f75293404d..3c2d610e1e5 100644 --- a/cypress/support/commands.js +++ b/cypress/support/commands.js @@ -196,10 +196,9 @@ Cypress.Commands.add( }) } if (doCreateClient) { - createClient(client).then(run) - return + return createClient(client).then(run) } - run(client) + return run(client) }, ) diff --git a/oauth2/handler.go b/oauth2/handler.go index 3f1a633038d..0a4ea1564ea 100644 --- a/oauth2/handler.go +++ b/oauth2/handler.go @@ -1021,6 +1021,32 @@ func (h *Handler) oauth2TokenExchange(w http.ResponseWriter, r *http.Request) { } } + if accessRequest.GetGrantTypes().ExactOne(string(fosite.GrantTypeRefreshToken)) { + var accessTokenKeyID string + if h.c.AccessTokenStrategy(ctx, client.AccessTokenStrategySource(accessRequest.GetClient())) == "jwt" { + accessTokenKeyID, err = h.r.AccessTokenJWTStrategy().GetPublicKeyID(ctx) + if err != nil { + h.logOrAudit(err, r) + h.r.OAuth2Provider().WriteAccessError(ctx, w, accessRequest, err) + events.Trace(ctx, events.TokenExchangeError, events.WithRequest(accessRequest)) + return + } + } + + openIDKeyID, err := h.r.OpenIDJWTStrategy().GetPublicKeyID(ctx) + if err != nil { + h.logOrAudit(err, r) + h.r.OAuth2Provider().WriteAccessError(ctx, w, accessRequest, err) + events.Trace(ctx, events.TokenExchangeError, events.WithRequest(accessRequest)) + return + } + + if sess, ok := accessRequest.GetSession().(*Session); ok { + sess.KID = accessTokenKeyID + sess.DefaultSession.Headers.Add("kid", openIDKeyID) + } + } + for _, hook := range h.r.AccessRequestHooks() { if err = hook(ctx, accessRequest); err != nil { h.logOrAudit(err, r) diff --git a/test/e2e/oauth2-client/src/index.js b/test/e2e/oauth2-client/src/index.js index b27512bedeb..7a211a758ab 100644 --- a/test/e2e/oauth2-client/src/index.js +++ b/test/e2e/oauth2-client/src/index.js @@ -179,23 +179,24 @@ app.get("/oauth2/revoke", (req, res) => { }) }) -app.get("/oauth2/validate-jwt", (req, res) => { +app.post("/oauth2/validate-jwt", (req, res) => { const client = jwksClient({ jwksUri: new URL("/.well-known/jwks.json", config.public).toString(), }) jwt.verify( - req.session.oauth2_flow.token.access_token, + req.body.jwt, (header, callback) => { client.getSigningKey(header.kid, function (err, key) { const signingKey = key.publicKey || key.rsaPublicKey callback(null, signingKey) }) }, + { complete: true }, (err, decoded) => { if (err) { console.error(err) - res.send(400) + res.status(400).send(JSON.stringify({ error: err.toString() })) return }