diff --git a/example_configs/external_service/custom.py b/example_configs/external_service/custom.py index 01d528716..1320de1ad 100644 --- a/example_configs/external_service/custom.py +++ b/example_configs/external_service/custom.py @@ -1,25 +1,16 @@ import numpy +from pydantic import Secret from tiled.adapters.array import ArrayAdapter -from tiled.authenticators import Mode, UserSessionState from tiled.structures.core import StructureFamily -class Authenticator: - "This accepts any password and stashes it in session state as 'token'." - mode = Mode.password - - async def authenticate(self, username: str, password: str) -> UserSessionState: - return UserSessionState(username, {"token": password}) - - -# This stands in for a secret token issued by the external service. -SERVICE_ISSUED_TOKEN = "secret" - - class MockClient: - def __init__(self, base_url): - self.base_url = base_url + + def __init__(self, base_url: str, example_token: str = "secret"): + self._base_url = base_url + # This stands in for a secret token issued by the external service. + self._example_token = Secret(example_token) # This API (get_contents, get_metadata, get_data) is just made up and not important. # Could be anything. @@ -27,19 +18,19 @@ def __init__(self, base_url): async def get_metadata(self, url, token): # This assert stands in for the mocked service # authenticating a request. - assert token == SERVICE_ISSUED_TOKEN + assert token == self._example_token.get_secret_value() return {"metadata": str(url)} async def get_contents(self, url, token): # This assert stands in for the mocked service # authenticating a request. - assert token == SERVICE_ISSUED_TOKEN + assert token == self._example_token.get_secret_value() return ["a", "b", "c"] async def get_data(self, url, token): # This assert stands in for the mocked service # authenticating a request. - assert token == SERVICE_ISSUED_TOKEN + assert token == self._example_token.get_secret_value() return numpy.ones((3, 3)) diff --git a/example_configs/mock-oidc-server.yml b/example_configs/mock-oidc-server.yml index c8531f4cd..762e5ba58 100644 --- a/example_configs/mock-oidc-server.yml +++ b/example_configs/mock-oidc-server.yml @@ -9,7 +9,9 @@ authentication: client_secret: secret well_known_uri: http://localhost:8080/.well-known/openid-configuration trees: - # Just some arbitrary example data... - # The point of this example is the authenticaiton above. - - tree: tiled.examples.generated_minimal:tree - path: / + - path: / + tree: catalog + args: + uri: "sqlite+aiosqlite:///:memory:" + writable_storage: "/tmp/data" + init_if_not_exists: true diff --git a/tiled/authenticators.py b/tiled/authenticators.py index 759ea91e5..d838ee7fe 100644 --- a/tiled/authenticators.py +++ b/tiled/authenticators.py @@ -9,28 +9,24 @@ import httpx from fastapi import APIRouter, Request -from jose import JWTError, jwk, jwt +from jose import JWTError, jwt from pydantic import Secret from starlette.responses import RedirectResponse -from .server.authentication import Mode -from .server.protocols import UserSessionState +from .server.protocols import ExternalAuthenticator, UserSessionState, PasswordAuthenticator from .server.utils import get_root_url from .utils import modules_available logger = logging.getLogger(__name__) -class DummyAuthenticator: +class DummyAuthenticator(PasswordAuthenticator): """ For test and demo purposes only! Accept any username and any password. """ - - mode = Mode.password - def __init__(self, confirmation_message=""): self.confirmation_message = confirmation_message @@ -38,14 +34,12 @@ async def authenticate(self, username: str, password: str) -> UserSessionState: return UserSessionState(username, {}) -class DictionaryAuthenticator: +class DictionaryAuthenticator(PasswordAuthenticator): """ For test and demo purposes only! Check passwords from a dictionary of usernames mapped to passwords. """ - - mode = Mode.password configuration_schema = """ $schema": http://json-schema.org/draft-07/schema# type: object @@ -74,8 +68,7 @@ async def authenticate(self, username: str, password: str) -> UserSessionState: return UserSessionState(username, {}) -class PAMAuthenticator: - mode = Mode.password +class PAMAuthenticator(PasswordAuthenticator): configuration_schema = """ $schema": http://json-schema.org/draft-07/schema# type: object @@ -110,8 +103,7 @@ async def authenticate(self, username: str, password: str) -> UserSessionState: return UserSessionState(username, {}) -class OIDCAuthenticator: - mode = Mode.external +class OIDCAuthenticator(ExternalAuthenticator): configuration_schema = """ $schema": http://json-schema.org/draft-07/schema# type: object @@ -164,7 +156,7 @@ def jwks_uri(self) -> str: def token_endpoint(self) -> str: return cast(str, self._config_from_oidc_url.get("token_endpoint")) - async def authenticate(self, request: Request) -> UserSessionState: + async def authenticate(self, request: Request) -> UserSessionState | None: code = request.query_params["code"] # A proxy in the middle may make the request into something like # 'http://localhost:8000/...' so we fix the first part but keep @@ -228,8 +220,7 @@ async def exchange_code(token_uri, auth_code, client_id, client_secret, redirect return response -class SAMLAuthenticator: - mode = Mode.external +class SAMLAuthenticator(ExternalAuthenticator): def __init__( self, @@ -271,7 +262,7 @@ async def saml_login(request: Request): self.include_routers = [router] - async def authenticate(self, request) -> UserSessionState: + async def authenticate(self, request) -> UserSessionState | None: if not modules_available("onelogin"): raise ModuleNotFoundError( "This SAMLAuthenticator requires the module 'oneline' to be installed." @@ -323,7 +314,7 @@ async def prepare_saml_from_fastapi_request(request, debug=False): return rv -class LDAPAuthenticator: +class LDAPAuthenticator(PasswordAuthenticator): """ The authenticator code is based on https://github.com/jupyterhub/ldapauthenticator The parameter ``use_tls`` was added for convenience of testing. @@ -506,8 +497,6 @@ class LDAPAuthenticator: id: user02 """ - mode = Mode.password - def __init__( self, server_address, diff --git a/tiled/server/app.py b/tiled/server/app.py index 3a5846063..c308ced6e 100644 --- a/tiled/server/app.py +++ b/tiled/server/app.py @@ -10,7 +10,7 @@ from contextlib import asynccontextmanager from functools import lru_cache, partial from pathlib import Path -from typing import List +from typing import Any, Dict, List import anyio import packaging.version @@ -34,7 +34,8 @@ HTTP_500_INTERNAL_SERVER_ERROR, ) -from ..authenticators import Mode +from tiled.server.protocols import ExternalAuthenticator, PasswordAuthenticator + from ..config import construct_build_app_kwargs from ..media_type_registration import ( compression_registry as default_compression_registry, @@ -81,7 +82,7 @@ current_principal = contextvars.ContextVar("current_principal") -def custom_openapi(app: FastAPI): +def custom_openapi(app: FastAPI) -> Dict[str, Any]: """ The app's openapi method will be monkey-patched with this. @@ -118,7 +119,7 @@ def build_app( validation_registry=None, tasks=None, scalable=False, -): +) -> FastAPI: """ Serve a Tree @@ -385,12 +386,11 @@ async def unhandled_exception_handler( for spec in authentication["providers"]: provider = spec["provider"] authenticator = spec["authenticator"] - mode = authenticator.mode - if mode == Mode.password: + if isinstance(authenticator, PasswordAuthenticator): authentication_router.post(f"/provider/{provider}/token")( build_handle_credentials_route(authenticator, provider) ) - elif mode == Mode.external: + elif isinstance(authenticator, ExternalAuthenticator): # Client starts here to create a PendingSession. authentication_router.post(f"/provider/{provider}/authorize")( build_device_code_authorize_route(authenticator, provider) @@ -415,7 +415,7 @@ async def unhandled_exception_handler( # build_auth_code_route(authenticator, provider) # ) else: - raise ValueError(f"unknown authentication mode {mode}") + raise ValueError(f"Unexpected authenticator type {type(authenticator)}") for custom_router in getattr(authenticator, "include_routers", []): authentication_router.include_router( custom_router, prefix=f"/provider/{provider}" diff --git a/tiled/server/authentication.py b/tiled/server/authentication.py index 19b9ad9b7..f5c0b627e 100644 --- a/tiled/server/authentication.py +++ b/tiled/server/authentication.py @@ -1,12 +1,12 @@ -import enum import hashlib import secrets import uuid as uuid_module import warnings from datetime import datetime, timedelta from pathlib import Path -from typing import Optional +from typing import Optional, cast +import httpx import sqlalchemy.exc from fastapi import ( APIRouter, @@ -38,6 +38,8 @@ HTTP_409_CONFLICT, ) +from tiled.authenticators import OIDCAuthenticator + # To hide third-party warning # .../jose/backends/cryptography_backend.py:18: CryptographyDeprecationWarning: # int_from_bytes is deprecated, use int.from_bytes instead @@ -61,7 +63,7 @@ from ..utils import SHARE_TILED_PATH, SpecialUsers from . import schemas from .core import DEFAULT_PAGE_SIZE, MAX_PAGE_SIZE, json_or_msgpack -from .protocols import UsernamePasswordAuthenticator, UserSessionState +from .protocols import ExternalAuthenticator, PasswordAuthenticator, UserSessionState from .settings import Settings, get_settings from .utils import API_KEY_COOKIE_NAME, get_authenticators, get_base_url @@ -85,11 +87,6 @@ def utcnow(): return datetime.utcnow().replace(microsecond=0) -class Mode(enum.Enum): - password = "password" - external = "external" - - class Token(BaseModel): access_token: str token_type: str @@ -166,6 +163,10 @@ def create_refresh_token(session_id, secret_key, expires_delta): return encoded_jwt +def decode_oidc_token(token: str, authentictor: OIDCAuthenticator): + return jwt.decode(token, httpx.get(authentictor.jwks_uri), algorithms=[ALGORITHM]) + + def decode_token(token: str, secret_keys: list[str]): credentials_exception = HTTPException( status_code=HTTP_401_UNAUTHORIZED, @@ -177,12 +178,15 @@ def decode_token(token: str, secret_keys: list[str]): # fail. They supports key rotation. for secret_key in secret_keys: try: + """ DO NOT MERGE! """ + print(secret_key) # Remove this!!!!!! payload = jwt.decode(token, secret_key, algorithms=[ALGORITHM]) break except ExpiredSignatureError: # Do not let this be caught below with the other JWTError types. raise - except JWTError: + except JWTError as e: + print(e) # Try the next key in the key rotation. continue else: @@ -221,11 +225,15 @@ async def get_decoded_access_token( access_token: str = Depends(oauth2_scheme), settings: Settings = Depends(get_settings), ): - print("Got access_token") if not access_token: return None try: - payload = decode_token(access_token, settings.secret_keys) + print(settings.authenticator) + if isinstance(settings.authenticator, OIDCAuthenticator): + payload = decode_oidc_token(access_token, settings.authenticator) + print("proof of concept!") + else: + payload = decode_token(access_token, settings.secret_keys) except ExpiredSignatureError: raise HTTPException( status_code=HTTP_401_UNAUTHORIZED, @@ -510,7 +518,7 @@ async def create_tokens_from_session(settings: Settings, db, session, provider): } -def build_auth_code_route(authenticator, provider): +def build_auth_code_route(authenticator: ExternalAuthenticator, provider): "Build an auth_code route function for this Authenticator." async def route( @@ -524,6 +532,7 @@ async def route( raise HTTPException( status_code=HTTP_401_UNAUTHORIZED, detail="Authentication failure" ) + user_session_state = cast(UserSessionState, user_session_state) session = await create_session( settings, db, @@ -537,7 +546,7 @@ async def route( return route -def build_device_code_authorize_route(authenticator, provider): +def build_device_code_authorize_route(authenticator: ExternalAuthenticator, provider): "Build an /authorize route function for this Authenticator." async def route( @@ -571,7 +580,7 @@ async def route( return route -def build_device_code_user_code_form_route(authentication, provider): +def build_device_code_user_code_form_route(authentication: ExternalAuthenticator, provider): if not SHARE_TILED_PATH: raise Exception( "Static assets could not be found and are required for " @@ -598,7 +607,7 @@ async def route( return route -def build_device_code_user_code_submit_route(authenticator, provider): +def build_device_code_user_code_submit_route(authenticator: ExternalAuthenticator, provider): "Build an /authorize route function for this Authenticator." if not SHARE_TILED_PATH: @@ -670,7 +679,7 @@ async def route( return route -def build_device_code_token_route(authenticator, provider): +def build_device_code_token_route(authenticator: ExternalAuthenticator, provider): "Build an /authorize route function for this Authenticator." async def route( @@ -711,7 +720,7 @@ async def route( def build_handle_credentials_route( - authenticator: UsernamePasswordAuthenticator, provider + authenticator: PasswordAuthenticator, provider ): "Register a handle_credentials route function for this Authenticator." @@ -988,7 +997,11 @@ async def revoke_session( ): "Mark a Session as revoked so it cannot be refreshed again." request.state.endpoint = "auth" - payload = decode_token(refresh_token.refresh_token, settings.secret_keys) + if isinstance(settings.authenticator, OIDCAuthenticator): + payload = decode_oidc_token(refresh_token.refresh_token, settings.authenticator) + print("proof of concept!") + else: + payload = decode_token(refresh_token.refresh_token, settings.secret_keys) session_id = payload["sid"] # Find this session in the database. session = await lookup_valid_session(db, session_id) @@ -1027,7 +1040,11 @@ async def revoke_session_by_id( async def slide_session(refresh_token, settings: Settings, db): try: - payload = decode_token(refresh_token, settings.secret_keys) + if isinstance(settings.authenticator, OIDCAuthenticator): + payload = decode_oidc_token(refresh_token, settings.authenticator) + print("proof of concept!") + else: + payload = decode_token(refresh_token, settings.secret_keys) except ExpiredSignatureError: raise HTTPException( status_code=HTTP_401_UNAUTHORIZED, diff --git a/tiled/server/protocols.py b/tiled/server/protocols.py index 232a5145a..728f40d00 100644 --- a/tiled/server/protocols.py +++ b/tiled/server/protocols.py @@ -1,5 +1,6 @@ +from abc import abstractmethod, ABC from dataclasses import dataclass -from typing import Protocol +from typing import Protocol, runtime_checkable from fastapi import Request @@ -10,13 +11,20 @@ class UserSessionState: user_name: str state: dict = None + + +@runtime_checkable # Required to be a field on a BaseSettings +class Authenticator(Protocol): + ... -class UsernamePasswordAuthenticator(Protocol): - def authenticate(self, username: str, password: str) -> UserSessionState: +class PasswordAuthenticator(Authenticator, ABC): + @abstractmethod + def authenticate(self, username: str, password: str) -> UserSessionState | None: pass -class Authenticator(Protocol): - def authenticate(self, request: Request) -> UserSessionState: +class ExternalAuthenticator(Authenticator, ABC): + @abstractmethod + def authenticate(self, request: Request) -> UserSessionState | None: pass diff --git a/tiled/server/router.py b/tiled/server/router.py index 050baf62f..7e5ef7097 100644 --- a/tiled/server/router.py +++ b/tiled/server/router.py @@ -26,12 +26,14 @@ HTTP_422_UNPROCESSABLE_ENTITY, ) +from tiled.server.protocols import ExternalAuthenticator, PasswordAuthenticator + from .. import __version__ from ..structures.core import Spec, StructureFamily from ..utils import ensure_awaitable, patch_mimetypes, path_from_uri from ..validation_registration import ValidationError from . import schemas -from .authentication import Mode, get_authenticators, get_current_principal +from .authentication import get_authenticators, get_current_principal from .core import ( DEFAULT_PAGE_SIZE, DEPTH_LIMIT, @@ -90,10 +92,9 @@ async def about( } provider_specs = [] for provider, authenticator in authenticators.items(): - if authenticator.mode == Mode.password: + if isinstance(authenticator, PasswordAuthenticator): spec = { "provider": provider, - "mode": authenticator.mode.value, "links": { "auth_endpoint": f"{base_url}/auth/provider/{provider}/token" }, @@ -101,10 +102,9 @@ async def about( authenticator, "confirmation_message", None ), } - elif authenticator.mode == Mode.external: + elif isinstance(authenticator, ExternalAuthenticator): spec = { "provider": provider, - "mode": authenticator.mode.value, "links": { "auth_endpoint": f"{base_url}/auth/provider/{provider}/authorize" }, diff --git a/tiled/server/settings.py b/tiled/server/settings.py index e68c9b283..535052784 100644 --- a/tiled/server/settings.py +++ b/tiled/server/settings.py @@ -7,6 +7,8 @@ from pydantic_settings import BaseSettings +from tiled.server.protocols import Authenticator + DatabaseSettings = collections.namedtuple( "DatabaseSettings", "uri pool_size pool_pre_ping max_overflow" ) @@ -20,7 +22,7 @@ class Settings(BaseSettings): allow_origins: List[str] = [ item for item in os.getenv("TILED_ALLOW_ORIGINS", "").split() if item ] - authenticator: Any = None + authenticator: Authenticator | None = None # These 'single user' settings are only applicable if authenticator is None. single_user_api_key: str = os.getenv( "TILED_SINGLE_USER_API_KEY", secrets.token_hex(32)