From 1dd9def4f2e97aeb7d546604de842d0718258715 Mon Sep 17 00:00:00 2001 From: Waket Zheng Date: Fri, 7 Feb 2025 20:09:16 +0800 Subject: [PATCH] fix: fastapi exception handlers not work (#1880) --- examples/fastapi/_tests.py | 34 +++++++ examples/fastapi/config.py | 1 - examples/fastapi/main.py | 9 +- examples/fastapi/main_custom_timezone.py | 1 + examples/fastapi/routers.py | 11 +++ tortoise/contrib/fastapi/__init__.py | 114 ++++++++++++----------- 6 files changed, 110 insertions(+), 60 deletions(-) diff --git a/examples/fastapi/_tests.py b/examples/fastapi/_tests.py index ff229ffdb..d46564783 100644 --- a/examples/fastapi/_tests.py +++ b/examples/fastapi/_tests.py @@ -96,6 +96,23 @@ async def test_user_list(self, client: AsyncClient) -> None: # nosec await self.user_list(client) +@pytest.mark.anyio +async def test_404(client: AsyncClient) -> None: + response = await client.get("/404") + assert response.status_code == 404, response.text + data = response.json() + assert isinstance(data["detail"], str) + + +@pytest.mark.anyio +async def test_422(client: AsyncClient) -> None: + response = await client.get("/422") + assert response.status_code == 422, response.text + data = response.json() + assert isinstance(data["detail"], list) + assert isinstance(data["detail"][0], dict) + + class TestUserEast(UserTester): timezone = "Asia/Shanghai" delta_hours = 8 @@ -123,6 +140,23 @@ async def test_user_list(self, client_east: AsyncClient) -> None: # nosec assert item.model_dump()["created_at"].hour == created_at.hour +@pytest.mark.anyio +async def test_404_east(client_east: AsyncClient) -> None: + response = await client_east.get("/404") + assert response.status_code == 404, response.text + data = response.json() + assert isinstance(data["detail"], str) + + +@pytest.mark.anyio +async def test_422_east(client_east: AsyncClient) -> None: + response = await client_east.get("/422") + assert response.status_code == 422, response.text + data = response.json() + assert isinstance(data["detail"], list) + assert isinstance(data["detail"][0], dict) + + def query_without_app(pk: int) -> int: async def runner() -> bool: async with register_orm(): diff --git a/examples/fastapi/config.py b/examples/fastapi/config.py index 60df75599..da0657289 100644 --- a/examples/fastapi/config.py +++ b/examples/fastapi/config.py @@ -8,5 +8,4 @@ db_url=os.getenv("DB_URL", "sqlite://db.sqlite3"), modules={"models": ["models"]}, generate_schemas=True, - add_exception_handlers=True, ) diff --git a/examples/fastapi/main.py b/examples/fastapi/main.py index 7aba0b88b..b4bd35e7f 100644 --- a/examples/fastapi/main.py +++ b/examples/fastapi/main.py @@ -8,7 +8,7 @@ from examples.fastapi.config import register_orm from tortoise import Tortoise, generate_config -from tortoise.contrib.fastapi import RegisterTortoise +from tortoise.contrib.fastapi import RegisterTortoise, tortoise_exception_handlers @asynccontextmanager @@ -23,7 +23,6 @@ async def lifespan_test(app: FastAPI) -> AsyncGenerator[None, None]: app=app, config=config, generate_schemas=True, - add_exception_handlers=True, _create_db=True, ): # db connected @@ -47,5 +46,9 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: # db connections closed -app = FastAPI(title="Tortoise ORM FastAPI example", lifespan=lifespan) +app = FastAPI( + title="Tortoise ORM FastAPI example", + lifespan=lifespan, + exception_handlers=tortoise_exception_handlers(), +) app.include_router(users_router, prefix="") diff --git a/examples/fastapi/main_custom_timezone.py b/examples/fastapi/main_custom_timezone.py index b9cc8b855..e75151155 100644 --- a/examples/fastapi/main_custom_timezone.py +++ b/examples/fastapi/main_custom_timezone.py @@ -14,6 +14,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: app, use_tz=False, timezone="Asia/Shanghai", + add_exception_handlers=True, ): # db connected yield diff --git a/examples/fastapi/routers.py b/examples/fastapi/routers.py index 1530af073..6813b1430 100644 --- a/examples/fastapi/routers.py +++ b/examples/fastapi/routers.py @@ -33,3 +33,14 @@ async def delete_user(user_id: int): if not deleted_count: raise HTTPException(status_code=404, detail=f"User {user_id} not found") return Status(message=f"Deleted user {user_id}") + + +@router.get("/404") +async def get_404(): + await Users.get(id=0) + + +@router.get("/422") +async def get_422(): + obj = await Users.create(username="foo") + await Users.create(username=obj.username) diff --git a/tortoise/contrib/fastapi/__init__.py b/tortoise/contrib/fastapi/__init__.py index 0c300f855..e9890a8b3 100644 --- a/tortoise/contrib/fastapi/__init__.py +++ b/tortoise/contrib/fastapi/__init__.py @@ -7,10 +7,6 @@ from types import ModuleType from typing import TYPE_CHECKING -from fastapi.responses import JSONResponse -from pydantic import BaseModel # pylint: disable=E0611 -from starlette.routing import _DefaultLifespan - from tortoise import Tortoise, connections from tortoise.exceptions import DoesNotExist, IntegrityError from tortoise.log import logger @@ -18,14 +14,29 @@ if TYPE_CHECKING: from fastapi import FastAPI, Request + if sys.version_info >= (3, 11): from typing import Self else: from typing_extensions import Self -class HTTPNotFoundError(BaseModel): - detail: str +def tortoise_exception_handlers() -> dict: + from fastapi.responses import JSONResponse + + async def doesnotexist_exception_handler(request: "Request", exc: DoesNotExist): + return JSONResponse(status_code=404, content={"detail": str(exc)}) + + async def integrityerror_exception_handler(request: "Request", exc: IntegrityError): + return JSONResponse( + status_code=422, + content={"detail": [{"loc": [], "msg": str(exc), "type": "IntegrityError"}]}, + ) + + return { + DoesNotExist: doesnotexist_exception_handler, + IntegrityError: integrityerror_exception_handler, + } class RegisterTortoise(AbstractAsyncContextManager): @@ -122,17 +133,22 @@ def __init__( self._create_db = _create_db if add_exception_handlers and app is not None: + from starlette.middleware.exceptions import ExceptionMiddleware + + warnings.warn( + "Setting `add_exception_handlers` to be true is deprecated, " + "use `FastAPI(exception_handlers=tortoise_exception_handlers())` instead." + "See more about it on https://tortoise.github.io/examples/fastapi", + DeprecationWarning, + ) + original_call_func = ExceptionMiddleware.__call__ - @app.exception_handler(DoesNotExist) - async def doesnotexist_exception_handler(request: "Request", exc: DoesNotExist): - return JSONResponse(status_code=404, content={"detail": str(exc)}) + async def wrap_middleware_call(self, *args, **kw) -> None: + if DoesNotExist not in self._exception_handlers: + self._exception_handlers.update(tortoise_exception_handlers()) + await original_call_func(self, *args, **kw) - @app.exception_handler(IntegrityError) - async def integrityerror_exception_handler(request: "Request", exc: IntegrityError): - return JSONResponse( - status_code=422, - content={"detail": [{"loc": [], "msg": str(exc), "type": "IntegrityError"}]}, - ) + ExceptionMiddleware.__call__ = wrap_middleware_call # type:ignore async def init_orm(self) -> None: # pylint: disable=W0612 await Tortoise.init( @@ -166,8 +182,7 @@ async def __aexit__(self, *args, **kw) -> None: def __await__(self) -> Generator[None, None, Self]: async def _self() -> Self: - await self.init_orm() - return self + return await self.__aenter__() return _self().__await__() @@ -182,8 +197,9 @@ def register_tortoise( add_exception_handlers: bool = False, ) -> None: """ - Registers ``startup`` and ``shutdown`` events to set-up and tear-down Tortoise-ORM - inside a FastAPI application. + Registers Tortoise-ORM with set-up at the beginning of FastAPI application's lifespan + (which allow user to read/write data from/to db inside the lifespan function), + and tear-down at the end of that lifespan. You can configure using only one of ``config``, ``config_file`` and ``(db_url, modules)``. @@ -245,40 +261,26 @@ def register_tortoise( ConfigurationError For any configuration error """ - orm = RegisterTortoise( - app, - config, - config_file, - db_url, - modules, - generate_schemas, - add_exception_handlers, - ) - if isinstance(lifespan := app.router.lifespan_context, _DefaultLifespan): - # Leave on_event here to compare with old versions - # So people can upgrade tortoise-orm in running project without changing any code - - @app.on_event("startup") - async def init_orm() -> None: # pylint: disable=W0612 - await orm.init_orm() - - @app.on_event("shutdown") - async def close_orm() -> None: # pylint: disable=W0612 - await orm.close_orm() - - else: - # If custom lifespan was passed to app, register tortoise in it - warnings.warn( - "`register_tortoise` function is deprecated, " - "use the `RegisterTortoise` class instead." - "See more about it on https://tortoise.github.io/examples/fastapi", - DeprecationWarning, - ) - - @asynccontextmanager - async def orm_lifespan(app_instance: "FastAPI"): - async with orm: - async with lifespan(app_instance): - yield - - app.router.lifespan_context = orm_lifespan + from fastapi.routing import _merge_lifespan_context + + # Leave this function here to compare with old versions + # So people can upgrade tortoise-orm in running project without changing any code + + @asynccontextmanager + async def orm_lifespan(app_instance: "FastAPI"): + async with RegisterTortoise( + app_instance, + config, + config_file, + db_url, + modules, + generate_schemas, + ): + yield + + original_lifespan = app.router.lifespan_context + app.router.lifespan_context = _merge_lifespan_context(orm_lifespan, original_lifespan) + + if add_exception_handlers: + for exp_type, endpoint in tortoise_exception_handlers().items(): + app.exception_handler(exp_type)(endpoint)