Skip to content

Commit

Permalink
Refactoring and type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
DiamondJoseph committed Jan 14, 2025
1 parent cd69a67 commit 920ceab
Show file tree
Hide file tree
Showing 8 changed files with 90 additions and 81 deletions.
27 changes: 9 additions & 18 deletions example_configs/external_service/custom.py
Original file line number Diff line number Diff line change
@@ -1,45 +1,36 @@
import numpy
from pydantic import Secret

from tiled.adapters.array import ArrayAdapter
from tiled.authenticators import Mode, UserSessionState
from tiled.structures.core import StructureFamily


class Authenticator:
"This accepts any password and stashes it in session state as 'token'."
mode = Mode.password

async def authenticate(self, username: str, password: str) -> UserSessionState:
return UserSessionState(username, {"token": password})


# This stands in for a secret token issued by the external service.
SERVICE_ISSUED_TOKEN = "secret"


class MockClient:
def __init__(self, base_url):
self.base_url = base_url

def __init__(self, base_url: str, example_token: str = "secret"):
self._base_url = base_url
# This stands in for a secret token issued by the external service.
self._example_token = Secret(example_token)

# This API (get_contents, get_metadata, get_data) is just made up and not important.
# Could be anything.

async def get_metadata(self, url, token):
# This assert stands in for the mocked service
# authenticating a request.
assert token == SERVICE_ISSUED_TOKEN
assert token == self._example_token.get_secret_value()
return {"metadata": str(url)}

async def get_contents(self, url, token):
# This assert stands in for the mocked service
# authenticating a request.
assert token == SERVICE_ISSUED_TOKEN
assert token == self._example_token.get_secret_value()
return ["a", "b", "c"]

async def get_data(self, url, token):
# This assert stands in for the mocked service
# authenticating a request.
assert token == SERVICE_ISSUED_TOKEN
assert token == self._example_token.get_secret_value()
return numpy.ones((3, 3))


Expand Down
10 changes: 6 additions & 4 deletions example_configs/mock-oidc-server.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ authentication:
client_secret: secret
well_known_uri: http://localhost:8080/.well-known/openid-configuration
trees:
# Just some arbitrary example data...
# The point of this example is the authenticaiton above.
- tree: tiled.examples.generated_minimal:tree
path: /
- path: /
tree: catalog
args:
uri: "sqlite+aiosqlite:///:memory:"
writable_storage: "/tmp/data"
init_if_not_exists: true
31 changes: 10 additions & 21 deletions tiled/authenticators.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,43 +9,37 @@

import httpx
from fastapi import APIRouter, Request
from jose import JWTError, jwk, jwt
from jose import JWTError, jwt
from pydantic import Secret
from starlette.responses import RedirectResponse

from .server.authentication import Mode
from .server.protocols import UserSessionState
from .server.protocols import ExternalAuthenticator, UserSessionState, PasswordAuthenticator
from .server.utils import get_root_url
from .utils import modules_available

logger = logging.getLogger(__name__)


class DummyAuthenticator:
class DummyAuthenticator(PasswordAuthenticator):
"""
For test and demo purposes only!
Accept any username and any password.
"""

mode = Mode.password

def __init__(self, confirmation_message=""):
self.confirmation_message = confirmation_message

async def authenticate(self, username: str, password: str) -> UserSessionState:
return UserSessionState(username, {})


class DictionaryAuthenticator:
class DictionaryAuthenticator(PasswordAuthenticator):
"""
For test and demo purposes only!
Check passwords from a dictionary of usernames mapped to passwords.
"""

mode = Mode.password
configuration_schema = """
$schema": http://json-schema.org/draft-07/schema#
type: object
Expand Down Expand Up @@ -74,8 +68,7 @@ async def authenticate(self, username: str, password: str) -> UserSessionState:
return UserSessionState(username, {})


class PAMAuthenticator:
mode = Mode.password
class PAMAuthenticator(PasswordAuthenticator):
configuration_schema = """
$schema": http://json-schema.org/draft-07/schema#
type: object
Expand Down Expand Up @@ -110,8 +103,7 @@ async def authenticate(self, username: str, password: str) -> UserSessionState:
return UserSessionState(username, {})


