diff --git a/not_my_board/_auth/_openid.py b/not_my_board/_auth/_openid.py index b7b4710..3b53d7c 100644 --- a/not_my_board/_auth/_openid.py +++ b/not_my_board/_auth/_openid.py @@ -8,6 +8,8 @@ import jwt +import not_my_board._http as http + @dataclasses.dataclass class IdentityProvider: @@ -17,11 +19,11 @@ class IdentityProvider: jwks_uri: str @classmethod - async def from_url(cls, issuer_url, http_client): + async def from_url(cls, issuer_url, http_client, cache=None): config_url = urllib.parse.urljoin( f"{issuer_url}/", ".well-known/openid-configuration" ) - config = await http_client.get_json(config_url) + config = await http_client.get_json(config_url, cache=cache) init_args = { field.name: config[field.name] for field in dataclasses.fields(cls) @@ -149,6 +151,11 @@ def __init__(self, client_id, http_client, trusted_issuers=None): self._client_id = client_id self._http = http_client self._trusted_issuers = trusted_issuers + if trusted_issuers is not None: + self._caches = { + issuer: (http.CacheEntry(), http.CacheEntry()) + for issuer in trusted_issuers + } async def extract_claims(self, id_token, leeway=0): unverified_token = jwt.api_jwt.decode_complete( @@ -157,11 +164,18 @@ async def extract_claims(self, id_token, leeway=0): key_id = unverified_token["header"]["kid"] issuer = unverified_token["payload"]["iss"] - if self._trusted_issuers is not None and issuer not in self._trusted_issuers: - raise RuntimeError(f"Unknown issuer: {issuer}") + if self._trusted_issuers is not None: + if issuer not in self._trusted_issuers: + raise RuntimeError(f"Unknown issuer: {issuer}") + + idp_cache, jwk_cache = self._caches[issuer] + else: + idp_cache = jwk_cache = None - identity_provider = await IdentityProvider.from_url(issuer, self._http) - jwk_set_raw = await self._http.get_json(identity_provider.jwks_uri) + identity_provider = await IdentityProvider.from_url( + issuer, self._http, idp_cache + ) + jwk_set_raw = await self._http.get_json(identity_provider.jwks_uri, jwk_cache) jwk_set = jwt.PyJWKSet.from_dict(jwk_set_raw) for key in jwk_set.keys: diff --git a/tests/test_auth.py b/tests/test_auth.py index 635ff20..9653e1d 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -43,7 +43,7 @@ def set_hub(self, hub_): def set_sub(self, sub): self._sub = sub - async def get_json(self, url): + async def get_json(self, url, cache=None): # noqa: ARG002 if url == f"{HUB_URL}/api/v1/auth-info": response = self._hub.auth_info() elif url == f"{ISSUER_URL}/.well-known/openid-configuration":