Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: re-work API a little in preparation for release #4

Merged
merged 1 commit into from
Nov 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/reference/exceptions.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
title: Exceptions
---
# Exceptions

::: federatedidentity.exceptions
File renamed without changes.
8 changes: 8 additions & 0 deletions docs/reference/transport.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
---
title: HTTP Transport
---
# HTTP transport providers

::: federatedidentity.transport

::: federatedidentity.transport.requests
24 changes: 5 additions & 19 deletions federatedidentity/__init__.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,8 @@
from .exceptions import (
FederatedIdentityError,
InvalidClaimsError,
InvalidIssuerError,
InvalidJWKSUrlError,
InvalidOIDCDiscoveryDocumentError,
InvalidTokenError,
TransportError,
)
from .oidc import AsyncOIDCTokenIssuer, OIDCTokenIssuer
from ._oidc import Issuer
from ._verify import ClaimVerifier, verify_id_token

__all__ = [
"FederatedIdentityError",
"InvalidClaimsError",
"InvalidIssuerError",
"InvalidJWKSUrlError",
"InvalidOIDCDiscoveryDocumentError",
"InvalidTokenError",
"TransportError",
"OIDCTokenIssuer",
"AsyncOIDCTokenIssuer",
"Issuer",
"ClaimVerifier",
"verify_id_token",
]
173 changes: 61 additions & 112 deletions federatedidentity/oidc.py → federatedidentity/_oidc.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import dataclasses
import json
from collections.abc import Mapping
from typing import Any, NewType, Optional, cast
from urllib.parse import urlparse

Expand All @@ -8,9 +8,7 @@
from jwcrypto.jwt import JWT
from validators.url import url as validate_url

