diff --git a/docs/reference/exceptions.md b/docs/reference/exceptions.md new file mode 100644 index 0000000..625136f --- /dev/null +++ b/docs/reference/exceptions.md @@ -0,0 +1,6 @@ +--- +title: Exceptions +--- +# Exceptions + +::: federatedidentity.exceptions diff --git a/docs/reference.md b/docs/reference/index.md similarity index 100% rename from docs/reference.md rename to docs/reference/index.md diff --git a/docs/reference/oidc.md b/docs/reference/oidc.md new file mode 100644 index 0000000..e297a88 --- /dev/null +++ b/docs/reference/oidc.md @@ -0,0 +1,6 @@ +--- +title: OpenID Connect +--- +# OpenID Connect (OIDC) + +::: federatedidentity.oidc diff --git a/federatedidentity/__init__.py b/federatedidentity/__init__.py index 6db27d3..d95e49d 100644 --- a/federatedidentity/__init__.py +++ b/federatedidentity/__init__.py @@ -1,22 +1,6 @@ -from .exceptions import ( - FederatedIdentityError, - InvalidClaimsError, - InvalidIssuerError, - InvalidJWKSUrlError, - InvalidOIDCDiscoveryDocumentError, - InvalidTokenError, - TransportError, -) -from .oidc import AsyncOIDCTokenIssuer, OIDCTokenIssuer +from ._verify import async_verify_oidc_token, verify_oidc_token __all__ = [ - "FederatedIdentityError", - "InvalidClaimsError", - "InvalidIssuerError", - "InvalidJWKSUrlError", - "InvalidOIDCDiscoveryDocumentError", - "InvalidTokenError", - "TransportError", - "OIDCTokenIssuer", - "AsyncOIDCTokenIssuer", + "verify_oidc_token", + "async_verify_oidc_token", ] diff --git a/federatedidentity/_async_helpers.py b/federatedidentity/_async_helpers.py new file mode 100644 index 0000000..ccff970 --- /dev/null +++ b/federatedidentity/_async_helpers.py @@ -0,0 +1,23 @@ +""" +Utility functions for asyncio. +""" + +from collections.abc import Awaitable +from inspect import isawaitable +from typing import TypeVar, Union + +_T = TypeVar("_T") + + +async def make_awaitable(v: _T) -> _T: + """ + Wrap a constant value into an awaitable. + """ + return v + + +def ensure_awaitable(v: Union[_T, Awaitable[_T]]) -> Awaitable[_T]: + """ + Return argument if it is awaitable otherwise wrap it as an awaitable. + """ + return v if isawaitable(v) else make_awaitable(v) diff --git a/federatedidentity/_verify.py b/federatedidentity/_verify.py new file mode 100644 index 0000000..72ebe49 --- /dev/null +++ b/federatedidentity/_verify.py @@ -0,0 +1,100 @@ +from collections.abc import Awaitable, Callable, Iterable +from typing import Any, Optional, Union + +from ._async_helpers import ensure_awaitable + +__all__ = [ + "verify_oidc_token", + "async_verify_oidc_token", +] + + +def verify_oidc_token( + token: Union[str, bytes], + issuers: Iterable[Union[str, Callable[[str], bool]]], + *, + claims: Optional[Iterable[Union[dict[str, Any], Callable[[dict[str, Any]], None]]]] = None, + request: Optional[Callable[[str], bytes]] = None, +) -> dict[str, Any]: + """ + Verify an OIDC identity token. + + Returns: + the token's claims dictionary. + + Parameters: + token: OIDC token to verify. If a string is passed it is first encoded using the ASCII + codec before verification. + issuers: Iterable of acceptable issuers. An issuer may be a static string to match or a + callable used. A match callable should return True if the issuer is acceptable or False + otherwise. At least one issuer must match the token's issuer for the token to be + verified. + claims: Iterable of claim verifiers. A claim verifier may be a dictionary of + acceptable claim values or a callable which takes the claims dictionary. A claims + verifier callable should raise [federatedidentity.exceptions.InvalidClaimsError][] if + the claims do not match the expected values. If omitted only claims required by OIDC + are validated. + request: Callable used to fetch issuer discovery documents. If omitted a default + implementation based on the [requests][] library is used. + + Raises: + federatedidentity.exceptions.FederatedIdentityError: the token could not be verified. + """ + raise NotImplementedError() + + +async def async_verify_oidc_token( + token: Union[str, bytes], + issuers: Iterable[Union[str, Awaitable[str], Callable[[str], Union[bool, Awaitable[bool]]]]], + *, + claims: Optional[ + Iterable[Union[dict[str, Any], Callable[[dict[str, Any]], Union[None, Awaitable[None]]]]] + ] = None, + request: Optional[Callable[[str], bytes | Awaitable[bytes]]] = None, +) -> dict[str, Any]: + """ + Asynchronously verify an OIDC identity token. + + Returns: + the token's claims dictionary. + + Parameters: + token: OIDC token to verify. If a string is passed it is first encoded using the ASCII + codec before verification. + issuers: Iterable of acceptable issuers. An issuer may be a static string to match or a + callable used. A match callable should return True if the issuer is acceptable or False + otherwise. At least one issuer must match the token's issuer for the token to be + verified. Callables may be asynchronous. + claims: Iterable of claim verifiers. A claim verifier may be a dictionary of acceptable + claim values or a callable which takes the claims dictionary. A claims verifier + callable should raise [federatedidentity.exceptions.InvalidClaimsError][] if the claims + do not match the expected values. If omitted only claims required by OIDC are + validated. + request: Callable used to fetch issuer discovery documents. If omitted a default + implementation based on the [requests][] library is used. + + Raises: + federatedidentity.exceptions.FederatedIdentityError: the token could not be verified. + """ + raise NotImplementedError() + + +def _any_issuer_matches(issuer: str, issuers: Iterable[Union[str, Callable[[str], bool]]]) -> bool: + for element in issuers: + if callable(element) and element(issuer): + return True + if element == issuer: + return True + return False + + +async def _async_any_issuer_matches( + issuer: str, + issuers: Iterable[Union[str, Awaitable[str], Callable[[str], Union[bool, Awaitable[bool]]]]], +) -> bool: + for element in issuers: + if callable(element) and await ensure_awaitable(element(issuer)): + return True + if (await ensure_awaitable(element)) == issuer: + return True + return False diff --git a/mkdocs.yml b/mkdocs.yml index deaaf2a..259e996 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -4,7 +4,7 @@ site_url: "https://rjw57.github.io/verify-oidc-identity" repo_url: "https://github.com/rjw57/verify-oidc-identity" repo_name: "rjw57/verify-oidc-identity" site_dir: "site" -watch: [mkdocs.yml, README.md] +watch: [mkdocs.yml, README.md, federatedidentity/] copyright: Copyright © Rich Wareham edit_uri: edit/main/docs/ @@ -13,7 +13,10 @@ nav: - index.md - changelog.md - license.md - - reference.md + - API Reference: + - reference/index.md + - reference/exceptions.md + - reference/oidc.md theme: name: material @@ -65,6 +68,7 @@ plugins: python: import: - https://docs.python.org/3/objects.inv + - https://requests.readthedocs.io/en/latest/objects.inv options: filters: ["!^_"] members_order: source diff --git a/pyproject.toml b/pyproject.toml index 0e5def2..92375fb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,7 @@ black = "^24.10.0" [tool.pytest.ini_options] addopts = "--cov --cov-report term --cov-report html" +asyncio_default_fixture_loop_scope = "function" [tool.mypy] ignore_missing_imports = true diff --git a/tests/test_async_helpers.py b/tests/test_async_helpers.py new file mode 100644 index 0000000..3d4bcf0 --- /dev/null +++ b/tests/test_async_helpers.py @@ -0,0 +1,21 @@ +import pytest + +from federatedidentity._async_helpers import ensure_awaitable + + +@pytest.mark.asyncio +async def test_ensure_non_awaitable(faker): + "ensure_awaitable() converts a non-awaitable to awaitable" + v = faker.slug() + assert (await ensure_awaitable(v)) is v + + +@pytest.mark.asyncio +async def test_ensure_awaitable(faker): + "ensure_awaitable() returns an awaitable as-is" + v = faker.slug() + + async def f(): + return v + + assert (await ensure_awaitable(f())) is v diff --git a/tests/test_validation.py b/tests/test_validation.py index d60f4a9..e99278a 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -7,7 +7,6 @@ from jwcrypto.jws import JWS from jwcrypto.jwt import JWT -from federatedidentity import AsyncOIDCTokenIssuer, OIDCTokenIssuer from federatedidentity import exceptions as exc from federatedidentity import oidc @@ -40,7 +39,7 @@ def test_missing_issuer_claim(oidc_claims: dict[str, str], ec_jwk: JWK): def test_basic_verification(faker: Faker, oidc_token: str, oidc_audience: str, jwt_issuer: str): - provider = OIDCTokenIssuer(jwt_issuer, oidc_audience) + provider = oidc.OIDCTokenIssuer(jwt_issuer, oidc_audience) provider.prepare() provider.validate(oidc_token) @@ -49,26 +48,26 @@ def test_basic_verification(faker: Faker, oidc_token: str, oidc_audience: str, j async def test_basic_async_verification( faker: Faker, oidc_token: str, oidc_audience: str, jwt_issuer: str ): - provider = AsyncOIDCTokenIssuer(jwt_issuer, oidc_audience) + provider = oidc.AsyncOIDCTokenIssuer(jwt_issuer, oidc_audience) await provider.prepare() provider.validate(oidc_token) def test_mismatched_audience(faker: Faker, oidc_token: str, jwt_issuer: str): - provider = OIDCTokenIssuer(jwt_issuer, faker.url(schemes=["https"])) + provider = oidc.OIDCTokenIssuer(jwt_issuer, faker.url(schemes=["https"])) provider.prepare() with pytest.raises(exc.InvalidClaimsError): provider.validate(oidc_token) def test_issuer_not_url(oidc_token: str, oidc_audience: str): - provider = OIDCTokenIssuer("-not a url-", oidc_audience) + provider = oidc.OIDCTokenIssuer("-not a url-", oidc_audience) with pytest.raises(exc.InvalidIssuerError): provider.prepare() def test_issuer_bad_scheme(faker: Faker, oidc_token: str, oidc_audience: str): - provider = OIDCTokenIssuer(faker.url(schemes=["http"]), oidc_audience) + provider = oidc.OIDCTokenIssuer(faker.url(schemes=["http"]), oidc_audience) with pytest.raises(exc.InvalidIssuerError): provider.prepare() @@ -82,7 +81,7 @@ def test_mismatched_issuer( jwt_issuer: str, jwks: dict[str, JWK], ): - provider = OIDCTokenIssuer(jwt_issuer, oidc_audience) + provider = oidc.OIDCTokenIssuer(jwt_issuer, oidc_audience) provider.prepare() iss = faker.url(schemes=["https"]) oidc_claims["iss"] = iss @@ -100,7 +99,7 @@ def test_exp_claim_in_past( jwt_issuer: str, jwks: dict[str, JWK], ): - provider = OIDCTokenIssuer(jwt_issuer, oidc_audience) + provider = oidc.OIDCTokenIssuer(jwt_issuer, oidc_audience) provider.prepare() oidc_claims["exp"] = datetime.datetime.now(datetime.UTC).timestamp() - 100000 token = make_jwt(oidc_claims, jwks[alg], alg) @@ -117,7 +116,7 @@ def test_nbf_claim_in_future( jwt_issuer: str, jwks: dict[str, JWK], ): - provider = OIDCTokenIssuer(jwt_issuer, oidc_audience) + provider = oidc.OIDCTokenIssuer(jwt_issuer, oidc_audience) provider.prepare() oidc_claims["nbf"] = datetime.datetime.now(datetime.UTC).timestamp() + 100000 token = make_jwt(oidc_claims, jwks[alg], alg) diff --git a/tests/test_verify_issuer_match.py b/tests/test_verify_issuer_match.py new file mode 100644 index 0000000..1386506 --- /dev/null +++ b/tests/test_verify_issuer_match.py @@ -0,0 +1,104 @@ +import pytest + +from federatedidentity import _verify +from federatedidentity._async_helpers import make_awaitable + + +@pytest.mark.parametrize( + "issuer,issuers", + [ + ["https://example.com/", ["https://some.other.invalid", "https://example.com/"]], + [ + "https://callable.example.com/", + [ + "https://some.other.invalid", + lambda issuer: issuer == "https://callable.example.com/", + ], + ], + ], +) +def test_expected_issuer_match(issuer, issuers): + assert _verify._any_issuer_matches(issuer, issuers) + + +@pytest.mark.parametrize( + "issuer,issuers", + [ + ["https://example.com/", ["https://some.other.invalid", "https://example.com/"]], + [ + "https://example.com/", + ["https://some.other.invalid", make_awaitable("https://example.com/")], + ], + [ + "https://example.com/", + [make_awaitable("https://some.other.invalid"), make_awaitable("https://example.com/")], + ], + [ + "https://callable.example.com/", + [ + "https://some.other.invalid", + lambda issuer: make_awaitable(issuer == "https://callable.example.com/"), + ], + ], + [ + "https://callable.example.com/", + [ + "https://some.other.invalid", + lambda issuer: issuer == "https://callable.example.com/", + ], + ], + ], +) +@pytest.mark.asyncio +async def test_async_expected_issuer_match(issuer, issuers): + assert await _verify._async_any_issuer_matches(issuer, issuers) + + +@pytest.mark.parametrize( + "issuer,issuers", + [ + ["https://nomatch.example.com/", ["https://some.other.invalid", "https://example.com/"]], + [ + "https://nomatch.callable.example.com/", + [ + "https://some.other.invalid", + lambda issuer: issuer == "https://callable.example.com/", + ], + ], + ], +) +def test_expected_no_issuer_match(issuer, issuers): + assert not _verify._any_issuer_matches(issuer, issuers) + + +@pytest.mark.parametrize( + "issuer,issuers", + [ + ["https://nomatch.example.com/", ["https://some.other.invalid", "https://example.com/"]], + [ + "https://nomatch.example.com/", + ["https://some.other.invalid", make_awaitable("https://example.com/")], + ], + [ + "https://nomatch.example.com/", + [make_awaitable("https://some.other.invalid"), make_awaitable("https://example.com/")], + ], + [ + "https://nomatch.example.com/", + [ + "https://some.other.invalid", + lambda issuer: make_awaitable(issuer == "https://callable.example.com/"), + ], + ], + [ + "https://nomatch.example.com/", + [ + "https://some.other.invalid", + lambda issuer: issuer == "https://callable.example.com/", + ], + ], + ], +) +@pytest.mark.asyncio +async def test_async_no_expected_issuer_match(issuer, issuers): + assert not await _verify._async_any_issuer_matches(issuer, issuers)