diff --git a/docs/reference/exceptions.md b/docs/reference/exceptions.md new file mode 100644 index 0000000..625136f --- /dev/null +++ b/docs/reference/exceptions.md @@ -0,0 +1,6 @@ +--- +title: Exceptions +--- +# Exceptions + +::: federatedidentity.exceptions diff --git a/docs/reference.md b/docs/reference/index.md similarity index 100% rename from docs/reference.md rename to docs/reference/index.md diff --git a/docs/reference/transport.md b/docs/reference/transport.md new file mode 100644 index 0000000..f0933d7 --- /dev/null +++ b/docs/reference/transport.md @@ -0,0 +1,8 @@ +--- +title: HTTP Transport +--- +# HTTP transport providers + +::: federatedidentity.transport + +::: federatedidentity.transport.requests diff --git a/federatedidentity/__init__.py b/federatedidentity/__init__.py index 6db27d3..9752cff 100644 --- a/federatedidentity/__init__.py +++ b/federatedidentity/__init__.py @@ -1,22 +1,8 @@ -from .exceptions import ( - FederatedIdentityError, - InvalidClaimsError, - InvalidIssuerError, - InvalidJWKSUrlError, - InvalidOIDCDiscoveryDocumentError, - InvalidTokenError, - TransportError, -) -from .oidc import AsyncOIDCTokenIssuer, OIDCTokenIssuer +from ._oidc import Issuer +from ._verify import ClaimVerifier, verify_id_token __all__ = [ - "FederatedIdentityError", - "InvalidClaimsError", - "InvalidIssuerError", - "InvalidJWKSUrlError", - "InvalidOIDCDiscoveryDocumentError", - "InvalidTokenError", - "TransportError", - "OIDCTokenIssuer", - "AsyncOIDCTokenIssuer", + "Issuer", + "ClaimVerifier", + "verify_id_token", ] diff --git a/federatedidentity/oidc.py b/federatedidentity/_oidc.py similarity index 60% rename from federatedidentity/oidc.py rename to federatedidentity/_oidc.py index 932fb16..76e4e46 100644 --- a/federatedidentity/oidc.py +++ b/federatedidentity/_oidc.py @@ -1,5 +1,5 @@ +import dataclasses import json -from collections.abc import Mapping from typing import Any, NewType, Optional, cast from urllib.parse import urlparse @@ -8,9 +8,7 @@ from jwcrypto.jwt import JWT from validators.url import url as validate_url -from .baseprovider import AsyncBaseProvider, BaseProvider from .exceptions import ( - InvalidClaimsError, InvalidIssuerError, InvalidJWKSUrlError, InvalidOIDCDiscoveryDocumentError, @@ -20,11 +18,71 @@ from .transport import AsyncRequestBase, RequestBase from .transport import requests as requests_transport +__all__ = [ + "Issuer", +] + ValidatedIssuer = NewType("ValidatedIssuer", str) ValidatedJWKSUrl = NewType("ValidatedJWKSUrl", str) UnvalidatedClaims = NewType("UnvalidatedClaims", dict[str, Any]) +@dataclasses.dataclass(frozen=True) +class Issuer: + """ + Represents an issuer of OIDC id tokens. + """ + + name: str + "Name of the issuer as it appears in `iss` claims." + key_set: JWKSet + "JWK key set associated with the issuer used to verify JWT signatures." + + @classmethod + def from_discovery(cls, name: str, request: Optional[RequestBase] = None) -> "Issuer": + """ + Initialise an issuer fetching key sets as per [OpenID Connect Discovery](oidc-discovery). + + [oidc-discovery]: https://openid.net/specs/openid-connect-discovery-1_0.html + + Arguments: + name: The name of the issuer as it would appear in the "iss" claim of a token + request: An optional HTTP request callable. If omitted a default implementation based + on the [requests][] module is used. + + Returns: + a newly-created issuer + + Raises: + federatedidentity.exceptions.FederatedIdentityError + """ + request = request if request is not None else requests_transport.request + return Issuer(name=name, key_set=fetch_jwks(name, request)) + + @classmethod + async def async_from_discovery( + cls, name: str, request: Optional[AsyncRequestBase] = None + ) -> "Issuer": + """ + Initialise an issuer fetching key sets as per [OpenID Connect Discovery](oidc-discovery). + + [oidc-discovery]: https://openid.net/specs/openid-connect-discovery-1_0.html + + Arguments: + name: The name of the issuer as it would appear in the "iss" claim of a token + request: An optional asynchronous HTTP request callable. If omitted a default + implementation based on the [requests][] module is used. + + Returns: + a newly-created issuer + + Raises: + federatedidentity.exceptions.FederatedIdentityError + """ + request = request if request is not None else requests_transport.async_request + return Issuer(name=name, key_set=await async_fetch_jwks(name, request)) + + def validate_issuer(unvalidated_issuer: str) -> ValidatedIssuer: """ Validate issuer is correctly formed. @@ -178,112 +236,3 @@ def validate_token(unvalidated_token: str, jwk_set: JWKSet) -> JWT: except JWException as e: raise InvalidTokenError(f"Invalid token: {e}") return jwt - - -class _BaseOIDCTokenIssuer: - - issuer: str - audience: str - _key_set: Optional[JWKSet] - - def __init__(self, issuer: str, audience: str): - self.issuer = issuer - self.audience = audience - self._key_set = None - - def validate(self, credential: str) -> Mapping[str, Any]: - """ - Validate a credential as being issued by this provider, having the required claims and - those claims having expected values. - - Returns the verified claims as a mapping. - - Raises: - FederatedIdentityError: if the token is invalid - ValueError: if prepare() has not been called - """ - if self._key_set is None: - raise ValueError("prepare() must have been called prior to validation") - - unvalidated_claims = unvalidated_claims_from_token(credential) - - if "iss" not in unvalidated_claims: - raise InvalidClaimsError("'iss' claim missing from token") - if unvalidated_claims["iss"] != self.issuer: - raise InvalidClaimsError( - f"'iss' claims has value '{unvalidated_claims['iss']}', " - f"expected '{self.issuer}'." - ) - - if "aud" not in unvalidated_claims: - raise InvalidClaimsError("'aud' claim is missing from token") - if unvalidated_claims["aud"] != self.audience: - raise InvalidClaimsError( - f"'aud' claims has value '{unvalidated_claims['aud']}', " - f"expected '{self.audience}'." - ) - - return json.loads(validate_token(credential, self._key_set).claims) - - -class OIDCTokenIssuer(_BaseOIDCTokenIssuer, BaseProvider): - """ - Represents an issuer of federated credentials in the form of OpenID Connect identity tokens. - - The issuer must publish an OIDC Discovery document as per - https://openid.net/specs/openid-connect-discovery-1_0.html. - - The id token is verified to have a signature which matches one of the keys in the issuer's - published key set and that it has at least an "iss", "sub", "aud" and "exp" claim. If an "exp" - claim is present, it is verified to be in the future. If a "nbf" claim is present it is - verified to be in the past and if a "iat" claim is present it is verified to be an integer. - - Args: - issuer: issuer of tokens as represented in the "iss" claim of the OIDC token. - audience: expected audience of tokens as represented in the "aud" claim of the OIDC token. - """ - - def prepare(self, request: Optional[RequestBase] = None) -> None: - """ - Prepare this issuer for token verification, fetching the issuer's public key if necessary. - The public key is only fetched once so it is safe to call this method repeatedly. - - Args: - request: HTTP transport to use to fetch the issuer public key set. Defaults to a - transport based on the requests library. - - Raises: - FederatedIdentityError: if the issuer, OIDC discovery document or JWKS is invalid or - some transport error ocurred. - """ - if self._key_set is not None: - return - request = request if request is not None else requests_transport.request - self._key_set = fetch_jwks(self.issuer, request) - - -class AsyncOIDCTokenIssuer(_BaseOIDCTokenIssuer, AsyncBaseProvider): - """ - Asynchronous version of OIDCTokenIssuer. The only difference being that prepare() takes an - optional AsyncRequestBase and must be awaited. - - """ - - async def prepare(self, request: Optional[AsyncRequestBase] = None) -> None: - """ - Prepare this issuer for token verification, fetching the issuer's public key if necessary. - The public key is only fetched once so it is safe to call this method repeatedly. - - Args: - request: Asynchronous HTTP transport to use to fetch the issuer public key set. - Defaults to a transport based on the requests library which runs in a separate - thread. - - Raises: - FederatedIdentityError: if the issuer, OIDC discovery document or JWKS is invalid or - some transport error ocurred. - """ - if self._key_set is not None: - return - request = request if request is not None else requests_transport.async_request - self._key_set = await async_fetch_jwks(self.issuer, request) diff --git a/federatedidentity/_verify.py b/federatedidentity/_verify.py new file mode 100644 index 0000000..9871537 --- /dev/null +++ b/federatedidentity/_verify.py @@ -0,0 +1,94 @@ +import json +from collections.abc import Callable, Iterable +from typing import Any, Optional, Union + +from . import _oidc +from .exceptions import InvalidClaimsError + +__all__ = [ + "ClaimVerifier", + "verify_id_token", +] + + +ClaimVerifier = Union[dict[str, Any], Callable[[dict[str, Any]], None]] +""" +Type representing a claim verifier. A claim verifier may be a dictionary of acceptable claim values +or a callable which takes the claims dictionary. A claims verifier callable should raise +[`InvalidClaimsError`][federatedidentity.exceptions.InvalidClaimsError] if the claims do not match +the expected values. +""" + + +def verify_id_token( + token: Union[str, bytes], + valid_issuers: Iterable[_oidc.Issuer], + valid_audiences: Iterable[str], + *, + required_claims: Optional[Iterable[ClaimVerifier]] = None, +) -> dict[str, Any]: + """ + Verify an OIDC identity token. + + Returns: + the token's claims dictionary. + + Parameters: + token: OIDC token to verify. If a [bytes][] object is passed it is decoded using the ASCII + codec before verification. + valid_issuers: Iterable of valid issuers. At least one Issuer must match the token issuer + for verification to succeed. + valid_audiences: Iterable of valid audiences. At least one audience must match the `aud` + claim for verification to succeed. + required_claims: Iterable of required claim verifiers. Claims are passed to verifiers after + the token's signature has been verified. Claims required by OIDC are always + validated. All claim verifiers must pass for verification to succeed. + + Raises: + federatedidentity.exceptions.FederatedIdentityError: The token failed verification. + UnicodeDecodeError: The token could not be decoded into an ASCII string. + """ + if isinstance(token, bytes): + token = token.decode("ascii") + + unvalidated_claims = _oidc.unvalidated_claims_from_token(token) + + # For required claims, see: https://openid.net/specs/openid-connect-core-1_0.html#IDToken + for claim in ["iss", "sub", "aud", "exp", "iat"]: + if claim not in unvalidated_claims: + raise InvalidClaimsError(f"'{claim}' claim not present in token") + + # Check that the token "aud" claim matches at least one of our expected audiences. + if not any(unvalidated_claims["aud"] == audience for audience in valid_audiences): + raise InvalidClaimsError( + f"Token issuer '{unvalidated_claims['aud']}' did not match any valid issuer" + ) + + # Determine which issuer matches the token. + for issuer in valid_issuers: + if issuer.name == unvalidated_claims["iss"]: + break + else: + # No issuer matched the token if the for loop exited without "break". + raise InvalidClaimsError( + f"Token issuer '{unvalidated_claims['iss']}' did not match any valid issuer" + ) + + # Note: validate_token() validates "exp", "iat" and "nbf" claims and that the "alg" header has + # an appropriate value. + verified_claims = json.loads(_oidc.validate_token(token, issuer.key_set).claims) + + required_claims = required_claims if required_claims is not None else [] + for claims_verifier in required_claims: + if callable(claims_verifier): + claims_verifier(verified_claims) + else: + for claim, value in claims_verifier.items(): + if claim not in verified_claims: + raise InvalidClaimsError(f"Required claim '{claim}' not present in token") + if verified_claims[claim] != value: + raise InvalidClaimsError( + f"Required claim '{claim}' has invalid value {value!r}" + ) + + return verified_claims diff --git a/federatedidentity/baseprovider.py b/federatedidentity/baseprovider.py deleted file mode 100644 index b05fe49..0000000 --- a/federatedidentity/baseprovider.py +++ /dev/null @@ -1,23 +0,0 @@ -from abc import ABCMeta, abstractmethod -from typing import Any, Optional - -from . import transport - - -class BaseProvider(metaclass=ABCMeta): - """ - Base class for credential providers. - """ - - @abstractmethod - def prepare(self, request: Optional[transport.RequestBase] = None) -> None: ... - - @abstractmethod - def validate(self, credential: str) -> Any: ... - - -class AsyncBaseProvider: - @abstractmethod - async def prepare( - self, request: Optional[transport.AsyncRequestBase] = None - ) -> None: ... # type: ignore[override] diff --git a/federatedidentity/transport/__init__.py b/federatedidentity/transport/__init__.py index d0d1337..70126a3 100644 --- a/federatedidentity/transport/__init__.py +++ b/federatedidentity/transport/__init__.py @@ -1,6 +1,11 @@ +""" +Generic HTTP transport base classes and utilities. +""" + import dataclasses from abc import ABCMeta, abstractmethod -from typing import Mapping, Optional +from collections.abc import Mapping +from typing import Optional @dataclasses.dataclass @@ -36,8 +41,8 @@ def __call__( The response from the resource server. Raises: - TransportError: on any transport error such as DNS resolution failure. Note that error - status codes from the server do not raise. + federatedidentity.exceptions.TransportError: on any transport error such as DNS + resolution failure. Note that error status codes from the server do not raise. """ @@ -67,6 +72,6 @@ async def __call__( The response from the resource server. Raises: - TransportError: on any transport error such as DNS resolution failure. Note that error - status codes from the server do not raise. + federatedidentity.exceptions.TransportError: on any transport error such as DNS + resolution failure. Note that error status codes from the server do not raise. """ diff --git a/federatedidentity/transport/requests.py b/federatedidentity/transport/requests.py index 0c60925..211fc63 100644 --- a/federatedidentity/transport/requests.py +++ b/federatedidentity/transport/requests.py @@ -1,3 +1,7 @@ +""" +HTTP transport based on [requests][]. +""" + import asyncio from typing import Mapping, Optional @@ -49,8 +53,12 @@ async def __call__(self, *args, **kwargs) -> Response: return await asyncio.to_thread(self._sync_request, *args, **kwargs) -#: RequestsSession object which uses a default requests.Session. request = RequestsSession() +""" +A HTTP transport implementation which uses a default [requests.Session][]. +""" -#: AsyncRequestsSession object which uses a default requests.Session. async_request = AsyncRequestsSession() +""" +An asynchronous HTTP transport implementation which uses a default [requests.Session][]. +""" diff --git a/mkdocs.yml b/mkdocs.yml index deaaf2a..f6b4970 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -4,7 +4,7 @@ site_url: "https://rjw57.github.io/verify-oidc-identity" repo_url: "https://github.com/rjw57/verify-oidc-identity" repo_name: "rjw57/verify-oidc-identity" site_dir: "site" -watch: [mkdocs.yml, README.md] +watch: [mkdocs.yml, README.md, federatedidentity/] copyright: Copyright © Rich Wareham edit_uri: edit/main/docs/ @@ -13,7 +13,10 @@ nav: - index.md - changelog.md - license.md - - reference.md + - API Reference: + - reference/index.md + - reference/exceptions.md + - reference/transport.md theme: name: material @@ -65,7 +68,10 @@ plugins: python: import: - https://docs.python.org/3/objects.inv + - https://requests.readthedocs.io/en/latest/objects.inv + - https://jwcrypto.readthedocs.io/en/latest/objects.inv options: + show_root_heading: true filters: ["!^_"] members_order: source separate_signature: true @@ -74,3 +80,4 @@ plugins: show_symbol_type_heading: true signature_crossrefs: true show_root_toc_entry: false + show_if_no_docstring: true diff --git a/pyproject.toml b/pyproject.toml index 0e5def2..92375fb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,7 @@ black = "^24.10.0" [tool.pytest.ini_options] addopts = "--cov --cov-report term --cov-report html" +asyncio_default_fixture_loop_scope = "function" [tool.mypy] ignore_missing_imports = true diff --git a/tests/oidcfixtures.py b/tests/oidcfixtures.py index 73e6247..7cbb7ea 100644 --- a/tests/oidcfixtures.py +++ b/tests/oidcfixtures.py @@ -1,3 +1,4 @@ +import datetime import json from typing import Any @@ -7,6 +8,8 @@ from jwcrypto.jwt import JWT from responses import RequestsMock +from federatedidentity._oidc import Issuer + @pytest.fixture def jwt_issuer(faker: Faker, jwks_uri: str, mocked_responses: RequestsMock) -> str: @@ -17,6 +20,11 @@ def jwt_issuer(faker: Faker, jwks_uri: str, mocked_responses: RequestsMock) -> s return issuer_url +@pytest.fixture +def oidc_issuer(jwt_issuer): + return Issuer.from_discovery(jwt_issuer) + + @pytest.fixture def oidc_subject(faker: Faker) -> str: return faker.slug() @@ -36,6 +44,8 @@ def oidc_claims( "sub": oidc_subject, "aud": oidc_audience, "jti": faker.uuid4(), + "iat": (faker.past_datetime() - datetime.timedelta(seconds=30)).timestamp(), + "exp": (faker.future_datetime() + datetime.timedelta(hours=1)).timestamp(), } diff --git a/tests/test_jwks_fetching.py b/tests/test_jwks_fetching.py index 93f5684..28391fa 100644 --- a/tests/test_jwks_fetching.py +++ b/tests/test_jwks_fetching.py @@ -4,98 +4,100 @@ from faker import Faker from jwcrypto.jwk import JWKSet -from federatedidentity import exceptions, oidc +from federatedidentity import _oidc, exceptions from federatedidentity.transport.requests import async_request, request def test_basic_case(jwt_issuer: str, jwk_set: JWKSet): - fetched_jwk_set = oidc.fetch_jwks(oidc.validate_issuer(jwt_issuer), request) + fetched_jwk_set = _oidc.fetch_jwks(_oidc.validate_issuer(jwt_issuer), request) assert fetched_jwk_set == jwk_set @pytest.mark.asyncio async def test_basic_case_async(jwt_issuer: str, jwk_set: JWKSet): - fetched_jwk_set = await oidc.async_fetch_jwks(oidc.validate_issuer(jwt_issuer), async_request) + fetched_jwk_set = await _oidc.async_fetch_jwks( + _oidc.validate_issuer(jwt_issuer), async_request + ) assert fetched_jwk_set == jwk_set @pytest.mark.parametrize("field", ["jwks_uri", "issuer"]) def test_missing_field_in_discovery_doc(field: str, jwt_issuer: str, mocked_responses): - doc_url = oidc.oidc_discovery_document_url(oidc.validate_issuer(jwt_issuer)) + doc_url = _oidc.oidc_discovery_document_url(_oidc.validate_issuer(jwt_issuer)) doc = json.loads(request(doc_url).content) del doc[field] mocked_responses.remove("GET", doc_url) mocked_responses.get(doc_url, body=json.dumps(doc), content_type="application/json") with pytest.raises(exceptions.InvalidOIDCDiscoveryDocumentError): - oidc.fetch_jwks(jwt_issuer, request) + _oidc.fetch_jwks(jwt_issuer, request) def test_discovery_doc_not_present(jwt_issuer: str, mocked_responses): - doc_url = oidc.oidc_discovery_document_url(oidc.validate_issuer(jwt_issuer)) + doc_url = _oidc.oidc_discovery_document_url(_oidc.validate_issuer(jwt_issuer)) doc_body = request(doc_url).content mocked_responses.remove("GET", doc_url) mocked_responses.get(doc_url, body=doc_body, status=404, content_type="application/json") with pytest.raises(exceptions.TransportError): - oidc.fetch_jwks(jwt_issuer, request) + _oidc.fetch_jwks(jwt_issuer, request) @pytest.mark.asyncio async def test_discovery_doc_not_present_async(jwt_issuer: str, mocked_responses): - doc_url = oidc.oidc_discovery_document_url(oidc.validate_issuer(jwt_issuer)) + doc_url = _oidc.oidc_discovery_document_url(_oidc.validate_issuer(jwt_issuer)) doc_body = (await async_request(doc_url)).content mocked_responses.remove("GET", doc_url) mocked_responses.get(doc_url, body=doc_body, status=404, content_type="application/json") with pytest.raises(exceptions.TransportError): - await oidc.async_fetch_jwks(jwt_issuer, async_request) + await _oidc.async_fetch_jwks(jwt_issuer, async_request) def test_discovery_doc_not_json(jwt_issuer: str, mocked_responses): - doc_url = oidc.oidc_discovery_document_url(oidc.validate_issuer(jwt_issuer)) + doc_url = _oidc.oidc_discovery_document_url(_oidc.validate_issuer(jwt_issuer)) mocked_responses.remove("GET", doc_url) mocked_responses.get(doc_url, body="this is not json", content_type="application/json") with pytest.raises(exceptions.InvalidOIDCDiscoveryDocumentError): - oidc.fetch_jwks(jwt_issuer, request) + _oidc.fetch_jwks(jwt_issuer, request) def test_mismatched_issuer_in_discovery_doc(faker: Faker, jwt_issuer: str, mocked_responses): - doc_url = oidc.oidc_discovery_document_url(oidc.validate_issuer(jwt_issuer)) + doc_url = _oidc.oidc_discovery_document_url(_oidc.validate_issuer(jwt_issuer)) doc = json.loads(request(doc_url).content) doc["issuer"] = faker.url() mocked_responses.remove("GET", doc_url) mocked_responses.get(doc_url, body=json.dumps(doc), content_type="application/json") with pytest.raises(exceptions.InvalidOIDCDiscoveryDocumentError): - oidc.fetch_jwks(jwt_issuer, request) + _oidc.fetch_jwks(jwt_issuer, request) def test_issuer_not_url(): with pytest.raises(exceptions.InvalidIssuerError) as e: - oidc.fetch_jwks("not/a/ url", request) + _oidc.fetch_jwks("not/a/ url", request) assert str(e.value) == "Issuer is not a valid URL." def test_issuer_not_https(faker: Faker): with pytest.raises(exceptions.InvalidIssuerError) as e: - oidc.fetch_jwks(faker.url(schemes=["http"]), request) + _oidc.fetch_jwks(faker.url(schemes=["http"]), request) assert str(e.value) == "Issuer does not have a https scheme." def test_jwks_uri_not_url(jwt_issuer: str, mocked_responses): - doc_url = oidc.oidc_discovery_document_url(oidc.validate_issuer(jwt_issuer)) + doc_url = _oidc.oidc_discovery_document_url(_oidc.validate_issuer(jwt_issuer)) doc = json.loads(request(doc_url).content) doc["jwks_uri"] = "not a /url" mocked_responses.remove("GET", doc_url) mocked_responses.get(doc_url, body=json.dumps(doc), content_type="application/json") with pytest.raises(exceptions.InvalidJWKSUrlError) as e: - oidc.fetch_jwks(jwt_issuer, request) + _oidc.fetch_jwks(jwt_issuer, request) assert str(e.value) == "JWKS URL is not a valid URL." def test_jwks_uri_not_https(faker: Faker, jwt_issuer: str, mocked_responses): - doc_url = oidc.oidc_discovery_document_url(oidc.validate_issuer(jwt_issuer)) + doc_url = _oidc.oidc_discovery_document_url(_oidc.validate_issuer(jwt_issuer)) doc = json.loads(request(doc_url).content) doc["jwks_uri"] = faker.url(schemes=["http"]) mocked_responses.remove("GET", doc_url) mocked_responses.get(doc_url, body=json.dumps(doc), content_type="application/json") with pytest.raises(exceptions.InvalidJWKSUrlError) as e: - oidc.fetch_jwks(jwt_issuer, request) + _oidc.fetch_jwks(jwt_issuer, request) assert str(e.value) == "JWKS URL does not have a https scheme." diff --git a/tests/test_oidc.py b/tests/test_oidc.py index 1e121f3..9c4ed62 100644 --- a/tests/test_oidc.py +++ b/tests/test_oidc.py @@ -1,22 +1,9 @@ import pytest import requests -import responses from jwcrypto.jwk import JWKSet from jwcrypto.jwt import JWT -from federatedidentity import oidc - - -@pytest.fixture -def unprepared_oidc_token_issuer(jwt_issuer, oidc_audience): - return oidc.OIDCTokenIssuer(issuer=jwt_issuer, audience=oidc_audience) - - -@pytest.fixture -def prepared_oidc_token_issuer(jwt_issuer, oidc_audience): - issuer = oidc.OIDCTokenIssuer(issuer=jwt_issuer, audience=oidc_audience) - issuer.prepare() - return issuer +from federatedidentity import Issuer, _oidc, exceptions def test_jwt_issuer(jwt_issuer: str, jwks_uri: str): @@ -30,44 +17,23 @@ def test_oidc_token(oidc_token: str, jwk_set: JWKSet, jwt_issuer: str): jwt.deserialize(oidc_token, jwk_set) -def test_oidc_token_issuer(oidc_token: str, jwt_issuer: str): - assert oidc.unvalidated_claim_from_token(oidc_token, "iss") == jwt_issuer - - def test_oidc_token_subject(oidc_token: str, oidc_subject: str): - assert oidc.unvalidated_claim_from_token(oidc_token, "sub") == oidc_subject + assert _oidc.unvalidated_claim_from_token(oidc_token, "sub") == oidc_subject def test_oidc_token_audience(oidc_token: str, oidc_audience: str): - assert oidc.unvalidated_claim_from_token(oidc_token, "aud") == oidc_audience - - -def test_basic_validated(oidc_token, oidc_claims, prepared_oidc_token_issuer): - assert prepared_oidc_token_issuer.validate(oidc_token) == oidc_claims + assert _oidc.unvalidated_claim_from_token(oidc_token, "aud") == oidc_audience -@pytest.mark.parametrize("missing_claim", ["iss", "aud"]) -def test_oidc_token_missing_claim( - missing_claim, make_oidc_token, oidc_claims, prepared_oidc_token_issuer -): - claims = {**oidc_claims} - del claims[missing_claim] - with pytest.raises(oidc.InvalidClaimsError): - prepared_oidc_token_issuer.validate(make_oidc_token(claims)) +def test_oidc_token_issuer(oidc_token: str, oidc_issuer: Issuer): + assert _oidc.unvalidated_claim_from_token(oidc_token, "iss") == oidc_issuer.name -def test_unprepared_issuer(oidc_token, unprepared_oidc_token_issuer): - with pytest.raises(ValueError): - unprepared_oidc_token_issuer.validate(oidc_token) +def test_malformed_token(): + with pytest.raises(exceptions.InvalidTokenError): + assert _oidc.unvalidated_claim_from_token("not a JWT", "iss") -def test_multiple_prepare_only_fetches_once( - oidc_token, oidc_claims, unprepared_oidc_token_issuer, mocked_responses: responses.RequestsMock -): - mocked_responses.assert_call_count(oidc.oidc_discovery_document_url(oidc_claims["iss"]), 0) - unprepared_oidc_token_issuer.prepare() - mocked_responses.assert_call_count(oidc.oidc_discovery_document_url(oidc_claims["iss"]), 1) - unprepared_oidc_token_issuer.prepare() - mocked_responses.assert_call_count(oidc.oidc_discovery_document_url(oidc_claims["iss"]), 1) - assert unprepared_oidc_token_issuer.validate(oidc_token) == oidc_claims - mocked_responses.assert_call_count(oidc.oidc_discovery_document_url(oidc_claims["iss"]), 1) +def test_missing_claim(faker, oidc_token: str): + with pytest.raises(exceptions.InvalidTokenError): + assert _oidc.unvalidated_claim_from_token(oidc_token, faker.slug()) diff --git a/tests/test_oidc_discovery.py b/tests/test_oidc_discovery.py new file mode 100644 index 0000000..85dfb86 --- /dev/null +++ b/tests/test_oidc_discovery.py @@ -0,0 +1,101 @@ +import json + +import pytest +from faker import Faker +from jwcrypto.jwk import JWKSet + +from federatedidentity import Issuer, _oidc, exceptions +from federatedidentity.transport.requests import async_request, request + + +def test_basic_case(jwt_issuer: str, jwk_set: JWKSet): + issuer = Issuer.from_discovery(jwt_issuer) + assert issuer.key_set == jwk_set + + +@pytest.mark.asyncio +async def test_basic_case_async(jwt_issuer: str, jwk_set: JWKSet): + issuer = await Issuer.async_from_discovery(jwt_issuer) + assert issuer.key_set == jwk_set + + +@pytest.mark.parametrize("field", ["jwks_uri", "issuer"]) +def test_missing_field_in_discovery_doc(field: str, jwt_issuer: str, mocked_responses): + doc_url = _oidc.oidc_discovery_document_url(_oidc.validate_issuer(jwt_issuer)) + doc = json.loads(request(doc_url).content) + del doc[field] + mocked_responses.remove("GET", doc_url) + mocked_responses.get(doc_url, body=json.dumps(doc), content_type="application/json") + with pytest.raises(exceptions.InvalidOIDCDiscoveryDocumentError): + Issuer.from_discovery(jwt_issuer) + + +def test_discovery_doc_not_present(jwt_issuer: str, mocked_responses): + doc_url = _oidc.oidc_discovery_document_url(_oidc.validate_issuer(jwt_issuer)) + doc_body = request(doc_url).content + mocked_responses.remove("GET", doc_url) + mocked_responses.get(doc_url, body=doc_body, status=404, content_type="application/json") + with pytest.raises(exceptions.TransportError): + Issuer.from_discovery(jwt_issuer) + + +@pytest.mark.asyncio +async def test_discovery_doc_not_present_async(jwt_issuer: str, mocked_responses): + doc_url = _oidc.oidc_discovery_document_url(_oidc.validate_issuer(jwt_issuer)) + doc_body = (await async_request(doc_url)).content + mocked_responses.remove("GET", doc_url) + mocked_responses.get(doc_url, body=doc_body, status=404, content_type="application/json") + with pytest.raises(exceptions.TransportError): + await Issuer.async_from_discovery(jwt_issuer) + + +def test_discovery_doc_not_json(jwt_issuer: str, mocked_responses): + doc_url = _oidc.oidc_discovery_document_url(_oidc.validate_issuer(jwt_issuer)) + mocked_responses.remove("GET", doc_url) + mocked_responses.get(doc_url, body="this is not json", content_type="application/json") + with pytest.raises(exceptions.InvalidOIDCDiscoveryDocumentError): + Issuer.from_discovery(jwt_issuer) + + +def test_mismatched_issuer_in_discovery_doc(faker: Faker, jwt_issuer: str, mocked_responses): + doc_url = _oidc.oidc_discovery_document_url(_oidc.validate_issuer(jwt_issuer)) + doc = json.loads(request(doc_url).content) + doc["issuer"] = faker.url() + mocked_responses.remove("GET", doc_url) + mocked_responses.get(doc_url, body=json.dumps(doc), content_type="application/json") + with pytest.raises(exceptions.InvalidOIDCDiscoveryDocumentError): + Issuer.from_discovery(jwt_issuer) + + +def test_issuer_not_url(): + with pytest.raises(exceptions.InvalidIssuerError) as e: + Issuer.from_discovery("not/a/ url") + assert str(e.value) == "Issuer is not a valid URL." + + +def test_issuer_not_https(faker: Faker): + with pytest.raises(exceptions.InvalidIssuerError) as e: + Issuer.from_discovery(faker.url(schemes=["http"])) + assert str(e.value) == "Issuer does not have a https scheme." + + +def test_jwks_uri_not_url(jwt_issuer: str, mocked_responses): + doc_url = _oidc.oidc_discovery_document_url(_oidc.validate_issuer(jwt_issuer)) + doc = json.loads(request(doc_url).content) + doc["jwks_uri"] = "not a /url" + mocked_responses.remove("GET", doc_url) + mocked_responses.get(doc_url, body=json.dumps(doc), content_type="application/json") + with pytest.raises(exceptions.InvalidJWKSUrlError) as e: + Issuer.from_discovery(jwt_issuer) + assert str(e.value) == "JWKS URL is not a valid URL." + + +def test_jwks_uri_not_https(faker: Faker, jwt_issuer: str, mocked_responses): + doc_url = _oidc.oidc_discovery_document_url(_oidc.validate_issuer(jwt_issuer)) + doc = json.loads(request(doc_url).content) + doc["jwks_uri"] = faker.url(schemes=["http"]) + mocked_responses.remove("GET", doc_url) + mocked_responses.get(doc_url, body=json.dumps(doc), content_type="application/json") + with pytest.raises(exceptions.InvalidJWKSUrlError) as e: + Issuer.from_discovery(jwt_issuer) + assert str(e.value) == "JWKS URL does not have a https scheme." diff --git a/tests/test_validation.py b/tests/test_validation.py deleted file mode 100644 index d60f4a9..0000000 --- a/tests/test_validation.py +++ /dev/null @@ -1,125 +0,0 @@ -import datetime -from typing import Any - -import pytest -from faker import Faker -from jwcrypto.jwk import JWK -from jwcrypto.jws import JWS -from jwcrypto.jwt import JWT - -from federatedidentity import AsyncOIDCTokenIssuer, OIDCTokenIssuer -from federatedidentity import exceptions as exc -from federatedidentity import oidc - -from .oidcfixtures import make_jwt - - -def test_oidc_token_issuer(oidc_token: str, jwt_issuer: str): - assert oidc.unvalidated_claim_from_token(oidc_token, "iss") == jwt_issuer - - -def test_token_payload_is_not_json(ec_jwk: JWK): - jws = JWS("not json") - jws.add_signature( - ec_jwk, alg="ES256", protected={"alg": "ES256", "kid": ec_jwk["kid"], "type": "JWT"} - ) - with pytest.raises(exc.InvalidTokenError): - oidc.unvalidated_claim_from_token(jws.serialize(compact=True), "iss") - - -def test_missing_issuer_claim(oidc_claims: dict[str, str], ec_jwk: JWK): - del oidc_claims["iss"] - jwt = JWT( - header={"alg": "ES256", "kid": ec_jwk["kid"], "type": "JWT"}, - claims=oidc_claims, - ) - jwt.make_signed_token(ec_jwk) - token = jwt.serialize() - with pytest.raises(exc.InvalidTokenError): - oidc.unvalidated_claim_from_token(token, "iss") - - -def test_basic_verification(faker: Faker, oidc_token: str, oidc_audience: str, jwt_issuer: str): - provider = OIDCTokenIssuer(jwt_issuer, oidc_audience) - provider.prepare() - provider.validate(oidc_token) - - -@pytest.mark.asyncio -async def test_basic_async_verification( - faker: Faker, oidc_token: str, oidc_audience: str, jwt_issuer: str -): - provider = AsyncOIDCTokenIssuer(jwt_issuer, oidc_audience) - await provider.prepare() - provider.validate(oidc_token) - - -def test_mismatched_audience(faker: Faker, oidc_token: str, jwt_issuer: str): - provider = OIDCTokenIssuer(jwt_issuer, faker.url(schemes=["https"])) - provider.prepare() - with pytest.raises(exc.InvalidClaimsError): - provider.validate(oidc_token) - - -def test_issuer_not_url(oidc_token: str, oidc_audience: str): - provider = OIDCTokenIssuer("-not a url-", oidc_audience) - with pytest.raises(exc.InvalidIssuerError): - provider.prepare() - - -def test_issuer_bad_scheme(faker: Faker, oidc_token: str, oidc_audience: str): - provider = OIDCTokenIssuer(faker.url(schemes=["http"]), oidc_audience) - with pytest.raises(exc.InvalidIssuerError): - provider.prepare() - - -@pytest.mark.parametrize("alg", ["RS256", "ES256"]) -def test_mismatched_issuer( - alg: str, - faker: Faker, - oidc_claims: dict[str, str], - oidc_audience: str, - jwt_issuer: str, - jwks: dict[str, JWK], -): - provider = OIDCTokenIssuer(jwt_issuer, oidc_audience) - provider.prepare() - iss = faker.url(schemes=["https"]) - oidc_claims["iss"] = iss - token = make_jwt(oidc_claims, jwks[alg], alg) - with pytest.raises(exc.InvalidClaimsError): - provider.validate(token) - - -@pytest.mark.parametrize("alg", ["RS256", "ES256"]) -def test_exp_claim_in_past( - alg: str, - faker: Faker, - oidc_claims: dict[str, Any], - oidc_audience: str, - jwt_issuer: str, - jwks: dict[str, JWK], -): - provider = OIDCTokenIssuer(jwt_issuer, oidc_audience) - provider.prepare() - oidc_claims["exp"] = datetime.datetime.now(datetime.UTC).timestamp() - 100000 - token = make_jwt(oidc_claims, jwks[alg], alg) - with pytest.raises(exc.InvalidTokenError): - provider.validate(token) - - -@pytest.mark.parametrize("alg", ["RS256", "ES256"]) -def test_nbf_claim_in_future( - alg: str, - faker: Faker, - oidc_claims: dict[str, Any], - oidc_audience: str, - jwt_issuer: str, - jwks: dict[str, JWK], -): - provider = OIDCTokenIssuer(jwt_issuer, oidc_audience) - provider.prepare() - oidc_claims["nbf"] = datetime.datetime.now(datetime.UTC).timestamp() + 100000 - token = make_jwt(oidc_claims, jwks[alg], alg) - with pytest.raises(exc.InvalidTokenError): - provider.validate(token) diff --git a/tests/test_verification.py b/tests/test_verification.py new file mode 100644 index 0000000..4d3acaa --- /dev/null +++ b/tests/test_verification.py @@ -0,0 +1,168 @@ +import datetime +from typing import Any +from unittest import mock + +import pytest +from faker import Faker +from jwcrypto.jwk import JWK +from jwcrypto.jws import JWS +from jwcrypto.jwt import JWT + +from federatedidentity import Issuer +from federatedidentity import exceptions as exc +from federatedidentity import verify_id_token + +from .oidcfixtures import make_jwt + + +def test_basic_verification( + faker: Faker, oidc_token: str, oidc_audience: str, oidc_issuer: Issuer +): + verify_id_token(oidc_token, [oidc_issuer], [oidc_audience]) + + +def test_auto_decode(faker: Faker, oidc_token: str, oidc_audience: str, oidc_issuer: Issuer): + verify_id_token(oidc_token.encode("ascii"), [oidc_issuer], [oidc_audience]) + + +def test_non_ascii_token(faker: Faker, oidc_token: str, oidc_audience: str, oidc_issuer: Issuer): + with pytest.raises(UnicodeDecodeError): + verify_id_token( + "\N{LATIN SMALL LETTER E}\N{COMBINING CIRCUMFLEX ACCENT}".encode("utf8"), + [oidc_issuer], + [oidc_audience], + ) + + +def test_good_subject(oidc_token: str, oidc_audience: str, oidc_issuer: Issuer, oidc_subject: str): + verify_id_token( + oidc_token, [oidc_issuer], [oidc_audience], required_claims=[{"sub": oidc_subject}] + ) + + +def test_bad_subject( + faker: Faker, oidc_token: str, oidc_audience: str, oidc_issuer: Issuer, oidc_subject: str +): + with pytest.raises(exc.InvalidClaimsError): + verify_id_token( + oidc_token, [oidc_issuer], [oidc_audience], required_claims=[{"sub": faker.slug()}] + ) + + +def test_claim_verifier_missing_claim( + faker: Faker, oidc_token: str, oidc_audience: str, oidc_issuer: Issuer, oidc_subject: str +): + with pytest.raises(exc.InvalidClaimsError): + verify_id_token( + oidc_token, + [oidc_issuer], + [oidc_audience], + required_claims=[{"some-other-claim": faker.slug()}], + ) + + +def test_claims_verifier_callable( + faker: Faker, + oidc_token: str, + oidc_audience: str, + oidc_issuer: Issuer, + oidc_subject: str, + oidc_claims: dict[str, str], +): + message = faker.bothify("####????") + validator = mock.Mock(side_effect=exc.InvalidClaimsError(message)) + with pytest.raises(exc.InvalidClaimsError, match=message): + verify_id_token( + oidc_token, + [oidc_issuer], + [oidc_audience], + required_claims=[{"sub": oidc_subject}, validator], + ) + validator.assert_called_once_with(oidc_claims) + + +def test_token_payload_is_not_json(ec_jwk: JWK, oidc_issuer, oidc_audience): + jws = JWS("not json") + jws.add_signature( + ec_jwk, alg="ES256", protected={"alg": "ES256", "kid": ec_jwk["kid"], "type": "JWT"} + ) + with pytest.raises(exc.InvalidTokenError): + verify_id_token(jws.serialize(compact=True), [oidc_issuer], [oidc_audience]) + + +@pytest.mark.parametrize("claim", ["iss", "sub", "aud", "iat", "exp"]) +def test_missing_required_claim( + claim, oidc_claims: dict[str, str], ec_jwk: JWK, oidc_audience, oidc_issuer +): + # Claims are OK as is + jwt = JWT( + header={"alg": "ES256", "kid": ec_jwk["kid"], "type": "JWT"}, + claims=oidc_claims, + ) + jwt.make_signed_token(ec_jwk) + verify_id_token(jwt.serialize(), [oidc_issuer], [oidc_audience]) + + del oidc_claims[claim] + jwt = JWT( + header={"alg": "ES256", "kid": ec_jwk["kid"], "type": "JWT"}, + claims=oidc_claims, + ) + jwt.make_signed_token(ec_jwk) + with pytest.raises(exc.InvalidClaimsError): + verify_id_token(jwt.serialize(), [oidc_issuer], [oidc_audience]) + + +def test_mismatched_audience(faker: Faker, oidc_token: str, oidc_issuer: Issuer): + with pytest.raises(exc.InvalidClaimsError): + verify_id_token(oidc_token, [oidc_issuer], [faker.url(schemes=["https"])]) + + +@pytest.mark.parametrize("issuer", ["not-a-url", "http://example.com/", "ftp://example.com/"]) +def test_malformed_issuer(issuer, oidc_token: str, oidc_audience: str): + with pytest.raises(exc.InvalidIssuerError): + Issuer.from_discovery(issuer) + + +@pytest.mark.parametrize("alg", ["RS256", "ES256"]) +def test_mismatched_issuer( + alg: str, + faker: Faker, + oidc_claims: dict[str, str], + oidc_audience: str, + oidc_issuer: Issuer, + jwks: dict[str, JWK], +): + iss = faker.url(schemes=["https"]) + token = make_jwt({**oidc_claims, "iss": iss}, jwks[alg], alg) + with pytest.raises(exc.InvalidClaimsError, match="Token issuer '.*' did not match"): + verify_id_token(token, [oidc_issuer], [oidc_audience]) + + +@pytest.mark.parametrize("alg", ["RS256", "ES256"]) +def test_exp_claim_in_past( + alg: str, + faker: Faker, + oidc_claims: dict[str, Any], + oidc_audience: str, + oidc_issuer: Issuer, + jwks: dict[str, JWK], +): + oidc_claims["exp"] = datetime.datetime.now(datetime.UTC).timestamp() - 100000 + token = make_jwt(oidc_claims, jwks[alg], alg) + with pytest.raises(exc.InvalidTokenError, match="Expired"): + verify_id_token(token, [oidc_issuer], [oidc_audience]) + + +@pytest.mark.parametrize("alg", ["RS256", "ES256"]) +def test_nbf_claim_in_future( + alg: str, + faker: Faker, + oidc_claims: dict[str, Any], + oidc_audience: str, + oidc_issuer: Issuer, + jwks: dict[str, JWK], +): + oidc_claims["nbf"] = datetime.datetime.now(datetime.UTC).timestamp() + 100000 + token = make_jwt(oidc_claims, jwks[alg], alg) + with pytest.raises(exc.InvalidTokenError, match="Valid from"): + verify_id_token(token, [oidc_issuer], [oidc_audience])