diff --git a/examples/fastapi_example.py b/examples/fastapi_example.py
index fc5aa67..7adb5fd 100644
--- a/examples/fastapi_example.py
+++ b/examples/fastapi_example.py
@@ -6,7 +6,7 @@
import json
import html
-from http import HTTPStatus
+import logging
from typing import Optional, cast
from fastapi import FastAPI, Form, Request, Depends, Response
@@ -29,6 +29,13 @@
app.add_middleware(SessionMiddleware)
+logging.basicConfig(
+ level=logging.DEBUG,
+ format="%(asctime)s,%(msecs)d %(levelname)s: %(message)s",
+ datefmt="%H:%M:%S",
+)
+
+
async def get_auth_server() -> AuthServer:
"""
initialize oauth authorization server
@@ -74,8 +81,24 @@ async def authorize(
oauth2 authorization endpoint using aioauth
"""
oauthreq = await to_request(request)
+ user = request.session.get("user", None)
+
response = await oauth.create_authorization_response(oauthreq)
- if response.status_code == HTTPStatus.UNAUTHORIZED:
+
+ # A demonstration example of request validation before checking the user's credentials.
+ # See a discussion here: https://github.com/aliev/aioauth/issues/101
+ if response.status_code >= 400:
+ content = f"""
+
+
+ {response.content['error']}
+ {response.content['description']}
+
+
+ """
+ return HTMLResponse(content, status_code=response.status_code)
+
+ if user is None:
request.session["oauth"] = oauthreq
return RedirectResponse("/login")
return to_response(response)
@@ -155,18 +178,29 @@ async def approve(request: Request):
if "user" not in request.session:
redirect = request.url_for("login")
return RedirectResponse(redirect)
- oauthreq: OAuthRequest = request.session["oauth"]
- content = f"""
-
-
- {oauthreq.query.client_id} would like permissions.
-
-
-
- """
+
+ oauth = request.session.get("oauth", None)
+ if oauth:
+ oauthreq: OAuthRequest = request.session["oauth"]
+ content = f"""
+
+
+ {oauthreq.query.client_id} would like permissions.
+
+
+
+ """
+ else:
+ content = f"""
+
+
+ Hello, {request.session['user'].username}.
+
+
+ """
return HTMLResponse(content)
diff --git a/examples/shared/__init__.py b/examples/shared/__init__.py
index ec88b6e..7d0e4a7 100644
--- a/examples/shared/__init__.py
+++ b/examples/shared/__init__.py
@@ -20,7 +20,7 @@
"AuthServer",
"BackendStore",
"engine",
- "config",
+ "app_config",
"settings",
"try_login",
"lifespan",
@@ -32,8 +32,8 @@
"sqlite+aiosqlite:///:memory:", echo=False, future=True
)
-config = load_config(CONFIG_PATH)
-settings = config.settings
+app_config = load_config(CONFIG_PATH)
+settings = app_config.settings
async def try_login(username: str, password: str) -> Optional[User]:
@@ -59,9 +59,9 @@ async def lifespan(*_):
await conn.run_sync(SQLModel.metadata.create_all)
# create test records
async with AsyncSession(engine) as session:
- for user in config.fixtures.users:
+ for user in app_config.fixtures.users:
session.add(user)
- for client in config.fixtures.clients:
+ for client in app_config.fixtures.clients:
session.add(client)
await session.commit()
yield
diff --git a/examples/shared/config.py b/examples/shared/config.py
index ccf89b6..6f10c24 100644
--- a/examples/shared/config.py
+++ b/examples/shared/config.py
@@ -10,13 +10,6 @@
from .models import User, Client
-def load_config(fpath: str) -> "Config":
- """load configuration from filepath"""
- with open(fpath, "r") as f:
- json = f.read()
- return Config.model_validate_json(json)
-
-
class Fixtures(BaseModel):
users: List[User]
clients: List[Client]
@@ -25,3 +18,10 @@ class Fixtures(BaseModel):
class Config(BaseModel):
fixtures: Fixtures
settings: Settings
+
+
+def load_config(fpath: str) -> Config:
+ """load configuration from filepath"""
+ with open(fpath, "r") as f:
+ json = f.read()
+ return Config.model_validate_json(json)
diff --git a/examples/shared/storage.py b/examples/shared/storage.py
index adfc50d..4d5f442 100644
--- a/examples/shared/storage.py
+++ b/examples/shared/storage.py
@@ -113,7 +113,11 @@ async def get_authorization_code(
) -> Optional[AuthorizationCode]:
""" """
async with self.session:
- sql = select(AuthCodeTable).where(AuthCodeTable.client_id == client_id)
+ sql = (
+ select(AuthCodeTable)
+ .where(AuthCodeTable.client_id == client_id)
+ .where(AuthCodeTable.code == code)
+ )
result = (await self.session.exec(sql)).one_or_none()
if result is not None:
return AuthorizationCode(
@@ -138,7 +142,11 @@ async def delete_authorization_code(
) -> None:
""" """
async with self.session:
- sql = select(AuthCodeTable).where(AuthCodeTable.client_id == client_id)
+ sql = (
+ select(AuthCodeTable)
+ .where(AuthCodeTable.client_id == client_id)
+ .where(AuthCodeTable.code == code)
+ )
result = (await self.session.exec(sql)).one()
await self.session.delete(result)
await self.session.commit()
diff --git a/tests/oidc/core/test_flow.py b/tests/oidc/core/test_flow.py
index 6ef920c..d793a7f 100644
--- a/tests/oidc/core/test_flow.py
+++ b/tests/oidc/core/test_flow.py
@@ -1,5 +1,4 @@
from http import HTTPStatus
-from typing import Optional
import pytest
@@ -8,27 +7,12 @@
generate_token,
)
-from tests.classes import User
from tests.utils import check_request_validators
@pytest.mark.asyncio
-@pytest.mark.parametrize(
- "user, expected_status_code",
- [
- ("username", HTTPStatus.FOUND),
- (None, HTTPStatus.FOUND),
- ],
-)
-async def test_authorization_endpoint_allows_prompt_query_param(
- expected_status_code: HTTPStatus,
- user: Optional[User],
- context_factory,
-):
- if user is None:
- context = context_factory()
- else:
- context = context_factory(users={user: "password"})
+async def test_authorization_endpoint_allows_prompt_query_param(context_factory):
+ context = context_factory()
server = context.server
client = context.clients[0]
client_id = client.client_id
@@ -52,4 +36,4 @@ async def test_authorization_endpoint_allows_prompt_query_param(
await check_request_validators(request, server.create_authorization_response)
response = await server.create_authorization_response(request)
- assert response.status_code == expected_status_code
+ assert response.status_code == HTTPStatus.FOUND