From ae14c2fd9b11eba8d710882f0fde17782ab44d58 Mon Sep 17 00:00:00 2001 From: Duane May Date: Wed, 24 Jul 2024 18:02:33 -0400 Subject: [PATCH] feat: Handle Multiple SAML keys - Rotation Tests working - Uses keys from SamlConfig for each zone - Fall back to default keys if none set [#187994938] Signed-off-by: Duane May --- .../identity/uaa/saml/SamlKey.java | 40 +---- .../identity/uaa/zone/SamlConfig.java | 90 ++++++----- .../identity/uaa/zone/SamlConfigTest.java | 112 ++++++++++--- .../IdentityZoneConfigurationBootstrap.java | 99 +----------- ...UaaRelyingPartyRegistrationRepository.java | 32 +++- .../saml/CertificateRuntimeException.java | 9 ++ ...torRelyingPartyRegistrationRepository.java | 17 +- ...ultRelyingPartyRegistrationRepository.java | 22 ++- .../saml/RelyingPartyRegistrationBuilder.java | 46 ++++-- .../uaa/provider/saml/SamlConfigProps.java | 16 ++ .../provider/saml/SamlMetadataEndpoint.java | 44 ++--- ...amlMetadataEntityDescriptorCustomizer.java | 133 +++++++++++++++ .../uaa/provider/saml/SamlNameIdFormats.java | 151 ++++++++++++++++++ ...yingPartyRegistrationRepositoryConfig.java | 17 +- .../identity/uaa/util/KeyWithCert.java | 36 ++--- ...entityZoneConfigurationBootstrapTests.java | 8 +- ...elyingPartyRegistrationRepositoryTest.java | 128 +++++++++++---- ...elyingPartyRegistrationRepositoryTest.java | 104 ++++++++++-- .../RelyingPartyRegistrationBuilderTest.java | 9 +- .../saml/SamlMetadataEndpointTest.java | 40 +++-- ...PartyRegistrationRepositoryConfigTest.java | 20 +-- .../uaa/provider/saml/idp/SamlTestUtils.java | 43 ----- .../LoginServerSecurityIntegrationTests.java | 107 ++++++------- .../saml/SamlAuthenticationMockMvcTests.java | 3 - .../saml/SamlKeyRotationMockMvcTests.java | 148 +++++++++-------- .../mock/saml/SamlMetadataMockMvcTests.java | 30 ++-- 26 files changed, 976 insertions(+), 528 deletions(-) create mode 100644 server/src/main/java/org/cloudfoundry/identity/uaa/provider/saml/CertificateRuntimeException.java create mode 100644 server/src/main/java/org/cloudfoundry/identity/uaa/provider/saml/SamlMetadataEntityDescriptorCustomizer.java create mode 100644 server/src/main/java/org/cloudfoundry/identity/uaa/provider/saml/SamlNameIdFormats.java diff --git a/model/src/main/java/org/cloudfoundry/identity/uaa/saml/SamlKey.java b/model/src/main/java/org/cloudfoundry/identity/uaa/saml/SamlKey.java index a7d56e71e9e..30965bf371e 100644 --- a/model/src/main/java/org/cloudfoundry/identity/uaa/saml/SamlKey.java +++ b/model/src/main/java/org/cloudfoundry/identity/uaa/saml/SamlKey.java @@ -17,45 +17,17 @@ import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonInclude; +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.NoArgsConstructor; +@Data +@AllArgsConstructor +@NoArgsConstructor @JsonIgnoreProperties(ignoreUnknown = true) @JsonInclude(JsonInclude.Include.NON_NULL) public class SamlKey { - private String key; private String passphrase; private String certificate; - - public SamlKey() { - } - - public SamlKey(String key, String passphrase, String certificate) { - this.key = key; - this.passphrase = passphrase; - this.certificate = certificate; - } - - public String getKey() { - return key; - } - - public void setKey(String key) { - this.key = key; - } - - public String getPassphrase() { - return passphrase; - } - - public void setPassphrase(String passphrase) { - this.passphrase = passphrase; - } - - public String getCertificate() { - return certificate; - } - - public void setCertificate(String certificate) { - this.certificate = certificate; - } } diff --git a/model/src/main/java/org/cloudfoundry/identity/uaa/zone/SamlConfig.java b/model/src/main/java/org/cloudfoundry/identity/uaa/zone/SamlConfig.java index 63a8e8da8bd..55432773bea 100644 --- a/model/src/main/java/org/cloudfoundry/identity/uaa/zone/SamlConfig.java +++ b/model/src/main/java/org/cloudfoundry/identity/uaa/zone/SamlConfig.java @@ -21,9 +21,12 @@ import lombok.Data; import org.cloudfoundry.identity.uaa.saml.SamlKey; +import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; +import java.util.List; import java.util.Map; +import java.util.Optional; import static org.springframework.util.StringUtils.hasText; @@ -55,81 +58,96 @@ public void setEntityID(String entityID) { @JsonProperty("certificate") public void setCertificate(String certificate) { - SamlKey legacyKey = keys.get(LEGACY_KEY_ID); - if (hasText(certificate) && null == legacyKey) { - legacyKey = new SamlKey(); - } - if (legacyKey != null) { - legacyKey.setCertificate(certificate); - keys.put(LEGACY_KEY_ID, legacyKey); + if (hasText(certificate)) { + keys.computeIfAbsent(LEGACY_KEY_ID, k -> new SamlKey()); } + keys.computeIfPresent(LEGACY_KEY_ID, (k, v) -> { + v.setCertificate(certificate); + return v; + }); } @JsonProperty("privateKey") public void setPrivateKey(String privateKey) { - SamlKey legacyKey = keys.get(LEGACY_KEY_ID); - if (hasText(privateKey) && null == legacyKey) { - legacyKey = new SamlKey(); - } - if (legacyKey != null) { - legacyKey.setKey(privateKey); - keys.put(LEGACY_KEY_ID, legacyKey); + if (hasText(privateKey)) { + keys.computeIfAbsent(LEGACY_KEY_ID, k -> new SamlKey()); } + keys.computeIfPresent(LEGACY_KEY_ID, (k, v) -> { + v.setKey(privateKey); + return v; + }); } @JsonProperty("privateKeyPassword") public void setPrivateKeyPassword(String privateKeyPassword) { - SamlKey legacyKey = keys.get(LEGACY_KEY_ID); - if (hasText(privateKeyPassword) && null == legacyKey) { - legacyKey = new SamlKey(); - } - if (legacyKey != null) { - legacyKey.setPassphrase(privateKeyPassword); - keys.put(LEGACY_KEY_ID, legacyKey); + if (hasText(privateKeyPassword)) { + keys.computeIfAbsent(LEGACY_KEY_ID, k -> new SamlKey()); } + keys.computeIfPresent(LEGACY_KEY_ID, (k, v) -> { + v.setPassphrase(privateKeyPassword); + return v; + }); } @JsonProperty("certificate") public String getCertificate() { - SamlKey legacyKey = keys.get(LEGACY_KEY_ID); - if (null != legacyKey) { - return legacyKey.getCertificate(); - } - return null; + return Optional.ofNullable(keys.get(LEGACY_KEY_ID)) + .map(SamlKey::getCertificate) + .orElse(null); } @JsonProperty public String getPrivateKey() { - SamlKey legacyKey = keys.get(LEGACY_KEY_ID); - if (null != legacyKey) { - return legacyKey.getKey(); - } - return null; + return Optional.ofNullable(keys.get(LEGACY_KEY_ID)) + .map(SamlKey::getKey) + .orElse(null); } @JsonProperty public String getPrivateKeyPassword() { - SamlKey legacyKey = keys.get(LEGACY_KEY_ID); - if (null != legacyKey) { - return legacyKey.getPassphrase(); - } - return null; + return Optional.ofNullable(keys.get(LEGACY_KEY_ID)) + .map(SamlKey::getPassphrase) + .orElse(null); } public String getActiveKeyId() { return hasText(activeKeyId) ? activeKeyId : hasLegacyKey() ? LEGACY_KEY_ID : null; } + @JsonIgnore + public SamlKey getActiveKey() { + String keyId = getActiveKeyId(); + return keyId != null ? keys.get(keyId) : null; + } + public void setActiveKeyId(String activeKeyId) { if (!LEGACY_KEY_ID.equals(activeKeyId)) { this.activeKeyId = activeKeyId; } } + /** + * @return a map of all keys by keyName + */ public Map getKeys() { return Collections.unmodifiableMap(keys); } + /** + * @return the list of keys, with the active key first. + */ + @JsonIgnore + public List getKeyList() { + List keyList = new ArrayList<>(); + String activeKeyId = getActiveKeyId(); + Optional.ofNullable(getActiveKey()).ifPresent(keyList::add); + keyList.addAll(keys.entrySet().stream() + .filter(e -> !e.getKey().equals(activeKeyId)) + .map(Map.Entry::getValue) + .toList()); + return Collections.unmodifiableList(keyList); + } + public void setKeys(Map keys) { this.keys = new HashMap<>(keys); } diff --git a/model/src/test/java/org/cloudfoundry/identity/uaa/zone/SamlConfigTest.java b/model/src/test/java/org/cloudfoundry/identity/uaa/zone/SamlConfigTest.java index 3a47709a494..76ccc5f3310 100644 --- a/model/src/test/java/org/cloudfoundry/identity/uaa/zone/SamlConfigTest.java +++ b/model/src/test/java/org/cloudfoundry/identity/uaa/zone/SamlConfigTest.java @@ -20,6 +20,7 @@ import org.junit.jupiter.api.Test; import java.util.Collections; +import java.util.List; import java.util.Map; import static org.assertj.core.api.Assertions.assertThat; @@ -104,7 +105,7 @@ void legacy_key_is_part_of_map() { config.setPrivateKeyPassword(passphrase); config.setCertificate(certificate); Map keys = config.getKeys(); - assertThat(keys).hasSize(1).containsKey(LEGACY_KEY_ID); + assertThat(keys).containsOnlyKeys(LEGACY_KEY_ID); assertThat(keys.get(LEGACY_KEY_ID).getKey()).isEqualTo(privateKey); assertThat(keys.get(LEGACY_KEY_ID).getPassphrase()).isEqualTo(passphrase); assertThat(keys.get(LEGACY_KEY_ID).getCertificate()).isEqualTo(certificate); @@ -116,12 +117,14 @@ void addActiveKey() { String keyId = "testKeyId"; config.addAndActivateKey(keyId, key); Map keys = config.getKeys(); - assertThat(keys).hasSize(1); + assertThat(keys).hasSize(1) + .containsKey(keyId); assertThat(config.getActiveKeyId()).isEqualTo(keyId); - assertThat(keys).containsKey(keyId); - assertThat(keys.get(keyId).getKey()).isEqualTo(privateKey); - assertThat(keys.get(keyId).getPassphrase()).isEqualTo(passphrase); - assertThat(keys.get(keyId).getCertificate()).isEqualTo(certificate); + assertThat(keys.get(keyId)).returns(privateKey, SamlKey::getKey) + .returns(passphrase, SamlKey::getPassphrase) + .returns(certificate, SamlKey::getCertificate); + assertThat(config.getActiveKey()).isSameAs(keys.get(keyId)); + assertThat(config.getKeyList()).hasSize(1).containsExactly(key); } @Test @@ -131,12 +134,44 @@ void addNonActive() { String keyId = "nonActiveKeyId"; config.addKey(keyId, key); Map keys = config.getKeys(); - assertThat(keys).hasSize(2); + assertThat(keys).hasSize(2) + .containsKey(keyId); assertThat(config.getActiveKeyId()).isNotEqualTo(keyId); - assertThat(keys).containsKey(keyId); - assertThat(keys.get(keyId).getKey()).isEqualTo(privateKey); - assertThat(keys.get(keyId).getPassphrase()).isEqualTo(passphrase); - assertThat(keys.get(keyId).getCertificate()).isEqualTo(certificate); + assertThat(keys.get(keyId)).returns(privateKey, SamlKey::getKey) + .returns(passphrase, SamlKey::getPassphrase) + .returns(certificate, SamlKey::getCertificate); + } + + @Test + void getKeyList() { + // Default is empty + assertThat(config.getKeyList()).isEmpty(); + + // Add active key, should only have that key + addActiveKey(); + SamlKey activeKey = config.getActiveKey(); + assertThat(config.getKeyList()).containsExactly(activeKey); + + // Add another key, should have both keys + SamlKey nonActiveKey = new SamlKey(privateKey, passphrase, certificate); + String nonActiveKeyId = "nonActiveKeyId"; + config.addKey(nonActiveKeyId, nonActiveKey); + assertThat(config.getKeyList()).containsExactly(activeKey, nonActiveKey); + + // add another active key, should have the new key first + SamlKey otherActiveKey = new SamlKey(privateKey, passphrase, certificate); + config.addAndActivateKey("anotherActiveKeyId", otherActiveKey); + assertThat(config.getKeyList()).hasSize(3).first().isSameAs(otherActiveKey); + + // remove the non-active key, should have other 2 keys + config.removeKey(nonActiveKeyId); + assertThat(config.getKeyList()).containsExactly(otherActiveKey, activeKey); + + // drop the current active key, should have only the remaining key... even though it is not active + config.removeKey("anotherActiveKeyId"); + assertThat(config.getActiveKey()).isNull(); + assertThat(config.getKeys()).hasSize(1); + assertThat(config.getKeyList()).containsExactly(activeKey); } @Test @@ -153,19 +188,56 @@ void testIsWantAssertionSigned() { @Test void testSetKeyAndCert() { + // Default values are null + assertThat(config).returns(null, SamlConfig::getPrivateKey) + .returns(null, SamlConfig::getPrivateKeyPassword) + .returns(null, SamlConfig::getCertificate) + .extracting(SamlConfig::getActiveKey) + .isNull(); + + // Set values to null, does not create a key + config.setPrivateKey(null); + config.setPrivateKeyPassword(null); + config.setCertificate(null); + assertThat(config).returns(null, SamlConfig::getPrivateKey) + .returns(null, SamlConfig::getPrivateKeyPassword) + .returns(null, SamlConfig::getCertificate) + .extracting(SamlConfig::getActiveKey) + .isNull(); + + // Set values to non-null, creates a key object config.setPrivateKey(privateKey); config.setPrivateKeyPassword(passphrase); config.setCertificate(certificate); - assertThat(config.getPrivateKey()).isEqualTo(privateKey); - assertThat(config.getPrivateKeyPassword()).isEqualTo(passphrase); + assertThat(config).returns(privateKey, SamlConfig::getPrivateKey) + .returns(passphrase, SamlConfig::getPrivateKeyPassword) + .returns(certificate, SamlConfig::getCertificate) + .extracting(SamlConfig::getActiveKey) + .isNotNull() + .returns(privateKey, SamlKey::getKey) + .returns(certificate, SamlKey::getCertificate) + .returns(passphrase, SamlKey::getPassphrase); + + // Set values to null, retains the key object with nulls + config.setPrivateKey(null); + config.setPrivateKeyPassword(null); + config.setCertificate(null); + assertThat(config).returns(null, SamlConfig::getPrivateKey) + .returns(null, SamlConfig::getPrivateKeyPassword) + .returns(null, SamlConfig::getCertificate) + .extracting(SamlConfig::getActiveKey) + .isNotNull() + .returns(null, SamlKey::getKey) + .returns(null, SamlKey::getCertificate) + .returns(null, SamlKey::getPassphrase); } @Test void read_old_json_works() { read_json(oldJson); - assertThat(config.getPrivateKey()).isEqualTo(privateKey); - assertThat(config.getPrivateKeyPassword()).isEqualTo(passphrase); - assertThat(config.getCertificate()).isEqualTo(certificate); + assertThat(config).returns(privateKey, SamlConfig::getPrivateKey) + .returns(passphrase, SamlConfig::getPrivateKeyPassword) + .returns(certificate, SamlConfig::getCertificate); } public void read_json(String json) { @@ -177,9 +249,9 @@ void to_json_ignores_legacy_values() { read_json(oldJson); String json = JsonUtils.writeValueAsString(config); read_json(json); - assertThat(config.getPrivateKey()).isEqualTo(privateKey); - assertThat(config.getPrivateKeyPassword()).isEqualTo(passphrase); - assertThat(config.getCertificate()).isEqualTo(certificate); + assertThat(config).returns(privateKey, SamlConfig::getPrivateKey) + .returns(passphrase, SamlConfig::getPrivateKeyPassword) + .returns(certificate, SamlConfig::getCertificate); } @Test @@ -193,8 +265,10 @@ void can_clear_keys() { read_json(oldJson); assertThat(config.getKeys()).hasSize(1); assertThat(config.getActiveKeyId()).isNotNull(); + assertThat(config.getActiveKey()).isNotNull(); config.setKeys(Collections.emptyMap()); assertThat(config.getKeys()).isEmpty(); assertThat(config.getActiveKeyId()).isNull(); + assertThat(config.getActiveKey()).isNull(); } } \ No newline at end of file diff --git a/server/src/main/java/org/cloudfoundry/identity/uaa/impl/config/IdentityZoneConfigurationBootstrap.java b/server/src/main/java/org/cloudfoundry/identity/uaa/impl/config/IdentityZoneConfigurationBootstrap.java index 4b34a56babe..a06f96cd790 100644 --- a/server/src/main/java/org/cloudfoundry/identity/uaa/impl/config/IdentityZoneConfigurationBootstrap.java +++ b/server/src/main/java/org/cloudfoundry/identity/uaa/impl/config/IdentityZoneConfigurationBootstrap.java @@ -12,8 +12,7 @@ *******************************************************************************/ package org.cloudfoundry.identity.uaa.impl.config; -import lombok.Getter; -import lombok.Setter; +import lombok.Data; import org.cloudfoundry.identity.uaa.login.Prompt; import org.cloudfoundry.identity.uaa.saml.SamlKey; import org.cloudfoundry.identity.uaa.util.JsonUtils; @@ -36,7 +35,7 @@ import static java.util.Optional.ofNullable; import static org.springframework.util.StringUtils.hasText; -@Setter +@Data public class IdentityZoneConfigurationBootstrap implements InitializingBean { private ClientSecretPolicy clientSecretPolicy; @@ -44,7 +43,6 @@ public class IdentityZoneConfigurationBootstrap implements InitializingBean { private final IdentityZoneProvisioning provisioning; private boolean selfServiceLinksEnabled = true; - @Getter private String homeRedirect = null; private Map selfServiceLinks; private List logoutRedirectWhitelist; @@ -70,7 +68,6 @@ public class IdentityZoneConfigurationBootstrap implements InitializingBean { private UserConfig defaultUserConfig; private IdentityZoneValidator validator = (config, mode) -> config; - @Getter private Map branding; public IdentityZoneConfigurationBootstrap(IdentityZoneProvisioning provisioning) { @@ -144,96 +141,4 @@ public IdentityZoneConfigurationBootstrap setActiveKeyId(String activeKeyId) { this.activeKeyId = activeKeyId != null ? activeKeyId.toLowerCase(Locale.ROOT) : null; return this; } - - public void setTokenPolicy(TokenPolicy tokenPolicy) { - this.tokenPolicy = tokenPolicy; - } - - public void setSelfServiceLinksEnabled(boolean selfServiceLinksEnabled) { - this.selfServiceLinksEnabled = selfServiceLinksEnabled; - } - - public void setHomeRedirect(String homeRedirect) { - this.homeRedirect = homeRedirect; - } - - public String getHomeRedirect() { - return homeRedirect; - } - - public void setSelfServiceLinks(Map links) { - this.selfServiceLinks = links; - } - - public void setLogoutDefaultRedirectUrl(String logoutDefaultRedirectUrl) { - this.logoutDefaultRedirectUrl = logoutDefaultRedirectUrl; - } - - public void setLogoutDisableRedirectParameter(boolean logoutDisableRedirectParameter) { - this.logoutDisableRedirectParameter = logoutDisableRedirectParameter; - } - - public void setLogoutRedirectParameterName(String logoutRedirectParameterName) { - this.logoutRedirectParameterName = logoutRedirectParameterName; - } - - public void setLogoutRedirectWhitelist(List logoutRedirectWhitelist) { - this.logoutRedirectWhitelist = logoutRedirectWhitelist; - } - - public void setPrompts(List prompts) { - this.prompts = prompts; - } - - public void setDefaultIdentityProvider(String defaultIdentityProvider) { - this.defaultIdentityProvider = defaultIdentityProvider; - } - - public void setSamlSpCertificate(String samlSpCertificate) { - this.samlSpCertificate = samlSpCertificate; - } - - public void setSamlSpPrivateKey(String samlSpPrivateKey) { - this.samlSpPrivateKey = samlSpPrivateKey; - } - - public void setSamlSpPrivateKeyPassphrase(String samlSpPrivateKeyPassphrase) { - this.samlSpPrivateKeyPassphrase = samlSpPrivateKeyPassphrase; - } - - public boolean isIdpDiscoveryEnabled() { - return idpDiscoveryEnabled; - } - - public void setIdpDiscoveryEnabled(boolean idpDiscoveryEnabled) { - this.idpDiscoveryEnabled = idpDiscoveryEnabled; - } - - public boolean isAccountChooserEnabled() { - return accountChooserEnabled; - } - - public void setAccountChooserEnabled(boolean accountChooserEnabled) { - this.accountChooserEnabled = accountChooserEnabled; - } - - public void setBranding(Map branding) { - this.branding = branding; - } - - public Map getBranding() { - return branding; - } - - public boolean isDisableSamlInResponseToCheck() { - return disableSamlInResponseToCheck; - } - - public void setDisableSamlInResponseToCheck(boolean disableSamlInResponseToCheck) { - this.disableSamlInResponseToCheck = disableSamlInResponseToCheck; - } - - public void setDefaultUserConfig(final UserConfig defaultUserConfig) { - this.defaultUserConfig = defaultUserConfig; - } } diff --git a/server/src/main/java/org/cloudfoundry/identity/uaa/provider/saml/BaseUaaRelyingPartyRegistrationRepository.java b/server/src/main/java/org/cloudfoundry/identity/uaa/provider/saml/BaseUaaRelyingPartyRegistrationRepository.java index 729a8f1d0cb..555e1f1b49b 100644 --- a/server/src/main/java/org/cloudfoundry/identity/uaa/provider/saml/BaseUaaRelyingPartyRegistrationRepository.java +++ b/server/src/main/java/org/cloudfoundry/identity/uaa/provider/saml/BaseUaaRelyingPartyRegistrationRepository.java @@ -1,5 +1,7 @@ package org.cloudfoundry.identity.uaa.provider.saml; +import lombok.extern.slf4j.Slf4j; +import org.cloudfoundry.identity.uaa.saml.SamlKey; import org.cloudfoundry.identity.uaa.util.KeyWithCert; import org.cloudfoundry.identity.uaa.zone.IdentityZone; import org.cloudfoundry.identity.uaa.zone.IdentityZoneConfiguration; @@ -7,22 +9,25 @@ import org.cloudfoundry.identity.uaa.zone.ZoneAware; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; +import java.security.cert.CertificateException; +import java.util.List; import java.util.Optional; +@Slf4j public abstract class BaseUaaRelyingPartyRegistrationRepository implements RelyingPartyRegistrationRepository, ZoneAware { - protected final KeyWithCert keyWithCert; protected final String uaaWideSamlEntityID; protected final String uaaWideSamlEntityIDAlias; + protected final List defaultKeysWithCerts; - protected BaseUaaRelyingPartyRegistrationRepository(KeyWithCert keyWithCert, String uaaWideSamlEntityID, String uaaWideSamlEntityIDAlias) { - this.keyWithCert = keyWithCert; + protected BaseUaaRelyingPartyRegistrationRepository(String uaaWideSamlEntityID, String uaaWideSamlEntityIDAlias, List defaultKeysWithCerts) { this.uaaWideSamlEntityID = uaaWideSamlEntityID; this.uaaWideSamlEntityIDAlias = uaaWideSamlEntityIDAlias; + this.defaultKeysWithCerts = defaultKeysWithCerts; } String getZoneEntityId(IdentityZone currentZone) { // for default zone, use the samlEntityID - if (currentZone.isUaa() ) { + if (currentZone.isUaa()) { return uaaWideSamlEntityID; } @@ -45,4 +50,23 @@ String getZoneEntityIdAlias(IdentityZone currentZone) { // for non-default zone, use the "zone subdomain+.+alias" return "%s.%s".formatted(currentZone.getSubdomain(), alias); } + + public List convertToKeysWithCerts(List samlKeys) { + if (samlKeys == null) { + return List.of(); + } + + try { + return samlKeys.stream().map(k -> { + try { + return new KeyWithCert(k); + } catch (CertificateException e) { + log.error("Error converting key with cert", e); + throw new CertificateRuntimeException(e); + } + }).toList(); + } catch (CertificateRuntimeException e) { + return List.of(); + } + } } diff --git a/server/src/main/java/org/cloudfoundry/identity/uaa/provider/saml/CertificateRuntimeException.java b/server/src/main/java/org/cloudfoundry/identity/uaa/provider/saml/CertificateRuntimeException.java new file mode 100644 index 00000000000..7a1d1621fd1 --- /dev/null +++ b/server/src/main/java/org/cloudfoundry/identity/uaa/provider/saml/CertificateRuntimeException.java @@ -0,0 +1,9 @@ +package org.cloudfoundry.identity.uaa.provider.saml; + +import java.security.cert.CertificateException; + +public class CertificateRuntimeException extends RuntimeException { + public CertificateRuntimeException(CertificateException e) { + super(e); + } +} diff --git a/server/src/main/java/org/cloudfoundry/identity/uaa/provider/saml/ConfiguratorRelyingPartyRegistrationRepository.java b/server/src/main/java/org/cloudfoundry/identity/uaa/provider/saml/ConfiguratorRelyingPartyRegistrationRepository.java index cda5d22034e..e8a1b4a1afa 100644 --- a/server/src/main/java/org/cloudfoundry/identity/uaa/provider/saml/ConfiguratorRelyingPartyRegistrationRepository.java +++ b/server/src/main/java/org/cloudfoundry/identity/uaa/provider/saml/ConfiguratorRelyingPartyRegistrationRepository.java @@ -4,6 +4,7 @@ import org.cloudfoundry.identity.uaa.provider.SamlIdentityProviderDefinition; import org.cloudfoundry.identity.uaa.util.KeyWithCert; import org.cloudfoundry.identity.uaa.zone.IdentityZone; +import org.cloudfoundry.identity.uaa.zone.SamlConfig; import org.cloudfoundry.identity.uaa.zone.ZoneAware; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; @@ -19,9 +20,9 @@ public class ConfiguratorRelyingPartyRegistrationRepository extends BaseUaaRelyi public ConfiguratorRelyingPartyRegistrationRepository(String uaaWideSamlEntityID, String uaaWideSamlEntityIDAlias, - KeyWithCert keyWithCert, + List defaultKeysWithCerts, SamlIdentityProviderConfigurator configurator) { - super(keyWithCert, uaaWideSamlEntityID, uaaWideSamlEntityIDAlias); + super(uaaWideSamlEntityID, uaaWideSamlEntityIDAlias, defaultKeysWithCerts); Assert.notNull(configurator, "configurator cannot be null"); this.configurator = configurator; } @@ -39,13 +40,23 @@ public RelyingPartyRegistration findByRegistrationId(String registrationId) { List identityProviderDefinitions = configurator.getIdentityProviderDefinitionsForZone(currentZone); for (SamlIdentityProviderDefinition identityProviderDefinition : identityProviderDefinitions) { if (identityProviderDefinition.getIdpEntityAlias().equals(registrationId)) { + + SamlConfig samlConfig = currentZone.getConfig().getSamlConfig(); + List keyWithCerts = null; + if (samlConfig != null) { + keyWithCerts = convertToKeysWithCerts(samlConfig.getKeyList()); + } + if (keyWithCerts == null || keyWithCerts.isEmpty()) { + keyWithCerts = defaultKeysWithCerts; + } + String zonedSamlEntityID = getZoneEntityId(currentZone); String zonedSamlEntityIDAlias = getZoneEntityIdAlias(currentZone); boolean requestSigned = currentZone.getConfig().getSamlConfig().isRequestSigned(); return RelyingPartyRegistrationBuilder.buildRelyingPartyRegistration( zonedSamlEntityID, identityProviderDefinition.getNameID(), - keyWithCert, identityProviderDefinition.getMetaDataLocation(), + keyWithCerts, identityProviderDefinition.getMetaDataLocation(), registrationId, zonedSamlEntityIDAlias, requestSigned); } } diff --git a/server/src/main/java/org/cloudfoundry/identity/uaa/provider/saml/DefaultRelyingPartyRegistrationRepository.java b/server/src/main/java/org/cloudfoundry/identity/uaa/provider/saml/DefaultRelyingPartyRegistrationRepository.java index 897d55db4e2..b826b8370dd 100644 --- a/server/src/main/java/org/cloudfoundry/identity/uaa/provider/saml/DefaultRelyingPartyRegistrationRepository.java +++ b/server/src/main/java/org/cloudfoundry/identity/uaa/provider/saml/DefaultRelyingPartyRegistrationRepository.java @@ -2,19 +2,26 @@ import org.cloudfoundry.identity.uaa.util.KeyWithCert; import org.cloudfoundry.identity.uaa.zone.IdentityZone; +import org.cloudfoundry.identity.uaa.zone.SamlConfig; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; +import java.util.List; + /** - * A {@link RelyingPartyRegistrationRepository} that always returns a default {@link RelyingPartyRegistrationRepository}. + * A ZoneAware {@link RelyingPartyRegistrationRepository} that always returns a default + * {@link RelyingPartyRegistrationRepository}. The default {@link RelyingPartyRegistration} in the + * {@link SamlRelyingPartyRegistrationRepositoryConfig} is configured to use a dummy SAML IdP metadata + * for the default zone (named example), this class also provides a dummy SAML IdP RelyingPartyRegistration + * but for non-default zones. */ public class DefaultRelyingPartyRegistrationRepository extends BaseUaaRelyingPartyRegistrationRepository { public static final String CLASSPATH_DUMMY_SAML_IDP_METADATA_XML = "classpath:dummy-saml-idp-metadata.xml"; public DefaultRelyingPartyRegistrationRepository(String uaaWideSamlEntityID, String uaaWideSamlEntityIDAlias, - KeyWithCert keyWithCert) { - super(keyWithCert, uaaWideSamlEntityID, uaaWideSamlEntityIDAlias); + List defaultKeysWithCerts) { + super(uaaWideSamlEntityID, uaaWideSamlEntityIDAlias, defaultKeysWithCerts); } /** @@ -27,18 +34,25 @@ public DefaultRelyingPartyRegistrationRepository(String uaaWideSamlEntityID, @Override public RelyingPartyRegistration findByRegistrationId(String registrationId) { IdentityZone currentZone = retrieveZone(); + List keyWithCerts = null; boolean requestSigned = true; if (currentZone.getConfig() != null && currentZone.getConfig().getSamlConfig() != null) { + SamlConfig samlConfig = currentZone.getConfig().getSamlConfig(); + keyWithCerts = convertToKeysWithCerts(samlConfig.getKeyList()); requestSigned = currentZone.getConfig().getSamlConfig().isRequestSigned(); } + if (keyWithCerts == null || keyWithCerts.isEmpty()) { + keyWithCerts = defaultKeysWithCerts; + } + String zonedSamlEntityID = getZoneEntityId(currentZone); String zonedSamlEntityIDAlias = getZoneEntityIdAlias(currentZone); return RelyingPartyRegistrationBuilder.buildRelyingPartyRegistration( zonedSamlEntityID, null, - keyWithCert, CLASSPATH_DUMMY_SAML_IDP_METADATA_XML, registrationId, + keyWithCerts, CLASSPATH_DUMMY_SAML_IDP_METADATA_XML, registrationId, zonedSamlEntityIDAlias, requestSigned); } } diff --git a/server/src/main/java/org/cloudfoundry/identity/uaa/provider/saml/RelyingPartyRegistrationBuilder.java b/server/src/main/java/org/cloudfoundry/identity/uaa/provider/saml/RelyingPartyRegistrationBuilder.java index b54140294f9..374f2fe439c 100644 --- a/server/src/main/java/org/cloudfoundry/identity/uaa/provider/saml/RelyingPartyRegistrationBuilder.java +++ b/server/src/main/java/org/cloudfoundry/identity/uaa/provider/saml/RelyingPartyRegistrationBuilder.java @@ -11,6 +11,7 @@ import java.io.ByteArrayInputStream; import java.io.InputStream; +import java.util.List; import java.util.function.UnaryOperator; @Slf4j @@ -24,10 +25,22 @@ private RelyingPartyRegistrationBuilder() { throw new java.lang.UnsupportedOperationException("This is a utility class and cannot be instantiated"); } + /** + * @param samlEntityID the entityId of the relying party + * @param samlSpNameId the nameIdFormat of the relying party + * @param keys a list of KeyWithCert objects, with the first key in the list being the active key, all keys in the + * list will be added for signing. Although it is possible to have multiple decryption keys, + * only the first one will be used to maintain parity with existing UAA + * @param metadataLocation the location or XML data of the metadata + * @param rpRegistrationId the registrationId of the relying party + * @param samlSpAlias the alias of the relying party for the SAML endpoints + * @param requestSigned whether the AuthnRequest should be signed + * @return a RelyingPartyRegistration object + */ public static RelyingPartyRegistration buildRelyingPartyRegistration( String samlEntityID, String samlSpNameId, - KeyWithCert keyWithCert, String metadataLocation, - String rpRegstrationId, String samlSpAlias, boolean requestSigned) { + List keys, String metadataLocation, + String rpRegistrationId, String samlSpAlias, boolean requestSigned) { SamlIdentityProviderDefinition.MetadataLocation type = SamlIdentityProviderDefinition.getType(metadataLocation); RelyingPartyRegistration.Builder builder; @@ -43,27 +56,32 @@ public static RelyingPartyRegistration buildRelyingPartyRegistration( } builder.entityId(samlEntityID); + if (rpRegistrationId != null) builder.registrationId(rpRegistrationId); if (samlSpNameId != null) builder.nameIdFormat(samlSpNameId); - if (rpRegstrationId != null) builder.registrationId(rpRegstrationId); + return builder + .signingX509Credentials(cred -> + keys.stream() + .map(k -> Saml2X509Credential.signing(k.getPrivateKey(), k.getCertificate())) + .forEach(cred::add) + ) + .decryptionX509Credentials(cred -> keys.stream() + .findFirst() + .map(k -> Saml2X509Credential.decryption(k.getPrivateKey(), k.getCertificate())) + .ifPresent(cred::add) + ) .assertionConsumerServiceLocation(assertionConsumerServiceLocationFunction.apply(samlSpAlias)) - .singleLogoutServiceResponseLocation(singleLogoutServiceResponseLocationFunction.apply(samlSpAlias)) + .assertionConsumerServiceBinding(Saml2MessageBinding.POST) .singleLogoutServiceLocation(singleLogoutServiceLocationFunction.apply(samlSpAlias)) .singleLogoutServiceResponseLocation(singleLogoutServiceResponseLocationFunction.apply(samlSpAlias)) // Accept both POST and REDIRECT bindings .singleLogoutServiceBindings(c -> { - c.add(Saml2MessageBinding.REDIRECT); c.add(Saml2MessageBinding.POST); + c.add(Saml2MessageBinding.REDIRECT); }) - .assertingPartyDetails(details -> details - .wantAuthnRequestsSigned(requestSigned) - ) - .signingX509Credentials(cred -> cred - .add(Saml2X509Credential.signing(keyWithCert.getPrivateKey(), keyWithCert.getCertificate())) - ) - .decryptionX509Credentials(cred -> cred - .add(Saml2X509Credential.decryption(keyWithCert.getPrivateKey(), keyWithCert.getCertificate())) - ) + // alter the default value of the APs wantAuthnRequestsSigned, + // to reflect the UAA configured desire to always sign/or-not the AuthnRequest + .assertingPartyDetails(details -> details.wantAuthnRequestsSigned(requestSigned)) .build(); } } diff --git a/server/src/main/java/org/cloudfoundry/identity/uaa/provider/saml/SamlConfigProps.java b/server/src/main/java/org/cloudfoundry/identity/uaa/provider/saml/SamlConfigProps.java index de8fc44c978..8a989cfeb46 100644 --- a/server/src/main/java/org/cloudfoundry/identity/uaa/provider/saml/SamlConfigProps.java +++ b/server/src/main/java/org/cloudfoundry/identity/uaa/provider/saml/SamlConfigProps.java @@ -1,11 +1,16 @@ package org.cloudfoundry.identity.uaa.provider.saml; import lombok.Data; +import lombok.extern.slf4j.Slf4j; import org.cloudfoundry.identity.uaa.saml.SamlKey; +import org.cloudfoundry.identity.uaa.util.KeyWithCert; import org.springframework.boot.context.properties.ConfigurationProperties; +import java.security.cert.CertificateException; +import java.util.List; import java.util.Map; +@Slf4j @Data @ConfigurationProperties(prefix = "login.saml") public class SamlConfigProps { @@ -24,4 +29,15 @@ public class SamlConfigProps { public SamlKey getActiveSamlKey() { return keys.get(activeKeyId); } + + public List getKeysWithCerts() { + return keys.values().stream().map(k -> { + try { + return new KeyWithCert(k); + } catch (CertificateException e) { + log.error("Error converting key with cert", e); + throw new CertificateRuntimeException(e); + } + }).toList(); + } } diff --git a/server/src/main/java/org/cloudfoundry/identity/uaa/provider/saml/SamlMetadataEndpoint.java b/server/src/main/java/org/cloudfoundry/identity/uaa/provider/saml/SamlMetadataEndpoint.java index c4abceb7dae..681c7f917d2 100644 --- a/server/src/main/java/org/cloudfoundry/identity/uaa/provider/saml/SamlMetadataEndpoint.java +++ b/server/src/main/java/org/cloudfoundry/identity/uaa/provider/saml/SamlMetadataEndpoint.java @@ -1,27 +1,23 @@ package org.cloudfoundry.identity.uaa.provider.saml; import org.cloudfoundry.identity.uaa.zone.IdentityZone; -import org.cloudfoundry.identity.uaa.zone.SamlConfig; import org.cloudfoundry.identity.uaa.zone.ZoneAware; import org.cloudfoundry.identity.uaa.zone.beans.IdentityZoneManager; -import org.opensaml.saml.common.xml.SAMLConstants; -import org.opensaml.saml.saml2.metadata.EntityDescriptor; -import org.opensaml.saml.saml2.metadata.SPSSODescriptor; import org.springframework.http.HttpHeaders; import org.springframework.http.ResponseEntity; import org.springframework.security.saml2.provider.service.metadata.OpenSamlMetadataResolver; import org.springframework.security.saml2.provider.service.metadata.Saml2MetadataResolver; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; -import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; +import org.springframework.security.saml2.provider.service.web.RelyingPartyRegistrationResolver; import org.springframework.util.Assert; import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.PathVariable; import org.springframework.web.bind.annotation.RestController; +import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import java.net.URLEncoder; import java.nio.charset.StandardCharsets; -import java.util.function.Consumer; @RestController public class SamlMetadataEndpoint implements ZoneAware { @@ -29,40 +25,26 @@ public class SamlMetadataEndpoint implements ZoneAware { private static final String APPLICATION_XML_CHARSET_UTF_8 = "application/xml; charset=UTF-8"; private final Saml2MetadataResolver saml2MetadataResolver; - private final IdentityZoneManager identityZoneManager; - private final RelyingPartyRegistrationRepository relyingPartyRegistrationRepository; + private final RelyingPartyRegistrationResolver relyingPartyRegistrationResolver; - public SamlMetadataEndpoint(RelyingPartyRegistrationRepository relyingPartyRegistrationRepository, + public SamlMetadataEndpoint(RelyingPartyRegistrationResolver registrationResolver, IdentityZoneManager identityZoneManager) { - Assert.notNull(relyingPartyRegistrationRepository, "relyingPartyRegistrationRepository cannot be null"); - this.relyingPartyRegistrationRepository = relyingPartyRegistrationRepository; - this.identityZoneManager = identityZoneManager; - OpenSamlMetadataResolver resolver = new OpenSamlMetadataResolver(); - this.saml2MetadataResolver = resolver; - resolver.setEntityDescriptorCustomizer(new EntityDescriptorCustomizer()); - } - - private class EntityDescriptorCustomizer implements Consumer { - @Override - public void accept(OpenSamlMetadataResolver.EntityDescriptorParameters entityDescriptorParameters) { - SamlConfig samlConfig = identityZoneManager.getCurrentIdentityZone().getConfig().getSamlConfig(); - - EntityDescriptor descriptor = entityDescriptorParameters.getEntityDescriptor(); - SPSSODescriptor spssodescriptor = descriptor.getSPSSODescriptor(SAMLConstants.SAML20P_NS); - spssodescriptor.setWantAssertionsSigned(samlConfig.isWantAssertionSigned()); - spssodescriptor.setAuthnRequestsSigned(samlConfig.isRequestSigned()); - } + Assert.notNull(registrationResolver, "registrationResolver cannot be null"); + relyingPartyRegistrationResolver = registrationResolver; + OpenSamlMetadataResolver metadataResolver = new OpenSamlMetadataResolver(); + saml2MetadataResolver = metadataResolver; + metadataResolver.setEntityDescriptorCustomizer(new SamlMetadataEntityDescriptorCustomizer(identityZoneManager)); } @GetMapping(value = "/saml/metadata", produces = APPLICATION_XML_CHARSET_UTF_8) - public ResponseEntity legacyMetadataEndpoint() { - return metadataEndpoint(DEFAULT_REGISTRATION_ID); + public ResponseEntity legacyMetadataEndpoint(HttpServletRequest request) { + return metadataEndpoint(request, DEFAULT_REGISTRATION_ID); } @GetMapping(value = "/saml/metadata/{registrationId}", produces = APPLICATION_XML_CHARSET_UTF_8) - public ResponseEntity metadataEndpoint(@PathVariable String registrationId) { - RelyingPartyRegistration relyingPartyRegistration = relyingPartyRegistrationRepository.findByRegistrationId(registrationId); + public ResponseEntity metadataEndpoint(HttpServletRequest request, @PathVariable String registrationId) { + RelyingPartyRegistration relyingPartyRegistration = relyingPartyRegistrationResolver.resolve(request, registrationId); if (relyingPartyRegistration == null) { return ResponseEntity.status(HttpServletResponse.SC_UNAUTHORIZED).build(); } diff --git a/server/src/main/java/org/cloudfoundry/identity/uaa/provider/saml/SamlMetadataEntityDescriptorCustomizer.java b/server/src/main/java/org/cloudfoundry/identity/uaa/provider/saml/SamlMetadataEntityDescriptorCustomizer.java new file mode 100644 index 00000000000..427c65eb9c8 --- /dev/null +++ b/server/src/main/java/org/cloudfoundry/identity/uaa/provider/saml/SamlMetadataEntityDescriptorCustomizer.java @@ -0,0 +1,133 @@ +package org.cloudfoundry.identity.uaa.provider.saml; + +import lombok.Value; +import org.cloudfoundry.identity.uaa.saml.SamlKey; +import org.cloudfoundry.identity.uaa.zone.SamlConfig; +import org.cloudfoundry.identity.uaa.zone.beans.IdentityZoneManager; +import org.opensaml.core.xml.XMLObjectBuilder; +import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport; +import org.opensaml.saml.common.xml.SAMLConstants; +import org.opensaml.saml.saml2.metadata.EntityDescriptor; +import org.opensaml.saml.saml2.metadata.NameIDFormat; +import org.opensaml.saml.saml2.metadata.SPSSODescriptor; +import org.opensaml.xmlsec.signature.KeyInfo; +import org.opensaml.xmlsec.signature.Signature; +import org.opensaml.xmlsec.signature.X509Certificate; +import org.opensaml.xmlsec.signature.X509Data; +import org.opensaml.xmlsec.signature.support.ContentReference; +import org.springframework.security.saml2.Saml2Exception; +import org.springframework.security.saml2.provider.service.metadata.OpenSamlMetadataResolver; + +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import java.util.function.Consumer; +import java.util.stream.Collectors; + +import static org.cloudfoundry.identity.uaa.provider.saml.SamlNameIdFormats.NAMEID_FORMAT_EMAIL; +import static org.cloudfoundry.identity.uaa.provider.saml.SamlNameIdFormats.NAMEID_FORMAT_PERSISTENT; +import static org.cloudfoundry.identity.uaa.provider.saml.SamlNameIdFormats.NAMEID_FORMAT_TRANSIENT; +import static org.cloudfoundry.identity.uaa.provider.saml.SamlNameIdFormats.NAMEID_FORMAT_UNSPECIFIED; +import static org.cloudfoundry.identity.uaa.provider.saml.SamlNameIdFormats.NAMEID_FORMAT_X509SUBJECT; + +/** + * This class is used to customize the EntityDescriptor used in the Metadata call, + * it is called as part of the {@link OpenSamlMetadataResolver} after basic creation is completed. + */ +@Value +public class SamlMetadataEntityDescriptorCustomizer implements Consumer { + private static final Set NAME_ID_FORMATS = new HashSet<>(); + + static { + NAME_ID_FORMATS.add(NAMEID_FORMAT_EMAIL); + NAME_ID_FORMATS.add(NAMEID_FORMAT_TRANSIENT); + NAME_ID_FORMATS.add(NAMEID_FORMAT_PERSISTENT); + NAME_ID_FORMATS.add(NAMEID_FORMAT_UNSPECIFIED); + NAME_ID_FORMATS.add(NAMEID_FORMAT_X509SUBJECT); + } + + IdentityZoneManager identityZoneManager; + + @Override + public void accept(OpenSamlMetadataResolver.EntityDescriptorParameters entityDescriptorParameters) { + SamlConfig samlConfig = identityZoneManager.getCurrentIdentityZone().getConfig().getSamlConfig(); + + EntityDescriptor entityDescriptor = entityDescriptorParameters.getEntityDescriptor(); + entityDescriptor.setID(entityDescriptor.getEntityID()); + addSignatureElement(entityDescriptor, samlConfig); + + SPSSODescriptor spSsoDescriptor = updateSpSsoDescriptor(entityDescriptor, samlConfig); + + updateNameIdFormats(spSsoDescriptor); + } + + private static SPSSODescriptor updateSpSsoDescriptor(EntityDescriptor entityDescriptor, SamlConfig samlConfig) { + SPSSODescriptor spSsoDescriptor = entityDescriptor.getSPSSODescriptor(SAMLConstants.SAML20P_NS); + spSsoDescriptor.setWantAssertionsSigned(samlConfig.isWantAssertionSigned()); + spSsoDescriptor.setAuthnRequestsSigned(samlConfig.isRequestSigned()); + + return spSsoDescriptor; + } + + /** + * Add a signature element to the entity descriptor. + * The signature contains the active key's certificate. + * + * @param entityDescriptor + * @param samlConfig + */ + private static void addSignatureElement(EntityDescriptor entityDescriptor, SamlConfig samlConfig) { + Signature signature = entityDescriptor.getSignature(); + if (signature == null) { + signature = (Signature) XMLObjectProviderRegistrySupport.getBuilderFactory().getBuilder(Signature.DEFAULT_ELEMENT_NAME).buildObject(Signature.DEFAULT_ELEMENT_NAME); + entityDescriptor.setSignature(signature); + } + signature.setSignatureAlgorithm("http://www.w3.org/2001/04/xmldsig-more#rsa-sha256"); + signature.setCanonicalizationAlgorithm("http://www.w3.org/2001/10/xml-exc-c14n#"); + List contentReferences = signature.getContentReferences(); + // TODO: ds:DigestValue is not set + // TODO: ds:SignatureValue is not set + + KeyInfo keyInfo = signature.getKeyInfo(); + if (keyInfo == null) { + keyInfo = (KeyInfo) XMLObjectProviderRegistrySupport.getBuilderFactory().getBuilder(KeyInfo.DEFAULT_ELEMENT_NAME).buildObject(KeyInfo.DEFAULT_ELEMENT_NAME); + signature.setKeyInfo(keyInfo); + } + + List x509Datas = keyInfo.getX509Datas(); + if (x509Datas.isEmpty()) { + x509Datas.add((X509Data) XMLObjectProviderRegistrySupport.getBuilderFactory().getBuilder(X509Data.DEFAULT_ELEMENT_NAME).buildObject(X509Data.DEFAULT_ELEMENT_NAME)); + } + X509Data x509Data = x509Datas.get(0); + List x509Certificates = x509Data.getX509Certificates(); + + SamlKey activeKey = samlConfig.getActiveKey(); + if (activeKey != null) { + X509Certificate x509 = (X509Certificate) XMLObjectProviderRegistrySupport.getBuilderFactory().getBuilder(X509Certificate.DEFAULT_ELEMENT_NAME).buildObject(X509Certificate.DEFAULT_ELEMENT_NAME); + x509.setValue(bareCertData(activeKey.getCertificate())); + x509Certificates.add(x509); + } + } + + private static String bareCertData(String cert) { + return cert.replace("-----BEGIN CERTIFICATE-----", "") + .replace("-----END CERTIFICATE-----", "") + .replace("\n", ""); + } + + private void updateNameIdFormats(SPSSODescriptor spSsoDescriptor) { + // TODO: dedupe the name id formats + spSsoDescriptor.getNameIDFormats().addAll(NAME_ID_FORMATS.stream().map(this::buildNameIDFormat).collect(Collectors.toSet())); + } + + private NameIDFormat buildNameIDFormat(String value) { + XMLObjectBuilder builder = (XMLObjectBuilder) XMLObjectProviderRegistrySupport.getBuilderFactory().getBuilder(NameIDFormat.DEFAULT_ELEMENT_NAME); + if (builder == null) { + throw new Saml2Exception("Unable to resolve Builder for " + NameIDFormat.DEFAULT_ELEMENT_NAME); + } + + NameIDFormat nameIdFormat = builder.buildObject(NameIDFormat.DEFAULT_ELEMENT_NAME); + nameIdFormat.setFormat(value); // nosonar + return nameIdFormat; + } +} diff --git a/server/src/main/java/org/cloudfoundry/identity/uaa/provider/saml/SamlNameIdFormats.java b/server/src/main/java/org/cloudfoundry/identity/uaa/provider/saml/SamlNameIdFormats.java new file mode 100644 index 00000000000..0d5fff1e4d8 --- /dev/null +++ b/server/src/main/java/org/cloudfoundry/identity/uaa/provider/saml/SamlNameIdFormats.java @@ -0,0 +1,151 @@ +package org.cloudfoundry.identity.uaa.provider.saml; + +/** + * This class contains NameID format constants for SAML 1.1 and SAML 2.0. + * + * @see Saml 2.0 Doc + * Section 8.3 - Name Identifier Format Identifiers + */ +public final class SamlNameIdFormats { + + private static final String NAMEID_FORMAT_BASE = "urn:oasis:names:tc:SAML:%s:nameid-format:%s"; + + /*************************************************************************** + * SAML 1.1 NameID Formats + */ + private static final String NAMEID_VERSION_1_1 = "1.1"; + + /** + * URI: urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress + *

+ * Indicates that the content of the element is in the form of an email address, specifically "addr-spec" as + * defined in IETF RFC 2822 [RFC 2822] Section 3.4.1. An addr-spec has the form local-part@domain. Note + * that an addr-spec has no phrase (such as a common name) before it, has no comment (text surrounded + * in parentheses) after it, and is not surrounded by "<" and ">". + */ + public static final String NAMEID_FORMAT_EMAIL = NAMEID_FORMAT_BASE.formatted(NAMEID_VERSION_1_1, "emailAddress"); + + /** + * URI: urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified + *

+ * The interpretation of the content of the element is left to individual implementations. + */ + public static final String NAMEID_FORMAT_UNSPECIFIED = NAMEID_FORMAT_BASE.formatted(NAMEID_VERSION_1_1, "unspecified"); + + /** + * URI: urn:oasis:names:tc:SAML:1.1:nameid-format:X509SubjectName + *

+ * Indicates that the content of the element is in the form specified for the contents of the + * element in the XML Signature Recommendation [XMLSig]. Implementors + *

+ * should note that the XML Signature specification specifies encoding rules for X.509 subject names that + * differ from the rules given in IETF RFC 2253 [RFC 2253]. + */ + public static final String NAMEID_FORMAT_X509SUBJECT = NAMEID_FORMAT_BASE.formatted(NAMEID_VERSION_1_1, "X509SubjectName"); + + /** + * URI: urn:oasis:names:tc:SAML:1.1:nameid-format:WindowsDomainQualifiedName + *

+ * Indicates that the content of the element is a Windows domain qualified name. A Windows domain + * qualified user name is a string of the form "DomainName\UserName". The domain name and "\" separator + * MAY be omitted. + */ + public static final String NAMEID_FORMAT_WINDOWS_DQN = NAMEID_FORMAT_BASE.formatted(NAMEID_VERSION_1_1, "WindowsDomainQualifiedName"); + + /*************************************************************************** + * SAML 2.0 NameID Formats + */ + private static final String NAMEID_VERSION_2_0 = "2.0"; + + /** + * URI: urn:oasis:names:tc:SAML:2.0:nameid-format:persistent + *

+ * Indicates that the content of the element is a persistent opaque identifier for a principal that is specific to + * an identity provider and a service provider or affiliation of service providers. Persistent name identifiers + * generated by identity providers MUST be constructed using pseudo-random values that have no + * discernible correspondence with the subject's actual identifier (for example, username). The intent is to + * create a non-public, pair-wise pseudonym to prevent the discovery of the subject's identity or activities. + * Persistent name identifier values MUST NOT exceed a length of 256 characters. + *

+ * The element's NameQualifier attribute, if present, MUST contain the unique identifier of the identity + * provider that generated the identifier (see Section 8.3.6). It MAY be omitted if the value can be derived + * from the context of the message containing the element, such as the issuer of a protocol message or an + * assertion containing the identifier in its subject. Note that a different system entity might later issue its own + * protocol message or assertion containing the identifier; the NameQualifier attribute does not change in + * this case, but MUST continue to identify the entity that originally created the identifier (and MUST NOT be + * omitted in such a case). + *

+ * The element's SPNameQualifier attribute, if present, MUST contain the unique identifier of the service + * provider or affiliation of providers for whom the identifier was generated (see Section 8.3.6). It MAY be + * omitted if the element is contained in a message intended only for consumption directly by the service + * provider, and the value would be the unique identifier of that service provider. + * The element's SPProvidedID attribute MUST contain the alternative identifier of the principal most + * recently set by the service provider or affiliation, if any (see Section 3.6). If no such identifier has been + * established, then the attribute MUST be omitted. + *

+ * Persistent identifiers are intended as a privacy protection mechanism; as such they MUST NOT be shared + * in clear text with providers other than the providers that have established the shared identifier. + * Furthermore, they MUST NOT appear in log files or similar locations without appropriate controls and + * protections. Deployments without such requirements are free to use other kinds of identifiers in their + * SAML exchanges, but MUST NOT overload this format with persistent but non-opaque values + *

+ * Note also that while persistent identifiers are typically used to reflect an account linking relationship + * between a pair of providers, a service provider is not obligated to recognize or make use of the long term + * nature of the persistent identifier or establish such a link. Such a "one-sided" relationship is not discernibly + * different and does not affect the behavior of the identity provider or any processing rules specific to + * persistent identifiers in the protocols defined in this specification. + *

+ * Finally, note that the NameQualifier and SPNameQualifier attributes indicate directionality of + * creation, but not of use. If a persistent identifier is created by a particular identity provider, the + * NameQualifier attribute value is permanently established at that time. If a service provider that receives + * such an identifier takes on the role of an identity provider and issues its own assertion containing that + * identifier, the NameQualifier attribute value does not change (and would of course not be omitted). It + * might alternatively choose to create its own persistent identifier to represent the principal and link the two + * values. This is a deployment decision. + */ + public static final String NAMEID_FORMAT_PERSISTENT = NAMEID_FORMAT_BASE.formatted(NAMEID_VERSION_2_0, "persistent"); + + /** + * URI: urn:oasis:names:tc:SAML:2.0:nameid-format:transient + *

+ * Indicates that the content of the element is an identifier with transient semantics and SHOULD be treated + * as an opaque and temporary value by the relying party. Transient identifier values MUST be generated in + * accordance with the rules for SAML identifiers (see Section 1.3.4), and MUST NOT exceed a length of + * 256 characters. + *

+ * The NameQualifier and SPNameQualifier attributes MAY be used to signify that the identifier + * represents a transient and temporary pair-wise identifier. In such a case, they MAY be omitted in + * accordance with the rules specified in Section 8.3.7. + */ + public static final String NAMEID_FORMAT_TRANSIENT = NAMEID_FORMAT_BASE.formatted(NAMEID_VERSION_2_0, "transient"); + + /** + * URI: urn:oasis:names:tc:SAML:2.0:nameid-format:kerberos + *

+ * Indicates that the content of the element is in the form of a Kerberos principal name using the format + * name[/instance]@REALM. The syntax, format and characters allowed for the name, instance, and + * realm are described in IETF RFC 1510 [RFC 1510]. + */ + public static final String NAMEID_FORMAT_KERBEROS = NAMEID_FORMAT_BASE.formatted(NAMEID_VERSION_2_0, "kerberos"); + + /** + * URI: urn:oasis:names:tc:SAML:2.0:nameid-format:entity + *

+ * Indicates that the content of the element is the identifier of an entity that provides SAML-based services + * (such as a SAML authority, requester, or responder) or is a participant in SAML profiles (such as a service + * provider supporting the browser SSO profile). Such an identifier can be used in the element to + * identify the issuer of a SAML request, response, or assertion, or within the element to make + * assertions about system entities that can issue SAML requests, responses, and assertions. It can also be + * used in other elements and attributes whose purpose is to identify a system entity in various protocol + * exchanges. + *

+ * The syntax of such an identifier is a URI of not more than 1024 characters in length. It is + * RECOMMENDED that a system entity use a URL containing its own domain name to identify itself. + * The NameQualifier, SPNameQualifier, and SPProvidedID attributes MUST be omitted. + */ + public static final String NAMEID_FORMAT_ENTITY = NAMEID_FORMAT_BASE.formatted(NAMEID_VERSION_2_0, "entity"); + + private SamlNameIdFormats() { + throw new UnsupportedOperationException("This is a utility class and cannot be instantiated"); + } +} diff --git a/server/src/main/java/org/cloudfoundry/identity/uaa/provider/saml/SamlRelyingPartyRegistrationRepositoryConfig.java b/server/src/main/java/org/cloudfoundry/identity/uaa/provider/saml/SamlRelyingPartyRegistrationRepositoryConfig.java index fa35a81302a..c07f79d40c1 100644 --- a/server/src/main/java/org/cloudfoundry/identity/uaa/provider/saml/SamlRelyingPartyRegistrationRepositoryConfig.java +++ b/server/src/main/java/org/cloudfoundry/identity/uaa/provider/saml/SamlRelyingPartyRegistrationRepositoryConfig.java @@ -2,7 +2,6 @@ import lombok.extern.slf4j.Slf4j; import org.cloudfoundry.identity.uaa.provider.SamlIdentityProviderDefinition; -import org.cloudfoundry.identity.uaa.saml.SamlKey; import org.cloudfoundry.identity.uaa.util.KeyWithCert; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Qualifier; @@ -15,7 +14,6 @@ import org.springframework.security.saml2.provider.service.web.DefaultRelyingPartyRegistrationResolver; import org.springframework.security.saml2.provider.service.web.RelyingPartyRegistrationResolver; -import java.security.cert.CertificateException; import java.util.ArrayList; import java.util.List; @@ -46,13 +44,10 @@ public SamlRelyingPartyRegistrationRepositoryConfig(@Qualifier("samlEntityID") S @Autowired @Bean - RelyingPartyRegistrationRepository relyingPartyRegistrationRepository(SamlIdentityProviderConfigurator samlIdentityProviderConfigurator) throws CertificateException { - - SamlKey activeSamlKey = samlConfigProps.getActiveSamlKey(); - KeyWithCert keyWithCert = new KeyWithCert(activeSamlKey.getKey(), activeSamlKey.getPassphrase(), activeSamlKey.getCertificate()); + RelyingPartyRegistrationRepository relyingPartyRegistrationRepository(SamlIdentityProviderConfigurator samlIdentityProviderConfigurator) { + List defaultKeysWithCerts = samlConfigProps.getKeysWithCerts(); List relyingPartyRegistrations = new ArrayList<>(); - String uaaWideSamlEntityIDAlias = samlConfigProps.getEntityIDAlias() != null ? samlConfigProps.getEntityIDAlias() : samlEntityID; @SuppressWarnings("java:S125") @@ -67,13 +62,13 @@ RelyingPartyRegistrationRepository relyingPartyRegistrationRepository(SamlIdenti // even when there are no SAML IDPs configured. // See relevant issue: https://github.com/spring-projects/spring-security/issues/11369 RelyingPartyRegistration exampleRelyingPartyRegistration = RelyingPartyRegistrationBuilder.buildRelyingPartyRegistration( - samlEntityID, samlSpNameID, keyWithCert, CLASSPATH_DUMMY_SAML_IDP_METADATA_XML, DEFAULT_REGISTRATION_ID, uaaWideSamlEntityIDAlias, samlConfigProps.getSignRequest()); + samlEntityID, samlSpNameID, defaultKeysWithCerts, CLASSPATH_DUMMY_SAML_IDP_METADATA_XML, DEFAULT_REGISTRATION_ID, uaaWideSamlEntityIDAlias, samlConfigProps.getSignRequest()); relyingPartyRegistrations.add(exampleRelyingPartyRegistration); for (SamlIdentityProviderDefinition samlIdentityProviderDefinition : bootstrapSamlIdentityProviderData.getIdentityProviderDefinitions()) { relyingPartyRegistrations.add( RelyingPartyRegistrationBuilder.buildRelyingPartyRegistration( - samlEntityID, samlSpNameID, keyWithCert, + samlEntityID, samlSpNameID, defaultKeysWithCerts, samlIdentityProviderDefinition.getMetaDataLocation(), samlIdentityProviderDefinition.getIdpEntityAlias(), uaaWideSamlEntityIDAlias, @@ -82,8 +77,8 @@ RelyingPartyRegistrationRepository relyingPartyRegistrationRepository(SamlIdenti } InMemoryRelyingPartyRegistrationRepository bootstrapRepo = new InMemoryRelyingPartyRegistrationRepository(relyingPartyRegistrations); - ConfiguratorRelyingPartyRegistrationRepository configuratorRepo = new ConfiguratorRelyingPartyRegistrationRepository(samlEntityID, uaaWideSamlEntityIDAlias, keyWithCert, samlIdentityProviderConfigurator); - DefaultRelyingPartyRegistrationRepository defaultRepo = new DefaultRelyingPartyRegistrationRepository(samlEntityID, uaaWideSamlEntityIDAlias, keyWithCert); + ConfiguratorRelyingPartyRegistrationRepository configuratorRepo = new ConfiguratorRelyingPartyRegistrationRepository(samlEntityID, uaaWideSamlEntityIDAlias, defaultKeysWithCerts, samlIdentityProviderConfigurator); + DefaultRelyingPartyRegistrationRepository defaultRepo = new DefaultRelyingPartyRegistrationRepository(samlEntityID, uaaWideSamlEntityIDAlias, defaultKeysWithCerts); return new DelegatingRelyingPartyRegistrationRepository(bootstrapRepo, configuratorRepo, defaultRepo); } diff --git a/server/src/main/java/org/cloudfoundry/identity/uaa/util/KeyWithCert.java b/server/src/main/java/org/cloudfoundry/identity/uaa/util/KeyWithCert.java index 40fd2c5b256..bb2d95c429f 100644 --- a/server/src/main/java/org/cloudfoundry/identity/uaa/util/KeyWithCert.java +++ b/server/src/main/java/org/cloudfoundry/identity/uaa/util/KeyWithCert.java @@ -1,5 +1,6 @@ package org.cloudfoundry.identity.uaa.util; +import lombok.Getter; import org.bouncycastle.asn1.pkcs.PrivateKeyInfo; import org.bouncycastle.cert.X509CertificateHolder; import org.bouncycastle.cert.jcajce.JcaX509CertificateConverter; @@ -10,6 +11,7 @@ import org.bouncycastle.openssl.PEMParser; import org.bouncycastle.openssl.jcajce.JcaPEMKeyConverter; import org.bouncycastle.openssl.jcajce.JcePEMDecryptorProviderBuilder; +import org.cloudfoundry.identity.uaa.saml.SamlKey; import java.io.ByteArrayInputStream; import java.io.IOException; @@ -23,12 +25,14 @@ import static org.cloudfoundry.identity.uaa.oauth.jwt.JwtAlgorithms.DEFAULT_RSA; +@Getter public class KeyWithCert { - private X509Certificate certificate; - private PrivateKey privateKey; + private final X509Certificate certificate; + private final PrivateKey privateKey; public KeyWithCert(String encodedCertificate) throws CertificateException { certificate = loadCertificate(encodedCertificate); + privateKey = null; } public KeyWithCert(String encodedPrivateKey, String passphrase, String encodedCertificate) throws CertificateException { @@ -37,7 +41,6 @@ public KeyWithCert(String encodedPrivateKey, String passphrase, String encodedCe } privateKey = loadPrivateKey(encodedPrivateKey, passphrase); - certificate = loadCertificate(encodedCertificate); if (!keysMatch(certificate.getPublicKey(), privateKey)) { @@ -45,12 +48,8 @@ public KeyWithCert(String encodedPrivateKey, String passphrase, String encodedCe } } - public X509Certificate getCertificate() { - return certificate; - } - - public PrivateKey getPrivateKey() { - return privateKey; + public KeyWithCert(SamlKey samlKey) throws CertificateException { + this(samlKey.getKey(), samlKey.getPassphrase(), samlKey.getCertificate()); } private boolean keysMatch(PublicKey publicKey, PrivateKey privateKey) { @@ -85,22 +84,21 @@ private static String getJavaAlgorithm(String publicKeyAlgorithm) { return publicKeyAlgorithm; } - private PrivateKey loadPrivateKey(String encodedPrivateKey, String passphrase) throws CertificateException { + private static PrivateKey loadPrivateKey(String encodedPrivateKey, String passphrase) throws CertificateException { PrivateKey privateKey = null; try (PEMParser pemParser = new PEMParser(new InputStreamReader(new ByteArrayInputStream(encodedPrivateKey.getBytes())))) { JcaPEMKeyConverter converter = new JcaPEMKeyConverter().setProvider(BouncyCastleFipsProvider.PROVIDER_NAME); Object object = pemParser.readObject(); - if (object instanceof PEMEncryptedKeyPair) { + if (object instanceof PEMEncryptedKeyPair pemEncryptedKeyPair) { PEMDecryptorProvider decProv = new JcePEMDecryptorProviderBuilder().build(passphrase.toCharArray()); - KeyPair keyPair = converter.getKeyPair(((PEMEncryptedKeyPair) object).decryptKeyPair(decProv)); + KeyPair keyPair = converter.getKeyPair(pemEncryptedKeyPair.decryptKeyPair(decProv)); privateKey = keyPair.getPrivate(); - } else if (object instanceof PEMKeyPair) { - KeyPair keyPair = converter.getKeyPair((PEMKeyPair) object); + } else if (object instanceof PEMKeyPair pemKeyPair) { + KeyPair keyPair = converter.getKeyPair(pemKeyPair); privateKey = keyPair.getPrivate(); - } else if (object instanceof PrivateKeyInfo) { - PrivateKeyInfo privateKeyInfo = (PrivateKeyInfo) object; + } else if (object instanceof PrivateKeyInfo privateKeyInfo) { privateKey = converter.getPrivateKey(privateKeyInfo); } } catch (IOException ex) { @@ -114,13 +112,13 @@ private PrivateKey loadPrivateKey(String encodedPrivateKey, String passphrase) t return privateKey; } - private X509Certificate loadCertificate(String encodedCertificate) throws CertificateException { + private static X509Certificate loadCertificate(String encodedCertificate) throws CertificateException { X509Certificate certificate; try (PEMParser pemParser = new PEMParser(new InputStreamReader(new ByteArrayInputStream(encodedCertificate.getBytes())))) { Object object = pemParser.readObject(); - if (object instanceof X509CertificateHolder) { - certificate = new JcaX509CertificateConverter().setProvider(BouncyCastleFipsProvider.PROVIDER_NAME).getCertificate((X509CertificateHolder) object); + if (object instanceof X509CertificateHolder x509CertificateHolder) { + certificate = new JcaX509CertificateConverter().setProvider(BouncyCastleFipsProvider.PROVIDER_NAME).getCertificate(x509CertificateHolder); } else { throw new CertificateException("Unsupported certificate type, not an X509CertificateHolder."); } diff --git a/server/src/test/java/org/cloudfoundry/identity/uaa/config/IdentityZoneConfigurationBootstrapTests.java b/server/src/test/java/org/cloudfoundry/identity/uaa/config/IdentityZoneConfigurationBootstrapTests.java index d5ba0acbfd7..aa88d68a572 100644 --- a/server/src/test/java/org/cloudfoundry/identity/uaa/config/IdentityZoneConfigurationBootstrapTests.java +++ b/server/src/test/java/org/cloudfoundry/identity/uaa/config/IdentityZoneConfigurationBootstrapTests.java @@ -145,8 +145,8 @@ void samlKeysAndSigningConfigs() throws Exception { assertThat(uaa.getConfig().getSamlConfig().getPrivateKey()).isEqualTo(SamlTestUtils.PROVIDER_PRIVATE_KEY); assertThat(uaa.getConfig().getSamlConfig().getPrivateKeyPassword()).isEqualTo(SamlTestUtils.PROVIDER_PRIVATE_KEY_PASSWORD); assertThat(uaa.getConfig().getSamlConfig().getCertificate()).isEqualTo(SamlTestUtils.PROVIDER_CERTIFICATE); - assertThat(uaa.getConfig().getSamlConfig().isWantAssertionSigned()).isEqualTo(false); - assertThat(uaa.getConfig().getSamlConfig().isRequestSigned()).isEqualTo(false); + assertThat(uaa.getConfig().getSamlConfig().isWantAssertionSigned()).isFalse(); + assertThat(uaa.getConfig().getSamlConfig().isRequestSigned()).isFalse(); } @Test @@ -220,11 +220,11 @@ void disableSelfServiceLinks() throws Exception { @Test void setHomeRedirect() throws Exception { - bootstrap.setHomeRedirect("http://some.redirect.com/redirect"); + bootstrap.setHomeRedirect("https://some.redirect.com/redirect"); bootstrap.afterPropertiesSet(); IdentityZone zone = provisioning.retrieve(IdentityZone.getUaaZoneId()); - assertThat(zone.getConfig().getLinks().getHomeRedirect()).isEqualTo("http://some.redirect.com/redirect"); + assertThat(zone.getConfig().getLinks().getHomeRedirect()).isEqualTo("https://some.redirect.com/redirect"); } @Test diff --git a/server/src/test/java/org/cloudfoundry/identity/uaa/provider/saml/ConfiguratorRelyingPartyRegistrationRepositoryTest.java b/server/src/test/java/org/cloudfoundry/identity/uaa/provider/saml/ConfiguratorRelyingPartyRegistrationRepositoryTest.java index 4079d720bd3..ce73ab3042e 100644 --- a/server/src/test/java/org/cloudfoundry/identity/uaa/provider/saml/ConfiguratorRelyingPartyRegistrationRepositoryTest.java +++ b/server/src/test/java/org/cloudfoundry/identity/uaa/provider/saml/ConfiguratorRelyingPartyRegistrationRepositoryTest.java @@ -1,19 +1,28 @@ package org.cloudfoundry.identity.uaa.provider.saml; +import org.bouncycastle.jcajce.provider.BouncyCastleFipsProvider; import org.cloudfoundry.identity.uaa.provider.SamlIdentityProviderDefinition; +import org.cloudfoundry.identity.uaa.saml.SamlKey; import org.cloudfoundry.identity.uaa.util.KeyWithCert; +import org.cloudfoundry.identity.uaa.util.KeyWithCertTest; import org.cloudfoundry.identity.uaa.zone.IdentityZone; import org.cloudfoundry.identity.uaa.zone.IdentityZoneConfiguration; import org.cloudfoundry.identity.uaa.zone.SamlConfig; +import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.junit.jupiter.params.provider.NullSource; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; import org.springframework.core.io.DefaultResourceLoader; import org.springframework.core.io.Resource; import org.springframework.core.io.ResourceLoader; import org.springframework.security.saml2.Saml2Exception; +import org.springframework.security.saml2.core.Saml2X509Credential; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; import org.springframework.util.FileCopyUtils; @@ -21,10 +30,11 @@ import java.io.InputStreamReader; import java.io.Reader; import java.io.UncheckedIOException; -import java.security.PrivateKey; -import java.security.cert.X509Certificate; +import java.security.Security; +import java.security.cert.CertificateException; import java.util.Arrays; import java.util.List; +import java.util.stream.Stream; import static java.nio.charset.StandardCharsets.UTF_8; import static org.assertj.core.api.Assertions.assertThat; @@ -44,15 +54,17 @@ class ConfiguratorRelyingPartyRegistrationRepositoryTest { private static final String ZONED_ENTITY_ID = "zoneDomain.entityId"; private static final String ZONE_SPECIFIC_ENTITY_ID = "zoneEntityId"; + private static final SamlKey samlKey1 = new SamlKey(KeyWithCertTest.encryptedKey, KeyWithCertTest.password, KeyWithCertTest.goodCert); + private static final SamlKey samlKey2 = new SamlKey(KeyWithCertTest.ecPrivateKey, KeyWithCertTest.password, KeyWithCertTest.ecCertificate); + private static KeyWithCert keyWithCert1; + private static KeyWithCert keyWithCert2; + @Mock private SamlIdentityProviderConfigurator configurator; @Mock private IdentityZone identityZone; - @Mock - private KeyWithCert keyWithCert; - @Mock private SamlIdentityProviderDefinition definition; @@ -64,16 +76,27 @@ class ConfiguratorRelyingPartyRegistrationRepositoryTest { private ConfiguratorRelyingPartyRegistrationRepository repository; + @BeforeAll + public static void addProvider() { + Security.addProvider(new BouncyCastleFipsProvider()); + try { + keyWithCert1 = new KeyWithCert(samlKey1); + keyWithCert2 = new KeyWithCert(samlKey2); + } catch (CertificateException e) { + throw new RuntimeException(e); + } + } + @BeforeEach void setUp() { - repository = spy(new ConfiguratorRelyingPartyRegistrationRepository(ENTITY_ID, ENTITY_ID_ALIAS, keyWithCert, - configurator)); + repository = spy(new ConfiguratorRelyingPartyRegistrationRepository(ENTITY_ID, ENTITY_ID_ALIAS, List.of(), configurator)); } @Test void constructorWithNullConfiguratorThrows() { + List emptyKeysWithCerts = List.of(); assertThatThrownBy(() -> new ConfiguratorRelyingPartyRegistrationRepository( - ENTITY_ID, ENTITY_ID_ALIAS, keyWithCert, null) + ENTITY_ID, ENTITY_ID_ALIAS, emptyKeysWithCerts, null) ).isInstanceOf(IllegalArgumentException.class); } @@ -83,8 +106,6 @@ void findByRegistrationIdWithMultipleInDb() { when(identityZone.isUaa()).thenReturn(true); when(identityZone.getConfig()).thenReturn(identityZoneConfiguration); when(identityZoneConfiguration.getSamlConfig()).thenReturn(samlConfig); - when(keyWithCert.getCertificate()).thenReturn(mock(X509Certificate.class)); - when(keyWithCert.getPrivateKey()).thenReturn(mock(PrivateKey.class)); //definition 1 when(definition.getIdpEntityAlias()).thenReturn(REGISTRATION_ID); @@ -127,9 +148,6 @@ void buildsCorrectRegistrationWhenMetadataXmlIsStored() { when(identityZone.isUaa()).thenReturn(true); when(identityZone.getConfig()).thenReturn(identityZoneConfiguration); when(identityZoneConfiguration.getSamlConfig()).thenReturn(samlConfig); - - when(keyWithCert.getCertificate()).thenReturn(mock(X509Certificate.class)); - when(keyWithCert.getPrivateKey()).thenReturn(mock(PrivateKey.class)); when(definition.getIdpEntityAlias()).thenReturn(REGISTRATION_ID); when(definition.getNameID()).thenReturn(NAME_ID); when(definition.getMetaDataLocation()).thenReturn(metadata); @@ -150,15 +168,80 @@ void buildsCorrectRegistrationWhenMetadataXmlIsStored() { .returns("https://idp-saml.ua3.int/simplesaml/saml2/idp/metadata.php", RelyingPartyRegistration.AssertingPartyDetails::getEntityId); } + @Test + void zoneWithCredentialsUsesCorrectValues() { + when(repository.retrieveZone()).thenReturn(identityZone); + when(identityZone.getConfig()).thenReturn(identityZoneConfiguration); + when(identityZoneConfiguration.getSamlConfig()).thenReturn(samlConfig); + when(samlConfig.getKeyList()).thenReturn(List.of(samlKey1, samlKey2)); + + when(definition.getIdpEntityAlias()).thenReturn(REGISTRATION_ID); + when(definition.getMetaDataLocation()).thenReturn("saml-sample-metadata.xml"); + when(configurator.getIdentityProviderDefinitionsForZone(identityZone)).thenReturn(List.of(definition)); + + RelyingPartyRegistration registration = repository.findByRegistrationId(REGISTRATION_ID); + + assertThat(registration.getDecryptionX509Credentials()) + .hasSize(1) + .first() + .extracting(Saml2X509Credential::getCertificate) + .isEqualTo(keyWithCert1.getCertificate()); + assertThat(registration.getSigningX509Credentials()) + .hasSize(2) + .first() + .extracting(Saml2X509Credential::getCertificate) + .isEqualTo(keyWithCert1.getCertificate()); + // Check the second element + assertThat(registration.getSigningX509Credentials()) + .element(1) + .extracting(Saml2X509Credential::getCertificate) + .isEqualTo(keyWithCert2.getCertificate()); + } + + private static Stream emptyList() { + return Stream.of(Arguments.of(List.of())); + } + + @ParameterizedTest + @NullSource + @MethodSource("emptyList") + void zoneWithoutCredentialsUsesDefault(List samlConfigKeys) { + repository = spy(new ConfiguratorRelyingPartyRegistrationRepository(ENTITY_ID, ENTITY_ID_ALIAS, List.of(keyWithCert1, keyWithCert2), configurator)); + + when(repository.retrieveZone()).thenReturn(identityZone); + when(identityZone.getConfig()).thenReturn(identityZoneConfiguration); + when(identityZoneConfiguration.getSamlConfig()).thenReturn(samlConfig); + when(samlConfig.getKeyList()).thenReturn(samlConfigKeys); + + when(definition.getIdpEntityAlias()).thenReturn(REGISTRATION_ID); + when(definition.getMetaDataLocation()).thenReturn("saml-sample-metadata.xml"); + when(configurator.getIdentityProviderDefinitionsForZone(identityZone)).thenReturn(List.of(definition)); + + RelyingPartyRegistration registration = repository.findByRegistrationId(REGISTRATION_ID); + + assertThat(registration.getDecryptionX509Credentials()) + .hasSize(1) + .first() + .extracting(Saml2X509Credential::getCertificate) + .isEqualTo(keyWithCert1.getCertificate()); + assertThat(registration.getSigningX509Credentials()) + .hasSize(2) + .first() + .extracting(Saml2X509Credential::getCertificate) + .isEqualTo(keyWithCert1.getCertificate()); + // Check the second element + assertThat(registration.getSigningX509Credentials()) + .element(1) + .extracting(Saml2X509Credential::getCertificate) + .isEqualTo(keyWithCert2.getCertificate()); + } + @Test void buildsCorrectRegistrationWhenMetadataLocationIsStored() { when(repository.retrieveZone()).thenReturn(identityZone); when(identityZone.isUaa()).thenReturn(true); when(identityZone.getConfig()).thenReturn(identityZoneConfiguration); when(identityZoneConfiguration.getSamlConfig()).thenReturn(samlConfig); - - when(keyWithCert.getCertificate()).thenReturn(mock(X509Certificate.class)); - when(keyWithCert.getPrivateKey()).thenReturn(mock(PrivateKey.class)); when(definition.getIdpEntityAlias()).thenReturn(REGISTRATION_ID_2); when(definition.getNameID()).thenReturn(NAME_ID); when(definition.getMetaDataLocation()).thenReturn("saml-sample-metadata.xml"); @@ -180,15 +263,12 @@ void buildsCorrectRegistrationWhenMetadataLocationIsStored() { @Test void fallsBackToUaaWideEntityIdWhenNoAlias() { - repository = spy(new ConfiguratorRelyingPartyRegistrationRepository(ENTITY_ID, null, keyWithCert, configurator)); + repository = spy(new ConfiguratorRelyingPartyRegistrationRepository(ENTITY_ID, null, List.of(), configurator)); when(repository.retrieveZone()).thenReturn(identityZone); when(identityZone.isUaa()).thenReturn(true); when(identityZone.getConfig()).thenReturn(identityZoneConfiguration); when(identityZoneConfiguration.getSamlConfig()).thenReturn(samlConfig); - - when(keyWithCert.getCertificate()).thenReturn(mock(X509Certificate.class)); - when(keyWithCert.getPrivateKey()).thenReturn(mock(PrivateKey.class)); when(definition.getIdpEntityAlias()).thenReturn(REGISTRATION_ID); when(definition.getNameID()).thenReturn(NAME_ID); when(definition.getMetaDataLocation()).thenReturn("saml-sample-metadata.xml"); @@ -213,9 +293,6 @@ void buildsCorrectRegistrationWhenZoneIdIsStored() { when(identityZone.getSubdomain()).thenReturn(ZONE_DOMAIN); when(identityZone.getConfig()).thenReturn(identityZoneConfiguration); when(identityZoneConfiguration.getSamlConfig()).thenReturn(samlConfig); - - when(keyWithCert.getCertificate()).thenReturn(mock(X509Certificate.class)); - when(keyWithCert.getPrivateKey()).thenReturn(mock(PrivateKey.class)); when(definition.getIdpEntityAlias()).thenReturn(REGISTRATION_ID); when(definition.getNameID()).thenReturn(NAME_ID); when(definition.getMetaDataLocation()).thenReturn("saml-sample-metadata.xml"); @@ -237,7 +314,7 @@ void buildsCorrectRegistrationWhenZoneIdIsStored() { @Test void buildsCorrectRegistrationWithZoneEntityIdSet() { - repository = spy(new ConfiguratorRelyingPartyRegistrationRepository(ENTITY_ID, null, keyWithCert, configurator)); + repository = spy(new ConfiguratorRelyingPartyRegistrationRepository(ENTITY_ID, null, List.of(), configurator)); when(repository.retrieveZone()).thenReturn(identityZone); when(identityZone.isUaa()).thenReturn(false); @@ -245,9 +322,6 @@ void buildsCorrectRegistrationWithZoneEntityIdSet() { when(identityZone.getConfig()).thenReturn(identityZoneConfiguration); when(identityZoneConfiguration.getSamlConfig()).thenReturn(samlConfig); when(samlConfig.getEntityID()).thenReturn(ZONE_SPECIFIC_ENTITY_ID); - - when(keyWithCert.getCertificate()).thenReturn(mock(X509Certificate.class)); - when(keyWithCert.getPrivateKey()).thenReturn(mock(PrivateKey.class)); when(definition.getIdpEntityAlias()).thenReturn(REGISTRATION_ID); when(definition.getNameID()).thenReturn(NAME_ID); when(definition.getMetaDataLocation()).thenReturn("saml-sample-metadata.xml"); diff --git a/server/src/test/java/org/cloudfoundry/identity/uaa/provider/saml/DefaultRelyingPartyRegistrationRepositoryTest.java b/server/src/test/java/org/cloudfoundry/identity/uaa/provider/saml/DefaultRelyingPartyRegistrationRepositoryTest.java index b837e105128..cef8da60338 100644 --- a/server/src/test/java/org/cloudfoundry/identity/uaa/provider/saml/DefaultRelyingPartyRegistrationRepositoryTest.java +++ b/server/src/test/java/org/cloudfoundry/identity/uaa/provider/saml/DefaultRelyingPartyRegistrationRepositoryTest.java @@ -1,21 +1,31 @@ package org.cloudfoundry.identity.uaa.provider.saml; +import org.bouncycastle.jcajce.provider.BouncyCastleFipsProvider; +import org.cloudfoundry.identity.uaa.saml.SamlKey; import org.cloudfoundry.identity.uaa.util.KeyWithCert; +import org.cloudfoundry.identity.uaa.util.KeyWithCertTest; import org.cloudfoundry.identity.uaa.zone.IdentityZone; import org.cloudfoundry.identity.uaa.zone.IdentityZoneConfiguration; import org.cloudfoundry.identity.uaa.zone.SamlConfig; +import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.junit.jupiter.params.provider.NullSource; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; +import org.springframework.security.saml2.core.Saml2X509Credential; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; -import java.security.PrivateKey; -import java.security.cert.X509Certificate; +import java.security.Security; +import java.security.cert.CertificateException; +import java.util.List; +import java.util.stream.Stream; import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.when; @@ -28,8 +38,10 @@ class DefaultRelyingPartyRegistrationRepositoryTest { private static final String REGISTRATION_ID = "registrationId"; private static final String REGISTRATION_ID_2 = "registrationId2"; - @Mock - private KeyWithCert mockKeyWithCert; + private static final SamlKey samlKey1 = new SamlKey(KeyWithCertTest.encryptedKey, KeyWithCertTest.password, KeyWithCertTest.goodCert); + private static final SamlKey samlKey2 = new SamlKey(KeyWithCertTest.ecPrivateKey, KeyWithCertTest.password, KeyWithCertTest.ecCertificate); + private static KeyWithCert keyWithCert1; + private static KeyWithCert keyWithCert2; @Mock private IdentityZone identityZone; @@ -42,15 +54,29 @@ class DefaultRelyingPartyRegistrationRepositoryTest { private DefaultRelyingPartyRegistrationRepository repository; + @BeforeAll + public static void addProvider() { + Security.addProvider(new BouncyCastleFipsProvider()); + try { + keyWithCert1 = new KeyWithCert(samlKey1); + keyWithCert2 = new KeyWithCert(samlKey2); + } catch (CertificateException e) { + throw new RuntimeException(e); + } + } + @BeforeEach void setUp() { - repository = spy(new DefaultRelyingPartyRegistrationRepository(ENTITY_ID, ENTITY_ID_ALIAS, mockKeyWithCert)); - when(mockKeyWithCert.getCertificate()).thenReturn(mock(X509Certificate.class)); - when(mockKeyWithCert.getPrivateKey()).thenReturn(mock(PrivateKey.class)); + repository = spy(new DefaultRelyingPartyRegistrationRepository(ENTITY_ID, ENTITY_ID_ALIAS, List.of())); } @Test void findByRegistrationId() { + when(repository.retrieveZone()).thenReturn(identityZone); + when(identityZone.isUaa()).thenReturn(true); + when(identityZone.getConfig()).thenReturn(identityZoneConfig); + when(identityZoneConfig.getSamlConfig()).thenReturn(samlConfig); + RelyingPartyRegistration registration = repository.findByRegistrationId(REGISTRATION_ID); assertThat(registration) @@ -110,8 +136,7 @@ void findByRegistrationIdForZoneWithoutConfig() { @Test void findByRegistrationId_NoAliasFailsOverToEntityId() { - repository = spy(new DefaultRelyingPartyRegistrationRepository(ENTITY_ID, null, mockKeyWithCert)); - + repository = spy(new DefaultRelyingPartyRegistrationRepository(ENTITY_ID, null, List.of())); when(repository.retrieveZone()).thenReturn(identityZone); when(identityZone.isUaa()).thenReturn(false); when(identityZone.getSubdomain()).thenReturn(ZONE_SUBDOMAIN); @@ -127,4 +152,63 @@ void findByRegistrationId_NoAliasFailsOverToEntityId() { .returns("{baseUrl}/saml/SSO/alias/testzone.entityId", RelyingPartyRegistration::getAssertionConsumerServiceLocation) .returns("{baseUrl}/saml/SingleLogout/alias/testzone.entityId", RelyingPartyRegistration::getSingleLogoutServiceResponseLocation); } + + @Test + void zoneWithCredentialsUsesCorrectValues() { + when(repository.retrieveZone()).thenReturn(identityZone); + when(identityZone.getConfig()).thenReturn(identityZoneConfig); + when(identityZoneConfig.getSamlConfig()).thenReturn(samlConfig); + when(samlConfig.getKeyList()).thenReturn(List.of(samlKey1, samlKey2)); + + RelyingPartyRegistration registration = repository.findByRegistrationId(REGISTRATION_ID); + + assertThat(registration.getDecryptionX509Credentials()) + .hasSize(1) + .first() + .extracting(Saml2X509Credential::getCertificate) + .isEqualTo(keyWithCert1.getCertificate()); + assertThat(registration.getSigningX509Credentials()) + .hasSize(2) + .first() + .extracting(Saml2X509Credential::getCertificate) + .isEqualTo(keyWithCert1.getCertificate()); + // Check the second element + assertThat(registration.getSigningX509Credentials()) + .element(1) + .extracting(Saml2X509Credential::getCertificate) + .isEqualTo(keyWithCert2.getCertificate()); + } + + private static Stream emptyList() { + return Stream.of(Arguments.of(List.of())); + } + + @ParameterizedTest + @NullSource + @MethodSource("emptyList") + void zoneWithoutCredentialsUsesDefault(List samlConfigKeys) { + repository = spy(new DefaultRelyingPartyRegistrationRepository(ENTITY_ID, null, List.of(keyWithCert1, keyWithCert2))); + when(repository.retrieveZone()).thenReturn(identityZone); + when(identityZone.getConfig()).thenReturn(identityZoneConfig); + when(identityZoneConfig.getSamlConfig()).thenReturn(samlConfig); + when(samlConfig.getKeyList()).thenReturn(samlConfigKeys); + + RelyingPartyRegistration registration = repository.findByRegistrationId(REGISTRATION_ID); + + assertThat(registration.getDecryptionX509Credentials()) + .hasSize(1) + .first() + .extracting(Saml2X509Credential::getCertificate) + .isEqualTo(keyWithCert1.getCertificate()); + assertThat(registration.getSigningX509Credentials()) + .hasSize(2) + .first() + .extracting(Saml2X509Credential::getCertificate) + .isEqualTo(keyWithCert1.getCertificate()); + // Check the second element + assertThat(registration.getSigningX509Credentials()) + .element(1) + .extracting(Saml2X509Credential::getCertificate) + .isEqualTo(keyWithCert2.getCertificate()); + } } \ No newline at end of file diff --git a/server/src/test/java/org/cloudfoundry/identity/uaa/provider/saml/RelyingPartyRegistrationBuilderTest.java b/server/src/test/java/org/cloudfoundry/identity/uaa/provider/saml/RelyingPartyRegistrationBuilderTest.java index 00e38908fcb..da6c2669faa 100644 --- a/server/src/test/java/org/cloudfoundry/identity/uaa/provider/saml/RelyingPartyRegistrationBuilderTest.java +++ b/server/src/test/java/org/cloudfoundry/identity/uaa/provider/saml/RelyingPartyRegistrationBuilderTest.java @@ -18,6 +18,7 @@ import java.io.UncheckedIOException; import java.security.PrivateKey; import java.security.cert.X509Certificate; +import java.util.List; import static java.nio.charset.StandardCharsets.UTF_8; import static org.assertj.core.api.Assertions.assertThat; @@ -42,7 +43,7 @@ void buildsRelyingPartyRegistrationFromLocation() { when(mockKeyWithCert.getPrivateKey()).thenReturn(mock(PrivateKey.class)); RelyingPartyRegistration registration = RelyingPartyRegistrationBuilder - .buildRelyingPartyRegistration(ENTITY_ID, NAME_ID, mockKeyWithCert, "saml-sample-metadata.xml", REGISTRATION_ID, ENTITY_ID_ALIAS, true); + .buildRelyingPartyRegistration(ENTITY_ID, NAME_ID, List.of(mockKeyWithCert), "saml-sample-metadata.xml", REGISTRATION_ID, ENTITY_ID_ALIAS, true); assertThat(registration) .returns(REGISTRATION_ID, RelyingPartyRegistration::getRegistrationId) .returns(ENTITY_ID, RelyingPartyRegistration::getEntityId) @@ -63,7 +64,7 @@ void buildsRelyingPartyRegistrationFromXML() { String metadataXml = loadResouceAsString("saml-sample-metadata.xml"); RelyingPartyRegistration registration = RelyingPartyRegistrationBuilder - .buildRelyingPartyRegistration(ENTITY_ID, NAME_ID, mockKeyWithCert, metadataXml, REGISTRATION_ID, ENTITY_ID_ALIAS,false); + .buildRelyingPartyRegistration(ENTITY_ID, NAME_ID, List.of(mockKeyWithCert), metadataXml, REGISTRATION_ID, ENTITY_ID_ALIAS, false); assertThat(registration) .returns(REGISTRATION_ID, RelyingPartyRegistration::getRegistrationId) @@ -81,9 +82,11 @@ void buildsRelyingPartyRegistrationFromXML() { @Test void failsWithInvalidXML() { String metadataXml = "\ninvalid xml"; + List keyList = List.of(mockKeyWithCert); + assertThatThrownBy(() -> RelyingPartyRegistrationBuilder.buildRelyingPartyRegistration(ENTITY_ID, NAME_ID, - mockKeyWithCert, metadataXml, REGISTRATION_ID, ENTITY_ID_ALIAS, true)) + keyList, metadataXml, REGISTRATION_ID, ENTITY_ID_ALIAS, true)) .isInstanceOf(Saml2Exception.class) .hasMessageContaining("Unsupported element"); } diff --git a/server/src/test/java/org/cloudfoundry/identity/uaa/provider/saml/SamlMetadataEndpointTest.java b/server/src/test/java/org/cloudfoundry/identity/uaa/provider/saml/SamlMetadataEndpointTest.java index 4a83a429213..f466cfd40f9 100644 --- a/server/src/test/java/org/cloudfoundry/identity/uaa/provider/saml/SamlMetadataEndpointTest.java +++ b/server/src/test/java/org/cloudfoundry/identity/uaa/provider/saml/SamlMetadataEndpointTest.java @@ -11,15 +11,21 @@ import org.mockito.junit.jupiter.MockitoExtension; import org.springframework.http.HttpHeaders; import org.springframework.http.ResponseEntity; +import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; -import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding; +import org.springframework.security.saml2.provider.service.web.RelyingPartyRegistrationResolver; import org.xmlunit.assertj.XmlAssert; import java.util.List; import static org.assertj.core.api.Assertions.assertThat; import static org.cloudfoundry.identity.uaa.provider.saml.Saml2TestUtils.xmlNamespaces; +import static org.cloudfoundry.identity.uaa.provider.saml.SamlNameIdFormats.NAMEID_FORMAT_EMAIL; +import static org.cloudfoundry.identity.uaa.provider.saml.SamlNameIdFormats.NAMEID_FORMAT_PERSISTENT; +import static org.cloudfoundry.identity.uaa.provider.saml.SamlNameIdFormats.NAMEID_FORMAT_TRANSIENT; +import static org.cloudfoundry.identity.uaa.provider.saml.SamlNameIdFormats.NAMEID_FORMAT_UNSPECIFIED; +import static org.cloudfoundry.identity.uaa.provider.saml.SamlNameIdFormats.NAMEID_FORMAT_X509SUBJECT; import static org.cloudfoundry.identity.uaa.provider.saml.TestSaml2X509Credentials.relyingPartySigningCredential; import static org.cloudfoundry.identity.uaa.provider.saml.TestSaml2X509Credentials.relyingPartyVerifyingCredential; import static org.mockito.Mockito.spy; @@ -36,7 +42,7 @@ class SamlMetadataEndpointTest { SamlMetadataEndpoint endpoint; @Mock - RelyingPartyRegistrationRepository repository; + RelyingPartyRegistrationResolver resolver; @Mock IdentityZoneManager identityZoneManager; @Mock @@ -48,9 +54,12 @@ class SamlMetadataEndpointTest { @Mock SamlConfig samlConfig; + MockHttpServletRequest request; + @BeforeEach void setUp() { - endpoint = spy(new SamlMetadataEndpoint(repository, identityZoneManager)); + request = new MockHttpServletRequest(); + endpoint = spy(new SamlMetadataEndpoint(resolver, identityZoneManager)); when(registration.getEntityId()).thenReturn(ENTITY_ID); when(registration.getSigningX509Credentials()).thenReturn(List.of(relyingPartySigningCredential())); when(registration.getDecryptionX509Credentials()).thenReturn(List.of(relyingPartyVerifyingCredential())); @@ -63,46 +72,53 @@ void setUp() { @Test void testDefaultFileName() { - when(repository.findByRegistrationId(REGISTRATION_ID)).thenReturn(registration); + when(resolver.resolve(request, REGISTRATION_ID)).thenReturn(registration); - ResponseEntity response = endpoint.metadataEndpoint(REGISTRATION_ID); + ResponseEntity response = endpoint.metadataEndpoint(request, REGISTRATION_ID); assertThat(response.getHeaders().getFirst(HttpHeaders.CONTENT_DISPOSITION)) .isEqualTo("attachment; filename=\"saml-sp.xml\"; filename*=UTF-8''saml-sp.xml"); } @Test void testZonedFileName() { - when(repository.findByRegistrationId(REGISTRATION_ID)).thenReturn(registration); + when(resolver.resolve(request, REGISTRATION_ID)).thenReturn(registration); when(identityZone.isUaa()).thenReturn(false); when(identityZone.getSubdomain()).thenReturn(TEST_ZONE); when(endpoint.retrieveZone()).thenReturn(identityZone); - ResponseEntity response = endpoint.metadataEndpoint(REGISTRATION_ID); + ResponseEntity response = endpoint.metadataEndpoint(request, REGISTRATION_ID); assertThat(response.getHeaders().getFirst(HttpHeaders.CONTENT_DISPOSITION)) .isEqualTo("attachment; filename=\"saml-%1$s-sp.xml\"; filename*=UTF-8''saml-%1$s-sp.xml".formatted(TEST_ZONE)); } @Test void testDefaultMetadataXml() { - when(repository.findByRegistrationId(REGISTRATION_ID)).thenReturn(registration); + when(resolver.resolve(request, REGISTRATION_ID)).thenReturn(registration); when(samlConfig.isWantAssertionSigned()).thenReturn(true); when(samlConfig.isRequestSigned()).thenReturn(true); - ResponseEntity response = endpoint.metadataEndpoint(REGISTRATION_ID); + ResponseEntity response = endpoint.metadataEndpoint(request, REGISTRATION_ID); XmlAssert xmlAssert = XmlAssert.assertThat(response.getBody()).withNamespaceContext(xmlNamespaces()); xmlAssert.valueByXPath("//md:EntityDescriptor/@entityID").isEqualTo(ENTITY_ID); + xmlAssert.valueByXPath("//md:EntityDescriptor/@ID").isEqualTo(ENTITY_ID); xmlAssert.valueByXPath("//md:SPSSODescriptor/@AuthnRequestsSigned").isEqualTo(true); xmlAssert.valueByXPath("//md:SPSSODescriptor/@WantAssertionsSigned").isEqualTo(true); - xmlAssert.valueByXPath("//md:AssertionConsumerService/@Location").isEqualTo(ASSERTION_CONSUMER_SERVICE); + xmlAssert.nodesByXPath("//md:AssertionConsumerService") + .extractingAttribute("Location") + .containsExactly(ASSERTION_CONSUMER_SERVICE); + xmlAssert.nodesByXPath("//md:NameIDFormat") + .extractingText() + .containsExactlyInAnyOrder(NAMEID_FORMAT_EMAIL, NAMEID_FORMAT_PERSISTENT, + NAMEID_FORMAT_TRANSIENT, NAMEID_FORMAT_UNSPECIFIED, NAMEID_FORMAT_X509SUBJECT); } @Test void testDefaultMetadataXml_alternateValues() { - when(repository.findByRegistrationId(REGISTRATION_ID)).thenReturn(registration); + when(resolver.resolve(request, REGISTRATION_ID)).thenReturn(registration); when(samlConfig.isWantAssertionSigned()).thenReturn(false); when(samlConfig.isRequestSigned()).thenReturn(false); - ResponseEntity response = endpoint.metadataEndpoint(REGISTRATION_ID); + ResponseEntity response = endpoint.metadataEndpoint(request, REGISTRATION_ID); XmlAssert xmlAssert = XmlAssert.assertThat(response.getBody()).withNamespaceContext(xmlNamespaces()); xmlAssert.valueByXPath("//md:SPSSODescriptor/@AuthnRequestsSigned").isEqualTo(false); xmlAssert.valueByXPath("//md:SPSSODescriptor/@WantAssertionsSigned").isEqualTo(false); diff --git a/server/src/test/java/org/cloudfoundry/identity/uaa/provider/saml/SamlRelyingPartyRegistrationRepositoryConfigTest.java b/server/src/test/java/org/cloudfoundry/identity/uaa/provider/saml/SamlRelyingPartyRegistrationRepositoryConfigTest.java index 02a191ca7ea..512435384fe 100644 --- a/server/src/test/java/org/cloudfoundry/identity/uaa/provider/saml/SamlRelyingPartyRegistrationRepositoryConfigTest.java +++ b/server/src/test/java/org/cloudfoundry/identity/uaa/provider/saml/SamlRelyingPartyRegistrationRepositoryConfigTest.java @@ -1,7 +1,7 @@ package org.cloudfoundry.identity.uaa.provider.saml; import org.bouncycastle.jcajce.provider.BouncyCastleFipsProvider; -import org.cloudfoundry.identity.uaa.saml.SamlKey; +import org.cloudfoundry.identity.uaa.util.KeyWithCert; import org.cloudfoundry.identity.uaa.util.KeyWithCertTest; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.BeforeEach; @@ -15,6 +15,7 @@ import java.security.Security; import java.security.cert.CertificateException; +import java.util.List; import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.Mockito.when; @@ -36,31 +37,26 @@ class SamlRelyingPartyRegistrationRepositoryConfigTest { @Mock SamlIdentityProviderConfigurator samlIdentityProviderConfigurator; - @Mock - SamlKey activeSamlKey; - @BeforeAll public static void addProvider() { Security.addProvider(new BouncyCastleFipsProvider()); } @BeforeEach - public void setup() { - when(samlConfigProps.getActiveSamlKey()).thenReturn(activeSamlKey); - when(activeSamlKey.getKey()).thenReturn(KEY); - when(activeSamlKey.getPassphrase()).thenReturn(PASSPHRASE); - when(activeSamlKey.getCertificate()).thenReturn(CERT); + public void setup() throws CertificateException { + KeyWithCert keyWithCert = new KeyWithCert(KEY, PASSPHRASE, CERT); + when(samlConfigProps.getKeysWithCerts()).thenReturn(List.of(keyWithCert)); } @Test - void relyingPartyRegistrationRepository() throws CertificateException { + void relyingPartyRegistrationRepository() { SamlRelyingPartyRegistrationRepositoryConfig config = new SamlRelyingPartyRegistrationRepositoryConfig(ENTITY_ID, samlConfigProps, bootstrapSamlIdentityProviderData, NAME_ID); RelyingPartyRegistrationRepository repository = config.relyingPartyRegistrationRepository(samlIdentityProviderConfigurator); assertThat(repository).isNotNull(); } @Test - void relyingPartyRegistrationResolver() throws CertificateException { + void relyingPartyRegistrationResolver() { SamlRelyingPartyRegistrationRepositoryConfig config = new SamlRelyingPartyRegistrationRepositoryConfig(ENTITY_ID, samlConfigProps, bootstrapSamlIdentityProviderData, NAME_ID); RelyingPartyRegistrationRepository repository = config.relyingPartyRegistrationRepository(samlIdentityProviderConfigurator); RelyingPartyRegistrationResolver resolver = config.relyingPartyRegistrationResolver(repository); @@ -69,7 +65,7 @@ void relyingPartyRegistrationResolver() throws CertificateException { } @Test - void buildsRegistrationForExample() throws CertificateException { + void buildsRegistrationForExample() { SamlRelyingPartyRegistrationRepositoryConfig config = new SamlRelyingPartyRegistrationRepositoryConfig(ENTITY_ID, samlConfigProps, bootstrapSamlIdentityProviderData, NAME_ID); RelyingPartyRegistrationRepository repository = config.relyingPartyRegistrationRepository(samlIdentityProviderConfigurator); RelyingPartyRegistration registration = repository.findByRegistrationId("example"); diff --git a/server/src/test/java/org/cloudfoundry/identity/uaa/provider/saml/idp/SamlTestUtils.java b/server/src/test/java/org/cloudfoundry/identity/uaa/provider/saml/idp/SamlTestUtils.java index 7a4a529bf48..a9906f529ee 100644 --- a/server/src/test/java/org/cloudfoundry/identity/uaa/provider/saml/idp/SamlTestUtils.java +++ b/server/src/test/java/org/cloudfoundry/identity/uaa/provider/saml/idp/SamlTestUtils.java @@ -5,25 +5,6 @@ import org.cloudfoundry.identity.uaa.login.AddBcProvider; import org.cloudfoundry.identity.uaa.provider.SamlIdentityProviderDefinition; import org.cloudfoundry.identity.uaa.zone.IdentityZone; -import org.w3c.dom.Document; -import org.w3c.dom.NodeList; -import org.xml.sax.InputSource; -import org.xml.sax.SAXException; - -import javax.xml.parsers.DocumentBuilderFactory; -import javax.xml.parsers.ParserConfigurationException; -import javax.xml.xpath.XPath; -import javax.xml.xpath.XPathConstants; -import javax.xml.xpath.XPathExpression; -import javax.xml.xpath.XPathExpressionException; -import javax.xml.xpath.XPathFactory; -import java.io.IOException; -import java.io.StringReader; -import java.util.LinkedList; -import java.util.List; - -import static org.assertj.core.api.Assertions.assertThat; -//import static org.opensaml.common.xml.SAMLConstants.SAML20P_NS; // TODO: this class seems to be used more broadly than what its location indicates (uaa as saml idp); need to move it // also remove unused code in here @@ -101,28 +82,4 @@ public static SamlIdentityProviderDefinition createLocalSamlIdpDefinition(String } return def; } - - public static List getCertificates(String metadata, String type) throws Exception { - Document doc = getMetadataDoc(metadata); - NodeList nodeList = evaluateXPathExpression(doc, "//*[local-name()='KeyDescriptor' and @*[local-name() = 'use']='" + type + "']//*[local-name()='X509Certificate']/text()"); - assertThat(nodeList).isNotNull(); - List result = new LinkedList<>(); - for (int i = 0; i < nodeList.getLength(); i++) { - result.add(nodeList.item(i).getNodeValue().replace("\n", "")); - } - return result; - } - - public static NodeList evaluateXPathExpression(Document doc, String xpath) throws XPathExpressionException { - XPath xPath = XPathFactory.newInstance().newXPath(); - XPathExpression expression = xPath.compile(xpath); - return (NodeList) expression.evaluate(doc, XPathConstants.NODESET); - } - - public static Document getMetadataDoc(String metadata) throws SAXException, IOException, ParserConfigurationException { - DocumentBuilderFactory documentBuilderFactory = DocumentBuilderFactory.newInstance(); - documentBuilderFactory.setNamespaceAware(false); - InputSource is = new InputSource(new StringReader(metadata)); - return documentBuilderFactory.newDocumentBuilder().parse(is); - } } diff --git a/uaa/src/test/java/org/cloudfoundry/identity/uaa/integration/LoginServerSecurityIntegrationTests.java b/uaa/src/test/java/org/cloudfoundry/identity/uaa/integration/LoginServerSecurityIntegrationTests.java index 2175f70fbce..1e4b7efdc29 100644 --- a/uaa/src/test/java/org/cloudfoundry/identity/uaa/integration/LoginServerSecurityIntegrationTests.java +++ b/uaa/src/test/java/org/cloudfoundry/identity/uaa/integration/LoginServerSecurityIntegrationTests.java @@ -1,4 +1,5 @@ -/******************************************************************************* +/* + * ***************************************************************************** * Cloud Foundry * Copyright (c) [2009-2016] Pivotal Software, Inc. All Rights Reserved. * @@ -49,14 +50,8 @@ import java.util.Map; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.fail; import static org.cloudfoundry.identity.uaa.constants.OriginKeys.LOGIN_SERVER; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertNotEquals; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertNull; -import static org.junit.Assert.assertTrue; -import static org.junit.Assert.fail; /** * Integration test to verify that the Login Server authentication channel is @@ -69,7 +64,6 @@ public class LoginServerSecurityIntegrationTests { private final String JOE = "joe" + new RandomValueStringGenerator().generate().toLowerCase(); private final String LOGIN_SERVER_JOE = "ls_joe" + new RandomValueStringGenerator().generate().toLowerCase(); - private final String userEndpoint = "/Users"; @Rule public ServerRunning serverRunning = ServerRunning.isRunning(); private ScimUser joe; @@ -81,7 +75,7 @@ public class LoginServerSecurityIntegrationTests { @Rule public OAuth2ContextSetup context = OAuth2ContextSetup.withTestAccounts(serverRunning, testAccountSetup); - private final MultiValueMap params = new LinkedMultiValueMap(); + private final MultiValueMap params = new LinkedMultiValueMap<>(); private final HttpHeaders headers = new HttpHeaders(); private ScimUser userForLoginServer; @@ -130,6 +124,7 @@ public void setUpUserAccounts() { userForLoginServer.setVerified(true); userForLoginServer.setOrigin(LOGIN_SERVER); + String userEndpoint = "/Users"; ResponseEntity newuser = client.postForEntity(serverRunning.getUrl(userEndpoint), user, ScimUser.class); userForLoginServer = client.postForEntity(serverRunning.getUrl(userEndpoint), userForLoginServer, ScimUser.class).getBody(); @@ -140,18 +135,17 @@ public void setUpUserAccounts() { PasswordChangeRequest change = new PasswordChangeRequest(); change.setPassword("Passwo3d"); - HttpHeaders headers = new HttpHeaders(); + headers.clear(); ResponseEntity result = client .exchange(serverRunning.getUrl(userEndpoint) + "/{id}/password", - HttpMethod.PUT, new HttpEntity(change, headers), + HttpMethod.PUT, new HttpEntity<>(change, headers), Void.class, joe.getId()); - assertEquals(HttpStatus.OK, result.getStatusCode()); + assertThat(result.getStatusCode()).isEqualTo(HttpStatus.OK); // The implicit grant for cf requires extra parameters in the // authorization request context.setParameters(Collections.singletonMap("credentials", testAccounts.getJsonCredentials(joe.getUserName(), "Passwo3d"))); - } @Test @@ -160,10 +154,11 @@ public void testAuthenticateReturnsUserID() { params.set("username", JOE); params.set("password", "Passwo3d"); ResponseEntity response = serverRunning.postForMap("/authenticate", params, headers); - assertEquals(HttpStatus.OK, response.getStatusCode()); - assertEquals(JOE, response.getBody().get("username")); - assertEquals(OriginKeys.UAA, response.getBody().get(OriginKeys.ORIGIN)); - assertTrue(StringUtils.hasText((String) response.getBody().get("user_id"))); + assertThat(response.getStatusCode()).isEqualTo(HttpStatus.OK); + assertThat(response.getBody()) + .containsEntry("username", JOE) + .containsEntry(OriginKeys.ORIGIN, OriginKeys.UAA); + assertThat(StringUtils.hasText((String) response.getBody().get("user_id"))).isTrue(); } @Test @@ -172,10 +167,10 @@ public void testAuthenticateMarissaReturnsUserID() { params.set("username", testAccounts.getUserName()); params.set("password", testAccounts.getPassword()); ResponseEntity response = serverRunning.postForMap("/authenticate", params, headers); - assertEquals(HttpStatus.OK, response.getStatusCode()); - assertEquals("marissa", response.getBody().get("username")); - assertEquals(OriginKeys.UAA, response.getBody().get(OriginKeys.ORIGIN)); - assertTrue(StringUtils.hasText((String) response.getBody().get("user_id"))); + assertThat(response.getStatusCode()).isEqualTo(HttpStatus.OK); + assertThat(response.getBody()).containsEntry("username", "marissa") + .containsEntry(OriginKeys.ORIGIN, OriginKeys.UAA); + assertThat(StringUtils.hasText((String) response.getBody().get("user_id"))).isTrue(); } @Test @@ -184,7 +179,7 @@ public void testAuthenticateMarissaFails() { params.set("username", testAccounts.getUserName()); params.set("password", ""); ResponseEntity response = serverRunning.postForMap("/authenticate", params, headers); - assertEquals(HttpStatus.UNAUTHORIZED, response.getStatusCode()); + assertThat(response.getStatusCode()).isEqualTo(HttpStatus.UNAUTHORIZED); } @Test @@ -192,10 +187,10 @@ public void testAuthenticateDoesNotReturnsUserID() { params.set("username", testAccounts.getUserName()); params.set("password", testAccounts.getPassword()); ResponseEntity response = serverRunning.postForMap("/authenticate", params, headers); - assertEquals(HttpStatus.OK, response.getStatusCode()); - assertEquals("marissa", response.getBody().get("username")); - assertNull(response.getBody().get(OriginKeys.ORIGIN)); - assertNull(response.getBody().get("user_id")); + assertThat(response.getStatusCode()).isEqualTo(HttpStatus.OK); + assertThat(response.getBody()).containsEntry("username", "marissa") + .doesNotContainKey(OriginKeys.ORIGIN) + .doesNotContainKey("user_id"); } @Test @@ -212,9 +207,9 @@ public void testLoginServerCanAuthenticateUserForCf() { } @SuppressWarnings("rawtypes") ResponseEntity response = serverRunning.postForMap(serverRunning.getAuthorizationUri(), params, headers); - assertEquals(HttpStatus.FOUND, response.getStatusCode()); + assertThat(response.getStatusCode()).isEqualTo(HttpStatus.FOUND); String results = response.getHeaders().getLocation().toString(); - assertTrue("There should be an access token: " + results, results.contains("access_token")); + assertThat(results).as("There should be an access token: " + results).contains("access_token"); } @Test @@ -230,11 +225,11 @@ public void testLoginServerCanAuthenticateUserForAuthorizationCode() { if (response.getStatusCode().is4xxClientError()) { fail(response.getBody().toString()); } else { - assertEquals(HttpStatus.OK, response.getStatusCode()); + assertThat(response.getStatusCode()).isEqualTo(HttpStatus.OK); @SuppressWarnings("unchecked") Map results = response.getBody(); // The approval page messaging response - assertNotNull("There should be scopes: " + results, results.get("scopes")); + assertThat(results).as("There should be scopes: " + results).containsKey("scopes"); } } @@ -250,11 +245,11 @@ public void testLoginServerCanAuthenticateUserWithIDForAuthorizationCode() { if (response.getStatusCode().is4xxClientError()) { fail(response.getBody().toString()); } else { - assertEquals(HttpStatus.OK, response.getStatusCode()); + assertThat(response.getStatusCode()).isEqualTo(HttpStatus.OK); @SuppressWarnings("unchecked") Map results = response.getBody(); // The approval page messaging response - assertNotNull("There should be scopes: " + results, results.get("scopes")); + assertThat(results).as("There should be scopes: " + results).containsKey("scopes"); } } @@ -265,10 +260,10 @@ public void testMissingUserInfoIsError() { params.remove("username"); @SuppressWarnings("rawtypes") ResponseEntity response = serverRunning.postForMap(serverRunning.getAuthorizationUri(), params, headers); - assertEquals(HttpStatus.UNAUTHORIZED, response.getStatusCode()); + assertThat(response.getStatusCode()).isEqualTo(HttpStatus.UNAUTHORIZED); @SuppressWarnings("unchecked") Map results = response.getBody(); - assertTrue("There should be an error: " + results, results.containsKey("error")); + assertThat(results).as("There should be an error: " + results).containsKey("error"); } @Test @@ -282,10 +277,10 @@ public void testMissingUsernameIsError() { params.set("given_name", "Mabel"); @SuppressWarnings("rawtypes") ResponseEntity response = serverRunning.postForMap(serverRunning.getAuthorizationUri(), params, headers); - assertEquals(HttpStatus.UNAUTHORIZED, response.getStatusCode()); + assertThat(response.getStatusCode()).isEqualTo(HttpStatus.UNAUTHORIZED); @SuppressWarnings("unchecked") Map results = response.getBody(); - assertTrue("There should be an error: " + results, results.containsKey("error")); + assertThat(results).as("There should be an error: " + results).containsKey("error"); } @Test @@ -306,9 +301,9 @@ public void testWrongUsernameIsErrorAddNewEnabled() { @SuppressWarnings("rawtypes") ResponseEntity response = serverRunning.postForMap(serverRunning.getAuthorizationUri(), params, headers); // add_new:true user accounts are automatically provisioned. - assertEquals(HttpStatus.FOUND, response.getStatusCode()); + assertThat(response.getStatusCode()).isEqualTo(HttpStatus.FOUND); String results = response.getHeaders().getLocation().getFragment(); - assertTrue("There should be an access token: " + results, results.contains("access_token")); + assertThat(results).as("There should be an access token: " + results).contains("access_token"); } @Test @@ -328,10 +323,10 @@ public void testWrongUsernameIsErrorAddNewDisabled() { } @SuppressWarnings("rawtypes") ResponseEntity response = serverRunning.postForMap(serverRunning.getAuthorizationUri(), params, headers); - assertEquals(HttpStatus.UNAUTHORIZED, response.getStatusCode()); + assertThat(response.getStatusCode()).isEqualTo(HttpStatus.UNAUTHORIZED); @SuppressWarnings("unchecked") Map results = response.getBody(); - assertTrue("There should be an error: " + results, results.containsKey("error")); + assertThat(results).as("There should be an error: " + results).containsKey("error"); } @Test @@ -348,13 +343,14 @@ public void testAddNewUserWithWrongEmailFormat() { params.set(UaaAuthenticationDetails.ADD_NEW, "true"); @SuppressWarnings("rawtypes") ResponseEntity response = serverRunning.postForMap(serverRunning.getAuthorizationUri(), params, headers); - assertNotNull(response); - assertNotEquals(HttpStatus.INTERNAL_SERVER_ERROR, response.getStatusCode()); - assertEquals(HttpStatus.FOUND, response.getStatusCode()); + assertThat(response).isNotNull() + .extracting(ResponseEntity::getStatusCode) + .isNotEqualTo(HttpStatus.INTERNAL_SERVER_ERROR) + .isEqualTo(HttpStatus.FOUND); @SuppressWarnings("unchecked") Map results = response.getBody(); if (results != null) { - assertFalse("There should not be an error: " + results, results.containsKey("error")); + assertThat(results).as("There should not be an error: " + results).doesNotContainKey("error"); } } @@ -362,7 +358,7 @@ public void testAddNewUserWithWrongEmailFormat() { @OAuth2ContextConfiguration(LoginClient.class) public void testLoginServerCfPasswordToken() { ImplicitResourceDetails resource = testAccounts.getDefaultImplicitResource(); - HttpHeaders headers = new HttpHeaders(); + headers.clear(); headers.add("Accept", MediaType.APPLICATION_JSON_VALUE); params.set("client_id", resource.getClientId()); params.set("client_secret", ""); @@ -377,17 +373,17 @@ public void testLoginServerCfPasswordToken() { } @SuppressWarnings("rawtypes") ResponseEntity response = serverRunning.postForMap(serverRunning.getAccessTokenUri(), params, headers); - assertEquals(HttpStatus.OK, response.getStatusCode()); + assertThat(response.getStatusCode()).isEqualTo(HttpStatus.OK); Map results = response.getBody(); - assertTrue("There should be a token: " + results, results.containsKey("access_token")); - assertTrue("There should be a refresh: " + results, results.containsKey("refresh_token")); + assertThat(results).as("There should be a token: " + results).containsKey("access_token") + .as("There should be a refresh: " + results).containsKey("refresh_token"); } @Test @OAuth2ContextConfiguration(LoginClient.class) public void testLoginServerWithoutBearerToken() { ImplicitResourceDetails resource = testAccounts.getDefaultImplicitResource(); - HttpHeaders headers = new HttpHeaders(); + headers.clear(); headers.add("Accept", MediaType.APPLICATION_JSON_VALUE); headers.add("Authorization", getAuthorizationEncodedValue(resource.getClientId(), "")); params.set("client_id", resource.getClientId()); @@ -401,14 +397,14 @@ public void testLoginServerWithoutBearerToken() { } @SuppressWarnings("rawtypes") ResponseEntity response = serverRunning.postForMap(serverRunning.getAccessTokenUri(), params, headers); - assertEquals(HttpStatus.UNAUTHORIZED, response.getStatusCode()); + assertThat(response.getStatusCode()).isEqualTo(HttpStatus.UNAUTHORIZED); } @Test @OAuth2ContextConfiguration(LoginClient.class) public void testLoginServerCfInvalidClientPasswordToken() { ImplicitResourceDetails resource = testAccounts.getDefaultImplicitResource(); - HttpHeaders headers = new HttpHeaders(); + headers.clear(); headers.add("Accept", MediaType.APPLICATION_JSON_VALUE); params.set("client_id", resource.getClientId()); params.set("client_secret", "bogus"); @@ -423,14 +419,14 @@ public void testLoginServerCfInvalidClientPasswordToken() { @SuppressWarnings("rawtypes") ResponseEntity response = serverRunning.postForMap(serverRunning.getAccessTokenUri(), params, headers); HttpStatus statusCode = response.getStatusCode(); - assertTrue("Status code should be 401 or 403.", statusCode == HttpStatus.FORBIDDEN || statusCode == HttpStatus.UNAUTHORIZED); + assertThat(statusCode == HttpStatus.FORBIDDEN || statusCode == HttpStatus.UNAUTHORIZED).as("Status code should be 401 or 403.").isTrue(); } @Test @OAuth2ContextConfiguration(AppClient.class) public void testLoginServerCfInvalidClientToken() { ImplicitResourceDetails resource = testAccounts.getDefaultImplicitResource(); - HttpHeaders headers = new HttpHeaders(); + headers.clear(); headers.add("Accept", MediaType.APPLICATION_JSON_VALUE); params.set("client_id", resource.getClientId()); params.set("client_secret", "bogus"); @@ -446,7 +442,7 @@ public void testLoginServerCfInvalidClientToken() { ResponseEntity response = serverRunning.postForMap(serverRunning.getAccessTokenUri(), params, headers); HttpStatus statusCode = response.getStatusCode(); - assertTrue("Status code should be 401 or 403.", statusCode == HttpStatus.FORBIDDEN || statusCode == HttpStatus.UNAUTHORIZED); + assertThat(statusCode == HttpStatus.FORBIDDEN || statusCode == HttpStatus.UNAUTHORIZED).as("Status code should be 401 or 403.").isTrue(); } private String getAuthorizationEncodedValue(String username, String password) { @@ -455,7 +451,6 @@ private String getAuthorizationEncodedValue(String username, String password) { return "Basic " + new String(encodedAuth); } - private static class LoginClient extends ClientCredentialsResourceDetails { @SuppressWarnings("unused") public LoginClient(Object target) { diff --git a/uaa/src/test/java/org/cloudfoundry/identity/uaa/mock/saml/SamlAuthenticationMockMvcTests.java b/uaa/src/test/java/org/cloudfoundry/identity/uaa/mock/saml/SamlAuthenticationMockMvcTests.java index cdc83c0c0d4..d1f4fad1ea7 100644 --- a/uaa/src/test/java/org/cloudfoundry/identity/uaa/mock/saml/SamlAuthenticationMockMvcTests.java +++ b/uaa/src/test/java/org/cloudfoundry/identity/uaa/mock/saml/SamlAuthenticationMockMvcTests.java @@ -73,7 +73,6 @@ class SamlAuthenticationMockMvcTests { private RandomValueStringGenerator generator; - private IdentityZone spZone; private IdentityZone idpZone; private String spZoneEntityId; @@ -101,7 +100,6 @@ private static void createUser( jdbcScimUserProvisioning.createUser(user, "secret", identityZone.getId()); } - @SuppressWarnings("SpringJavaInjectionPointsAutowiringInspection") @BeforeEach void createSamlRelationship( @Autowired JdbcIdentityProviderProvisioning jdbcIdentityProviderProvisioning, @@ -443,7 +441,6 @@ public void describeTo(Description description) { } @Nested - @DefaultTestContext class WithCustomLogAppender { private List logEvents; private AbstractAppender appender; diff --git a/uaa/src/test/java/org/cloudfoundry/identity/uaa/mock/saml/SamlKeyRotationMockMvcTests.java b/uaa/src/test/java/org/cloudfoundry/identity/uaa/mock/saml/SamlKeyRotationMockMvcTests.java index 2c2cb877419..b7c8a8e3a4f 100644 --- a/uaa/src/test/java/org/cloudfoundry/identity/uaa/mock/saml/SamlKeyRotationMockMvcTests.java +++ b/uaa/src/test/java/org/cloudfoundry/identity/uaa/mock/saml/SamlKeyRotationMockMvcTests.java @@ -14,57 +14,62 @@ package org.cloudfoundry.identity.uaa.mock.saml; import org.cloudfoundry.identity.uaa.DefaultTestContext; +import org.cloudfoundry.identity.uaa.client.UaaClientDetails; import org.cloudfoundry.identity.uaa.mock.util.MockMvcUtils; -import org.cloudfoundry.identity.uaa.provider.saml.idp.SamlTestUtils; +import org.cloudfoundry.identity.uaa.oauth.common.util.RandomValueStringGenerator; import org.cloudfoundry.identity.uaa.saml.SamlKey; import org.cloudfoundry.identity.uaa.zone.IdentityZone; import org.cloudfoundry.identity.uaa.zone.SamlConfig; import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Disabled; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.ValueSource; +import org.junit.jupiter.api.Test; import org.springframework.beans.factory.annotation.Autowired; -import org.cloudfoundry.identity.uaa.oauth.common.util.RandomValueStringGenerator; -import org.cloudfoundry.identity.uaa.client.UaaClientDetails; import org.springframework.test.web.servlet.MockMvc; import org.springframework.web.context.WebApplicationContext; -import org.w3c.dom.NodeList; +import org.xmlunit.assertj.XmlAssert; +import java.util.Arrays; import java.util.HashMap; -import java.util.List; import java.util.Map; -import static org.cloudfoundry.identity.uaa.provider.saml.SamlKeyManagerFactoryTests.*; -import static org.cloudfoundry.identity.uaa.provider.saml.idp.SamlTestUtils.getCertificates; -import static org.hamcrest.Matchers.containsInAnyOrder; -import static org.junit.Assert.*; +import static org.cloudfoundry.identity.uaa.provider.saml.Saml2TestUtils.xmlNamespaces; +import static org.cloudfoundry.identity.uaa.provider.saml.SamlKeyManagerFactoryTests.certificate1; +import static org.cloudfoundry.identity.uaa.provider.saml.SamlKeyManagerFactoryTests.certificate2; +import static org.cloudfoundry.identity.uaa.provider.saml.SamlKeyManagerFactoryTests.key1; +import static org.cloudfoundry.identity.uaa.provider.saml.SamlKeyManagerFactoryTests.key2; +import static org.cloudfoundry.identity.uaa.provider.saml.SamlKeyManagerFactoryTests.legacyCertificate; +import static org.cloudfoundry.identity.uaa.provider.saml.SamlKeyManagerFactoryTests.legacyKey; +import static org.cloudfoundry.identity.uaa.provider.saml.SamlKeyManagerFactoryTests.legacyPassphrase; +import static org.cloudfoundry.identity.uaa.provider.saml.SamlKeyManagerFactoryTests.passphrase1; +import static org.cloudfoundry.identity.uaa.provider.saml.SamlKeyManagerFactoryTests.passphrase2; import static org.springframework.http.MediaType.APPLICATION_XML; import static org.springframework.restdocs.mockmvc.RestDocumentationRequestBuilders.get; +import static org.springframework.test.web.servlet.result.MockMvcResultHandlers.print; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; @DefaultTestContext class SamlKeyRotationMockMvcTests { + private static final String METADATA_URL = "/saml/metadata"; + private static final String SIGNATURE_CERTIFICATE_XPATH_FORMAT = "//ds:Signature//ds:X509Certificate"; + public static final String KEY_DESCRIPTOR_CERTIFICATE_XPATH_FORMAT = "//md:SPSSODescriptor/md:KeyDescriptor[@use='%s']//ds:X509Certificate"; private IdentityZone zone; private SamlKey samlKey2; + @Autowired private MockMvc mockMvc; - @BeforeEach - void createZone( - @Autowired WebApplicationContext webApplicationContext, - @Autowired MockMvc mockMvc - ) throws Exception { - this.mockMvc = mockMvc; + @Autowired + WebApplicationContext webApplicationContext; + @BeforeEach + void createZone() throws Exception { String id = new RandomValueStringGenerator().generate().toLowerCase(); IdentityZone identityZone = new IdentityZone(); identityZone.setId(id); identityZone.setSubdomain(id); identityZone.setName("Test Saml Key Zone"); identityZone.setDescription("Testing SAML Key Rotation"); - Map keys = new HashMap<>(); - keys.put("exampleKeyId", "s1gNiNg.K3y/t3XT"); + Map keys = Map.of("exampleKeyId", "s1gNiNg.K3y/t3XT"); identityZone.getConfig().getTokenPolicy().setKeys(keys); SamlConfig samlConfig = new SamlConfig(); samlConfig.setCertificate(legacyCertificate); @@ -77,74 +82,65 @@ void createZone( identityZone.getConfig().setSamlConfig(samlConfig); UaaClientDetails zoneAdminClient = new UaaClientDetails("admin", null, - "openid", - "client_credentials,authorization_code", - "clients.admin,scim.read,scim.write", - "http://test.redirect.com"); + "openid", + "client_credentials,authorization_code", + "clients.admin,scim.read,scim.write", + "http://test.redirect.com"); zoneAdminClient.setClientSecret("admin-secret"); MockMvcUtils.IdentityZoneCreationResult identityZoneCreationResult = MockMvcUtils - .createOtherIdentityZoneAndReturnResult(mockMvc, webApplicationContext, zoneAdminClient, identityZone, false, id); + .createOtherIdentityZoneAndReturnResult(mockMvc, webApplicationContext, zoneAdminClient, identityZone, false, id); zone = identityZoneCreationResult.getIdentityZone(); } - @ParameterizedTest - @ValueSource(strings = {"/saml/metadata"}) - @Disabled("SAML test fails") - void key_rotation(String url) throws Exception { + @Test + void key_rotation() throws Exception { //default with three keys - String metadata = getMetadata(url); - List signatureVerificationKeys = getCertificates(metadata, "signing"); - assertThat(signatureVerificationKeys, containsInAnyOrder(clean(legacyCertificate), clean(certificate1), clean(certificate2))); - List encryptionKeys = getCertificates(metadata, "encryption"); - assertThat(encryptionKeys, containsInAnyOrder(clean(legacyCertificate))); - evaluateSignatureKey(metadata, legacyCertificate); + XmlAssert metadataAssert = getMetadataAssert(); + assertThatSigningKeyHasValues(metadataAssert, legacyCertificate, certificate1, certificate2); + assertThatEncryptionKeyHasValues(metadataAssert, legacyCertificate); + assertSignatureKeyHasValue(metadataAssert, legacyCertificate); //activate key1 zone.getConfig().getSamlConfig().setActiveKeyId("key1"); zone = MockMvcUtils.updateZone(mockMvc, zone); - metadata = getMetadata(url); - signatureVerificationKeys = getCertificates(metadata, "signing"); - assertThat(signatureVerificationKeys, containsInAnyOrder(clean(legacyCertificate), clean(certificate1), clean(certificate2))); - encryptionKeys = getCertificates(metadata, "encryption"); - evaluateSignatureKey(metadata, certificate1); - assertThat(encryptionKeys, containsInAnyOrder(clean(certificate1))); + metadataAssert = getMetadataAssert(); + assertThatSigningKeyHasValues(metadataAssert, legacyCertificate, certificate1, certificate2); + assertThatEncryptionKeyHasValues(metadataAssert, certificate1); + assertSignatureKeyHasValue(metadataAssert, certificate1); //remove all but key2 zone.getConfig().getSamlConfig().setKeys(new HashMap<>()); zone.getConfig().getSamlConfig().addAndActivateKey("key2", samlKey2); zone = MockMvcUtils.updateZone(mockMvc, zone); - metadata = getMetadata(url); - signatureVerificationKeys = getCertificates(metadata, "signing"); - assertThat(signatureVerificationKeys, containsInAnyOrder(clean(certificate2))); - evaluateSignatureKey(metadata, certificate2); - encryptionKeys = getCertificates(metadata, "encryption"); - assertThat(encryptionKeys, containsInAnyOrder(clean(certificate2))); + metadataAssert = getMetadataAssert(); + assertThatSigningKeyHasValues(metadataAssert, certificate2); + assertThatEncryptionKeyHasValues(metadataAssert, certificate2); + assertSignatureKeyHasValue(metadataAssert, certificate2); } - @ParameterizedTest - @ValueSource(strings = {"/saml/metadata"}) - @Disabled("SAML test fails") - void check_metadata_signature_key(String url) throws Exception { - String metadata = getMetadata(url); - - evaluateSignatureKey(metadata, legacyCertificate); + @Test + void check_metadata_signature_key() throws Exception { + XmlAssert metadataAssert = getMetadataAssert(); + assertSignatureKeyHasValue(metadataAssert, legacyCertificate); zone.getConfig().getSamlConfig().setActiveKeyId("key1"); zone = MockMvcUtils.updateZone(mockMvc, zone); - metadata = getMetadata(url); - - evaluateSignatureKey(metadata, certificate1); + metadataAssert = getMetadataAssert(); + assertSignatureKeyHasValue(metadataAssert, certificate1); } - private String getMetadata(String uri) throws Exception { - return mockMvc.perform( - get(uri) - .header("Host", zone.getSubdomain() + ".localhost") - .accept(APPLICATION_XML) - ) + private XmlAssert getMetadataAssert() throws Exception { + String metadata = mockMvc.perform( + get(METADATA_URL) + .header("Host", zone.getSubdomain() + ".localhost") + .accept(APPLICATION_XML) + ) + .andDo(print()) .andExpect(status().isOk()) .andReturn().getResponse().getContentAsString(); + + return XmlAssert.assertThat(metadata).withNamespaceContext(xmlNamespaces()); } private String clean(String cert) { @@ -153,12 +149,26 @@ private String clean(String cert) { .replace("\n", ""); } - private void evaluateSignatureKey(String metadata, String expectedKey) throws Exception { - String xpath = "//*[local-name() = 'Signature']//*[local-name() = 'X509Certificate']/text()"; - NodeList nodeList = SamlTestUtils.evaluateXPathExpression(SamlTestUtils.getMetadataDoc(metadata), xpath); - assertNotNull(nodeList); - assertEquals(1, nodeList.getLength()); - assertEquals(clean(expectedKey), clean(nodeList.item(0).getNodeValue())); + private void assertSignatureKeyHasValue(XmlAssert metadata, String expectedKey) { + metadata.hasXPath(SIGNATURE_CERTIFICATE_XPATH_FORMAT) + .isNotEmpty() + .extractingText() + .containsOnly(clean(expectedKey)); + } + + private void assertThatSigningKeyHasValues(XmlAssert xmlAssert, String... certificates) { + assertThatXmlKeysOfTypeHasValues(xmlAssert, "signing", certificates); } + private void assertThatEncryptionKeyHasValues(XmlAssert xmlAssert, String... certificates) { + assertThatXmlKeysOfTypeHasValues(xmlAssert, "encryption", certificates); + } + + private void assertThatXmlKeysOfTypeHasValues(XmlAssert xmlAssert, String type, String... certificates) { + String[] cleanCerts = Arrays.stream(certificates).map(this::clean).toArray(String[]::new); + xmlAssert.hasXPath(KEY_DESCRIPTOR_CERTIFICATE_XPATH_FORMAT.formatted(type)) + .isNotEmpty() + .extractingText() + .containsExactlyInAnyOrder(cleanCerts); + } } diff --git a/uaa/src/test/java/org/cloudfoundry/identity/uaa/mock/saml/SamlMetadataMockMvcTests.java b/uaa/src/test/java/org/cloudfoundry/identity/uaa/mock/saml/SamlMetadataMockMvcTests.java index 2e743216e22..70b49f187e1 100644 --- a/uaa/src/test/java/org/cloudfoundry/identity/uaa/mock/saml/SamlMetadataMockMvcTests.java +++ b/uaa/src/test/java/org/cloudfoundry/identity/uaa/mock/saml/SamlMetadataMockMvcTests.java @@ -7,7 +7,6 @@ import org.cloudfoundry.identity.uaa.zone.IdentityZone; import org.cloudfoundry.identity.uaa.zone.IdentityZoneHolder; import org.cloudfoundry.identity.uaa.zone.MultitenancyFixture; -import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; import org.springframework.beans.factory.annotation.Autowired; @@ -32,22 +31,8 @@ class SamlMetadataMockMvcTests { @Autowired private MockMvc mockMvc; - private RandomValueStringGenerator generator; - private IdentityZone spZone; - @Autowired private WebApplicationContext webApplicationContext; - private UaaClientDetails adminClient; - - @BeforeEach - void setUp() throws Exception { - adminClient = new UaaClientDetails("admin", "", "", "client_credentials", "uaa.admin"); - adminClient.setClientSecret("adminsecret"); - - generator = new RandomValueStringGenerator(); - String zoneSubdomain = "testzone-" + generator.generate(); - spZone = createZone(zoneSubdomain, adminClient, false, false, zoneSubdomain + "-entity-id"); - } @Test void testSamlMetadataRootNoEndingSlash() throws Exception { @@ -93,6 +78,7 @@ void testSamlMetadataXMLValidation() throws Exception { @Test void testNonDefaultZoneSamlMetadataXMLValidation() throws Exception { + IdentityZone spZone = setupIdentityZone(true); String subdomain = spZone.getSubdomain(); mockMvc.perform(get(new URI("/saml/metadata")) @@ -137,6 +123,7 @@ void testSamlMetadataXMLValidation() throws Exception { @Test void testNonDefaultZoneSamlMetadataXMLValidation() throws Exception { + IdentityZone spZone = setupIdentityZone(true); String subdomain = spZone.getSubdomain(); mockMvc.perform(get(new URI("/saml/metadata")) @@ -157,8 +144,7 @@ void testNonDefaultZoneSamlMetadataXMLValidation() throws Exception { @Test void testNonDefaultZoneSamlMetadataXMLValidation_ZoneSamlEntityIDNotSet() throws Exception { - generator = new RandomValueStringGenerator(); - IdentityZone alternativeSpZone = createZone("testzone-" + generator.generate(), adminClient, false, false, null); + IdentityZone alternativeSpZone = setupIdentityZone(false); String zoneSubdomain = alternativeSpZone.getSubdomain(); mockMvc.perform(get(new URI("/saml/metadata")) @@ -178,6 +164,16 @@ void testNonDefaultZoneSamlMetadataXMLValidation_ZoneSamlEntityIDNotSet() throws } } + private IdentityZone setupIdentityZone(boolean hasEntityId) throws Exception { + UaaClientDetails adminClient = new UaaClientDetails("admin", "", "", "client_credentials", "uaa.admin"); + adminClient.setClientSecret("adminsecret"); + + RandomValueStringGenerator generator = new RandomValueStringGenerator(); + String zoneSubdomain = "testzone-" + generator.generate(); + String entityId = hasEntityId ? zoneSubdomain + "-entity-id" : null; + return createZone(zoneSubdomain, adminClient, false, false, entityId); + } + private IdentityZone createZone(String zoneSubdomain, UaaClientDetails adminClient, Boolean samlRequestSigned, Boolean samlWantAssertionSigned, String samlZoneEntityID) throws Exception { IdentityZone identityZone = MultitenancyFixture.identityZone(zoneSubdomain, zoneSubdomain); identityZone.getConfig().getSamlConfig().setRequestSigned(samlRequestSigned);