Skip to content

Commit

Permalink
fix: fastapi exception handlers not work (#1880)
Browse files Browse the repository at this point in the history
  • Loading branch information
waketzheng authored Feb 7, 2025
1 parent 7cb24a5 commit 1dd9def
Show file tree
Hide file tree
Showing 6 changed files with 110 additions and 60 deletions.
34 changes: 34 additions & 0 deletions examples/fastapi/_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,23 @@ async def test_user_list(self, client: AsyncClient) -> None: # nosec
await self.user_list(client)


@pytest.mark.anyio
async def test_404(client: AsyncClient) -> None:
response = await client.get("/404")
assert response.status_code == 404, response.text
data = response.json()
assert isinstance(data["detail"], str)


@pytest.mark.anyio
async def test_422(client: AsyncClient) -> None:
response = await client.get("/422")
assert response.status_code == 422, response.text
data = response.json()
assert isinstance(data["detail"], list)
assert isinstance(data["detail"][0], dict)


class TestUserEast(UserTester):
timezone = "Asia/Shanghai"
delta_hours = 8
Expand Down Expand Up @@ -123,6 +140,23 @@ async def test_user_list(self, client_east: AsyncClient) -> None: # nosec
assert item.model_dump()["created_at"].hour == created_at.hour


@pytest.mark.anyio
async def test_404_east(client_east: AsyncClient) -> None:
response = await client_east.get("/404")
assert response.status_code == 404, response.text
data = response.json()
assert isinstance(data["detail"], str)


@pytest.mark.anyio
async def test_422_east(client_east: AsyncClient) -> None:
response = await client_east.get("/422")
assert response.status_code == 422, response.text
data = response.json()
assert isinstance(data["detail"], list)
assert isinstance(data["detail"][0], dict)


def query_without_app(pk: int) -> int:
async def runner() -> bool:
async with register_orm():
Expand Down
1 change: 0 additions & 1 deletion examples/fastapi/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,4 @@
db_url=os.getenv("DB_URL", "sqlite://db.sqlite3"),
modules={"models": ["models"]},
generate_schemas=True,
add_exception_handlers=True,
)
9 changes: 6 additions & 3 deletions examples/fastapi/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from examples.fastapi.config import register_orm
from tortoise import Tortoise, generate_config
from tortoise.contrib.fastapi import RegisterTortoise
from tortoise.contrib.fastapi import RegisterTortoise, tortoise_exception_handlers


@asynccontextmanager
Expand All @@ -23,7 +23,6 @@ async def lifespan_test(app: FastAPI) -> AsyncGenerator[None, None]:
app=app,
config=config,
generate_schemas=True,
add_exception_handlers=True,
_create_db=True,
):
# db connected
Expand All @@ -47,5 +46,9 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
# db connections closed


app = FastAPI(title="Tortoise ORM FastAPI example", lifespan=lifespan)
app = FastAPI(
title="Tortoise ORM FastAPI example",
lifespan=lifespan,
exception_handlers=tortoise_exception_handlers(),
)
app.include_router(users_router, prefix="")
1 change: 1 addition & 0 deletions examples/fastapi/main_custom_timezone.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
app,
use_tz=False,
timezone="Asia/Shanghai",
add_exception_handlers=True,
):
# db connected
yield
Expand Down
11 changes: 11 additions & 0 deletions examples/fastapi/routers.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,14 @@ async def delete_user(user_id: int):
if not deleted_count:
raise HTTPException(status_code=404, detail=f"User {user_id} not found")
return Status(message=f"Deleted user {user_id}")


@router.get("/404")
async def get_404():
await Users.get(id=0)


@router.get("/422")
async def get_422():
obj = await Users.create(username="foo")
await Users.create(username=obj.username)
114 changes: 58 additions & 56 deletions tortoise/contrib/fastapi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,36 @@
from types import ModuleType
from typing import TYPE_CHECKING

from fastapi.responses import JSONResponse
from pydantic import BaseModel # pylint: disable=E0611
from starlette.routing import _DefaultLifespan

from tortoise import Tortoise, connections
from tortoise.exceptions import DoesNotExist, IntegrityError
from tortoise.log import logger

if TYPE_CHECKING:
from fastapi import FastAPI, Request


if sys.version_info >= (3, 11):
from typing import Self
else:
from typing_extensions import Self


class HTTPNotFoundError(BaseModel):
detail: str
def tortoise_exception_handlers() -> dict:
from fastapi.responses import JSONResponse

