Skip to content

Commit

Permalink
Add modified OIDC authenticator
Browse files Browse the repository at this point in the history
  • Loading branch information
DiamondJoseph committed Jan 10, 2025
1 parent cc21998 commit 52695e2
Showing 1 changed file with 40 additions and 49 deletions.
89 changes: 40 additions & 49 deletions tiled/authenticators.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
import re
import secrets
from collections.abc import Iterable
from typing import Any, cast

import httpx
from fastapi import APIRouter, Request
from jose import JWTError, jwk, jwt
from pydantic import Secret
from starlette.responses import RedirectResponse

from .server.authentication import Mode
Expand Down Expand Up @@ -107,7 +109,6 @@ async def authenticate(self, username: str, password: str) -> UserSessionState:
else:
return UserSessionState(username, {})


class OIDCAuthenticator:
mode = Mode.external
configuration_schema = """
Expand All @@ -119,63 +120,53 @@ class OIDCAuthenticator:
type: string
client_secret:
type: string
token_uri:
type: string
authorization_endpoint:
well_known_uri:
type: string
public_keys:
type: array
item:
type: object
properties:
- alg:
type: string
- e
type: string
- kid
type: string
- kty
type: string
- n
type: string
- use
type: string
required:
- alg
- e
- kid
- kty
- n
- use
confirmation_message:
type: string
description: May be displayed by client after successful login.
"""

def __init__(
self,
client_id,
client_secret,
public_keys,
token_uri,
authorization_endpoint,
confirmation_message,
client_id: str,
client_secret: str,
well_known_uri: str
):
self.client_id = client_id
self.client_secret = client_secret
self.confirmation_message = confirmation_message
self.public_keys = public_keys
self.token_uri = token_uri
self.authorization_endpoint = httpx.URL(authorization_endpoint)
self._client_id = client_id
self._client_secret = Secret(client_secret)
self.well_known_url = well_known_uri

@functools.cached_property
def _config_from_oidc_url(self) -> dict[str, Any]:
response: httpx.Response = httpx.get(self.well_known_url)
response.raise_for_status()
return response.json()

@functools.cached_property
def token_endpoint(self) -> str:
return cast(str, self._config_from_oidc_url.get("token_endpoint"))

@functools.cached_property
def jwks_uri(self) -> str:
return cast(str, self._config_from_oidc_url.get("jwks_uri"))

@functools.cached_property
def id_token_signing_alg_values_supported(self) -> list[str]:
return cast(
list[str],
self._config_from_oidc_url.get("id_token_signing_alg_values_supported"),
)

async def authenticate(self, request) -> UserSessionState:
async def authenticate(self, request: Request) -> UserSessionState:
code = request.query_params["code"]
# A proxy in the middle may make the request into something like
# 'http://localhost:8000/...' so we fix the first part but keep
# the original URI path.
redirect_uri = f"{get_root_url(request)}{request.url.path}"
response = await exchange_code(
self.token_uri, code, self.client_id, self.client_secret, redirect_uri
self.token_endpoint,
code,
self._client_id,
self._client_secret.get_secret_value(),
redirect_uri
)
response_body = response.json()
if response.is_error:
Expand All @@ -184,11 +175,12 @@ async def authenticate(self, request) -> UserSessionState:
response_body = response.json()
id_token = response_body["id_token"]
access_token = response_body["access_token"]
# Match the kid in id_token to a key in the list of public_keys.
key = find_key(id_token, self.public_keys)
keys = request.get(self.jwks_uri)
try:
verified_body = jwt.decode(
id_token, key, access_token=access_token, audience=self.client_id
access_token,
keys,
algorithms=self.id_token_signing_alg_values_supported,
)
except JWTError:
logger.exception(
Expand All @@ -198,7 +190,6 @@ async def authenticate(self, request) -> UserSessionState:
return None
return UserSessionState(verified_body["sub"], {})


class KeyNotFoundError(Exception):
pass

Expand Down

0 comments on commit 52695e2

Please sign in to comment.