diff --git a/examples/fastapi_example.py b/examples/fastapi_example.py index fc5aa67..7adb5fd 100644 --- a/examples/fastapi_example.py +++ b/examples/fastapi_example.py @@ -6,7 +6,7 @@ import json import html -from http import HTTPStatus +import logging from typing import Optional, cast from fastapi import FastAPI, Form, Request, Depends, Response @@ -29,6 +29,13 @@ app.add_middleware(SessionMiddleware) +logging.basicConfig( + level=logging.DEBUG, + format="%(asctime)s,%(msecs)d %(levelname)s: %(message)s", + datefmt="%H:%M:%S", +) + + async def get_auth_server() -> AuthServer: """ initialize oauth authorization server @@ -74,8 +81,24 @@ async def authorize( oauth2 authorization endpoint using aioauth """ oauthreq = await to_request(request) + user = request.session.get("user", None) + response = await oauth.create_authorization_response(oauthreq) - if response.status_code == HTTPStatus.UNAUTHORIZED: + + # A demonstration example of request validation before checking the user's credentials. + # See a discussion here: https://github.com/aliev/aioauth/issues/101 + if response.status_code >= 400: + content = f""" + + +

{response.content['error']}

+

{response.content['description']}

+ + + """ + return HTMLResponse(content, status_code=response.status_code) + + if user is None: request.session["oauth"] = oauthreq return RedirectResponse("/login") return to_response(response) @@ -155,18 +178,29 @@ async def approve(request: Request): if "user" not in request.session: redirect = request.url_for("login") return RedirectResponse(redirect) - oauthreq: OAuthRequest = request.session["oauth"] - content = f""" - - -

{oauthreq.query.client_id} would like permissions.

-
- - -
- - - """ + + oauth = request.session.get("oauth", None) + if oauth: + oauthreq: OAuthRequest = request.session["oauth"] + content = f""" + + +

{oauthreq.query.client_id} would like permissions.

+
+ + +
+ + + """ + else: + content = f""" + + +

Hello, {request.session['user'].username}.

+ + + """ return HTMLResponse(content) diff --git a/examples/shared/__init__.py b/examples/shared/__init__.py index ec88b6e..7d0e4a7 100644 --- a/examples/shared/__init__.py +++ b/examples/shared/__init__.py @@ -20,7 +20,7 @@ "AuthServer", "BackendStore", "engine", - "config", + "app_config", "settings", "try_login", "lifespan", @@ -32,8 +32,8 @@ "sqlite+aiosqlite:///:memory:", echo=False, future=True ) -config = load_config(CONFIG_PATH) -settings = config.settings +app_config = load_config(CONFIG_PATH) +settings = app_config.settings async def try_login(username: str, password: str) -> Optional[User]: @@ -59,9 +59,9 @@ async def lifespan(*_): await conn.run_sync(SQLModel.metadata.create_all) # create test records async with AsyncSession(engine) as session: - for user in config.fixtures.users: + for user in app_config.fixtures.users: session.add(user) - for client in config.fixtures.clients: + for client in app_config.fixtures.clients: session.add(client) await session.commit() yield diff --git a/examples/shared/config.py b/examples/shared/config.py index ccf89b6..6f10c24 100644 --- a/examples/shared/config.py +++ b/examples/shared/config.py @@ -10,13 +10,6 @@ from .models import User, Client -def load_config(fpath: str) -> "Config": - """load configuration from filepath""" - with open(fpath, "r") as f: - json = f.read() - return Config.model_validate_json(json) - - class Fixtures(BaseModel): users: List[User] clients: List[Client] @@ -25,3 +18,10 @@ class Fixtures(BaseModel): class Config(BaseModel): fixtures: Fixtures settings: Settings + + +def load_config(fpath: str) -> Config: + """load configuration from filepath""" + with open(fpath, "r") as f: + json = f.read() + return Config.model_validate_json(json) diff --git a/examples/shared/storage.py b/examples/shared/storage.py index adfc50d..4d5f442 100644 --- a/examples/shared/storage.py +++ b/examples/shared/storage.py @@ -113,7 +113,11 @@ async def get_authorization_code( ) -> Optional[AuthorizationCode]: """ """ async with self.session: - sql = select(AuthCodeTable).where(AuthCodeTable.client_id == client_id) + sql = ( + select(AuthCodeTable) + .where(AuthCodeTable.client_id == client_id) + .where(AuthCodeTable.code == code) + ) result = (await self.session.exec(sql)).one_or_none() if result is not None: return AuthorizationCode( @@ -138,7 +142,11 @@ async def delete_authorization_code( ) -> None: """ """ async with self.session: - sql = select(AuthCodeTable).where(AuthCodeTable.client_id == client_id) + sql = ( + select(AuthCodeTable) + .where(AuthCodeTable.client_id == client_id) + .where(AuthCodeTable.code == code) + ) result = (await self.session.exec(sql)).one() await self.session.delete(result) await self.session.commit() diff --git a/tests/oidc/core/test_flow.py b/tests/oidc/core/test_flow.py index 6ef920c..d793a7f 100644 --- a/tests/oidc/core/test_flow.py +++ b/tests/oidc/core/test_flow.py @@ -1,5 +1,4 @@ from http import HTTPStatus -from typing import Optional import pytest @@ -8,27 +7,12 @@ generate_token, ) -from tests.classes import User from tests.utils import check_request_validators @pytest.mark.asyncio -@pytest.mark.parametrize( - "user, expected_status_code", - [ - ("username", HTTPStatus.FOUND), - (None, HTTPStatus.FOUND), - ], -) -async def test_authorization_endpoint_allows_prompt_query_param( - expected_status_code: HTTPStatus, - user: Optional[User], - context_factory, -): - if user is None: - context = context_factory() - else: - context = context_factory(users={user: "password"}) +async def test_authorization_endpoint_allows_prompt_query_param(context_factory): + context = context_factory() server = context.server client = context.clients[0] client_id = client.client_id @@ -52,4 +36,4 @@ async def test_authorization_endpoint_allows_prompt_query_param( await check_request_validators(request, server.create_authorization_response) response = await server.create_authorization_response(request) - assert response.status_code == expected_status_code + assert response.status_code == HTTPStatus.FOUND