Skip to content

Commit

Permalink
feature: Store client authentication method in JWT (#2385)
Browse files Browse the repository at this point in the history
* Store client authentication method in JWT

Why: UAA historical supported only secret based client authentication, so no need to have this information on client side.
No there is a public usage, later private_key_jwt should be supported. Maybe then tls_client_auth.

* refactor

* tests

* tests

* extend tests for client_auth_method in refresh token

* more tests

* review
  • Loading branch information
strehle authored Jul 12, 2023
1 parent 5eb4687 commit b7c4c78
Show file tree
Hide file tree
Showing 13 changed files with 252 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -55,4 +55,5 @@ public class ClaimConstants {
public static final String AMR = "amr";
public static final String ACR = "acr";
public static final String PREVIOUS_LOGON_TIME = "previous_logon_time";
public static final String CLIENT_AUTH_METHOD = "client_auth_method";
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
*/
package org.cloudfoundry.identity.uaa.authentication;

import org.cloudfoundry.identity.uaa.oauth.token.ClaimConstants;
import org.cloudfoundry.identity.uaa.util.UaaStringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand Down Expand Up @@ -129,6 +130,13 @@ private Authentication performClientAuthentication(HttpServletRequest req, Map<S
AuthorizationRequest authorizationRequest = new AuthorizationRequest(clientId, getScope(req));
authorizationRequest.setRequestParameters(getSingleValueMap(req));
authorizationRequest.setApproved(true);

if (auth.getDetails() instanceof UaaAuthenticationDetails) {
UaaAuthenticationDetails clientDetails = (UaaAuthenticationDetails) auth.getDetails();
if (clientDetails.getAuthenticationMethod() != null) {
authorizationRequest.setExtensions(Map.of(ClaimConstants.CLIENT_AUTH_METHOD, clientDetails.getAuthenticationMethod()));
}
}
//must set this to true in order for
//Authentication.isAuthenticated to return true
OAuth2Authentication result = new OAuth2Authentication(authorizationRequest.createOAuth2Request(), null);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ protected void additionalAuthenticationChecks(UserDetails userDetails, UsernameP
Object allowPublic = uaaClient.getAdditionalInformation().get(ClientConstants.ALLOW_PUBLIC);
if (allowPublic instanceof String && Boolean.TRUE.toString().equalsIgnoreCase((String)allowPublic) ||
allowPublic instanceof Boolean && Boolean.TRUE.equals(allowPublic)) {
((UaaAuthenticationDetails) authentication.getDetails()).setAuthenticationMethod("none");
break;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,15 @@ public class UaaAuthenticationDetails implements Serializable {

private String clientId;

private transient String authenticationMethod;
public String getAuthenticationMethod() {
return this.authenticationMethod;
}

public void setAuthenticationMethod(final String authenticationMethod) {
this.authenticationMethod = authenticationMethod;
}

@JsonIgnore
private Map<String,String[]> parameterMap;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import org.cloudfoundry.identity.uaa.util.JsonUtils;
import org.cloudfoundry.identity.uaa.util.TimeService;
import org.cloudfoundry.identity.uaa.util.JwtTokenSignedByThisUAA;
import org.cloudfoundry.identity.uaa.util.UaaSecurityContextUtils;
import org.cloudfoundry.identity.uaa.util.UaaTokenUtils;
import org.cloudfoundry.identity.uaa.zone.MultitenantClientServices;
import org.cloudfoundry.identity.uaa.zone.IdentityZoneHolder;
Expand Down Expand Up @@ -102,6 +103,7 @@
import static org.cloudfoundry.identity.uaa.oauth.token.ClaimConstants.AUTH_TIME;
import static org.cloudfoundry.identity.uaa.oauth.token.ClaimConstants.AZP;
import static org.cloudfoundry.identity.uaa.oauth.token.ClaimConstants.CID;
import static org.cloudfoundry.identity.uaa.oauth.token.ClaimConstants.CLIENT_AUTH_METHOD;
import static org.cloudfoundry.identity.uaa.oauth.token.ClaimConstants.CLIENT_ID;
import static org.cloudfoundry.identity.uaa.oauth.token.ClaimConstants.EMAIL;
import static org.cloudfoundry.identity.uaa.oauth.token.ClaimConstants.EXPIRY_IN_SECONDS;
Expand Down Expand Up @@ -276,6 +278,7 @@ public OAuth2AccessToken refreshAccessToken(String refreshTokenValue, TokenReque
getUserAttributes(claims.getUserId()),
claims.getNonce(),
claims.getGrantType(),
UaaSecurityContextUtils.getClientAuthenticationMethod(),
generateUniqueTokenId()
);

Expand Down Expand Up @@ -414,6 +417,11 @@ private CompositeToken createCompositeToken(String tokenId,
info.put(ADDITIONAL_AZ_ATTR, additionalAuthorizationAttributes);
}

String clientAuthentication = userAuthenticationData.clientAuth;
if (clientAuthentication != null) {
addRootClaimEntry(additionalRootClaims, CLIENT_AUTH_METHOD, clientAuthentication);
}

String nonce = userAuthenticationData.nonce;
if (nonce != null) {
info.put(NONCE, nonce);
Expand Down Expand Up @@ -459,6 +467,12 @@ private CompositeToken createCompositeToken(String tokenId,
return compositeToken;
}

private static Map<String, Object> addRootClaimEntry(Map<String, Object> additionalRootClaims, String entry, String value) {
Map<String, Object> claims = additionalRootClaims != null ? additionalRootClaims : new HashMap<>();
claims.put(entry, value);
return claims;
}

private KeyInfo getActiveKeyInfo() {
return ofNullable(keyInfoService.getActiveKey())
.orElseThrow(() -> new InternalAuthenticationServiceException("Unable to sign token, misconfigured JWT signing keys"));
Expand Down Expand Up @@ -635,6 +649,7 @@ public OAuth2AccessToken createAccessToken(OAuth2Authentication authentication)
userAttributesForIdToken,
nonce,
grantType,
ofNullable(oAuth2Request.getExtensions().get(CLIENT_AUTH_METHOD)).map(String.class::cast).orElse(null),
tokenId);

String refreshTokenValue = refreshToken != null ? refreshToken.getValue() : null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ public class UserAuthenticationData {
public final Map<String, List<String>> userAttributes;
public final String nonce;
public final String grantType;
public final String clientAuth;
public final String jti;

public UserAuthenticationData(Date authTime,
Expand All @@ -24,6 +25,7 @@ public UserAuthenticationData(Date authTime,
Map<String, List<String>> userAttributes,
String nonce,
String grantType,
String clientAuth,
String jti) {
this.authTime = authTime;
this.authenticationMethods = authenticationMethods;
Expand All @@ -33,6 +35,7 @@ public UserAuthenticationData(Date authTime,
this.userAttributes = userAttributes;
this.nonce = nonce;
this.grantType = grantType;
this.clientAuth = clientAuth;
this.jti = jti;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import org.cloudfoundry.identity.uaa.oauth.pkce.PkceValidationException;
import org.cloudfoundry.identity.uaa.oauth.pkce.PkceValidationService;
import org.cloudfoundry.identity.uaa.util.UaaSecurityContextUtils;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.common.exceptions.InvalidClientException;
import org.springframework.security.oauth2.common.exceptions.InvalidGrantException;
Expand Down Expand Up @@ -107,6 +108,11 @@ protected OAuth2Authentication getOAuth2Authentication(ClientDetails client, Tok

Authentication userAuth = storedAuth.getUserAuthentication();

String clientAuthentication = UaaSecurityContextUtils.getClientAuthenticationMethod();
if (clientAuthentication != null) {
finalStoredOAuth2Request.getExtensions().put(ClaimConstants.CLIENT_AUTH_METHOD, clientAuthentication);
}

return new OAuth2Authentication(finalStoredOAuth2Request, userAuth);

}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package org.cloudfoundry.identity.uaa.util;

import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.oauth2.provider.OAuth2Authentication;

import java.io.Serializable;
import java.util.Map;

import static org.cloudfoundry.identity.uaa.oauth.token.ClaimConstants.CLIENT_AUTH_METHOD;

public final class UaaSecurityContextUtils {

private UaaSecurityContextUtils() {}

public static String getClientAuthenticationMethod() {
Authentication a = SecurityContextHolder.getContext().getAuthentication();
if (!(a instanceof OAuth2Authentication)) {
return null;
}
OAuth2Authentication oAuth2Authentication = (OAuth2Authentication) a;

Map<String, Serializable> extensions = oAuth2Authentication.getOAuth2Request().getExtensions();
if (extensions.isEmpty()) {
return null;
}

return (String) extensions.get(CLIENT_AUTH_METHOD);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,17 @@
import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.security.authentication.AuthenticationManager;
import org.springframework.security.authentication.BadCredentialsException;
import org.springframework.security.core.Authentication;
import org.springframework.security.web.AuthenticationEntryPoint;

import javax.servlet.ServletException;
import java.io.IOException;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.atLeast;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;

Expand Down Expand Up @@ -56,4 +59,84 @@ public void doesNotContinueWithFilterChain_IfAuthenticationException() throws IO
verifyNoMoreInteractions(chain);
}

}
@Test
public void testStoreClientAuthenticationMethod() throws IOException, ServletException {
ClientParametersAuthenticationFilter filter = new ClientParametersAuthenticationFilter();

AuthenticationEntryPoint authenticationEntryPoint = mock(AuthenticationEntryPoint.class);
filter.setAuthenticationEntryPoint(authenticationEntryPoint);
AuthenticationManager clientAuthenticationManager = mock(AuthenticationManager.class);
filter.setClientAuthenticationManager(clientAuthenticationManager);

Authentication authentication = mock(Authentication.class);
MockHttpServletRequest request = new MockHttpServletRequest();
UaaAuthenticationDetails authenticationDetails = mock(UaaAuthenticationDetails.class);
when(clientAuthenticationManager.authenticate(Mockito.any())).thenReturn(authentication);
when(authentication.isAuthenticated()).thenReturn(true);
when(authentication.getDetails()).thenReturn(authenticationDetails);
when(authenticationDetails.getAuthenticationMethod()).thenReturn("none");

MockFilterChain chain = mock(MockFilterChain.class);
request.addHeader("Content-Type", MediaType.APPLICATION_FORM_URLENCODED_VALUE);
MockHttpServletResponse response = new MockHttpServletResponse();

filter.doFilter(request, response, chain);

verifyNoInteractions(authenticationEntryPoint);
verify(chain).doFilter(request, response);
verify(authenticationDetails, atLeast(1)).getAuthenticationMethod();
}

@Test
public void testStoreClientAuthenticationMethodNoDetails() throws IOException, ServletException {
ClientParametersAuthenticationFilter filter = new ClientParametersAuthenticationFilter();

AuthenticationEntryPoint authenticationEntryPoint = mock(AuthenticationEntryPoint.class);
filter.setAuthenticationEntryPoint(authenticationEntryPoint);
AuthenticationManager clientAuthenticationManager = mock(AuthenticationManager.class);
filter.setClientAuthenticationManager(clientAuthenticationManager);

Authentication authentication = mock(Authentication.class);
MockHttpServletRequest request = new MockHttpServletRequest();
when(clientAuthenticationManager.authenticate(Mockito.any())).thenReturn(authentication);
when(authentication.isAuthenticated()).thenReturn(true);
when(authentication.getDetails()).thenReturn(null);

MockFilterChain chain = mock(MockFilterChain.class);
request.addHeader("Content-Type", MediaType.APPLICATION_FORM_URLENCODED_VALUE);
MockHttpServletResponse response = new MockHttpServletResponse();

filter.doFilter(request, response, chain);

verifyNoInteractions(authenticationEntryPoint);
verify(chain).doFilter(request, response);
}

@Test
public void testStoreClientAuthenticationMethodNoMethod() throws IOException, ServletException {
ClientParametersAuthenticationFilter filter = new ClientParametersAuthenticationFilter();

AuthenticationEntryPoint authenticationEntryPoint = mock(AuthenticationEntryPoint.class);
filter.setAuthenticationEntryPoint(authenticationEntryPoint);
AuthenticationManager clientAuthenticationManager = mock(AuthenticationManager.class);
filter.setClientAuthenticationManager(clientAuthenticationManager);

Authentication authentication = mock(Authentication.class);
MockHttpServletRequest request = new MockHttpServletRequest();
UaaAuthenticationDetails authenticationDetails = mock(UaaAuthenticationDetails.class);
when(clientAuthenticationManager.authenticate(Mockito.any())).thenReturn(authentication);
when(authentication.isAuthenticated()).thenReturn(true);
when(authentication.getDetails()).thenReturn(authenticationDetails);
when(authenticationDetails.getAuthenticationMethod()).thenReturn(null);

MockFilterChain chain = mock(MockFilterChain.class);
request.addHeader("Content-Type", MediaType.APPLICATION_FORM_URLENCODED_VALUE);
MockHttpServletResponse response = new MockHttpServletResponse();

filter.doFilter(request, response, chain);

verifyNoInteractions(authenticationEntryPoint);
verify(chain).doFilter(request, response);
verify(authenticationDetails).getAuthenticationMethod();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
import static org.hamcrest.CoreMatchers.*;
import static org.hamcrest.core.IsCollectionContaining.hasItems;
import static org.junit.Assert.assertThat;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

Expand Down Expand Up @@ -149,6 +148,7 @@ void setup() throws Exception {
userAttributes,
nonce,
grantType,
null,
jti);
excludedClaims = new HashSet<>();

Expand Down Expand Up @@ -258,6 +258,7 @@ void create_setsRolesToNullIfRolesAreNull() throws IdTokenCreationException {
userAttributes,
nonce,
grantType,
null,
jti);

IdToken idToken = tokenCreator.create(clientDetails, user, userAuthenticationData);
Expand Down Expand Up @@ -286,6 +287,7 @@ void create_doesntSetUserAttributesIfTheyAreNull() throws IdTokenCreationExcepti
null,
nonce,
grantType,
null,
jti);

IdToken idToken = tokenCreator.create(clientDetails, user, userAuthenticationData);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,15 @@
import static org.cloudfoundry.identity.uaa.oauth.TokenTestSupport.GRANT_TYPE;
import static org.cloudfoundry.identity.uaa.oauth.token.TokenConstants.GRANT_TYPE_AUTHORIZATION_CODE;
import static org.cloudfoundry.identity.uaa.util.JwtTokenSignedByThisUAATest.CLIENT_ID;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.atMost;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

class PkceEnhancedAuthorizationCodeTokenGranterTest {
Expand Down Expand Up @@ -67,12 +71,12 @@ public void setup() {
when(clientDetailsService.loadClientByClientId(eq(requestingClient.getClientId()), anyString())).thenReturn(requestingClient);
when(authorizationCodeServices.consumeAuthorizationCode("1234")).thenReturn(authentication);
when(authentication.getOAuth2Request()).thenReturn(oAuth2Request);
when(oAuth2Request.getRequestParameters()).thenReturn(requestParameters);
requestParameters = new HashMap<>();
requestParameters.put(GRANT_TYPE, TokenConstants.GRANT_TYPE_USER_TOKEN);
requestParameters.put(CLIENT_ID, requestingClient.getClientId());
requestParameters.put("code", "1234");
requestParameters.put(PkceValidationService.CODE_VERIFIER, "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM");
when(oAuth2Request.getRequestParameters()).thenReturn(requestParameters);
tokenRequest = new UserTokenGranterTest.PublicTokenRequest();
tokenRequest.setRequestParameters(requestParameters);

Expand All @@ -84,4 +88,26 @@ void getOAuth2Authentication() throws PkceValidationException {
when(pkceValidationService.checkAndValidate(any(), any(), any())).thenReturn(false);
assertThrows(InvalidGrantException.class, () -> granter.getOAuth2Authentication(requestingClient, tokenRequest));
}

@Test
void getOAuth2AuthenticationMethod() throws PkceValidationException {
HashMap authMap = new HashMap();
authMap.put(ClaimConstants.CLIENT_AUTH_METHOD, "none");
when(pkceValidationService.checkAndValidate(any(), any(), any())).thenReturn(true);
when(oAuth2Request.getExtensions()).thenReturn(authMap);
when(oAuth2Request.createOAuth2Request(any())).thenReturn(oAuth2Request);
assertNotNull(granter.getOAuth2Authentication(requestingClient, tokenRequest));
verify(oAuth2Request, times(2)).getExtensions();
}

@Test
void getOAuth2AuthenticationNoMethod() throws PkceValidationException {
HashMap authMap = new HashMap();
authMap.put(ClaimConstants.CLIENT_AUTH_METHOD, null);
when(pkceValidationService.checkAndValidate(any(), any(), any())).thenReturn(true);
when(oAuth2Request.getExtensions()).thenReturn(authMap);
when(oAuth2Request.createOAuth2Request(any())).thenReturn(oAuth2Request);
assertNotNull(granter.getOAuth2Authentication(requestingClient, tokenRequest));
verify(oAuth2Request, atMost(1)).getExtensions();
}
}
Loading

0 comments on commit b7c4c78

Please sign in to comment.