Skip to content

Commit

Permalink
feat(wip): work in progress experimentation on API style
Browse files Browse the repository at this point in the history
  • Loading branch information
rjw57 committed Nov 17, 2024
1 parent 1d2b0af commit 505d596
Show file tree
Hide file tree
Showing 11 changed files with 278 additions and 30 deletions.
6 changes: 6 additions & 0 deletions docs/reference/exceptions.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
title: Exceptions
---
# Exceptions

::: federatedidentity.exceptions
File renamed without changes.
6 changes: 6 additions & 0 deletions docs/reference/oidc.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
title: OpenID Connect
---
# OpenID Connect (OIDC)

::: federatedidentity.oidc
22 changes: 3 additions & 19 deletions federatedidentity/__init__.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,6 @@
from .exceptions import (
FederatedIdentityError,
InvalidClaimsError,
InvalidIssuerError,
InvalidJWKSUrlError,
InvalidOIDCDiscoveryDocumentError,
InvalidTokenError,
TransportError,
)
from .oidc import AsyncOIDCTokenIssuer, OIDCTokenIssuer
from ._verify import async_verify_oidc_token, verify_oidc_token

__all__ = [
"FederatedIdentityError",
"InvalidClaimsError",
"InvalidIssuerError",
"InvalidJWKSUrlError",
"InvalidOIDCDiscoveryDocumentError",
"InvalidTokenError",
"TransportError",
"OIDCTokenIssuer",
"AsyncOIDCTokenIssuer",
"verify_oidc_token",
"async_verify_oidc_token",
]
23 changes: 23 additions & 0 deletions federatedidentity/_async_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
"""
Utility functions for asyncio.
"""

from collections.abc import Awaitable
from inspect import isawaitable
from typing import TypeVar, Union

_T = TypeVar("_T")


async def make_awaitable(v: _T) -> _T:
"""
Wrap a constant value into an awaitable.
"""
return v


def ensure_awaitable(v: Union[_T, Awaitable[_T]]) -> Awaitable[_T]:
"""
Return argument if it is awaitable otherwise wrap it as an awaitable.
"""
return v if isawaitable(v) else make_awaitable(v)
100 changes: 100 additions & 0 deletions federatedidentity/_verify.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
from collections.abc import Awaitable, Callable, Iterable
from typing import Any, Optional, Union

from ._async_helpers import ensure_awaitable

__all__ = [
"verify_oidc_token",
"async_verify_oidc_token",
]


def verify_oidc_token(
token: Union[str, bytes],
issuers: Iterable[Union[str, Callable[[str], bool]]],
*,
claims: Optional[Iterable[Union[dict[str, Any], Callable[[dict[str, Any]], None]]]] = None,
request: Optional[Callable[[str], bytes]] = None,
) -> dict[str, Any]:
"""
Verify an OIDC identity token.
Returns:
the token's claims dictionary.
Parameters:
token: OIDC token to verify. If a string is passed it is first encoded using the ASCII
codec before verification.
issuers: Iterable of acceptable issuers. An issuer may be a static string to match or a
callable used. A match callable should return True if the issuer is acceptable or False
otherwise. At least one issuer must match the token's issuer for the token to be
verified.
claims: Iterable of claim verifiers. A claim verifier may be a dictionary of
acceptable claim values or a callable which takes the claims dictionary. A claims
verifier callable should raise [federatedidentity.exceptions.InvalidClaimsError][] if
the claims do not match the expected values. If omitted only claims required by OIDC
are validated.
request: Callable used to fetch issuer discovery documents. If omitted a default
implementation based on the [requests][] library is used.
Raises:
federatedidentity.exceptions.FederatedIdentityError: the token could not be verified.
"""
raise NotImplementedError()


async def async_verify_oidc_token(
token: Union[str, bytes],
issuers: Iterable[Union[str, Awaitable[str], Callable[[str], Union[bool, Awaitable[bool]]]]],
*,
claims: Optional[
Iterable[Union[dict[str, Any], Callable[[dict[str, Any]], Union[None, Awaitable[None]]]]]
] = None,
request: Optional[Callable[[str], bytes | Awaitable[bytes]]] = None,
) -> dict[str, Any]:
"""
Asynchronously verify an OIDC identity token.
Returns:
the token's claims dictionary.
Parameters:
token: OIDC token to verify. If a string is passed it is first encoded using the ASCII
codec before verification.
issuers: Iterable of acceptable issuers. An issuer may be a static string to match or a
callable used. A match callable should return True if the issuer is acceptable or False
otherwise. At least one issuer must match the token's issuer for the token to be
verified. Callables may be asynchronous.
claims: Iterable of claim verifiers. A claim verifier may be a dictionary of acceptable
claim values or a callable which takes the claims dictionary. A claims verifier
callable should raise [federatedidentity.exceptions.InvalidClaimsError][] if the claims
do not match the expected values. If omitted only claims required by OIDC are
validated.
request: Callable used to fetch issuer discovery documents. If omitted a default
implementation based on the [requests][] library is used.
Raises:
federatedidentity.exceptions.FederatedIdentityError: the token could not be verified.
"""
raise NotImplementedError()


def _any_issuer_matches(issuer: str, issuers: Iterable[Union[str, Callable[[str], bool]]]) -> bool:
for element in issuers:
if callable(element) and element(issuer):
return True
if element == issuer:
return True
return False


