Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions cypress/helpers/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,20 @@ const deleteGrant = (id) =>
"DELETE",
Cypress.env("admin_url") + "/trust/grants/jwt-bearer/issuers/" + id,
)

export const validateJwt = (jwt) =>
cy
.request({
method: "POST",
url: `${Cypress.env("client_url")}/oauth2/validate-jwt`,
form: true,
body: { jwt },
})
.then(({ body }) => body)

export const rotateJwks = (set) =>
cy
.request("POST", `${Cypress.env("admin_url")}/keys/${set}`, {
alg: "RS256",
})
.then(({ body }) => body)
15 changes: 6 additions & 9 deletions cypress/integration/oauth2/jwt.js
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Copyright © 2022 Ory Corp
// SPDX-License-Identifier: Apache-2.0

import { createClient, prng } from "../../helpers"
import { createClient, prng, validateJwt } from "../../helpers"

const accessTokenStrategies = ["opaque", "jwt"]

Expand Down Expand Up @@ -44,15 +44,12 @@ describe("OAuth 2.0 JSON Web Token Access Tokens", () => {
expect(token.refresh_token).to.not.be.empty
expect(token.access_token.split(".").length).to.equal(3)
expect(token.refresh_token.split(".").length).to.equal(2)
})

cy.request(`${Cypress.env("client_url")}/oauth2/validate-jwt`)
.its("body")
.then((body) => {
console.log(body)
expect(body.sub).to.eq("foo@bar.com")
expect(body.client_id).to.eq(client.client_id)
expect(body.jti).to.not.be.empty
validateJwt(token.access_token).then(({ payload }) => {
expect(payload.sub).to.eq("foo@bar.com")
expect(payload.client_id).to.eq(client.client_id)
expect(payload.jti).to.not.be.empty
})
})
})
})
Expand Down
57 changes: 56 additions & 1 deletion cypress/integration/oauth2/refresh_token.js
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Copyright © 2022 Ory Corp
// SPDX-License-Identifier: Apache-2.0

import { createClient, prng } from "../../helpers"
import { createClient, prng, rotateJwks, validateJwt } from "../../helpers"

const accessTokenStrategies = ["opaque", "jwt"]

Expand Down Expand Up @@ -100,6 +100,61 @@ describe("The OAuth 2.0 Refresh Token Grant", function () {
})
})
})

const validateJwtAndGetKid = (token) =>
validateJwt(token).then(({ header }) => header.kid)

it("should refresh the Access and ID Token with newly rotated keys", function () {
if (
accessTokenStrategy === "opaque" ||
(Cypress.env("jwt_enabled") !== "true" &&
!Boolean(Cypress.env("jwt_enabled")))
) {
this.skip()
}

const referrer = `${Cypress.env("client_url")}/empty`
cy.visit(referrer, {
failOnStatusCode: false,
})

createClient({
scope: "offline_access openid",
redirect_uris: [referrer],
grant_types: ["authorization_code", "refresh_token"],
response_types: ["code"],
token_endpoint_auth_method: "none",
}).then((client) => {
cy.authCodeFlowBrowser(client, {
consent: {
scope: ["offline_access", "openid"],
},
createClient: false,
}).then(({ body: tokensBefore }) => {
const kidsBefore = {
accessToken: validateJwtAndGetKid(tokensBefore.access_token),
idToken: validateJwtAndGetKid(tokensBefore.id_token),
}

rotateJwks("hydra.jwt.access-token")
rotateJwks("hydra.openid.id-token")

cy.refreshTokenBrowser(client, tokensBefore.refresh_token).then(
({ body: tokensAfter }) => {
const kidsAfter = {
accessToken: validateJwtAndGetKid(tokensAfter.access_token),
idToken: validateJwtAndGetKid(tokensAfter.id_token),
}

expect(kidsAfter.accessToken).to.not.equal(
kidsBefore.accessToken,
)
expect(kidsAfter.idToken).to.not.equal(kidsBefore.idToken)
},
)
})
})
})
})
})
})
5 changes: 2 additions & 3 deletions cypress/support/commands.js
Original file line number Diff line number Diff line change
Expand Up @@ -196,10 +196,9 @@ Cypress.Commands.add(
})
}
if (doCreateClient) {
createClient(client).then(run)
return
return createClient(client).then(run)
}
run(client)
return run(client)
},
)

Expand Down
26 changes: 26 additions & 0 deletions oauth2/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -1021,6 +1021,32 @@ func (h *Handler) oauth2TokenExchange(w http.ResponseWriter, r *http.Request) {
}
}

if accessRequest.GetGrantTypes().ExactOne(string(fosite.GrantTypeRefreshToken)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe a short comment explaining why we do this here?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if accessRequest.GetGrantTypes().ExactOne(string(fosite.GrantTypeRefreshToken)) {
// When refreshing tokens, we want to ensure to use the latest key-id available for signing the
// potentially JWT-formatted tokens.
if accessRequest.GetGrantTypes().ExactOne(string(fosite.GrantTypeRefreshToken)) {

var accessTokenKeyID string
if h.c.AccessTokenStrategy(ctx, client.AccessTokenStrategySource(accessRequest.GetClient())) == "jwt" {
accessTokenKeyID, err = h.r.AccessTokenJWTStrategy().GetPublicKeyID(ctx)
if err != nil {
h.logOrAudit(err, r)
h.r.OAuth2Provider().WriteAccessError(ctx, w, accessRequest, err)
events.Trace(ctx, events.TokenExchangeError, events.WithRequest(accessRequest))
return
}
}

openIDKeyID, err := h.r.OpenIDJWTStrategy().GetPublicKeyID(ctx)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please add an if condition to only fetch this key if the flow is an openid flow. I think it should be enough to check if GrantedScope contains openid. This will reduce DB load in cases where we don't have an OpenID flow.

if err != nil {
h.logOrAudit(err, r)
h.r.OAuth2Provider().WriteAccessError(ctx, w, accessRequest, err)
events.Trace(ctx, events.TokenExchangeError, events.WithRequest(accessRequest))
return
}

if sess, ok := accessRequest.GetSession().(*Session); ok {
sess.KID = accessTokenKeyID
sess.DefaultSession.Headers.Add("kid", openIDKeyID)
}
}

for _, hook := range h.r.AccessRequestHooks() {
if err = hook(ctx, accessRequest); err != nil {
h.logOrAudit(err, r)
Expand Down
7 changes: 4 additions & 3 deletions test/e2e/oauth2-client/src/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -179,23 +179,24 @@
})
})

app.get("/oauth2/validate-jwt", (req, res) => {
app.post("/oauth2/validate-jwt", (req, res) => {
const client = jwksClient({
jwksUri: new URL("/.well-known/jwks.json", config.public).toString(),
})

jwt.verify(
req.session.oauth2_flow.token.access_token,
req.body.jwt,
(header, callback) => {
client.getSigningKey(header.kid, function (err, key) {
const signingKey = key.publicKey || key.rsaPublicKey
callback(null, signingKey)
})
},
{ complete: true },
(err, decoded) => {
if (err) {
console.error(err)
res.send(400)
res.status(400).send(JSON.stringify({ error: err.toString() }))
return
}

Expand Down
Loading