From 3fd082fb8832ac48e691947417e8552442da3337 Mon Sep 17 00:00:00 2001 From: Sergey Beryozkin Date: Sat, 6 Jul 2024 18:44:52 +0100 Subject: [PATCH] Support list of TokenCustomizers --- .../io/quarkus/oidc/runtime/OidcProvider.java | 31 ++++++++++--------- .../oidc/runtime/TenantFeatureFinder.java | 22 ++++--------- .../oidc/runtime/OidcProviderTest.java | 4 +-- 3 files changed, 24 insertions(+), 33 deletions(-) diff --git a/extensions/oidc/runtime/src/main/java/io/quarkus/oidc/runtime/OidcProvider.java b/extensions/oidc/runtime/src/main/java/io/quarkus/oidc/runtime/OidcProvider.java index dde4e3d77d34d..6d32ab3644e0c 100644 --- a/extensions/oidc/runtime/src/main/java/io/quarkus/oidc/runtime/OidcProvider.java +++ b/extensions/oidc/runtime/src/main/java/io/quarkus/oidc/runtime/OidcProvider.java @@ -72,7 +72,7 @@ public class OidcProvider implements Closeable { final RefreshableVerificationKeyResolver asymmetricKeyResolver; final DynamicVerificationKeyResolver keyResolverProvider; final OidcTenantConfig oidcConfig; - final TokenCustomizer tokenCustomizer; + final List tokenCustomizers; final String issuer; final String[] audience; final Map requiredClaims; @@ -85,10 +85,10 @@ public OidcProvider(OidcProviderClient client, OidcTenantConfig oidcConfig, Json } public OidcProvider(OidcProviderClient client, OidcTenantConfig oidcConfig, JsonWebKeySet jwks, - TokenCustomizer tokenCustomizer, Key tokenDecryptionKey, List customValidators) { + List tokenCustomizers, Key tokenDecryptionKey, List customValidators) { this.client = client; this.oidcConfig = oidcConfig; - this.tokenCustomizer = tokenCustomizer; + this.tokenCustomizers = tokenCustomizers; if (jwks != null) { this.asymmetricKeyResolver = new JsonWebKeyResolver(jwks, oidcConfig.token.forcedJwkRefreshInterval); } else if (oidcConfig != null && oidcConfig.certificateChain.trustStoreFile.isPresent()) { @@ -113,7 +113,7 @@ public OidcProvider(OidcProviderClient client, OidcTenantConfig oidcConfig, Json public OidcProvider(String publicKeyEnc, OidcTenantConfig oidcConfig, Key tokenDecryptionKey) { this.client = null; this.oidcConfig = oidcConfig; - this.tokenCustomizer = TenantFeatureFinder.find(oidcConfig); + this.tokenCustomizers = TenantFeatureFinder.find(oidcConfig); if (publicKeyEnc != null) { this.asymmetricKeyResolver = new LocalPublicKeyResolver(publicKeyEnc); } else if (oidcConfig.certificateChain.trustStoreFile.isPresent()) { @@ -274,17 +274,18 @@ private TokenVerificationResult verifyJwtTokenInternal(String token, } private String customizeJwtToken(String token) { - if (tokenCustomizer != null) { - JsonObject headers = AbstractJsonObjectResponse.toJsonObject( - OidcUtils.decodeJwtHeadersAsString(token)); - headers = tokenCustomizer.customizeHeaders(headers); - if (headers != null) { - String newHeaders = new String( - Base64.getUrlEncoder().withoutPadding().encode(headers.toString().getBytes()), - StandardCharsets.UTF_8); - int dotIndex = token.indexOf('.'); - String newToken = newHeaders + token.substring(dotIndex); - return newToken; + if (tokenCustomizers != null) { + for (TokenCustomizer tokenCustomizer : tokenCustomizers) { + JsonObject headers = AbstractJsonObjectResponse.toJsonObject(OidcUtils.decodeJwtHeadersAsString(token)); + headers = tokenCustomizer.customizeHeaders(headers); + if (headers != null) { + String newHeaders = new String( + Base64.getUrlEncoder().withoutPadding().encode(headers.toString().getBytes()), + StandardCharsets.UTF_8); + int dotIndex = token.indexOf('.'); + String newToken = newHeaders + token.substring(dotIndex); + return newToken; + } } } return token; diff --git a/extensions/oidc/runtime/src/main/java/io/quarkus/oidc/runtime/TenantFeatureFinder.java b/extensions/oidc/runtime/src/main/java/io/quarkus/oidc/runtime/TenantFeatureFinder.java index 603e0b9e2ac32..7fbc3ae2d1580 100644 --- a/extensions/oidc/runtime/src/main/java/io/quarkus/oidc/runtime/TenantFeatureFinder.java +++ b/extensions/oidc/runtime/src/main/java/io/quarkus/oidc/runtime/TenantFeatureFinder.java @@ -21,9 +21,9 @@ private TenantFeatureFinder() { } - public static TokenCustomizer find(OidcTenantConfig oidcConfig) { + public static List find(OidcTenantConfig oidcConfig) { if (oidcConfig == null) { - return null; + return List.of(); } ArcContainer container = Arc.container(); if (container != null) { @@ -31,25 +31,15 @@ public static TokenCustomizer find(OidcTenantConfig oidcConfig) { if (customizerName != null && !customizerName.isEmpty()) { InstanceHandle tokenCustomizer = container.instance(customizerName); if (tokenCustomizer.isAvailable()) { - return tokenCustomizer.get(); + return List.of(tokenCustomizer.get()); } else { throw new OIDCException("Unable to find TokenCustomizer " + customizerName); } - } else if (oidcConfig.tenantId.isPresent()) { - String tenantId = oidcConfig.tenantId.get(); - List list = findTenantFeaturesByTenantId(TokenCustomizer.class, tenantId, container); - if (!list.isEmpty()) { - if (list.size() >= 2) { - throw new OIDCException( - "Found multiple TokenCustomizers that are annotated with @TenantFeature that has tenantId (" - + tenantId + ")"); - } - return list.get(0); - } - + } else { + return find(oidcConfig, TokenCustomizer.class); } } - return null; + return List.of(); } public static List find(OidcTenantConfig oidcTenantConfig, Class tenantFeatureClass) { diff --git a/extensions/oidc/runtime/src/test/java/io/quarkus/oidc/runtime/OidcProviderTest.java b/extensions/oidc/runtime/src/test/java/io/quarkus/oidc/runtime/OidcProviderTest.java index fe13364f4e63b..79cc983cdc9d9 100644 --- a/extensions/oidc/runtime/src/test/java/io/quarkus/oidc/runtime/OidcProviderTest.java +++ b/extensions/oidc/runtime/src/test/java/io/quarkus/oidc/runtime/OidcProviderTest.java @@ -54,14 +54,14 @@ public void testAlgorithmCustomizer() throws Exception { } } - try (OidcProvider provider = new OidcProvider(null, oidcConfig, jwkSet, new TokenCustomizer() { + try (OidcProvider provider = new OidcProvider(null, oidcConfig, jwkSet, List.of(new TokenCustomizer() { @Override public JsonObject customizeHeaders(JsonObject headers) { return Json.createObjectBuilder(headers).add("alg", "RS256").build(); } - }, null, null)) { + }), null, null)) { TokenVerificationResult result = provider.verifyJwtToken(newToken, false, false, null); assertEquals("http://keycloak/realm", result.localVerificationResult.getString("iss")); }