Skip to content

Commit

Permalink
fix: explicitly show the response error
Browse files Browse the repository at this point in the history
  • Loading branch information
aliev committed Jan 26, 2025
1 parent e380fe2 commit 7352502
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 47 deletions.
62 changes: 48 additions & 14 deletions examples/fastapi_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import json
import html
from http import HTTPStatus
import logging
from typing import Optional, cast

from fastapi import FastAPI, Form, Request, Depends, Response
Expand All @@ -29,6 +29,13 @@
app.add_middleware(SessionMiddleware)


logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s,%(msecs)d %(levelname)s: %(message)s",
datefmt="%H:%M:%S",
)


async def get_auth_server() -> AuthServer:
"""
initialize oauth authorization server
Expand Down Expand Up @@ -74,8 +81,24 @@ async def authorize(
oauth2 authorization endpoint using aioauth
"""
oauthreq = await to_request(request)
user = request.session.get("user", None)

response = await oauth.create_authorization_response(oauthreq)
if response.status_code == HTTPStatus.UNAUTHORIZED:

# A demonstration example of request validation before checking the user's credentials.
# See a discussion here: https://github.com/aliev/aioauth/issues/101
if response.status_code >= 400:
content = f"""
<html>
<body>
<h3>{response.content['error']}</h3>
<p style="color: red">{response.content['description']}</p>
</body>
</html>
"""
return HTMLResponse(content, status_code=response.status_code)

if user is None:
request.session["oauth"] = oauthreq
return RedirectResponse("/login")
return to_response(response)
Expand Down Expand Up @@ -155,18 +178,29 @@ async def approve(request: Request):
if "user" not in request.session:
redirect = request.url_for("login")
return RedirectResponse(redirect)
oauthreq: OAuthRequest = request.session["oauth"]
content = f"""
<html>
<body>
<h3>{oauthreq.query.client_id} would like permissions.</h3>
<form method="POST">
<button name="approval" value="0" type="submit">Deny</button>
<button name="approval" value="1" type="submit">Approve</button>
</form>
</body>
</html>
"""

oauth = request.session.get("oauth", None)
if oauth:
oauthreq: OAuthRequest = request.session["oauth"]
content = f"""
<html>
<body>
<h3>{oauthreq.query.client_id} would like permissions.</h3>
<form method="POST">
<button name="approval" value="0" type="submit">Deny</button>
<button name="approval" value="1" type="submit">Approve</button>
</form>
</body>
</html>
"""
else:
content = f"""
<html>
<body>
<h3>Hello, {request.session['user'].username}.</h3>
</body>
</html>
"""
return HTMLResponse(content)


Expand Down
10 changes: 5 additions & 5 deletions examples/shared/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
"AuthServer",
"BackendStore",
"engine",
"config",
"app_config",
"settings",
"try_login",
"lifespan",
Expand All @@ -32,8 +32,8 @@
"sqlite+aiosqlite:///:memory:", echo=False, future=True
)

config = load_config(CONFIG_PATH)
settings = config.settings
app_config = load_config(CONFIG_PATH)
settings = app_config.settings


async def try_login(username: str, password: str) -> Optional[User]:
Expand All @@ -59,9 +59,9 @@ async def lifespan(*_):
await conn.run_sync(SQLModel.metadata.create_all)
# create test records
async with AsyncSession(engine) as session:
for user in config.fixtures.users:
for user in app_config.fixtures.users:
session.add(user)
for client in config.fixtures.clients:
for client in app_config.fixtures.clients:
session.add(client)
await session.commit()
yield
Expand Down
14 changes: 7 additions & 7 deletions examples/shared/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,6 @@
from .models import User, Client


def load_config(fpath: str) -> "Config":
"""load configuration from filepath"""
with open(fpath, "r") as f:
json = f.read()
return Config.model_validate_json(json)


class Fixtures(BaseModel):
users: List[User]
clients: List[Client]
Expand All @@ -25,3 +18,10 @@ class Fixtures(BaseModel):
class Config(BaseModel):
fixtures: Fixtures
settings: Settings


def load_config(fpath: str) -> Config:
"""load configuration from filepath"""
with open(fpath, "r") as f:
json = f.read()
return Config.model_validate_json(json)
12 changes: 10 additions & 2 deletions examples/shared/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,11 @@ async def get_authorization_code(
) -> Optional[AuthorizationCode]:
""" """
async with self.session:
sql = select(AuthCodeTable).where(AuthCodeTable.client_id == client_id)
sql = (
select(AuthCodeTable)
.where(AuthCodeTable.client_id == client_id)
.where(AuthCodeTable.code == code)
)
result = (await self.session.exec(sql)).one_or_none()
if result is not None:
return AuthorizationCode(
Expand All @@ -138,7 +142,11 @@ async def delete_authorization_code(
) -> None:
""" """
async with self.session:
sql = select(AuthCodeTable).where(AuthCodeTable.client_id == client_id)
sql = (
select(AuthCodeTable)
.where(AuthCodeTable.client_id == client_id)
.where(AuthCodeTable.code == code)
)
result = (await self.session.exec(sql)).one()
await self.session.delete(result)
await self.session.commit()
Expand Down
22 changes: 3 additions & 19 deletions tests/oidc/core/test_flow.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from http import HTTPStatus
from typing import Optional

import pytest

Expand All @@ -8,27 +7,12 @@
generate_token,
)

from tests.classes import User
from tests.utils import check_request_validators


@pytest.mark.asyncio
@pytest.mark.parametrize(
"user, expected_status_code",
[
("username", HTTPStatus.FOUND),
(None, HTTPStatus.FOUND),
],
)
async def test_authorization_endpoint_allows_prompt_query_param(
expected_status_code: HTTPStatus,
user: Optional[User],
context_factory,
):
if user is None:
context = context_factory()
else:
context = context_factory(users={user: "password"})
async def test_authorization_endpoint_allows_prompt_query_param(context_factory):
context = context_factory()
server = context.server
client = context.clients[0]
client_id = client.client_id
Expand All @@ -52,4 +36,4 @@ async def test_authorization_endpoint_allows_prompt_query_param(
await check_request_validators(request, server.create_authorization_response)

response = await server.create_authorization_response(request)
assert response.status_code == expected_status_code
assert response.status_code == HTTPStatus.FOUND

0 comments on commit 7352502

Please sign in to comment.