From b7c4c78cf2224ce8b6b7917b30d9971330faf70c Mon Sep 17 00:00:00 2001 From: Markus Strehle <11627201+strehle@users.noreply.github.com> Date: Wed, 12 Jul 2023 06:48:02 +0200 Subject: [PATCH] feature: Store client authentication method in JWT (#2385) * Store client authentication method in JWT Why: UAA historical supported only secret based client authentication, so no need to have this information on client side. No there is a public usage, later private_key_jwt should be supported. Maybe then tls_client_auth. * refactor * tests * tests * extend tests for client_auth_method in refresh token * more tests * review --- .../uaa/oauth/token/ClaimConstants.java | 1 + ...tClientParametersAuthenticationFilter.java | 8 ++ .../ClientDetailsAuthenticationProvider.java | 1 + .../UaaAuthenticationDetails.java | 9 ++ .../identity/uaa/oauth/UaaTokenServices.java | 15 ++++ .../oauth/openid/UserAuthenticationData.java | 3 + ...EnhancedAuthorizationCodeTokenGranter.java | 6 ++ .../uaa/util/UaaSecurityContextUtils.java | 31 +++++++ ...entParametersAuthenticationFilterTest.java | 85 ++++++++++++++++++- .../uaa/oauth/openid/IdTokenCreatorTest.java | 4 +- ...ncedAuthorizationCodeTokenGranterTest.java | 28 +++++- .../uaa/util/UaaSecurityContextUtilsTest.java | 47 ++++++++++ .../uaa/oauth/UaaTokenServicesTests.java | 17 ++++ 13 files changed, 252 insertions(+), 3 deletions(-) create mode 100644 server/src/main/java/org/cloudfoundry/identity/uaa/util/UaaSecurityContextUtils.java create mode 100644 server/src/test/java/org/cloudfoundry/identity/uaa/util/UaaSecurityContextUtilsTest.java diff --git a/model/src/main/java/org/cloudfoundry/identity/uaa/oauth/token/ClaimConstants.java b/model/src/main/java/org/cloudfoundry/identity/uaa/oauth/token/ClaimConstants.java index ccee9736ef5..8f3a81adffa 100644 --- a/model/src/main/java/org/cloudfoundry/identity/uaa/oauth/token/ClaimConstants.java +++ b/model/src/main/java/org/cloudfoundry/identity/uaa/oauth/token/ClaimConstants.java @@ -55,4 +55,5 @@ public class ClaimConstants { public static final String AMR = "amr"; public static final String ACR = "acr"; public static final String PREVIOUS_LOGON_TIME = "previous_logon_time"; + public static final String CLIENT_AUTH_METHOD = "client_auth_method"; } diff --git a/server/src/main/java/org/cloudfoundry/identity/uaa/authentication/AbstractClientParametersAuthenticationFilter.java b/server/src/main/java/org/cloudfoundry/identity/uaa/authentication/AbstractClientParametersAuthenticationFilter.java index e2a36ce0a7b..c2a7de66538 100644 --- a/server/src/main/java/org/cloudfoundry/identity/uaa/authentication/AbstractClientParametersAuthenticationFilter.java +++ b/server/src/main/java/org/cloudfoundry/identity/uaa/authentication/AbstractClientParametersAuthenticationFilter.java @@ -14,6 +14,7 @@ */ package org.cloudfoundry.identity.uaa.authentication; +import org.cloudfoundry.identity.uaa.oauth.token.ClaimConstants; import org.cloudfoundry.identity.uaa.util.UaaStringUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -129,6 +130,13 @@ private Authentication performClientAuthentication(HttpServletRequest req, Map parameterMap; diff --git a/server/src/main/java/org/cloudfoundry/identity/uaa/oauth/UaaTokenServices.java b/server/src/main/java/org/cloudfoundry/identity/uaa/oauth/UaaTokenServices.java index 2e936ee1c81..ebb1a4a3136 100644 --- a/server/src/main/java/org/cloudfoundry/identity/uaa/oauth/UaaTokenServices.java +++ b/server/src/main/java/org/cloudfoundry/identity/uaa/oauth/UaaTokenServices.java @@ -40,6 +40,7 @@ import org.cloudfoundry.identity.uaa.util.JsonUtils; import org.cloudfoundry.identity.uaa.util.TimeService; import org.cloudfoundry.identity.uaa.util.JwtTokenSignedByThisUAA; +import org.cloudfoundry.identity.uaa.util.UaaSecurityContextUtils; import org.cloudfoundry.identity.uaa.util.UaaTokenUtils; import org.cloudfoundry.identity.uaa.zone.MultitenantClientServices; import org.cloudfoundry.identity.uaa.zone.IdentityZoneHolder; @@ -102,6 +103,7 @@ import static org.cloudfoundry.identity.uaa.oauth.token.ClaimConstants.AUTH_TIME; import static org.cloudfoundry.identity.uaa.oauth.token.ClaimConstants.AZP; import static org.cloudfoundry.identity.uaa.oauth.token.ClaimConstants.CID; +import static org.cloudfoundry.identity.uaa.oauth.token.ClaimConstants.CLIENT_AUTH_METHOD; import static org.cloudfoundry.identity.uaa.oauth.token.ClaimConstants.CLIENT_ID; import static org.cloudfoundry.identity.uaa.oauth.token.ClaimConstants.EMAIL; import static org.cloudfoundry.identity.uaa.oauth.token.ClaimConstants.EXPIRY_IN_SECONDS; @@ -276,6 +278,7 @@ public OAuth2AccessToken refreshAccessToken(String refreshTokenValue, TokenReque getUserAttributes(claims.getUserId()), claims.getNonce(), claims.getGrantType(), + UaaSecurityContextUtils.getClientAuthenticationMethod(), generateUniqueTokenId() ); @@ -414,6 +417,11 @@ private CompositeToken createCompositeToken(String tokenId, info.put(ADDITIONAL_AZ_ATTR, additionalAuthorizationAttributes); } + String clientAuthentication = userAuthenticationData.clientAuth; + if (clientAuthentication != null) { + addRootClaimEntry(additionalRootClaims, CLIENT_AUTH_METHOD, clientAuthentication); + } + String nonce = userAuthenticationData.nonce; if (nonce != null) { info.put(NONCE, nonce); @@ -459,6 +467,12 @@ private CompositeToken createCompositeToken(String tokenId, return compositeToken; } + private static Map addRootClaimEntry(Map additionalRootClaims, String entry, String value) { + Map claims = additionalRootClaims != null ? additionalRootClaims : new HashMap<>(); + claims.put(entry, value); + return claims; + } + private KeyInfo getActiveKeyInfo() { return ofNullable(keyInfoService.getActiveKey()) .orElseThrow(() -> new InternalAuthenticationServiceException("Unable to sign token, misconfigured JWT signing keys")); @@ -635,6 +649,7 @@ public OAuth2AccessToken createAccessToken(OAuth2Authentication authentication) userAttributesForIdToken, nonce, grantType, + ofNullable(oAuth2Request.getExtensions().get(CLIENT_AUTH_METHOD)).map(String.class::cast).orElse(null), tokenId); String refreshTokenValue = refreshToken != null ? refreshToken.getValue() : null; diff --git a/server/src/main/java/org/cloudfoundry/identity/uaa/oauth/openid/UserAuthenticationData.java b/server/src/main/java/org/cloudfoundry/identity/uaa/oauth/openid/UserAuthenticationData.java index e24868cbee0..3fbf5c4ac55 100644 --- a/server/src/main/java/org/cloudfoundry/identity/uaa/oauth/openid/UserAuthenticationData.java +++ b/server/src/main/java/org/cloudfoundry/identity/uaa/oauth/openid/UserAuthenticationData.java @@ -14,6 +14,7 @@ public class UserAuthenticationData { public final Map> userAttributes; public final String nonce; public final String grantType; + public final String clientAuth; public final String jti; public UserAuthenticationData(Date authTime, @@ -24,6 +25,7 @@ public UserAuthenticationData(Date authTime, Map> userAttributes, String nonce, String grantType, + String clientAuth, String jti) { this.authTime = authTime; this.authenticationMethods = authenticationMethods; @@ -33,6 +35,7 @@ public UserAuthenticationData(Date authTime, this.userAttributes = userAttributes; this.nonce = nonce; this.grantType = grantType; + this.clientAuth = clientAuth; this.jti = jti; } } diff --git a/server/src/main/java/org/cloudfoundry/identity/uaa/oauth/token/PkceEnhancedAuthorizationCodeTokenGranter.java b/server/src/main/java/org/cloudfoundry/identity/uaa/oauth/token/PkceEnhancedAuthorizationCodeTokenGranter.java index 112133bb434..22dc40c2c4e 100644 --- a/server/src/main/java/org/cloudfoundry/identity/uaa/oauth/token/PkceEnhancedAuthorizationCodeTokenGranter.java +++ b/server/src/main/java/org/cloudfoundry/identity/uaa/oauth/token/PkceEnhancedAuthorizationCodeTokenGranter.java @@ -5,6 +5,7 @@ import org.cloudfoundry.identity.uaa.oauth.pkce.PkceValidationException; import org.cloudfoundry.identity.uaa.oauth.pkce.PkceValidationService; +import org.cloudfoundry.identity.uaa.util.UaaSecurityContextUtils; import org.springframework.security.core.Authentication; import org.springframework.security.oauth2.common.exceptions.InvalidClientException; import org.springframework.security.oauth2.common.exceptions.InvalidGrantException; @@ -107,6 +108,11 @@ protected OAuth2Authentication getOAuth2Authentication(ClientDetails client, Tok Authentication userAuth = storedAuth.getUserAuthentication(); + String clientAuthentication = UaaSecurityContextUtils.getClientAuthenticationMethod(); + if (clientAuthentication != null) { + finalStoredOAuth2Request.getExtensions().put(ClaimConstants.CLIENT_AUTH_METHOD, clientAuthentication); + } + return new OAuth2Authentication(finalStoredOAuth2Request, userAuth); } diff --git a/server/src/main/java/org/cloudfoundry/identity/uaa/util/UaaSecurityContextUtils.java b/server/src/main/java/org/cloudfoundry/identity/uaa/util/UaaSecurityContextUtils.java new file mode 100644 index 00000000000..dfe5be3947d --- /dev/null +++ b/server/src/main/java/org/cloudfoundry/identity/uaa/util/UaaSecurityContextUtils.java @@ -0,0 +1,31 @@ +package org.cloudfoundry.identity.uaa.util; + +import org.springframework.security.core.Authentication; +import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.oauth2.provider.OAuth2Authentication; + +import java.io.Serializable; +import java.util.Map; + +import static org.cloudfoundry.identity.uaa.oauth.token.ClaimConstants.CLIENT_AUTH_METHOD; + +public final class UaaSecurityContextUtils { + + private UaaSecurityContextUtils() {} + + public static String getClientAuthenticationMethod() { + Authentication a = SecurityContextHolder.getContext().getAuthentication(); + if (!(a instanceof OAuth2Authentication)) { + return null; + } + OAuth2Authentication oAuth2Authentication = (OAuth2Authentication) a; + + Map extensions = oAuth2Authentication.getOAuth2Request().getExtensions(); + if (extensions.isEmpty()) { + return null; + } + + return (String) extensions.get(CLIENT_AUTH_METHOD); + } + +} diff --git a/server/src/test/java/org/cloudfoundry/identity/uaa/authentication/ClientParametersAuthenticationFilterTest.java b/server/src/test/java/org/cloudfoundry/identity/uaa/authentication/ClientParametersAuthenticationFilterTest.java index f8f0296a960..13497ebf619 100644 --- a/server/src/test/java/org/cloudfoundry/identity/uaa/authentication/ClientParametersAuthenticationFilterTest.java +++ b/server/src/test/java/org/cloudfoundry/identity/uaa/authentication/ClientParametersAuthenticationFilterTest.java @@ -19,14 +19,17 @@ import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.security.authentication.AuthenticationManager; import org.springframework.security.authentication.BadCredentialsException; +import org.springframework.security.core.Authentication; import org.springframework.security.web.AuthenticationEntryPoint; import javax.servlet.ServletException; import java.io.IOException; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.atLeast; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.when; @@ -56,4 +59,84 @@ public void doesNotContinueWithFilterChain_IfAuthenticationException() throws IO verifyNoMoreInteractions(chain); } -} \ No newline at end of file + @Test + public void testStoreClientAuthenticationMethod() throws IOException, ServletException { + ClientParametersAuthenticationFilter filter = new ClientParametersAuthenticationFilter(); + + AuthenticationEntryPoint authenticationEntryPoint = mock(AuthenticationEntryPoint.class); + filter.setAuthenticationEntryPoint(authenticationEntryPoint); + AuthenticationManager clientAuthenticationManager = mock(AuthenticationManager.class); + filter.setClientAuthenticationManager(clientAuthenticationManager); + + Authentication authentication = mock(Authentication.class); + MockHttpServletRequest request = new MockHttpServletRequest(); + UaaAuthenticationDetails authenticationDetails = mock(UaaAuthenticationDetails.class); + when(clientAuthenticationManager.authenticate(Mockito.any())).thenReturn(authentication); + when(authentication.isAuthenticated()).thenReturn(true); + when(authentication.getDetails()).thenReturn(authenticationDetails); + when(authenticationDetails.getAuthenticationMethod()).thenReturn("none"); + + MockFilterChain chain = mock(MockFilterChain.class); + request.addHeader("Content-Type", MediaType.APPLICATION_FORM_URLENCODED_VALUE); + MockHttpServletResponse response = new MockHttpServletResponse(); + + filter.doFilter(request, response, chain); + + verifyNoInteractions(authenticationEntryPoint); + verify(chain).doFilter(request, response); + verify(authenticationDetails, atLeast(1)).getAuthenticationMethod(); + } + + @Test + public void testStoreClientAuthenticationMethodNoDetails() throws IOException, ServletException { + ClientParametersAuthenticationFilter filter = new ClientParametersAuthenticationFilter(); + + AuthenticationEntryPoint authenticationEntryPoint = mock(AuthenticationEntryPoint.class); + filter.setAuthenticationEntryPoint(authenticationEntryPoint); + AuthenticationManager clientAuthenticationManager = mock(AuthenticationManager.class); + filter.setClientAuthenticationManager(clientAuthenticationManager); + + Authentication authentication = mock(Authentication.class); + MockHttpServletRequest request = new MockHttpServletRequest(); + when(clientAuthenticationManager.authenticate(Mockito.any())).thenReturn(authentication); + when(authentication.isAuthenticated()).thenReturn(true); + when(authentication.getDetails()).thenReturn(null); + + MockFilterChain chain = mock(MockFilterChain.class); + request.addHeader("Content-Type", MediaType.APPLICATION_FORM_URLENCODED_VALUE); + MockHttpServletResponse response = new MockHttpServletResponse(); + + filter.doFilter(request, response, chain); + + verifyNoInteractions(authenticationEntryPoint); + verify(chain).doFilter(request, response); + } + + @Test + public void testStoreClientAuthenticationMethodNoMethod() throws IOException, ServletException { + ClientParametersAuthenticationFilter filter = new ClientParametersAuthenticationFilter(); + + AuthenticationEntryPoint authenticationEntryPoint = mock(AuthenticationEntryPoint.class); + filter.setAuthenticationEntryPoint(authenticationEntryPoint); + AuthenticationManager clientAuthenticationManager = mock(AuthenticationManager.class); + filter.setClientAuthenticationManager(clientAuthenticationManager); + + Authentication authentication = mock(Authentication.class); + MockHttpServletRequest request = new MockHttpServletRequest(); + UaaAuthenticationDetails authenticationDetails = mock(UaaAuthenticationDetails.class); + when(clientAuthenticationManager.authenticate(Mockito.any())).thenReturn(authentication); + when(authentication.isAuthenticated()).thenReturn(true); + when(authentication.getDetails()).thenReturn(authenticationDetails); + when(authenticationDetails.getAuthenticationMethod()).thenReturn(null); + + MockFilterChain chain = mock(MockFilterChain.class); + request.addHeader("Content-Type", MediaType.APPLICATION_FORM_URLENCODED_VALUE); + MockHttpServletResponse response = new MockHttpServletResponse(); + + filter.doFilter(request, response, chain); + + verifyNoInteractions(authenticationEntryPoint); + verify(chain).doFilter(request, response); + verify(authenticationDetails).getAuthenticationMethod(); + } +} diff --git a/server/src/test/java/org/cloudfoundry/identity/uaa/oauth/openid/IdTokenCreatorTest.java b/server/src/test/java/org/cloudfoundry/identity/uaa/oauth/openid/IdTokenCreatorTest.java index b7482471d58..7a2e1ec2cbc 100644 --- a/server/src/test/java/org/cloudfoundry/identity/uaa/oauth/openid/IdTokenCreatorTest.java +++ b/server/src/test/java/org/cloudfoundry/identity/uaa/oauth/openid/IdTokenCreatorTest.java @@ -29,7 +29,6 @@ import static org.hamcrest.CoreMatchers.*; import static org.hamcrest.core.IsCollectionContaining.hasItems; import static org.junit.Assert.assertThat; -import static org.junit.jupiter.api.Assertions.assertThrows; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -149,6 +148,7 @@ void setup() throws Exception { userAttributes, nonce, grantType, + null, jti); excludedClaims = new HashSet<>(); @@ -258,6 +258,7 @@ void create_setsRolesToNullIfRolesAreNull() throws IdTokenCreationException { userAttributes, nonce, grantType, + null, jti); IdToken idToken = tokenCreator.create(clientDetails, user, userAuthenticationData); @@ -286,6 +287,7 @@ void create_doesntSetUserAttributesIfTheyAreNull() throws IdTokenCreationExcepti null, nonce, grantType, + null, jti); IdToken idToken = tokenCreator.create(clientDetails, user, userAuthenticationData); diff --git a/server/src/test/java/org/cloudfoundry/identity/uaa/oauth/token/PkceEnhancedAuthorizationCodeTokenGranterTest.java b/server/src/test/java/org/cloudfoundry/identity/uaa/oauth/token/PkceEnhancedAuthorizationCodeTokenGranterTest.java index 49fdc4e9c55..84d0cf067e3 100644 --- a/server/src/test/java/org/cloudfoundry/identity/uaa/oauth/token/PkceEnhancedAuthorizationCodeTokenGranterTest.java +++ b/server/src/test/java/org/cloudfoundry/identity/uaa/oauth/token/PkceEnhancedAuthorizationCodeTokenGranterTest.java @@ -21,11 +21,15 @@ import static org.cloudfoundry.identity.uaa.oauth.TokenTestSupport.GRANT_TYPE; import static org.cloudfoundry.identity.uaa.oauth.token.TokenConstants.GRANT_TYPE_AUTHORIZATION_CODE; import static org.cloudfoundry.identity.uaa.util.JwtTokenSignedByThisUAATest.CLIENT_ID; +import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.atMost; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; class PkceEnhancedAuthorizationCodeTokenGranterTest { @@ -67,12 +71,12 @@ public void setup() { when(clientDetailsService.loadClientByClientId(eq(requestingClient.getClientId()), anyString())).thenReturn(requestingClient); when(authorizationCodeServices.consumeAuthorizationCode("1234")).thenReturn(authentication); when(authentication.getOAuth2Request()).thenReturn(oAuth2Request); - when(oAuth2Request.getRequestParameters()).thenReturn(requestParameters); requestParameters = new HashMap<>(); requestParameters.put(GRANT_TYPE, TokenConstants.GRANT_TYPE_USER_TOKEN); requestParameters.put(CLIENT_ID, requestingClient.getClientId()); requestParameters.put("code", "1234"); requestParameters.put(PkceValidationService.CODE_VERIFIER, "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM"); + when(oAuth2Request.getRequestParameters()).thenReturn(requestParameters); tokenRequest = new UserTokenGranterTest.PublicTokenRequest(); tokenRequest.setRequestParameters(requestParameters); @@ -84,4 +88,26 @@ void getOAuth2Authentication() throws PkceValidationException { when(pkceValidationService.checkAndValidate(any(), any(), any())).thenReturn(false); assertThrows(InvalidGrantException.class, () -> granter.getOAuth2Authentication(requestingClient, tokenRequest)); } + + @Test + void getOAuth2AuthenticationMethod() throws PkceValidationException { + HashMap authMap = new HashMap(); + authMap.put(ClaimConstants.CLIENT_AUTH_METHOD, "none"); + when(pkceValidationService.checkAndValidate(any(), any(), any())).thenReturn(true); + when(oAuth2Request.getExtensions()).thenReturn(authMap); + when(oAuth2Request.createOAuth2Request(any())).thenReturn(oAuth2Request); + assertNotNull(granter.getOAuth2Authentication(requestingClient, tokenRequest)); + verify(oAuth2Request, times(2)).getExtensions(); + } + + @Test + void getOAuth2AuthenticationNoMethod() throws PkceValidationException { + HashMap authMap = new HashMap(); + authMap.put(ClaimConstants.CLIENT_AUTH_METHOD, null); + when(pkceValidationService.checkAndValidate(any(), any(), any())).thenReturn(true); + when(oAuth2Request.getExtensions()).thenReturn(authMap); + when(oAuth2Request.createOAuth2Request(any())).thenReturn(oAuth2Request); + assertNotNull(granter.getOAuth2Authentication(requestingClient, tokenRequest)); + verify(oAuth2Request, atMost(1)).getExtensions(); + } } \ No newline at end of file diff --git a/server/src/test/java/org/cloudfoundry/identity/uaa/util/UaaSecurityContextUtilsTest.java b/server/src/test/java/org/cloudfoundry/identity/uaa/util/UaaSecurityContextUtilsTest.java new file mode 100644 index 00000000000..56c2526fdc6 --- /dev/null +++ b/server/src/test/java/org/cloudfoundry/identity/uaa/util/UaaSecurityContextUtilsTest.java @@ -0,0 +1,47 @@ +package org.cloudfoundry.identity.uaa.util; + +import org.cloudfoundry.identity.uaa.oauth.token.ClaimConstants; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.oauth2.provider.OAuth2Authentication; +import org.springframework.security.oauth2.provider.OAuth2Request; + +import java.util.HashMap; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +class UaaSecurityContextUtilsTest { + + private OAuth2Request auth2Request; + + @BeforeEach + void setUp() { + OAuth2Authentication authentication = mock(OAuth2Authentication.class); + SecurityContextHolder.getContext().setAuthentication(authentication); + auth2Request = mock(OAuth2Request.class); + when(auth2Request.getExtensions()).thenReturn(new HashMap<>()); + when(authentication.getOAuth2Request()).thenReturn(auth2Request); + } + + @Test + void getNoClientAuthenticationMethod() { + assertNull(UaaSecurityContextUtils.getClientAuthenticationMethod()); + } + + @Test + void getNullClientAuthenticationMethod() { + SecurityContextHolder.getContext().setAuthentication(null); + assertNull(UaaSecurityContextUtils.getClientAuthenticationMethod()); + } + + @Test + void getClientAuthenticationMethod() { + when(auth2Request.getExtensions()).thenReturn(Map.of(ClaimConstants.CLIENT_AUTH_METHOD, "none")); + assertEquals("none", UaaSecurityContextUtils.getClientAuthenticationMethod()); + } +} diff --git a/uaa/src/test/java/org/cloudfoundry/identity/uaa/oauth/UaaTokenServicesTests.java b/uaa/src/test/java/org/cloudfoundry/identity/uaa/oauth/UaaTokenServicesTests.java index 96128a0a1b1..019224a5632 100644 --- a/uaa/src/test/java/org/cloudfoundry/identity/uaa/oauth/UaaTokenServicesTests.java +++ b/uaa/src/test/java/org/cloudfoundry/identity/uaa/oauth/UaaTokenServicesTests.java @@ -33,10 +33,12 @@ import org.junit.jupiter.params.provider.ValueSource; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.security.core.GrantedAuthority; +import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.oauth2.common.OAuth2AccessToken; import org.springframework.security.oauth2.provider.AuthorizationRequest; import org.springframework.security.oauth2.provider.NoSuchClientException; import org.springframework.security.oauth2.provider.OAuth2Authentication; +import org.springframework.security.oauth2.provider.OAuth2Request; import org.springframework.security.oauth2.provider.TokenRequest; import org.springframework.test.annotation.DirtiesContext; import org.springframework.test.context.TestPropertySource; @@ -64,11 +66,14 @@ import static org.hamcrest.CoreMatchers.nullValue; import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.greaterThan; +import static org.hamcrest.Matchers.hasEntry; import static org.hamcrest.Matchers.hasItem; import static org.hamcrest.Matchers.hasKey; import static org.junit.Assert.assertThat; import static org.junit.jupiter.api.Assertions.assertAll; import static org.junit.jupiter.api.Assumptions.assumeTrue; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; @DisplayName("Uaa Token Services Tests") @DefaultTestContext @@ -257,6 +262,11 @@ class WhenRefreshGrant { private CompositeExpiringOAuth2RefreshToken refreshToken; + @AfterEach + void cleanup() { + SecurityContextHolder.clearContext(); + } + @Test void happyCase() { assumeTrue(waitForClient("jku_test", 5), "Test client needs to be setup for this test"); @@ -275,9 +285,16 @@ void happyCase() { UaaUser uaaUser = jdbcUaaUserDatabase.retrieveUserByName("admin", "uaa"); refreshToken = refreshTokenCreator.createRefreshToken(uaaUser, refreshTokenRequestData, null); assertThat(refreshToken, is(notNullValue())); + OAuth2Authentication authentication = mock(OAuth2Authentication.class); + SecurityContextHolder.getContext().setAuthentication(authentication); + OAuth2Request auth2Request = mock(OAuth2Request.class); + when(authentication.getOAuth2Request()).thenReturn(auth2Request); + when(auth2Request.getExtensions()).thenReturn(Map.of(ClaimConstants.CLIENT_AUTH_METHOD, "none")); OAuth2AccessToken refreshedToken = tokenServices.refreshAccessToken(this.refreshToken.getValue(), new TokenRequest(new HashMap<>(), "jku_test", Lists.newArrayList("openid", "user_attributes"), GRANT_TYPE_REFRESH_TOKEN)); assertThat(refreshedToken, is(notNullValue())); + Map claims = UaaTokenUtils.getClaims(refreshedToken.getValue()); + assertThat(claims, hasEntry(ClaimConstants.CLIENT_AUTH_METHOD, "none")); } @MethodSource("org.cloudfoundry.identity.uaa.oauth.UaaTokenServicesTests#dates")