async def _async_any_issuer_matches(
issuer: str,
issuers: Iterable[Union[str, Awaitable[str], Callable[[str], Union[bool, Awaitable[bool]]]]],
) -> bool:
for element in issuers:
if callable(element) and await ensure_awaitable(element(issuer)):
return True
if (await ensure_awaitable(element)) == issuer:
return True
return False
8 changes: 6 additions & 2 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ site_url: "https://rjw57.github.io/verify-oidc-identity"
repo_url: "https://github.com/rjw57/verify-oidc-identity"
repo_name: "rjw57/verify-oidc-identity"
site_dir: "site"
watch: [mkdocs.yml, README.md]
watch: [mkdocs.yml, README.md, federatedidentity/]
copyright: Copyright © Rich Wareham
edit_uri: edit/main/docs/

Expand All @@ -13,7 +13,10 @@ nav:
- index.md
- changelog.md
- license.md
- reference.md
- API Reference:
- reference/index.md
- reference/exceptions.md
- reference/oidc.md

theme:
name: material
Expand Down Expand Up @@ -65,6 +68,7 @@ plugins:
python:
import:
- https://docs.python.org/3/objects.inv
- https://requests.readthedocs.io/en/latest/objects.inv
options:
filters: ["!^_"]
members_order: source
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ black = "^24.10.0"

[tool.pytest.ini_options]
addopts = "--cov --cov-report term --cov-report html"
asyncio_default_fixture_loop_scope = "function"

[tool.mypy]
ignore_missing_imports = true
Expand Down
21 changes: 21 additions & 0 deletions tests/test_async_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import pytest

from federatedidentity._async_helpers import ensure_awaitable


@pytest.mark.asyncio
async def test_ensure_non_awaitable(faker):
"ensure_awaitable() converts a non-awaitable to awaitable"
v = faker.slug()
assert (await ensure_awaitable(v)) is v


@pytest.mark.asyncio
async def test_ensure_awaitable(faker):
"ensure_awaitable() returns an awaitable as-is"
v = faker.slug()

async def f():
return v

assert (await ensure_awaitable(f())) is v
17 changes: 8 additions & 9 deletions tests/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from jwcrypto.jws import JWS
from jwcrypto.jwt import JWT

from federatedidentity import AsyncOIDCTokenIssuer, OIDCTokenIssuer
from federatedidentity import exceptions as exc
from federatedidentity import oidc

Expand Down Expand Up @@ -40,7 +39,7 @@ def test_missing_issuer_claim(oidc_claims: dict[str, str], ec_jwk: JWK):


def test_basic_verification(faker: Faker, oidc_token: str, oidc_audience: str, jwt_issuer: str):
provider = OIDCTokenIssuer(jwt_issuer, oidc_audience)
provider = oidc.OIDCTokenIssuer(jwt_issuer, oidc_audience)
provider.prepare()
provider.validate(oidc_token)

Expand All @@ -49,26 +48,26 @@ def test_basic_verification(faker: Faker, oidc_token: str, oidc_audience: str, j
async def test_basic_async_verification(
faker: Faker, oidc_token: str, oidc_audience: str, jwt_issuer: str
):
provider = AsyncOIDCTokenIssuer(jwt_issuer, oidc_audience)
provider = oidc.AsyncOIDCTokenIssuer(jwt_issuer, oidc_audience)
await provider.prepare()
provider.validate(oidc_token)


def test_mismatched_audience(faker: Faker, oidc_token: str, jwt_issuer: str):
provider = OIDCTokenIssuer(jwt_issuer, faker.url(schemes=["https"]))
provider = oidc.OIDCTokenIssuer(jwt_issuer, faker.url(schemes=["https"]))
provider.prepare()
with pytest.raises(exc.InvalidClaimsError):
provider.validate(oidc_token)


def test_issuer_not_url(oidc_token: str, oidc_audience: str):
provider = OIDCTokenIssuer("-not a url-", oidc_audience)
provider = oidc.OIDCTokenIssuer("-not a url-", oidc_audience)
with pytest.raises(exc.InvalidIssuerError):
provider.prepare()


def test_issuer_bad_scheme(faker: Faker, oidc_token: str, oidc_audience: str):
provider = OIDCTokenIssuer(faker.url(schemes=["http"]), oidc_audience)
provider = oidc.OIDCTokenIssuer(faker.url(schemes=["http"]), oidc_audience)
with pytest.raises(exc.InvalidIssuerError):
provider.prepare()

Expand All @@ -82,7 +81,7 @@ def test_mismatched_issuer(
jwt_issuer: str,
jwks: dict[str, JWK],
):
provider = OIDCTokenIssuer(jwt_issuer, oidc_audience)
provider = oidc.OIDCTokenIssuer(jwt_issuer, oidc_audience)
provider.prepare()
iss = faker.url(schemes=["https"])
oidc_claims["iss"] = iss
Expand All @@ -100,7 +99,7 @@ def test_exp_claim_in_past(
jwt_issuer: str,
jwks: dict[str, JWK],
):
provider = OIDCTokenIssuer(jwt_issuer, oidc_audience)
provider = oidc.OIDCTokenIssuer(jwt_issuer, oidc_audience)
provider.prepare()
oidc_claims["exp"] = datetime.datetime.now(datetime.UTC).timestamp() - 100000
token = make_jwt(oidc_claims, jwks[alg], alg)
Expand All @@ -117,7 +116,7 @@ def test_nbf_claim_in_future(
jwt_issuer: str,
jwks: dict[str, JWK],
):
provider = OIDCTokenIssuer(jwt_issuer, oidc_audience)
provider = oidc.OIDCTokenIssuer(jwt_issuer, oidc_audience)
provider.prepare()
oidc_claims["nbf"] = datetime.datetime.now(datetime.UTC).timestamp() + 100000
token = make_jwt(oidc_claims, jwks[alg], alg)
Expand Down
Loading

0 comments on commit 505d596

Please sign in to comment.