diff --git a/src/auth_server/db/transaction_state.py b/src/auth_server/db/transaction_state.py index 1c4b67b..d730129 100644 --- a/src/auth_server/db/transaction_state.py +++ b/src/auth_server/db/transaction_state.py @@ -11,7 +11,7 @@ from auth_server.db.client import BaseDB, get_motor_client from auth_server.mdq import MDQData -from auth_server.models.gnap import Access, GrantRequest, GrantResponse, SubjectRequest +from auth_server.models.gnap import Access, GrantRequest, GrantResponse, Key, SubjectRequest from auth_server.saml2 import SessionInfo from auth_server.time_utils import utc_now from auth_server.tls_fed_auth import MetadataEntity @@ -87,12 +87,16 @@ class ConfigState(TransactionState): config_claims: Dict[str, Any] = Field(default_factory=dict) -class MDQState(TransactionState): +class MetadataState(TransactionState): + keys_from_metadata: List[Key] = Field(default_factory=list) + + +class MDQState(MetadataState): auth_source: AuthSource = AuthSource.MDQ mdq_data: Optional[MDQData] = None -class TLSFEDState(TransactionState): +class TLSFEDState(MetadataState): auth_source: AuthSource = AuthSource.TLSFED entity: Optional[MetadataEntity] = None diff --git a/src/auth_server/flows.py b/src/auth_server/flows.py index 5eeb176..a6a15c4 100644 --- a/src/auth_server/flows.py +++ b/src/auth_server/flows.py @@ -32,7 +32,7 @@ TLSFEDState, get_transaction_state_db, ) -from auth_server.mdq import mdq_data_to_key, xml_mdq_get +from auth_server.mdq import mdq_data_to_keys, xml_mdq_get from auth_server.models.claims import CAClaims, Claims, ConfigClaims, MDQClaims, SAMLAssertionClaims, TLSFEDClaims from auth_server.models.gnap import ( AccessTokenFlags, @@ -58,7 +58,7 @@ from auth_server.proof.jws import check_jws_proof, check_jwsd_proof from auth_server.proof.mtls import check_mtls_proof from auth_server.time_utils import utc_now -from auth_server.tls_fed_auth import entity_to_key, get_entity +from auth_server.tls_fed_auth import entity_to_keys, get_entity from auth_server.utils import get_hex_uuid4, get_values __author__ = "lundberg" @@ -519,7 +519,23 @@ async def handle_interaction(self) -> Optional[GrantResponse]: return None -class MDQFlow(OnlyMTLSProofFlow): +class MetadataFlow(OnlyMTLSProofFlow): + # Used to handle multiple keys in metadata when rolling out new a new key + async def validate_proof(self) -> Optional[GrantResponse]: + for client_key in self.state.keys_from_metadata: + self.state.grant_request.client.key = client_key + try: + await super().validate_proof() + except NextFlowException: + pass + if self.state.proof_ok: + break + if not self.state.proof_ok: + raise NextFlowException(status_code=401, detail="no client certificate found") + return None + + +class MDQFlow(MetadataFlow): @classmethod def load_state(cls, state: Mapping[str, Any]) -> MDQState: return MDQState.from_dict(state=state) @@ -541,11 +557,12 @@ async def lookup_client_key(self) -> Optional[GrantResponse]: # Look for a key using mdq logger.info(f"Trying to load key from mdq") self.state.mdq_data = await xml_mdq_get(entity_id=key_id, mdq_url=self.config.mdq_server) - client_key = await mdq_data_to_key(self.state.mdq_data) + client_keys = await mdq_data_to_keys(self.state.mdq_data) - if not client_key: + if not client_keys: raise NextFlowException(status_code=400, detail=f"no client key found for {key_id}") - self.state.grant_request.client.key = client_key + + self.state.keys_from_metadata = client_keys return None async def create_claims(self) -> MDQClaims: @@ -578,7 +595,7 @@ async def create_claims(self) -> MDQClaims: return MDQClaims(**base_claims.model_dump(exclude_none=True), entity_id=entity_id, scopes=scopes, source=source) -class TLSFEDFlow(OnlyMTLSProofFlow): +class TLSFEDFlow(MetadataFlow): @classmethod def load_state(cls, state: Mapping[str, Any]) -> TLSFEDState: return TLSFEDState.from_dict(state=state) @@ -600,11 +617,12 @@ async def lookup_client_key(self) -> Optional[GrantResponse]: # Look for a key in the TLS fed metadata logger.info("Trying to load key from TLS fed auth") self.state.entity = await get_entity(entity_id=key_id) - client_key = await entity_to_key(self.state.entity) + client_keys = await entity_to_keys(self.state.entity) - if not client_key: + if not client_keys: raise NextFlowException(status_code=400, detail=f"no client key found for {key_id}") - self.state.grant_request.client.key = client_key + + self.state.keys_from_metadata = client_keys return None async def create_claims(self) -> TLSFEDClaims: diff --git a/src/auth_server/mdq.py b/src/auth_server/mdq.py index b778ca6..74475b2 100644 --- a/src/auth_server/mdq.py +++ b/src/auth_server/mdq.py @@ -1,18 +1,17 @@ # -*- coding: utf-8 -*- import logging -from base64 import b64encode from collections import OrderedDict as _OrderedDict from enum import Enum -from typing import Any, List, Optional, OrderedDict +from typing import Any, List, OrderedDict import aiohttp import xmltodict -from cryptography.hazmat.primitives.hashes import SHA1, SHA256 +from cryptography.hazmat.primitives.hashes import SHA1 from cryptography.x509 import Certificate from pydantic import BaseModel, ConfigDict, Field, field_validator, model_serializer from pyexpat import ExpatError -from auth_server.cert_utils import load_pem_from_str, serialize_certificate +from auth_server.cert_utils import load_pem_from_str, rfc8705_fingerprint, serialize_certificate from auth_server.models.gnap import Key, Proof, ProofMethod from auth_server.utils import get_values, hash_with @@ -83,7 +82,7 @@ async def xml_mdq_get(entity_id: str, mdq_url: str) -> MDQData: entity = xmltodict.parse(xml, process_namespaces=True) certs = [] # Certs - for key_descriptor in get_values(key="urn:oasis:names:tc:SAML:2.0:metadata:KeyDescriptor", obj=entity): + for key_descriptor in list(get_values(key="urn:oasis:names:tc:SAML:2.0:metadata:KeyDescriptor", obj=entity))[0]: use = list(get_values(key="@use", obj=key_descriptor))[0] raw_cert = list(get_values(key="http://www.w3.org/2000/09/xmldsig#:X509Certificate", obj=key_descriptor))[0] cert = load_pem_from_str(raw_cert) @@ -94,13 +93,16 @@ async def xml_mdq_get(entity_id: str, mdq_url: str) -> MDQData: return MDQData() -async def mdq_data_to_key(mdq_data: MDQData) -> Optional[Key]: - signing_cert = [item.cert for item in mdq_data.certs if item.use == KeyUse.SIGNING] - # There should only be one or zero signing certs - if signing_cert: - logger.info("Found cert in metadata") - return Key( - proof=Proof(method=ProofMethod.MTLS), - cert_S256=b64encode(signing_cert[0].fingerprint(algorithm=SHA256())).decode("utf-8"), +async def mdq_data_to_keys(mdq_data: MDQData) -> list[Key]: + keys = list() + signing_certs = [item.cert for item in mdq_data.certs if item.use == KeyUse.SIGNING] + for cert in signing_certs: + _fingerprint = rfc8705_fingerprint(cert) + logger.info(f"Found cert in metadata, S256: {_fingerprint}") + keys.append( + Key( + proof=Proof(method=ProofMethod.MTLS), + cert_S256=_fingerprint, + ) ) - return None + return keys diff --git a/src/auth_server/tls_fed_auth.py b/src/auth_server/tls_fed_auth.py index 0c2a891..1b5446b 100644 --- a/src/auth_server/tls_fed_auth.py +++ b/src/auth_server/tls_fed_auth.py @@ -271,20 +271,23 @@ async def get_entity(entity_id: str) -> Optional[MetadataEntity]: return None -async def entity_to_key(entity: Optional[MetadataEntity]) -> Optional[Key]: +async def entity_to_keys(entity: Optional[MetadataEntity]) -> list[Key]: + keys: list[Key] = [] if entity is None: - return None + return keys certs = [ load_pem_x509_certificate(item.x509certificate.encode()) for item in entity.issuers if item.x509certificate is not None ] - if certs: - # TODO: how do we handle multiple certs? - logger.info("Found cert in metadata") - return Key( - proof=Proof(method=ProofMethod.MTLS), - cert_S256=rfc8705_fingerprint(certs[0]), + for cert in certs: + _fingerprint = rfc8705_fingerprint(cert) + logger.info(f"Found cert in metadata, S256: {_fingerprint}") + keys.append( + Key( + proof=Proof(method=ProofMethod.MTLS), + cert_S256=_fingerprint, + ) ) - return None + return keys