Skip to content

Commit

Permalink
added tests for multiple certs in metadata based flows
Browse files Browse the repository at this point in the history
  • Loading branch information
johanlundberg committed Dec 11, 2024
1 parent 8f61896 commit ea7b4d6
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 14 deletions.
7 changes: 7 additions & 0 deletions src/auth_server/tests/data/test_mdq.xml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,13 @@
</ds:X509Data>
</ds:KeyInfo>
</md:KeyDescriptor>
<md:KeyDescriptor use="signing">
<ds:KeyInfo>
<ds:X509Data>
<ds:X509Certificate>MIIFOTCCAyGgAwIBAgIUFfCwL9eeKjTqY5RZCuLLnPvYxdgwDQYJKoZIhvcNAQELBQAwTzELMAkGA1UEBhMCU0UxCTAHBgNVBAgMADENMAsGA1UEBwwEVGVzdDENMAsGA1UECgwEVGVzdDEXMBUGA1UEAwwOdGVzdC5sb2NhbGhvc3QwHhcNMjQxMjExMTMzNTI3WhcNMjQxMjEyMTMzNTI3WjBPMQswCQYDVQQGEwJTRTEJMAcGA1UECAwAMQ0wCwYDVQQHDARUZXN0MQ0wCwYDVQQKDARUZXN0MRcwFQYDVQQDDA50ZXN0LmxvY2FsaG9zdDCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBALougAZhSedNXRcPVYMpCZKKHscY5l8Kb1pLk14++Ktz5olIZKdfY9SWfYkZpAmshubEQ13n0PJFzZohEvZ/xDczbK7xrCAjYZuCFzVLgUj1E3rBm7yN5D1wTSKzmhmGs2JFSDxo5a+NAJDEuZXvi2ypOuWn/KZzmY+aZY9e/L7jTz7e8kT9xZN8n4Nd7Uc50S1RB89zkmbc4M/sRLkFypv7rO8BGStEn+KnaPfAVCsyiPDjoIss5Qm1KdDAl+7g/gmYch+u/ilv+52jkUecDo7cyoipvNcSIawH3pIM7S3tmF7PuUl/Ko7qotNG9OxIQJCSlkyIO3F/hFKWe60tG+Gxh6PnDangOhAt6kvCYUpemELqFQwVjB8KveqkddlQx3TUPM1x5oJ/p6JKklgUxbzWrC9oMrxR9gsjs4jd2384WuADM3C5UxDoPQLEirLUB50Gj9Xkx3dPtEM3kqpAxOh4SKMMvN7vGE8iCAcga7HZzDekwn/R8gUzxLpY7qSvMwgADX7GW+Cb+z9wrPg8gg9vRAFV1XCBMH1+1l4m6+ZaWE+rKFTlLT8YPLKBjBlzZ5MjX3XWhdvQesRu1SlgA+mR7GrAj9xF3BMvUE2Vn2hbQqgJYBSFasP5PvkLITClbR4uMUfeskLcllogHQt2a4Pj71pyRsN8s7SlDRLviAElAgMBAAGjDTALMAkGA1UdEQQCMAAwDQYJKoZIhvcNAQELBQADggIBAAHo8UTXtytQmf0Q6c2pRsn96uVxlxP4+tQ6J1GXAtGq511SpqAR/BnBYbMw6VOwPjfZxKN2HK43dKX6us2wz4vD5RV7rt7ssZwysSn0kCJGqmH8/vRewQrKceamnRsF3Y+PUdXWhqDTJsLnYev/XnkpFQjhKs/1ALY7D7PaH8UoQCNrwa0ZQPKUJaCqZ08E43wbvOlk4Gwosa+HN3eMMsmCj4nURxGV8IpSc445GWHzMGw3JrfWwENFcVp4He9CB3Uem0MqUnU6H4FlFpbiOYGS3oH6fnfqAmTa4aLm0Hg75t5xc/nXPPNZXmwlWzG91QgP/AFv/PpFvc4HdmDIl7kgSYol7SPvwC9Stvw2nXXcc4Vg/ceeYxmbcZWB4bAy8oYPNqq/+GWOQeC2SFlie2H2NtYBRqFEJhlspYpjRR79cU+98syWe76ccDYw2w7+RhX5NEdE3/+VDmlPIePhy0iPXueLjL0VgGvIRWmcxcZ2ZaF/hQ8yTqP7f92igU7Y6ynej+mzPcDzQhXA1wDNSD3cBM2E56/MLQTKmgbeFGgr/MsGOiSpUMYR9Dh1nao1itlBhkvcLkdKy8Ulx4RqsnCohtbexSW3Qu1ObLGOabafL069DzcHL9JmainO3UwFpp/z+SFfyq/ZgRz4I34AXDg/x7BtLIKO/c8Rkzhr3fF4</ds:X509Certificate>
</ds:X509Data>
</ds:KeyInfo>
</md:KeyDescriptor>
<md:NameIDFormat>urn:oasis:names:tc:SAML:2.0:nameid-format:transient</md:NameIDFormat>
<md:SingleSignOnService Binding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect" Location="https://test.localhost/Saml2SP/sso/redirect"/>
<md:SingleSignOnService Binding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST" Location="https://test.localhost/Saml2SP/sso/post"/>
Expand Down
57 changes: 48 additions & 9 deletions src/auth_server/tests/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
from auth_server.models.status import Status
from auth_server.saml2 import AuthnInfo, NameID, SAMLAttributes, SessionInfo
from auth_server.testing import MongoTemporaryInstance
from auth_server.tests.utils import create_tls_fed_metadata, tls_fed_metadata_to_jws
from auth_server.tests.utils import create_cert, create_tls_fed_metadata, tls_fed_metadata_to_jws
from auth_server.time_utils import utc_now
from auth_server.tls_fed_auth import get_tls_fed_metadata
from auth_server.utils import get_hash_by_name, get_signing_key, hash_with, load_jwks
Expand Down Expand Up @@ -503,8 +503,14 @@ def test_mdq_flow(self, mock_mdq):
assert claims["scopes"] == ["localhost"]
assert claims["source"] == "http://www.swamid.se/"

