Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Identity provider key caching behavior configurable #2920

Merged
merged 3 commits into from
Jun 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ public enum OAuthGroupMappingMode {
private boolean performRpInitiatedLogout = true;
@JsonInclude(JsonInclude.Include.NON_NULL)
private String authMethod;
private boolean cacheJwks = true;
peterhaochen47 marked this conversation as resolved.
Show resolved Hide resolved

public T setAuthUrl(URL authUrl) {
this.authUrl = authUrl;
Expand Down Expand Up @@ -143,6 +144,11 @@ public T setGroupMappingMode(OAuthGroupMappingMode externalGroupMappingMode) {
return (T) this;
}

public T setCacheJwks(final boolean cacheJwks) {
this.cacheJwks = cacheJwks;
return (T) this;
}

public void setPkce(final boolean pkce) {
this.pkce = pkce;
}
Expand Down Expand Up @@ -194,6 +200,7 @@ public boolean equals(Object o) {
if (pkce != that.pkce) return false;
if (performRpInitiatedLogout != that.performRpInitiatedLogout) return false;
if (!Objects.equals(authMethod, that.authMethod)) return false;
if (cacheJwks != that.cacheJwks) return false;
return Objects.equals(responseType, that.responseType);

}
Expand All @@ -220,6 +227,7 @@ public int hashCode() {
result = 31 * result + (pkce ? 1 : 0);
result = 31 * result + (performRpInitiatedLogout ? 1 : 0);
result = 31 * result + (authMethod != null ? authMethod.hashCode() : 0);
result = 31 * result + (cacheJwks ? 1 : 0);
return result;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,9 @@ protected void setCommonProperties(Map<String, Object> idpDefinitionMap, Abstrac
if (idpDefinitionMap.get("performRpInitiatedLogout") instanceof Boolean) {
idpDefinition.setPerformRpInitiatedLogout((boolean)idpDefinitionMap.get("performRpInitiatedLogout"));
}
if (idpDefinitionMap.get("cacheJwks") instanceof Boolean) {
idpDefinition.setCacheJwks((boolean)idpDefinitionMap.get("cacheJwks"));
}
}

private static Map<String, String> parseAdditionalParameters(Map<String, Object> idpDefinitionMap) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import org.cloudfoundry.identity.uaa.util.JsonUtils;
import org.springframework.http.HttpEntity;
import org.springframework.http.HttpMethod;
import org.springframework.http.HttpStatus;
import org.springframework.http.ResponseEntity;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.web.client.RestTemplate;
Expand Down Expand Up @@ -47,13 +49,13 @@ public void fetchMetadataAndUpdateDefinition(OIDCIdentityProviderDefinition defi
}
}

public JsonWebKeySet<JsonWebKey> fetchWebKeySet(AbstractExternalOAuthIdentityProviderDefinition config)
public JsonWebKeySet<JsonWebKey> fetchWebKeySet(AbstractExternalOAuthIdentityProviderDefinition<?> config)
throws OidcMetadataFetchingException {
URL tokenKeyUrl = config.getTokenKeyUrl();
if (tokenKeyUrl == null || !org.springframework.util.StringUtils.hasText(tokenKeyUrl.toString())) {
return new JsonWebKeySet<>(Collections.emptyList());
}
byte[] rawContents = getJsonBody(tokenKeyUrl.toString(), config.isSkipSslValidation(), getClientAuthHeader(config));
byte[] rawContents = getJsonBody(tokenKeyUrl.toString(), config.isSkipSslValidation(), config.isCacheJwks(), getClientAuthHeader(config));
if (rawContents == null || rawContents.length == 0) {
throw new OidcMetadataFetchingException("Unable to fetch verification keys");
}
Expand All @@ -68,7 +70,7 @@ public JsonWebKeySet<JsonWebKey> fetchWebKeySet(ClientJwtConfiguration clientJwt
if (clientJwtConfiguration.getJwkSet() != null) {
return clientJwtConfiguration.getJwkSet();
} else if (clientJwtConfiguration.getJwksUri() != null) {
byte[] rawContents = getJsonBody(clientJwtConfiguration.getJwksUri(), false, null);
byte[] rawContents = getJsonBody(clientJwtConfiguration.getJwksUri(), false, true, null);
if (rawContents != null && rawContents.length > 0) {
ClientJwtConfiguration clientKeys = ClientJwtConfiguration.parse(null, new String(rawContents, StandardCharsets.UTF_8));
if (clientKeys != null && clientKeys.getJwkSet() != null) {
Expand All @@ -79,21 +81,44 @@ public JsonWebKeySet<JsonWebKey> fetchWebKeySet(ClientJwtConfiguration clientJwt
throw new OidcMetadataFetchingException("Unable to fetch verification keys");
}

private byte[] getJsonBody(String uri, boolean isSkipSslValidation, String authorizationValue) {
private byte[] getJsonBody(String uri, boolean isSkipSslValidation, boolean isCached, String authorizationValue) {
MultiValueMap<String, String> headers = new LinkedMultiValueMap<>();
if (authorizationValue != null) {
headers.add("Authorization", authorizationValue);
}
headers.add("Accept", "application/json");
HttpEntity tokenKeyRequest = new HttpEntity<>(null, headers);
HttpEntity<Object> tokenKeyRequest = new HttpEntity<>(null, headers);
if (isCached) {
return getCachedResponse(uri, isSkipSslValidation, HttpMethod.GET, tokenKeyRequest);
} else {
return getResponse(uri, isSkipSslValidation, HttpMethod.GET, tokenKeyRequest);
}
}

private byte[] getResponse(String uri, boolean isSkipSslValidation, HttpMethod method, HttpEntity<Object> header) {
ResponseEntity<byte[]> responseEntity;
if (isSkipSslValidation) {
responseEntity = trustingRestTemplate.exchange(uri, method, header, byte[].class);
} else {
responseEntity = nonTrustingRestTemplate.exchange(uri, method, header, byte[].class);
}
if (responseEntity.getStatusCode() == HttpStatus.OK) {
return responseEntity.getBody();
} else {
throw new IllegalArgumentException(
"Unable to fetch content, status:" + responseEntity.getStatusCode().getReasonPhrase());
}
}

private byte[] getCachedResponse(String uri, boolean isSkipSslValidation, HttpMethod method, HttpEntity<Object> header) {
if (isSkipSslValidation) {
return contentCache.getUrlContent(uri, trustingRestTemplate, HttpMethod.GET, tokenKeyRequest);
return contentCache.getUrlContent(uri, trustingRestTemplate, method, header);
} else {
return contentCache.getUrlContent(uri, nonTrustingRestTemplate, HttpMethod.GET, tokenKeyRequest);
return contentCache.getUrlContent(uri, nonTrustingRestTemplate, method, header);
}
}

private String getClientAuthHeader(AbstractExternalOAuthIdentityProviderDefinition config) {
private String getClientAuthHeader(AbstractExternalOAuthIdentityProviderDefinition<?> config) {
if (config.getRelyingPartySecret() == null) {
return null;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ public void setup() {
idpDefinitionMap.put("tokenKeyUrl", "http://token-key.url");
idpDefinitionMap.put("logoutUrl", "http://logout.url");
idpDefinitionMap.put("clientAuthInBody", false);
idpDefinitionMap.put("cacheJwks", true);
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import org.mockito.Answers;
import org.springframework.http.HttpEntity;
import org.springframework.http.HttpMethod;
import org.springframework.http.HttpStatus;
import org.springframework.http.ResponseEntity;
import org.springframework.web.client.RestTemplate;

import java.net.MalformedURLException;
Expand Down Expand Up @@ -128,6 +130,49 @@ void shouldPerformTokenKeyUrlUsingCache() throws OidcMetadataFetchingException,
any(), any(), any(), any()
);
}

@Test
void shouldPerformTokenKeyUrlNoCacheUsed() throws OidcMetadataFetchingException, MalformedURLException {
definition.setTokenKeyUrl(new URL("http://should.be.updated"));
definition.setSkipSslValidation(false);
definition.setCacheJwks(false);

ResponseEntity<byte[]> responseEntity = mock(ResponseEntity.class);
when(restTemplate.exchange(anyString(), any(HttpMethod.class), any(HttpEntity.class), any(Class.class)))
.thenReturn(responseEntity);
when(responseEntity.getStatusCode()).thenReturn(HttpStatus.OK);
when(responseEntity.getBody()).thenReturn("{\"keys\":[{\"alg\":\"RS256\",\"e\":\"e\",\"kid\":\"id\",\"kty\":\"RSA\",\"n\":\"n\"}]}".getBytes());

metadataDiscoverer.fetchWebKeySet(definition);
definition.setSkipSslValidation(true);
metadataDiscoverer.fetchWebKeySet(definition);

verify(urlContentCache, times(0))
.getUrlContent(
any(), any(), any(), any()
);
verify(restTemplate, times(2)).exchange(anyString(), any(HttpMethod.class), any(HttpEntity.class), any(Class.class));
}

@Test
void shouldPerformTokenKeyUrlNoCacheUsedError() throws MalformedURLException {
definition.setTokenKeyUrl(new URL("http://should.be.updated"));
definition.setSkipSslValidation(false);
definition.setCacheJwks(false);

ResponseEntity<byte[]> responseEntity = mock(ResponseEntity.class);
when(restTemplate.exchange(anyString(), any(HttpMethod.class), any(HttpEntity.class), any(Class.class)))
.thenReturn(responseEntity);
when(responseEntity.getStatusCode()).thenReturn(HttpStatus.FORBIDDEN);

assertThrows(IllegalArgumentException.class, () -> metadataDiscoverer.fetchWebKeySet(definition));

verify(urlContentCache, times(0))
.getUrlContent(
any(), any(), any(), any()
);
verify(restTemplate, times(1)).exchange(anyString(), any(HttpMethod.class), any(HttpEntity.class), any(Class.class));
}
}

@Nested
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -575,6 +575,7 @@ void createOAuthIdentityProvider() throws Exception {
definition.setAttributeMappings(getAttributeMappingMap());
definition.setUserPropagationParameter("username");
definition.setPkce(true);
definition.setCacheJwks(true);
definition.setPerformRpInitiatedLogout(true);
identityProvider.setConfig(definition);
identityProvider.setSerializeConfigRaw(true);
Expand All @@ -584,7 +585,8 @@ void createOAuthIdentityProvider() throws Exception {
fieldWithPath("originKey").required().description("A unique alias for a OAuth provider"),
fieldWithPath("config.authUrl").required().type(STRING).description("The OAuth 2.0 authorization endpoint URL"),
fieldWithPath("config.tokenUrl").required().type(STRING).description("The OAuth 2.0 token endpoint URL"),
fieldWithPath("config.tokenKeyUrl").optional(null).type(STRING).description("The URL of the token key endpoint which renders a verification key for validating token signatures"),
fieldWithPath("config.tokenKeyUrl").optional(null).type(STRING).description("The URL of the token key endpoint which renders the JWKS (verification key for validating token signatures)."),
fieldWithPath("config.cacheJwks").optional(true).type(BOOLEAN).description("<small><mark>UAA 77.11.0</mark></small>. Option to enable caching for the JWKS (verification key for validating token signatures). Setting it to `true` increases UAA performance and is hence recommended. Setting it to `false` forces UAA to fetch the remote JWKS at each token validation, which impacts performance but may be required for when the remote JWKS changes very frequently.").attributes(new Attributes.Attribute("constraints", "Used only if `discoveryUrl` or `tokenKeyUrl` is set.")),
fieldWithPath("config.tokenKey").optional(null).type(STRING).description("A verification key for validating token signatures, set to null if a `tokenKeyUrl` is provided."),
fieldWithPath("config.userInfoUrl").optional(null).type(STRING).description("A URL for fetching user info attributes when queried with the obtained token authorization."),
fieldWithPath("config.showLinkText").optional(true).type(BOOLEAN).description("A flag controlling whether a link to this provider's login will be shown on the UAA login page"),
Expand Down Expand Up @@ -674,6 +676,7 @@ void createOidcIdentityProvider() throws Exception {
definition.setRelyingPartySecret("secret");
definition.setShowLinkText(false);
definition.setPkce(true);
definition.setCacheJwks(true);
definition.setPerformRpInitiatedLogout(true);
definition.setAttributeMappings(getAttributeMappingMap());
definition.setUserPropagationParameter("username");
Expand All @@ -691,7 +694,8 @@ void createOidcIdentityProvider() throws Exception {
fieldWithPath("config.discoveryUrl").optional(null).type(STRING).description("The OpenID Connect Discovery URL, typically ends with /.well-known/openid-configurationmit "),
fieldWithPath("config.authUrl").optional().type(STRING).description("The OIDC 1.0 authorization endpoint URL. This can be left blank if a discovery URL is provided. If both are provided, this property overrides the discovery URL.").attributes(new Attributes.Attribute("constraints", "Required unless `discoveryUrl` is set.")),
fieldWithPath("config.tokenUrl").optional().type(STRING).description("The OIDC 1.0 token endpoint URL. This can be left blank if a discovery URL is provided. If both are provided, this property overrides the discovery URL.").attributes(new Attributes.Attribute("constraints", "Required unless `discoveryUrl` is set.")),
fieldWithPath("config.tokenKeyUrl").optional(null).type(STRING).description("The URL of the token key endpoint which renders a verification key for validating token signatures. This can be left blank if a discovery URL is provided. If both are provided, this property overrides the discovery URL.").attributes(new Attributes.Attribute("constraints", "Required unless `discoveryUrl` is set.")),
fieldWithPath("config.tokenKeyUrl").optional(null).type(STRING).description("The URL of the token key endpoint which renders the JWKS (verification key for validating token signatures). This can be left blank if a discovery URL is provided. If both are provided, this property overrides the discovery URL.").attributes(new Attributes.Attribute("constraints", "Required unless `discoveryUrl` is set.")),
fieldWithPath("config.cacheJwks").optional(true).type(BOOLEAN).description("<small><mark>UAA 77.11.0</mark></small>. Option to enable caching for the JWKS (verification key for validating token signatures). Setting it to `true` increases UAA performance and is hence recommended. Setting it to `false` forces UAA to fetch the remote JWKS at each token validation, which impacts performance but may be required for when the remote JWKS changes very frequently.").attributes(new Attributes.Attribute("constraints", "Used only if `discoveryUrl` or `tokenKeyUrl` is set.")),
fieldWithPath("config.tokenKey").optional(null).type(STRING).description("A verification key for validating token signatures. We recommend not setting this as it will not allow for key rotation. This can be left blank if a discovery URL is provided. If both are provided, this property overrides the discovery URL.").attributes(new Attributes.Attribute("constraints", "Required unless `discoveryUrl` is set.")),
fieldWithPath("config.showLinkText").optional(true).type(BOOLEAN).description("A flag controlling whether a link to this provider's login will be shown on the UAA login page"),
fieldWithPath("config.linkText").optional(null).type(STRING).description("Text to use for the login link to the provider"),
Expand Down
Loading