Skip to content

Commit

Permalink
Changes to upgrade spring security
Browse files Browse the repository at this point in the history
  • Loading branch information
arcshiftsolutions committed Oct 5, 2024
1 parent fca967a commit a668640
Show file tree
Hide file tree
Showing 12 changed files with 177 additions and 85 deletions.
18 changes: 14 additions & 4 deletions docker/keycloak/extensions-24/services/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -146,12 +146,22 @@
<dependency>
<groupId>org.springframework</groupId>
<artifactId>spring-web</artifactId>
<version>5.2.10.RELEASE</version>
<version>6.1.13</version>
</dependency>
<dependency>
<groupId>org.springframework.security.oauth</groupId>
<artifactId>spring-security-oauth2</artifactId>
<version>2.5.0.RELEASE</version>
<groupId>org.springframework</groupId>
<artifactId>spring-webflux</artifactId>
<version>6.0.14</version>
</dependency>
<dependency>
<groupId>org.springframework.security</groupId>
<artifactId>spring-security-oauth2-client</artifactId>
<version>6.3.3</version>
</dependency>
<dependency>
<groupId>io.projectreactor.netty</groupId>
<artifactId>reactor-netty-http</artifactId>
<version>1.1.13</version>
</dependency>
<dependency>
<groupId>org.mockito</groupId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ public class SoamFirstTimeLoginAuthenticator extends AbstractIdpAuthenticator {

private static Logger logger = Logger.getLogger(SoamFirstTimeLoginAuthenticator.class);

private SoamRestUtils soamRestUtils = new SoamRestUtils();


@Override
protected void actionImpl(AuthenticationFlowContext context, SerializedBrokeredIdentityContext serializedCtx, BrokeredIdentityContext brokerContext) {
Expand Down Expand Up @@ -139,7 +141,7 @@ protected void createOrUpdateUser(String guid, String accountType, String credTy
logger.debug("SOAM: performing login for " + accountType + " user: " + guid);

try {
SoamRestUtils.getInstance().performLogin(credType, guid, guid, servicesCard);
soamRestUtils.performLogin(credType, guid, guid, servicesCard);
} catch (Exception e) {
logger.error("Exception occurred within SOAM while processing login" + e.getMessage());
throw new SoamRuntimeException("Exception occurred within SOAM while processing login, check downstream logs for SOAM API service");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

package com.github.bcgov.keycloak.authenticators;

import com.github.bcgov.keycloak.authenticators.SoamFirstTimeLoginAuthenticator;
import org.keycloak.Config;
import org.keycloak.authentication.Authenticator;
import org.keycloak.authentication.AuthenticatorFactory;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ public class SoamPostLoginAuthenticator extends AbstractIdpAuthenticator {

private static Logger logger = Logger.getLogger(SoamPostLoginAuthenticator.class);

private SoamRestUtils soamRestUtils = new SoamRestUtils();

@Override
protected void actionImpl(AuthenticationFlowContext context, SerializedBrokeredIdentityContext serializedCtx, BrokeredIdentityContext brokerContext) {
Expand Down Expand Up @@ -124,7 +125,7 @@ protected void updateUserInfo(String guid, String accountType, String credType,
logger.debug("SOAM: performing login for " + accountType + " user: " + guid);

try {
SoamRestUtils.getInstance().performLogin(credType, guid, guid, servicesCard);
soamRestUtils.performLogin(credType, guid, guid, servicesCard);
} catch (Exception e) {
logger.error("Exception occurred within SOAM while processing login" + e.getMessage());
throw new SoamRuntimeException("Exception occurred within SOAM while processing login, check downstream logs for SOAM API service");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

package com.github.bcgov.keycloak.authenticators;

import com.github.bcgov.keycloak.authenticators.SoamPostLoginAuthenticator;
import org.keycloak.Config;
import org.keycloak.authentication.Authenticator;
import org.keycloak.authentication.AuthenticatorFactory;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,20 @@

import com.fasterxml.jackson.databind.ObjectMapper;
import org.jboss.logging.Logger;
import org.springframework.stereotype.Component;

/**
* Class holds all application properties
*
* @author Marco Villeneuve
*
*/
@Component
public class ApplicationProperties {
public static final ObjectMapper mapper = new ObjectMapper();
private static Logger logger = Logger.getLogger(ApplicationProperties.class);
public static final String CORRELATION_ID = "correlationID";
public static final String SOAM = "SOAM";
private String soamApiURL;
private String tokenURL;
private String clientID;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ public class SoamProtocolMapper extends AbstractOIDCProtocolMapper

private static Logger logger = Logger.getLogger(SoamProtocolMapper.class);
private static final List<ProviderConfigProperty> configProperties = new ArrayList<ProviderConfigProperty>();
private SoamRestUtils soamRestUtils = new SoamRestUtils();

static {
// OIDCAttributeMapperHelper.addTokenClaimNameConfig(configProperties);
Expand Down Expand Up @@ -80,7 +81,7 @@ private SoamLoginEntity fetchSoamLoginEntity(String type, String userGUID) {
return loginDetailCache.get(userGUID);
}
logger.debug("SOAM Fetching " + type + " Claims for UserGUID: " + userGUID);
SoamLoginEntity soamLoginEntity = SoamRestUtils.getInstance().getSoamLoginEntity(type, userGUID);
SoamLoginEntity soamLoginEntity = soamRestUtils.getSoamLoginEntity(type, userGUID);
loginDetailCache.put(userGUID, soamLoginEntity);

return soamLoginEntity;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
package com.github.bcgov.keycloak.rest;

import com.github.bcgov.keycloak.common.properties.ApplicationProperties;
import com.github.bcgov.keycloak.util.LogHelper;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.ComponentScan;
import org.springframework.context.annotation.Configuration;
import org.springframework.http.client.reactive.ClientHttpConnector;
import org.springframework.http.client.reactive.ReactorClientHttpConnector;
import org.springframework.security.oauth2.client.AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager;
import org.springframework.security.oauth2.client.InMemoryReactiveOAuth2AuthorizedClientService;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.InMemoryReactiveClientRegistrationRepository;
import org.springframework.security.oauth2.client.web.reactive.function.client.ServerOAuth2AuthorizedClientExchangeFilterFunction;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.web.reactive.function.client.ExchangeFilterFunction;
import org.springframework.web.reactive.function.client.WebClient;
import org.springframework.web.util.DefaultUriBuilderFactory;
import reactor.netty.http.client.HttpClient;

/**
* The type Rest web client.
*/
@Configuration
@ComponentScan(basePackages={"com.github.bcgov.keycloak"})
public class RestWebClient {
private final DefaultUriBuilderFactory factory;
private final ClientHttpConnector connector;
/**
* The Props.
*/
private final ApplicationProperties props = new ApplicationProperties();

public RestWebClient() {
this.factory = new DefaultUriBuilderFactory();
this.factory.setEncodingMode(DefaultUriBuilderFactory.EncodingMode.NONE);
final HttpClient client = HttpClient.create().compress(true);
client.warmup()
.block();
this.connector = new ReactorClientHttpConnector(client);
}

/**
* Web client web client.
*
* @return the web client
*/
@Bean
public WebClient webClient() {
InMemoryReactiveClientRegistrationRepository clientRegistryRepo = new InMemoryReactiveClientRegistrationRepository(ClientRegistration
.withRegistrationId(this.props.getClientID())
.tokenUri(this.props.getTokenURL())
.clientId(this.props.getClientID())
.clientSecret(this.props.getClientSecret())
.authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS)
.build());
InMemoryReactiveOAuth2AuthorizedClientService clientService = new InMemoryReactiveOAuth2AuthorizedClientService(clientRegistryRepo);
AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager authorizedClientManager =
new AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager(clientRegistryRepo, clientService);
ServerOAuth2AuthorizedClientExchangeFilterFunction oauthFilter = new ServerOAuth2AuthorizedClientExchangeFilterFunction(authorizedClientManager);
oauthFilter.setDefaultClientRegistrationId(this.props.getClientID());
return WebClient.builder()
.defaultHeader("X-Client-Name", ApplicationProperties.SOAM)
.codecs(configurer -> configurer
.defaultCodecs()
.maxInMemorySize(100 * 1024 * 1024))
.filter(this.log())
.clientConnector(this.connector)
.uriBuilderFactory(this.factory)
.filter(oauthFilter)
.build();
}

private ExchangeFilterFunction log() {
return (clientRequest, next) ->
next
.exchange(clientRequest)
.doOnNext((clientResponse -> LogHelper.logClientHttpReqResponseDetails(clientRequest.method(), clientRequest.url().toString(), clientResponse.statusCode() != null ? clientResponse.statusCode().value() : 400, clientRequest.headers().get(ApplicationProperties.CORRELATION_ID))));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,12 @@
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.security.oauth2.client.DefaultOAuth2ClientContext;
import org.springframework.security.oauth2.client.OAuth2RestTemplate;
import org.springframework.security.oauth2.client.token.grant.client.ClientCredentialsResourceDetails;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.web.client.HttpClientErrorException;
import org.springframework.web.client.RestTemplate;
import reactor.core.publisher.Mono;

import java.util.Collections;
import java.util.List;
import java.util.UUID;

/**
Expand All @@ -31,37 +27,18 @@ public class SoamRestUtils {

private static Logger logger = Logger.getLogger(SoamRestUtils.class);

private static SoamRestUtils soamRestUtilsInstance;

private static ApplicationProperties props;

private SoamRestUtils() {
props = new ApplicationProperties();
}

public static SoamRestUtils getInstance() {
if (soamRestUtilsInstance == null) {
soamRestUtilsInstance = new SoamRestUtils();
}
return soamRestUtilsInstance;
}
private RestWebClient restWebClient;

public RestTemplate getRestTemplate(List<String> scopes) {
logger.debug("Calling get token method");
ClientCredentialsResourceDetails resourceDetails = new ClientCredentialsResourceDetails();
resourceDetails.setClientId(props.getClientID());
resourceDetails.setClientSecret(props.getClientSecret());
resourceDetails.setAccessTokenUri(props.getTokenURL());
if (scopes != null) {
resourceDetails.setScope(scopes);
}
return new OAuth2RestTemplate(resourceDetails, new DefaultOAuth2ClientContext());
public SoamRestUtils() {
this.restWebClient = new RestWebClient();
props = new ApplicationProperties();
}

public void performLogin(String identifierType, String identifierValue, String userID, SoamServicesCard servicesCard) {
String url = props.getSoamApiURL() + "/login";
final String correlationID = logAndGetCorrelationID(identifierValue, url, HttpMethod.POST.toString());
RestTemplate restTemplate = getRestTemplate(null);
HttpHeaders headers = new HttpHeaders();
headers.setContentType(MediaType.APPLICATION_FORM_URLENCODED);
headers.add("correlationID", correlationID);
Expand All @@ -86,7 +63,13 @@ public void performLogin(String identifierType, String identifierValue, String u
HttpEntity<MultiValueMap<String, String>> request = new HttpEntity<MultiValueMap<String, String>>(map, headers);

try {
restTemplate.postForEntity(url, request, SoamLoginEntity.class);
this.restWebClient.webClient().post()
.uri(url)
.headers(httpHeadersOnWebClientBeingBuilt -> httpHeadersOnWebClientBeingBuilt.addAll( headers ))
.body(Mono.just(request), HttpEntity.class)
.retrieve()
.bodyToMono(SoamLoginEntity.class)
.block();
} catch (final HttpClientErrorException e) {
throw new RuntimeException("Could not complete login call: " + e.getMessage());
}
Expand All @@ -95,31 +78,21 @@ public void performLogin(String identifierType, String identifierValue, String u
public SoamLoginEntity getSoamLoginEntity(String identifierType, String identifierValue) {
String url = props.getSoamApiURL() + "/" + identifierType + "/" + identifierValue;
final String correlationID = logAndGetCorrelationID(identifierValue, url, HttpMethod.GET.toString());
RestTemplate restTemplate = getRestTemplate(null);
HttpHeaders headers = new HttpHeaders();
headers.setAccept(Collections.singletonList(MediaType.APPLICATION_JSON));
headers.add("correlationID", correlationID);
try {
return restTemplate.exchange(url, HttpMethod.GET, new HttpEntity<>("parameters", headers), SoamLoginEntity.class).getBody();
return this.restWebClient.webClient().get()
.uri(url)
.headers(httpHeadersOnWebClientBeingBuilt -> httpHeadersOnWebClientBeingBuilt.addAll( headers ))
.retrieve()
.bodyToMono(SoamLoginEntity.class)
.block();
} catch (final HttpClientErrorException e) {
throw new RuntimeException("Could not complete getSoamLoginEntity call: " + e.getMessage());
}
}

public List<String> getSTSRoles(String identifierValue) {
String url = props.getSoamApiURL() + "/" + identifierValue + "/" + "sts-user-roles";
final String correlationID = logAndGetCorrelationID(identifierValue, url, HttpMethod.GET.toString());
RestTemplate restTemplate = getRestTemplate(null);
HttpHeaders headers = new HttpHeaders();
headers.setAccept(Collections.singletonList(MediaType.APPLICATION_JSON));
headers.add("correlationID", correlationID);
try {
return restTemplate.exchange(url, HttpMethod.GET, new HttpEntity<>("parameters", headers), List.class).getBody();
} catch (final HttpClientErrorException e) {
throw new RuntimeException("Could not complete getSTSRoles call: " + e.getMessage());
}
}

private String logAndGetCorrelationID(String identifierValue, String url, String httpMethod) {
final String correlationID = UUID.randomUUID().toString();
MDC.put("correlation_id", correlationID);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ public class TenantProtocolMapper extends AbstractOIDCProtocolMapper

private static Logger logger = Logger.getLogger(TenantProtocolMapper.class);
private static final List<ProviderConfigProperty> configProperties = new ArrayList<ProviderConfigProperty>();
private TenantRestUtils tenantRestUtils = new TenantRestUtils();

static {
// OIDCAttributeMapperHelper.addTokenClaimNameConfig(configProperties);
Expand Down Expand Up @@ -78,7 +79,7 @@ private TenantAccess fetchTenantAccessEntity(String clientID, String tenantID) {
return loginDetailCache.get(tenantID);
}
logger.debug("Tenant Access Fetching by Tenant ID: " + tenantID + " and Client ID: " + clientID);
TenantAccess tenantAccess = TenantRestUtils.getInstance().checkForValidTenant(clientID, tenantID);
TenantAccess tenantAccess = tenantRestUtils.checkForValidTenant(clientID, tenantID);
loginDetailCache.put(tenantID, tenantAccess);

return tenantAccess;
Expand Down
Loading

0 comments on commit a668640

Please sign in to comment.