@mock.patch("aiohttp.ClientSession.get", new_callable=AsyncMock)
def test_tls_fed_flow_remote_metadata(self, mock_metadata):
def _setup_remote_tls_fed_test(
self, entity_id: str, scopes: list[str] | None = None, client_certs: list[str] | None = None
) -> bytes:
if scopes is None:
scopes = ["test.localhost"]
if client_certs is None:
client_certs = [self.client_cert_str]

self.config["auth_flows"] = json.dumps(["TestFlow", "TLSFEDFlow"])
self.config["tls_fed_metadata"] = json.dumps(
[{"remote": "https://metadata.example.com/metadata.jws", "jwks": f"{self.datadir}/tls_fed_jwks.json"}]
Expand All @@ -516,17 +522,20 @@ def test_tls_fed_flow_remote_metadata(self, mock_metadata):
tls_fed_jwks = jwk.JWKSet()
tls_fed_jwks.import_keyset(f.read())

entity_id = "https://test.localhost"
metadata = create_tls_fed_metadata(
entity_id=entity_id, scopes=["test.localhost"], client_cert=self.client_cert_str
)
metadata = create_tls_fed_metadata(entity_id=entity_id, scopes=scopes, client_certs=client_certs)
metadata_jws = tls_fed_metadata_to_jws(
metadata,
key=tls_fed_jwks.get_key("metadata_signing_key_id"),
issuer="metadata.example.com",
expires=timedelta(days=14),
alg=SupportedAlgorithms.ES256,
)
return metadata_jws

@mock.patch("aiohttp.ClientSession.get", new_callable=AsyncMock)
def test_tls_fed_flow_remote_metadata(self, mock_metadata):
entity_id = "https://test.localhost"
metadata_jws = self._setup_remote_tls_fed_test(entity_id=entity_id)
mock_metadata.return_value = MockResponse(content=metadata_jws)

# Start transaction
Expand All @@ -550,6 +559,36 @@ def test_tls_fed_flow_remote_metadata(self, mock_metadata):
assert claims["organization_id"] == "SE0123456789"
assert claims["source"] == "metadata.example.com"

@mock.patch("aiohttp.ClientSession.get", new_callable=AsyncMock)
def test_tls_fed_flow_remote_metadata_multi_certs(self, mock_metadata):
entity_id = "https://test.localhost"
new_client_key, new_client_cert = create_cert(common_name="test.localhost")
new_client_cert_str = serialize_certificate(cert=new_client_cert)
client_certs = [new_client_cert_str, self.client_cert_str]
metadata_jws = self._setup_remote_tls_fed_test(entity_id=entity_id, client_certs=client_certs)
mock_metadata.return_value = MockResponse(content=metadata_jws)

# Start transaction
req = GrantRequest(
client=Client(key=entity_id),
access_token=[AccessTokenRequest(flags=[AccessTokenFlags.BEARER])],
)
client_header = {"Client-Cert": new_client_cert_str}
response = self.client.post("/transaction", json=req.model_dump(exclude_none=True), headers=client_header)
assert response.status_code == 200
assert "access_token" in response.json()
access_token = response.json()["access_token"]
assert AccessTokenFlags.BEARER.value in access_token["flags"]
assert access_token["value"] is not None

# Verify token and check claims
claims = self._get_access_token_claims(access_token=access_token, client=self.client)
assert claims["auth_source"] == AuthSource.TLSFED
assert claims["entity_id"] == "https://test.localhost"
assert claims["scopes"] == ["test.localhost"]
assert claims["organization_id"] == "SE0123456789"
assert claims["source"] == "metadata.example.com"

def test_tls_fed_flow_local_metadata(self):
# Create metadata jws and save it as a temporary file
with open(f"{self.datadir}/tls_fed_jwks.json", "r") as f:
Expand All @@ -558,7 +597,7 @@ def test_tls_fed_flow_local_metadata(self):

entity_id = "https://test.localhost"
metadata = create_tls_fed_metadata(
entity_id=entity_id, scopes=["test.localhost"], client_cert=self.client_cert_str
entity_id=entity_id, scopes=["test.localhost"], client_certs=[self.client_cert_str]
)
metadata_jws = tls_fed_metadata_to_jws(
metadata,
Expand Down Expand Up @@ -613,7 +652,7 @@ def test_tls_fed_flow_expired_entity(self, mock_metadata):
tls_fed_jwks.import_keyset(f.read())

entity_id = "https://test.localhost"
metadata = create_tls_fed_metadata(entity_id=entity_id, client_cert=self.client_cert_str)
metadata = create_tls_fed_metadata(entity_id=entity_id, client_certs=[self.client_cert_str])
metadata_jws = tls_fed_metadata_to_jws(
metadata,
key=tls_fed_jwks.get_key("metadata_signing_key_id"),
Expand Down
6 changes: 3 additions & 3 deletions src/auth_server/tests/test_tls_fed_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ async def _load_metadata(
entity_id=self.entity_id,
cache_ttl=self.cache_ttl.seconds,
scopes=self.scopes,
client_cert=self.client_cert_str,
client_certs=[self.client_cert_str],
)
metadata_jws = tls_fed_metadata_to_jws(
metadata,
Expand Down Expand Up @@ -85,7 +85,7 @@ async def test_parse_faulty_metadata(self):
entity_id=self.entity_id,
cache_ttl=self.cache_ttl.seconds,
scopes=self.scopes,
client_cert=self.client_cert_str,
client_certs=[self.client_cert_str],
).json(by_alias=True)
deserialized_metadata = json.loads(serialized_metadata)
entity = deserialized_metadata["entities"][0]
Expand Down Expand Up @@ -117,7 +117,7 @@ async def test_parse_unregistered_extension_in_metadata(self):
entity_id=self.entity_id,
cache_ttl=self.cache_ttl.seconds,
scopes=self.scopes,
client_cert=self.client_cert_str,
client_certs=[self.client_cert_str],
).model_dump_json(by_alias=True)
deserialized_metadata = json.loads(serialized_metadata)

Expand Down
45 changes: 43 additions & 2 deletions src/auth_server/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@
from datetime import datetime, timedelta
from typing import List, Optional, Union

from cryptography import x509
from cryptography.hazmat._oid import NameOID
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey
from cryptography.x509 import Certificate
from jwcrypto import jwk, jws

from auth_server.models.jose import SupportedAlgorithms
Expand Down Expand Up @@ -46,7 +52,7 @@ def tls_fed_metadata_to_jws(

def create_tls_fed_metadata(
entity_id: str,
client_cert: str,
client_certs: list[str],
cache_ttl: int = 3600,
organization_id: str = "SE0123456789",
scopes: Optional[List[str]] = None,
Expand All @@ -59,8 +65,43 @@ def create_tls_fed_metadata(
entity_id=entity_id,
organization="Test Org",
organization_id=organization_id,
issuers=[CertIssuers(x509certificate=client_cert)],
issuers=[CertIssuers(x509certificate=client_cert) for client_cert in client_certs],
extensions=Extensions(saml_scope=SAMLScopeExtension(scope=scopes)),
)
]
return TLSFEDMetadata(version="1.0.0", cache_ttl=cache_ttl, entities=entities)


def create_cert(
common_name: str, alt_names: list[str] | None = None, days_valid: int = 1
) -> tuple[RSAPrivateKey, Certificate]:
if alt_names is None:
alt_names = list()
key = rsa.generate_private_key(public_exponent=65537, key_size=4096)
subject = issuer = x509.Name(
[
x509.NameAttribute(NameOID.COUNTRY_NAME, "SE"),
x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, ""),
x509.NameAttribute(NameOID.LOCALITY_NAME, "Test"),
x509.NameAttribute(NameOID.ORGANIZATION_NAME, "Test"),
x509.NameAttribute(NameOID.COMMON_NAME, common_name),
]
)
_alt_names = [x509.DNSName(alt_name) for alt_name in alt_names]
now = utc_now()
cert = (
x509.CertificateBuilder()
.subject_name(subject)
.issuer_name(issuer)
.public_key(key.public_key())
.serial_number(x509.random_serial_number())
.not_valid_before(now)
.not_valid_after(now + timedelta(days=days_valid))
.add_extension(
x509.SubjectAlternativeName(_alt_names),
critical=False,
# Sign our certificate with our private key
)
.sign(key, hashes.SHA256())
)
return key, cert
5 changes: 5 additions & 0 deletions src/auth_server/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import json
import logging
from base64 import urlsafe_b64encode
from datetime import datetime, timezone
from functools import lru_cache
from typing import Any, Callable, Generator, Mapping, Sequence, Union
from uuid import uuid4
Expand All @@ -20,6 +21,10 @@
logger = logging.getLogger(__name__)


def utc_now() -> datetime:
return datetime.now(tz=timezone.utc)


@lru_cache()
def load_jwks() -> jwk.JWKSet:
config = load_config()
Expand Down

0 comments on commit ea7b4d6

Please sign in to comment.