class OIDCAuthenticator:
mode = Mode.external
class OIDCAuthenticator(ExternalAuthenticator):
configuration_schema = """
$schema": http://json-schema.org/draft-07/schema#
type: object
Expand Down Expand Up @@ -164,7 +156,7 @@ def jwks_uri(self) -> str:
def token_endpoint(self) -> str:
return cast(str, self._config_from_oidc_url.get("token_endpoint"))

async def authenticate(self, request: Request) -> UserSessionState:
async def authenticate(self, request: Request) -> UserSessionState | None:
code = request.query_params["code"]
# A proxy in the middle may make the request into something like
# 'http://localhost:8000/...' so we fix the first part but keep
Expand Down Expand Up @@ -228,8 +220,7 @@ async def exchange_code(token_uri, auth_code, client_id, client_secret, redirect
return response


class SAMLAuthenticator:
mode = Mode.external
class SAMLAuthenticator(ExternalAuthenticator):

def __init__(
self,
Expand Down Expand Up @@ -271,7 +262,7 @@ async def saml_login(request: Request):

self.include_routers = [router]

async def authenticate(self, request) -> UserSessionState:
async def authenticate(self, request) -> UserSessionState | None:
if not modules_available("onelogin"):
raise ModuleNotFoundError(
"This SAMLAuthenticator requires the module 'oneline' to be installed."
Expand Down Expand Up @@ -323,7 +314,7 @@ async def prepare_saml_from_fastapi_request(request, debug=False):
return rv


class LDAPAuthenticator:
class LDAPAuthenticator(PasswordAuthenticator):
"""
The authenticator code is based on https://github.com/jupyterhub/ldapauthenticator
The parameter ``use_tls`` was added for convenience of testing.
Expand Down Expand Up @@ -506,8 +497,6 @@ class LDAPAuthenticator:
id: user02
"""

mode = Mode.password

def __init__(
self,
server_address,
Expand Down
16 changes: 8 additions & 8 deletions tiled/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from contextlib import asynccontextmanager
from functools import lru_cache, partial
from pathlib import Path
from typing import List
from typing import Any, Dict, List

import anyio
import packaging.version
Expand All @@ -34,7 +34,8 @@
HTTP_500_INTERNAL_SERVER_ERROR,
)

from ..authenticators import Mode
from tiled.server.protocols import ExternalAuthenticator, PasswordAuthenticator

from ..config import construct_build_app_kwargs
from ..media_type_registration import (
compression_registry as default_compression_registry,
Expand Down Expand Up @@ -81,7 +82,7 @@
current_principal = contextvars.ContextVar("current_principal")


def custom_openapi(app: FastAPI):
def custom_openapi(app: FastAPI) -> Dict[str, Any]:
"""
The app's openapi method will be monkey-patched with this.
Expand Down Expand Up @@ -118,7 +119,7 @@ def build_app(
validation_registry=None,
tasks=None,
scalable=False,
):
) -> FastAPI:
"""
Serve a Tree
Expand Down Expand Up @@ -385,12 +386,11 @@ async def unhandled_exception_handler(
for spec in authentication["providers"]:
provider = spec["provider"]
authenticator = spec["authenticator"]
mode = authenticator.mode
if mode == Mode.password:
if isinstance(authenticator, PasswordAuthenticator):
authentication_router.post(f"/provider/{provider}/token")(
build_handle_credentials_route(authenticator, provider)
)
elif mode == Mode.external:
elif isinstance(authenticator, ExternalAuthenticator):
# Client starts here to create a PendingSession.
authentication_router.post(f"/provider/{provider}/authorize")(
build_device_code_authorize_route(authenticator, provider)
Expand All @@ -415,7 +415,7 @@ async def unhandled_exception_handler(
# build_auth_code_route(authenticator, provider)
# )
else:
raise ValueError(f"unknown authentication mode {mode}")
raise ValueError(f"Unexpected authenticator type {type(authenticator)}")
for custom_router in getattr(authenticator, "include_routers", []):
authentication_router.include_router(
custom_router, prefix=f"/provider/{provider}"
Expand Down
Loading

0 comments on commit 920ceab

Please sign in to comment.