Skip to content

Commit

Permalink
wip: Zoned Login
Browse files Browse the repository at this point in the history
[#187902333]

Signed-off-by: Duane May <duane.may@broadcom.com>
Signed-off-by: Peter Chen <peter-h.chen@broadcom.com>
  • Loading branch information
duanemay authored and peterhaochen47 committed Jul 12, 2024
1 parent 748f5f2 commit 6233e58
Show file tree
Hide file tree
Showing 11 changed files with 269 additions and 82 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,17 @@ public class ConfiguratorRelyingPartyRegistrationRepository
private final KeyWithCert keyWithCert;
private final String samlEntityID;

public ConfiguratorRelyingPartyRegistrationRepository(@Qualifier("samlEntityID") String samlEntityID,
private final String samlEntityIDAlias;

public ConfiguratorRelyingPartyRegistrationRepository(String samlEntityID,
String samlEntityIDAlias,
KeyWithCert keyWithCert,
SamlIdentityProviderConfigurator configurator) {
Assert.notNull(configurator, "configurator cannot be null");
this.configurator = configurator;
this.keyWithCert = keyWithCert;
this.samlEntityID = samlEntityID;
this.samlEntityIDAlias = samlEntityIDAlias;
}

/**
Expand All @@ -43,34 +47,14 @@ public RelyingPartyRegistration findByRegistrationId(String registrationId) {
if (identityProviderDefinition.getIdpEntityAlias().equals(registrationId)) {

IdentityZone zone = retrieveZone();
String zonedSamlEntityID = zone.isUaa() ? samlEntityID : zone.getConfig().getSamlConfig().getEntityID();

return RelyingPartyRegistrationBuilder.buildRelyingPartyRegistration(
samlEntityID, identityProviderDefinition.getNameID(),
zonedSamlEntityID, identityProviderDefinition.getNameID(),
keyWithCert, identityProviderDefinition.getMetaDataLocation(),
registrationId, zone.getConfig().getSamlConfig().isRequestSigned());
registrationId, samlEntityIDAlias, zone.getConfig().getSamlConfig().isRequestSigned());
}
}
return buildDefaultRelyingPartyRegistration();
}

private RelyingPartyRegistration buildDefaultRelyingPartyRegistration() {
String samlEntityID, samlServiceUri;
IdentityZone zone = retrieveZone();
if (zone.isUaa()) {
samlEntityID = this.samlEntityID;
samlServiceUri = this.samlEntityID;
}
else if (zone.getConfig() != null && zone.getConfig().getSamlConfig() != null) {

samlEntityID = zone.getConfig().getSamlConfig().getEntityID();
samlServiceUri = zone.getSubdomain() + "." + this.samlEntityID;
}
else {
return null;
}

return RelyingPartyRegistrationBuilder.buildRelyingPartyRegistration(
samlEntityID, null,
keyWithCert, "dummy-saml-idp-metadata.xml", null,
samlServiceUri, zone.getConfig().getSamlConfig().isRequestSigned());
return null;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package org.cloudfoundry.identity.uaa.provider.saml;

import org.cloudfoundry.identity.uaa.util.KeyWithCert;
import org.cloudfoundry.identity.uaa.zone.IdentityZone;
import org.cloudfoundry.identity.uaa.zone.ZoneAware;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;

/**
* A {@link RelyingPartyRegistrationRepository} that always returns a default {@link RelyingPartyRegistrationRepository}.
*/
public class DefaultRelyingPartyRegistrationRepository implements RelyingPartyRegistrationRepository, ZoneAware {
public static final String CLASSPATH_DUMMY_SAML_IDP_METADATA_XML = "classpath:dummy-saml-idp-metadata.xml";

private final KeyWithCert keyWithCert;
private final String samlEntityID;

public DefaultRelyingPartyRegistrationRepository(String samlEntityID,
KeyWithCert keyWithCert) {
this.keyWithCert = keyWithCert;
this.samlEntityID = samlEntityID;
}

/**
* Returns the relying party registration identified by the provided
* {@code registrationId}, or {@code null} if not found.
*
* @param registrationId the registration identifier
* @return the {@link RelyingPartyRegistration} if found, otherwise {@code null}
*/
@Override
public RelyingPartyRegistration findByRegistrationId(String registrationId) {
IdentityZone zone = retrieveZone();

String zonedSamlEntityID;
if (!zone.isUaa() && zone.getConfig() != null && zone.getConfig().getSamlConfig() != null && zone.getConfig().getSamlConfig().getEntityID() != null) {
zonedSamlEntityID = zone.getConfig().getSamlConfig().getEntityID();
} else {
zonedSamlEntityID = this.samlEntityID;
}

return RelyingPartyRegistrationBuilder.buildRelyingPartyRegistration(
zonedSamlEntityID, null,
keyWithCert, CLASSPATH_DUMMY_SAML_IDP_METADATA_XML, registrationId,
zonedSamlEntityID, zone.getConfig().getSamlConfig().isRequestSigned());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,10 @@ private RelyingPartyRegistrationBuilder() {
throw new java.lang.UnsupportedOperationException("This is a utility class and cannot be instantiated");
}

public static RelyingPartyRegistration buildRelyingPartyRegistration(
String samlEntityID, String samlSpNameId,
KeyWithCert keyWithCert,
String metadataLocation, String rpRegstrationId, boolean requestSigned) {
return buildRelyingPartyRegistration(samlEntityID, samlSpNameId,
keyWithCert, metadataLocation, rpRegstrationId,
samlEntityID, requestSigned);
}

public static RelyingPartyRegistration buildRelyingPartyRegistration(
String samlEntityID, String samlSpNameId,
KeyWithCert keyWithCert, String metadataLocation,
String rpRegstrationId, String samlServiceUri, boolean requestSigned) {
String rpRegstrationId, String samlSpAlias, boolean requestSigned) {
SamlIdentityProviderDefinition.MetadataLocation type = SamlIdentityProviderDefinition.getType(metadataLocation);

RelyingPartyRegistration.Builder builder;
Expand All @@ -51,14 +42,17 @@ public static RelyingPartyRegistration buildRelyingPartyRegistration(
builder = RelyingPartyRegistrations.fromMetadataLocation(metadataLocation);
}

// fallback to entityId if alias is not provided
samlSpAlias = samlSpAlias == null ? samlEntityID : samlSpAlias;

builder.entityId(samlEntityID);
if (samlSpNameId != null) builder.nameIdFormat(samlSpNameId);
if (rpRegstrationId != null) builder.registrationId(rpRegstrationId);
return builder
.assertionConsumerServiceLocation(assertionConsumerServiceLocationFunction.apply(samlServiceUri))
.singleLogoutServiceResponseLocation(singleLogoutServiceResponseLocationFunction.apply(samlServiceUri))
.singleLogoutServiceLocation(singleLogoutServiceLocationFunction.apply(samlServiceUri))
.singleLogoutServiceResponseLocation(singleLogoutServiceResponseLocationFunction.apply(samlServiceUri))
.assertionConsumerServiceLocation(assertionConsumerServiceLocationFunction.apply(samlSpAlias))
.singleLogoutServiceResponseLocation(singleLogoutServiceResponseLocationFunction.apply(samlSpAlias))
.singleLogoutServiceLocation(singleLogoutServiceLocationFunction.apply(samlSpAlias))
.singleLogoutServiceResponseLocation(singleLogoutServiceResponseLocationFunction.apply(samlSpAlias))
// Accept both POST and REDIRECT bindings
.singleLogoutServiceBindings(c -> {
c.add(Saml2MessageBinding.REDIRECT);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,12 @@ public class SamlConfigProps {

private String activeKeyId;

private String entityIDAlias;

private Map<String, SamlKey> keys;

private Boolean wantAssertionSigned = true;

private Boolean signRequest = true;

public SamlKey getActiveSamlKey() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ RelyingPartyRegistrationRepository relyingPartyRegistrationRepository(SamlIdenti
// even when there are no SAML IDPs configured.
// See relevant issue: https://github.com/spring-projects/spring-security/issues/11369
RelyingPartyRegistration defaultRelyingPartyRegistration = RelyingPartyRegistrationBuilder.buildRelyingPartyRegistration(
samlEntityID, samlSpNameID, keyWithCert, CLASSPATH_DUMMY_SAML_IDP_METADATA_XML, DEFAULT_REGISTRATION_ID, samlConfigProps.getSignRequest());
samlEntityID, samlSpNameID, keyWithCert, CLASSPATH_DUMMY_SAML_IDP_METADATA_XML, DEFAULT_REGISTRATION_ID, samlConfigProps.getEntityIDAlias(), samlConfigProps.getSignRequest());
relyingPartyRegistrations.add(defaultRelyingPartyRegistration);

for (SamlIdentityProviderDefinition samlIdentityProviderDefinition : bootstrapSamlIdentityProviderData.getIdentityProviderDefinitions()) {
Expand All @@ -74,13 +74,15 @@ RelyingPartyRegistrationRepository relyingPartyRegistrationRepository(SamlIdenti
samlEntityID, samlSpNameID, keyWithCert,
samlIdentityProviderDefinition.getMetaDataLocation(),
samlIdentityProviderDefinition.getIdpEntityAlias(),
samlConfigProps.getEntityIDAlias(),
samlConfigProps.getSignRequest())
);
}

InMemoryRelyingPartyRegistrationRepository bootstrapRepo = new InMemoryRelyingPartyRegistrationRepository(relyingPartyRegistrations);
ConfiguratorRelyingPartyRegistrationRepository configuratorRepo = new ConfiguratorRelyingPartyRegistrationRepository(samlEntityID, keyWithCert, samlIdentityProviderConfigurator);
return new DelegatingRelyingPartyRegistrationRepository(bootstrapRepo, configuratorRepo);
ConfiguratorRelyingPartyRegistrationRepository configuratorRepo = new ConfiguratorRelyingPartyRegistrationRepository(samlEntityID, samlConfigProps.getEntityIDAlias(), keyWithCert, samlIdentityProviderConfigurator);
DefaultRelyingPartyRegistrationRepository defaultRepo = new DefaultRelyingPartyRegistrationRepository(samlEntityID, keyWithCert);
return new DelegatingRelyingPartyRegistrationRepository(bootstrapRepo, configuratorRepo, defaultRepo);
}

@Autowired
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import org.cloudfoundry.identity.uaa.provider.SamlIdentityProviderDefinition;
import org.cloudfoundry.identity.uaa.util.KeyWithCert;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
Expand Down Expand Up @@ -33,7 +32,9 @@
@ExtendWith(MockitoExtension.class)
class ConfiguratorRelyingPartyRegistrationRepositoryTest {
private static final String ENTITY_ID = "entityId";
private static final String ENTITY_ID_ALIAS = "entityIdAlias";
private static final String REGISTRATION_ID = "registrationId";
private static final String REGISTRATION_ID_2 = "registrationId2";
private static final String NAME_ID = "name1";

@Mock
Expand All @@ -46,14 +47,14 @@ class ConfiguratorRelyingPartyRegistrationRepositoryTest {

@BeforeEach
void setUp() {
repository = new ConfiguratorRelyingPartyRegistrationRepository(ENTITY_ID, mockKeyWithCert,
repository = new ConfiguratorRelyingPartyRegistrationRepository(ENTITY_ID, ENTITY_ID_ALIAS, mockKeyWithCert,
mockConfigurator);
}

@Test
void constructorWithNullConfiguratorThrows() {
assertThatThrownBy(() -> new ConfiguratorRelyingPartyRegistrationRepository(
ENTITY_ID, mockKeyWithCert, null)
ENTITY_ID, ENTITY_ID_ALIAS, mockKeyWithCert, null)
).isInstanceOf(IllegalArgumentException.class);
}

Expand Down Expand Up @@ -81,15 +82,14 @@ void findByRegistrationIdWithMultipleInDb() {
.returns(ENTITY_ID, RelyingPartyRegistration::getEntityId)
.returns(NAME_ID, RelyingPartyRegistration::getNameIdFormat)
// from functions
.returns("{baseUrl}/saml/SSO/alias/entityId", RelyingPartyRegistration::getAssertionConsumerServiceLocation)
.returns("{baseUrl}/saml/SingleLogout/alias/entityId", RelyingPartyRegistration::getSingleLogoutServiceResponseLocation)
.returns("{baseUrl}/saml/SSO/alias/entityIdAlias", RelyingPartyRegistration::getAssertionConsumerServiceLocation)
.returns("{baseUrl}/saml/SingleLogout/alias/entityIdAlias", RelyingPartyRegistration::getSingleLogoutServiceResponseLocation)
// from xml
.extracting(RelyingPartyRegistration::getAssertingPartyDetails)
.returns("https://idp-saml.ua3.int/simplesaml/saml2/idp/metadata.php", RelyingPartyRegistration.AssertingPartyDetails::getEntityId);
}

@Test
@Disabled("Test not valid because ConfiguratorRelyingPartyRegistrationRepository now returns default RelyingPartyRegistration when none found")
void findByRegistrationIdWhenNoneFound() {
SamlIdentityProviderDefinition definition = mock(SamlIdentityProviderDefinition.class);
when(definition.getIdpEntityAlias()).thenReturn(REGISTRATION_ID);
Expand All @@ -104,21 +104,21 @@ void buildsCorrectRegistrationWhenMetadataXmlIsStored() {
when(mockKeyWithCert.getCertificate()).thenReturn(mock(X509Certificate.class));
when(mockKeyWithCert.getPrivateKey()).thenReturn(mock(PrivateKey.class));
SamlIdentityProviderDefinition definition = mock(SamlIdentityProviderDefinition.class);
when(definition.getIdpEntityAlias()).thenReturn("no_slos");
when(definition.getIdpEntityAlias()).thenReturn(REGISTRATION_ID);
when(definition.getNameID()).thenReturn(NAME_ID);
when(definition.getMetaDataLocation()).thenReturn(metadata);
when(mockConfigurator.getIdentityProviderDefinitions()).thenReturn(List.of(definition));

RelyingPartyRegistration registration = repository.findByRegistrationId("no_slos");
RelyingPartyRegistration registration = repository.findByRegistrationId(REGISTRATION_ID);

assertThat(registration)
// from definition
.returns("no_slos", RelyingPartyRegistration::getRegistrationId)
.returns(REGISTRATION_ID, RelyingPartyRegistration::getRegistrationId)
.returns(ENTITY_ID, RelyingPartyRegistration::getEntityId)
.returns(NAME_ID, RelyingPartyRegistration::getNameIdFormat)
// from functions
.returns("{baseUrl}/saml/SSO/alias/entityId", RelyingPartyRegistration::getAssertionConsumerServiceLocation)
.returns("{baseUrl}/saml/SingleLogout/alias/entityId", RelyingPartyRegistration::getSingleLogoutServiceResponseLocation)
.returns("{baseUrl}/saml/SSO/alias/entityIdAlias", RelyingPartyRegistration::getAssertionConsumerServiceLocation)
.returns("{baseUrl}/saml/SingleLogout/alias/entityIdAlias", RelyingPartyRegistration::getSingleLogoutServiceResponseLocation)
// from xml
.extracting(RelyingPartyRegistration::getAssertingPartyDetails)
.returns("http://uaa-acceptance.cf-app.com/saml-idp", RelyingPartyRegistration.AssertingPartyDetails::getEntityId);
Expand All @@ -129,20 +129,20 @@ void buildsCorrectRegistrationWhenMetadataLocationIsStored() {
when(mockKeyWithCert.getCertificate()).thenReturn(mock(X509Certificate.class));
when(mockKeyWithCert.getPrivateKey()).thenReturn(mock(PrivateKey.class));
SamlIdentityProviderDefinition definition = mock(SamlIdentityProviderDefinition.class);
when(definition.getIdpEntityAlias()).thenReturn("no_slos");
when(definition.getIdpEntityAlias()).thenReturn(REGISTRATION_ID_2);
when(definition.getNameID()).thenReturn(NAME_ID);
when(definition.getMetaDataLocation()).thenReturn("no_single_logout_service-metadata.xml");
when(mockConfigurator.getIdentityProviderDefinitions()).thenReturn(List.of(definition));

RelyingPartyRegistration registration = repository.findByRegistrationId("no_slos");
RelyingPartyRegistration registration = repository.findByRegistrationId(REGISTRATION_ID_2);
assertThat(registration)
// from definition
.returns("no_slos", RelyingPartyRegistration::getRegistrationId)
.returns(REGISTRATION_ID_2, RelyingPartyRegistration::getRegistrationId)
.returns(ENTITY_ID, RelyingPartyRegistration::getEntityId)
.returns(NAME_ID, RelyingPartyRegistration::getNameIdFormat)
// from functions
.returns("{baseUrl}/saml/SSO/alias/entityId", RelyingPartyRegistration::getAssertionConsumerServiceLocation)
.returns("{baseUrl}/saml/SingleLogout/alias/entityId", RelyingPartyRegistration::getSingleLogoutServiceResponseLocation)
.returns("{baseUrl}/saml/SSO/alias/entityIdAlias", RelyingPartyRegistration::getAssertionConsumerServiceLocation)
.returns("{baseUrl}/saml/SingleLogout/alias/entityIdAlias", RelyingPartyRegistration::getSingleLogoutServiceResponseLocation)
// from xml
.extracting(RelyingPartyRegistration::getAssertingPartyDetails)
.returns("http://uaa-acceptance.cf-app.com/saml-idp", RelyingPartyRegistration.AssertingPartyDetails::getEntityId);
Expand Down
Loading

0 comments on commit 6233e58

Please sign in to comment.