async def doesnotexist_exception_handler(request: "Request", exc: DoesNotExist):
return JSONResponse(status_code=404, content={"detail": str(exc)})

async def integrityerror_exception_handler(request: "Request", exc: IntegrityError):
return JSONResponse(
status_code=422,
content={"detail": [{"loc": [], "msg": str(exc), "type": "IntegrityError"}]},
)

return {
DoesNotExist: doesnotexist_exception_handler,
IntegrityError: integrityerror_exception_handler,
}


class RegisterTortoise(AbstractAsyncContextManager):
Expand Down Expand Up @@ -122,17 +133,22 @@ def __init__(
self._create_db = _create_db

if add_exception_handlers and app is not None:
from starlette.middleware.exceptions import ExceptionMiddleware

warnings.warn(
"Setting `add_exception_handlers` to be true is deprecated, "
"use `FastAPI(exception_handlers=tortoise_exception_handlers())` instead."
"See more about it on https://tortoise.github.io/examples/fastapi",
DeprecationWarning,
)
original_call_func = ExceptionMiddleware.__call__

@app.exception_handler(DoesNotExist)
async def doesnotexist_exception_handler(request: "Request", exc: DoesNotExist):
return JSONResponse(status_code=404, content={"detail": str(exc)})
async def wrap_middleware_call(self, *args, **kw) -> None:
if DoesNotExist not in self._exception_handlers:
self._exception_handlers.update(tortoise_exception_handlers())
await original_call_func(self, *args, **kw)

@app.exception_handler(IntegrityError)
async def integrityerror_exception_handler(request: "Request", exc: IntegrityError):
return JSONResponse(
status_code=422,
content={"detail": [{"loc": [], "msg": str(exc), "type": "IntegrityError"}]},
)
ExceptionMiddleware.__call__ = wrap_middleware_call # type:ignore

async def init_orm(self) -> None: # pylint: disable=W0612
await Tortoise.init(
Expand Down Expand Up @@ -166,8 +182,7 @@ async def __aexit__(self, *args, **kw) -> None:

def __await__(self) -> Generator[None, None, Self]:
async def _self() -> Self:
await self.init_orm()
return self
return await self.__aenter__()

return _self().__await__()

Expand All @@ -182,8 +197,9 @@ def register_tortoise(
add_exception_handlers: bool = False,
) -> None:
"""
Registers ``startup`` and ``shutdown`` events to set-up and tear-down Tortoise-ORM
inside a FastAPI application.
Registers Tortoise-ORM with set-up at the beginning of FastAPI application's lifespan
(which allow user to read/write data from/to db inside the lifespan function),
and tear-down at the end of that lifespan.
You can configure using only one of ``config``, ``config_file``
and ``(db_url, modules)``.
Expand Down Expand Up @@ -245,40 +261,26 @@ def register_tortoise(
ConfigurationError
For any configuration error
"""
orm = RegisterTortoise(
app,
config,
config_file,
db_url,
modules,
generate_schemas,
add_exception_handlers,
)
if isinstance(lifespan := app.router.lifespan_context, _DefaultLifespan):
# Leave on_event here to compare with old versions
# So people can upgrade tortoise-orm in running project without changing any code

@app.on_event("startup")
async def init_orm() -> None: # pylint: disable=W0612
await orm.init_orm()

@app.on_event("shutdown")
async def close_orm() -> None: # pylint: disable=W0612
await orm.close_orm()

else:
# If custom lifespan was passed to app, register tortoise in it
warnings.warn(
"`register_tortoise` function is deprecated, "
"use the `RegisterTortoise` class instead."
"See more about it on https://tortoise.github.io/examples/fastapi",
DeprecationWarning,
)

@asynccontextmanager
async def orm_lifespan(app_instance: "FastAPI"):
async with orm:
async with lifespan(app_instance):
yield

app.router.lifespan_context = orm_lifespan
from fastapi.routing import _merge_lifespan_context

# Leave this function here to compare with old versions
# So people can upgrade tortoise-orm in running project without changing any code

@asynccontextmanager
async def orm_lifespan(app_instance: "FastAPI"):
async with RegisterTortoise(
app_instance,
config,
config_file,
db_url,
modules,
generate_schemas,
):
yield

original_lifespan = app.router.lifespan_context
app.router.lifespan_context = _merge_lifespan_context(orm_lifespan, original_lifespan)

if add_exception_handlers:
for exp_type, endpoint in tortoise_exception_handlers().items():
app.exception_handler(exp_type)(endpoint)

0 comments on commit 1dd9def

Please sign in to comment.