from .baseprovider import AsyncBaseProvider, BaseProvider
from .exceptions import (
InvalidClaimsError,
InvalidIssuerError,
InvalidJWKSUrlError,
InvalidOIDCDiscoveryDocumentError,
Expand All @@ -20,11 +18,71 @@
from .transport import AsyncRequestBase, RequestBase
from .transport import requests as requests_transport

__all__ = [
"Issuer",
]

ValidatedIssuer = NewType("ValidatedIssuer", str)
ValidatedJWKSUrl = NewType("ValidatedJWKSUrl", str)
UnvalidatedClaims = NewType("UnvalidatedClaims", dict[str, Any])


@dataclasses.dataclass(frozen=True)
class Issuer:
"""
Represents an issuer of OIDC id tokens.
"""

name: str
"Name of the issuer as it appears in `iss` claims."
key_set: JWKSet
"JWK key set associated with the issuer used to verify JWT signatures."

@classmethod
def from_discovery(cls, name: str, request: Optional[RequestBase] = None) -> "Issuer":
"""
Initialise an issuer fetching key sets as per [OpenID Connect Discovery](oidc-discovery).

[oidc-discovery]: https://openid.net/specs/openid-connect-discovery-1_0.html

Arguments:
name: The name of the issuer as it would appear in the "iss" claim of a token
request: An optional HTTP request callable. If omitted a default implementation based
on the [requests][] module is used.

Returns:
a newly-created issuer

Raises:
federatedidentity.exceptions.FederatedIdentityError
"""
request = request if request is not None else requests_transport.request
return Issuer(name=name, key_set=fetch_jwks(name, request))

@classmethod
async def async_from_discovery(
cls, name: str, request: Optional[AsyncRequestBase] = None
) -> "Issuer":
"""
Initialise an issuer fetching key sets as per [OpenID Connect Discovery](oidc-discovery).

[oidc-discovery]: https://openid.net/specs/openid-connect-discovery-1_0.html

Arguments:
name: The name of the issuer as it would appear in the "iss" claim of a token
request: An optional asynchronous HTTP request callable. If omitted a default
implementation based on the [requests][] module is used.

Returns:
a newly-created issuer

Raises:
federatedidentity.exceptions.FederatedIdentityError
"""
request = request if request is not None else requests_transport.async_request
return Issuer(name=name, key_set=await async_fetch_jwks(name, request))


def validate_issuer(unvalidated_issuer: str) -> ValidatedIssuer:
"""
Validate issuer is correctly formed.
Expand Down Expand Up @@ -178,112 +236,3 @@ def validate_token(unvalidated_token: str, jwk_set: JWKSet) -> JWT:
except JWException as e:
raise InvalidTokenError(f"Invalid token: {e}")
return jwt


class _BaseOIDCTokenIssuer:

issuer: str
audience: str
_key_set: Optional[JWKSet]

def __init__(self, issuer: str, audience: str):
self.issuer = issuer
self.audience = audience
self._key_set = None

def validate(self, credential: str) -> Mapping[str, Any]:
"""
Validate a credential as being issued by this provider, having the required claims and
those claims having expected values.

Returns the verified claims as a mapping.

Raises:
FederatedIdentityError: if the token is invalid
ValueError: if prepare() has not been called
"""
if self._key_set is None:
raise ValueError("prepare() must have been called prior to validation")

unvalidated_claims = unvalidated_claims_from_token(credential)

if "iss" not in unvalidated_claims:
raise InvalidClaimsError("'iss' claim missing from token")
if unvalidated_claims["iss"] != self.issuer:
raise InvalidClaimsError(
f"'iss' claims has value '{unvalidated_claims['iss']}', "
f"expected '{self.issuer}'."
)

if "aud" not in unvalidated_claims:
raise InvalidClaimsError("'aud' claim is missing from token")
if unvalidated_claims["aud"] != self.audience:
raise InvalidClaimsError(
f"'aud' claims has value '{unvalidated_claims['aud']}', "
f"expected '{self.audience}'."
)

return json.loads(validate_token(credential, self._key_set).claims)


class OIDCTokenIssuer(_BaseOIDCTokenIssuer, BaseProvider):
"""
Represents an issuer of federated credentials in the form of OpenID Connect identity tokens.

The issuer must publish an OIDC Discovery document as per
https://openid.net/specs/openid-connect-discovery-1_0.html.

The id token is verified to have a signature which matches one of the keys in the issuer's
published key set and that it has at least an "iss", "sub", "aud" and "exp" claim. If an "exp"
claim is present, it is verified to be in the future. If a "nbf" claim is present it is
verified to be in the past and if a "iat" claim is present it is verified to be an integer.

Args:
issuer: issuer of tokens as represented in the "iss" claim of the OIDC token.
audience: expected audience of tokens as represented in the "aud" claim of the OIDC token.
"""

def prepare(self, request: Optional[RequestBase] = None) -> None:
"""
Prepare this issuer for token verification, fetching the issuer's public key if necessary.
The public key is only fetched once so it is safe to call this method repeatedly.

Args:
request: HTTP transport to use to fetch the issuer public key set. Defaults to a
transport based on the requests library.

Raises:
FederatedIdentityError: if the issuer, OIDC discovery document or JWKS is invalid or
some transport error ocurred.
"""
if self._key_set is not None:
return
request = request if request is not None else requests_transport.request
self._key_set = fetch_jwks(self.issuer, request)


class AsyncOIDCTokenIssuer(_BaseOIDCTokenIssuer, AsyncBaseProvider):
"""
Asynchronous version of OIDCTokenIssuer. The only difference being that prepare() takes an
optional AsyncRequestBase and must be awaited.

"""

async def prepare(self, request: Optional[AsyncRequestBase] = None) -> None:
"""
Prepare this issuer for token verification, fetching the issuer's public key if necessary.
The public key is only fetched once so it is safe to call this method repeatedly.

Args:
request: Asynchronous HTTP transport to use to fetch the issuer public key set.
Defaults to a transport based on the requests library which runs in a separate
thread.

Raises:
FederatedIdentityError: if the issuer, OIDC discovery document or JWKS is invalid or
some transport error ocurred.
"""
if self._key_set is not None:
return
request = request if request is not None else requests_transport.async_request
self._key_set = await async_fetch_jwks(self.issuer, request)
94 changes: 94 additions & 0 deletions federatedidentity/_verify.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import json
from collections.abc import Callable, Iterable
from typing import Any, Optional, Union

from . import _oidc
from .exceptions import InvalidClaimsError

__all__ = [
"ClaimVerifier",
"verify_id_token",
]


ClaimVerifier = Union[dict[str, Any], Callable[[dict[str, Any]], None]]
"""
Type representing a claim verifier. A claim verifier may be a dictionary of acceptable claim values
or a callable which takes the claims dictionary. A claims verifier callable should raise
[`InvalidClaimsError`][federatedidentity.exceptions.InvalidClaimsError] if the claims do not match
the expected values.
"""


def verify_id_token(
token: Union[str, bytes],
valid_issuers: Iterable[_oidc.Issuer],
valid_audiences: Iterable[str],
*,
required_claims: Optional[Iterable[ClaimVerifier]] = None,
) -> dict[str, Any]:
"""
Verify an OIDC identity token.

Returns:
the token's claims dictionary.

Parameters:
token: OIDC token to verify. If a [bytes][] object is passed it is decoded using the ASCII
codec before verification.
valid_issuers: Iterable of valid issuers. At least one Issuer must match the token issuer
for verification to succeed.
valid_audiences: Iterable of valid audiences. At least one audience must match the `aud`
claim for verification to succeed.
required_claims: Iterable of required claim verifiers. Claims are passed to verifiers after
the token's signature has been verified. Claims required by OIDC are always
validated. All claim verifiers must pass for verification to succeed.

Raises:
federatedidentity.exceptions.FederatedIdentityError: The token failed verification.
UnicodeDecodeError: The token could not be decoded into an ASCII string.
"""
if isinstance(token, bytes):
token = token.decode("ascii")

unvalidated_claims = _oidc.unvalidated_claims_from_token(token)

# For required claims, see: https://openid.net/specs/openid-connect-core-1_0.html#IDToken
for claim in ["iss", "sub", "aud", "exp", "iat"]:
if claim not in unvalidated_claims:
raise InvalidClaimsError(f"'{claim}' claim not present in token")

# Check that the token "aud" claim matches at least one of our expected audiences.
if not any(unvalidated_claims["aud"] == audience for audience in valid_audiences):
raise InvalidClaimsError(
f"Token issuer '{unvalidated_claims['aud']}' did not match any valid issuer"
)

# Determine which issuer matches the token.
for issuer in valid_issuers:
if issuer.name == unvalidated_claims["iss"]:
break
else:
# No issuer matched the token if the for loop exited without "break".
raise InvalidClaimsError(
f"Token issuer '{unvalidated_claims['iss']}' did not match any valid issuer"
)

# Note: validate_token() validates "exp", "iat" and "nbf" claims and that the "alg" header has
# an appropriate value.
verified_claims = json.loads(_oidc.validate_token(token, issuer.key_set).claims)

required_claims = required_claims if required_claims is not None else []
for claims_verifier in required_claims:
if callable(claims_verifier):
claims_verifier(verified_claims)
else:
for claim, value in claims_verifier.items():
if claim not in verified_claims:
raise InvalidClaimsError(f"Required claim '{claim}' not present in token")
if verified_claims[claim] != value:
raise InvalidClaimsError(
f"Required claim '{claim}' has invalid value {value!r}"
)

return verified_claims
23 changes: 0 additions & 23 deletions federatedidentity/baseprovider.py

This file was deleted.

15 changes: 10 additions & 5 deletions federatedidentity/transport/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
"""
Generic HTTP transport base classes and utilities.
"""

import dataclasses
from abc import ABCMeta, abstractmethod
from typing import Mapping, Optional
from collections.abc import Mapping
from typing import Optional


@dataclasses.dataclass
Expand Down Expand Up @@ -36,8 +41,8 @@ def __call__(
The response from the resource server.

Raises:
TransportError: on any transport error such as DNS resolution failure. Note that error
status codes from the server do not raise.
federatedidentity.exceptions.TransportError: on any transport error such as DNS
resolution failure. Note that error status codes from the server do not raise.
"""


Expand Down Expand Up @@ -67,6 +72,6 @@ async def __call__(
The response from the resource server.

Raises:
TransportError: on any transport error such as DNS resolution failure. Note that error
status codes from the server do not raise.
federatedidentity.exceptions.TransportError: on any transport error such as DNS
resolution failure. Note that error status codes from the server do not raise.
"""
Loading