diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 70bc4229..b71ac38d 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -15,7 +15,7 @@ Change log ``auth_logger_dependency`` returns the same logger as ``logger_dependency`` but with the ``user`` parameter bound to the username from ``auth_dependency``. - Add utility functions to initialize a database and create a sync or async session. The session creation functions optionally support a health check to ensure the database schema has been initialized. -- Add new FastAPI dependency ``db_session_dependency`` that creates a task-local async SQLAlchemy session and optionally manages the database transaction for that session. +- Add new FastAPI dependency ``db_session_dependency`` that creates a task-local async SQLAlchemy session. - Add utility functions ``datetime_from_db`` and ``datetime_to_db`` to convert between timezone-naive UTC datetimes stored in a database and timezone-aware UTC datetimes used elsewhere in a program. - Add a ``run_with_async`` decorator that runs the decorated async function synchronously. This is primarily useful for decorating Click command functions (for a command-line interface) that need to make async calls. diff --git a/docs/database.rst b/docs/database.rst index 8a931293..3c5a35b4 100644 --- a/docs/database.rst +++ b/docs/database.rst @@ -45,7 +45,7 @@ For applications using `Click`_ (the recommended way to implement a command-line import click import structlog from safir.asyncio import run_with_asyncio - from safir.database import initialize_database + from safir.database import create_database_engine, initialize_database from .config import config from .schema import Base @@ -61,12 +61,11 @@ For applications using `Click`_ (the recommended way to implement a command-line @run_with_asyncio async def init() -> None: logger = structlog.get_logger(config.logger_name) - engine = await initialize_database( - config.database_url, - config.database_password, - logger, - schema=Base.metadata, - reset=reset, + engine = create_database_engine( + config.database_url, config.database_password + ) + await initialize_database( + engine, logger, schema=Base.metadata, reset=reset ) await engine.dispose() @@ -158,15 +157,25 @@ Then, any handler that needs a database session can depend on the `~safir.depend async def get_index( session: async_scoped_session = Depends(db_session_dependency), ) -> Dict[str, str]: - # ... do something with session here ... - return {} + async with session.begin(): + # ... do something with session here ... + return {} -By default, the session returned by this dependency will be inside a transaction that will automatically be committed when the route handler returns. -This is normally the best way to write database code for a RESTful web application, since each request should be a single transaction. -However, be aware that this means you should call ``await session.flush()`` and not ``await session.commit()`` to make changes visible to subsequent database statements. +Transaction management +---------------------- -If you need to manage the transactions directly, disable automatic transaction management by passing ``manage_transactions=False`` to ``initialize`` during application startup. -The session returned by the dependency will then not have an open transaction, and you should put any database code inside an ``async with session.begin()`` block to create and commit a transaction. +The application must manage transactions when using the Safir database dependency. +SQLAlchemy will automatically start a transaction if you perform any database operation using a session (including read-only operations). +If that transaction is not explicitly ended, `asyncpg`_ may leave it open, which will cause database deadlocks and other problems. + +Generally it's best to manage the transaction in the handler function (see the ``get_index`` example, above). +Wrap all code that may make database calls in an ``async with session.begin()`` block. +This will open a transaction, commit the transaction at the end of the block, and roll back the transaction if the block raises an exception. + +.. note:: + + Due to an as-yet-unexplained interaction with FastAPI 0.74 and later, managing the transaction inside the database session dependency does not work. + Calling ``await session.commit()`` there, either explicitly or implicitly via a context manager, immediately fails by raising ``asyncio.CancelledError`` and the transaction is not committed or closed. Handling datetimes in database tables ===================================== @@ -290,7 +299,7 @@ For example: import pytest_asyncio from asgi_lifespan import LifespanManager from fastapi import FastAPI - from safir.database import initialize_database + from safir.database import create_database_engine, initialize_database from application import main from application.config import config @@ -300,13 +309,10 @@ For example: @pytest_asyncio.fixture async def app() -> AsyncIterator[FastAPI]: logger = structlog.get_logger(config.logger_name) - engine = await initialize_database( - config.database_url, - config.database_password, - logger, - schema=Base.metadata, - reset=True, + engine = create_database_engine( + config.database_url, config.database_password ) + await initialize_database(engine, logger, schema=Base.metadata, reset=True) await engine.dispose() async with LifespanManager(main.app): yield main.app diff --git a/src/safir/database.py b/src/safir/database.py index b0d22a98..59bb1d03 100644 --- a/src/safir/database.py +++ b/src/safir/database.py @@ -278,21 +278,18 @@ def create_sync_session( async def initialize_database( - url: str, - password: Optional[str], + engine: AsyncEngine, logger: BoundLogger, *, schema: MetaData, reset: bool = False, -) -> AsyncEngine: +) -> None: """Create and initialize a new database. Parameters ---------- - url : `str` - Database connection URL, not including the password. - password : `str` or `None` - Database connection password. + engine : `sqlalchemy.ext.asyncio.AsyncEngine` + Database engine to use. Create with `create_database_engine`. logger : ``structlog.stdlib.BoundLogger`` Logger used to report problems schema : `sqlalchemy.sql.schema.MetaData` @@ -306,14 +303,6 @@ async def initialize_database( Useful when running tests with an external database. Default is `False`. - Returns - ------- - engine : `sqlalchemy.ext.asyncio.AsyncEngine` - Database engine for the initialized database. This may be used by the - caller to perform any additional necessary database initialization not - included in the schema, such as adding default table rows. The engine - must then be closed with ``await engine.dispose()``. - Raises ------ DatabaseInitializationError @@ -322,7 +311,6 @@ async def initialize_database( """ success = False error = None - engine = create_database_engine(url, password) for _ in range(5): try: async with engine.begin() as conn: @@ -343,4 +331,3 @@ async def initialize_database( logger.error(msg) await engine.dispose() raise DatabaseInitializationError(error) - return engine diff --git a/src/safir/dependencies/db_session.py b/src/safir/dependencies/db_session.py index 4f1f0d5e..81670458 100644 --- a/src/safir/dependencies/db_session.py +++ b/src/safir/dependencies/db_session.py @@ -46,16 +46,11 @@ class DatabaseSessionDependency: def __init__(self) -> None: self._engine: Optional[AsyncEngine] = None + self._override_engine: Optional[AsyncEngine] = None self._session: Optional[async_scoped_session] = None - self._manage_transactions = True async def __call__(self) -> AsyncIterator[async_scoped_session]: - """Create a database session and open a transaction. - - By default, this implements a policy of one request equals one - transaction, which is closed when that request returns. To disable - managed transactions, pass ``manage_transactions=False`` to the - `initialize` method. + """Return the database session manager. Returns ------- @@ -63,11 +58,7 @@ async def __call__(self) -> AsyncIterator[async_scoped_session]: The newly-created session. """ assert self._session, "db_session_dependency not initialized" - if self._manage_transactions: - async with self._session.begin(): - yield self._session - else: - yield self._session + yield self._session # Following the recommendations in the SQLAlchemy documentation, each # session is scoped to a single web request. However, this all uses @@ -87,7 +78,6 @@ async def initialize( password: Optional[str], *, isolation_level: Optional[_IsolationLevel] = None, - manage_transactions: bool = True, ) -> None: """Initialize the session dependency. @@ -100,19 +90,29 @@ async def initialize( isolation_level : `str`, optional If specified, sets a non-default isolation level for the database engine. - manage_transactions : `bool`, optional - Whether the dependency should open a new transaction for each - request and commit that transaction at the end of the request. - This is the default behavior; to manage transactions manually, - set this parameter to `False`. (Disabling managed transactions - may be necessary if the application database code has to retry - failed transactions due to a non-default isolation level.) """ - self._manage_transactions = manage_transactions - self._engine = create_database_engine( - url, password, isolation_level=isolation_level - ) - self._session = await create_async_session(self._engine) + if self._override_engine: + self._session = await create_async_session(self._override_engine) + else: + self._engine = create_database_engine( + url, password, isolation_level=isolation_level + ) + self._session = await create_async_session(self._engine) + + def override_engine(self, engine: AsyncEngine) -> None: + """Force the dependency to use the provided engine. + + Intended for testing, this allows the test suite to configure a single + database engine and share it across all of the tests, benefiting from + connection pooling for a minor test speed-up. (This is not + significant enough to bother with except for an extensive test suite.) + + Parameters + ---------- + engine : `sqlalchemy.ext.asyncio.AsyncEngine` + Database engine to use for all sessions. + """ + self._override_engine = engine db_session_dependency = DatabaseSessionDependency() diff --git a/tests/database_test.py b/tests/database_test.py index 3f9d130c..bcea0e65 100644 --- a/tests/database_test.py +++ b/tests/database_test.py @@ -39,41 +39,23 @@ class User(Base): @pytest.mark.asyncio async def test_database_init() -> None: logger = structlog.get_logger(__name__) - engine = await initialize_database( - TEST_DATABASE_URL, - TEST_DATABASE_PASSWORD, - logger, - schema=Base.metadata, - reset=True, - ) + engine = create_database_engine(TEST_DATABASE_URL, TEST_DATABASE_PASSWORD) + await initialize_database(engine, logger, schema=Base.metadata, reset=True) session = await create_async_session(engine, logger) async with session.begin(): session.add(User(username="someuser")) await session.remove() - await engine.dispose() # Reinitializing the database without reset should preserve the row. - engine = await initialize_database( - TEST_DATABASE_URL, - TEST_DATABASE_PASSWORD, - logger, - schema=Base.metadata, - ) + await initialize_database(engine, logger, schema=Base.metadata) session = await create_async_session(engine, logger) async with session.begin(): result = await session.scalars(select(User.username)) assert result.all() == ["someuser"] await session.remove() - await engine.dispose() # Reinitializing the database with reset should delete the data. - engine = await initialize_database( - TEST_DATABASE_URL, - TEST_DATABASE_PASSWORD, - logger, - schema=Base.metadata, - reset=True, - ) + await initialize_database(engine, logger, schema=Base.metadata, reset=True) session = await create_async_session(engine, logger) async with session.begin(): result = await session.scalars(select(User.username)) @@ -105,16 +87,9 @@ def test_build_database_url() -> None: @pytest.mark.asyncio async def test_create_async_session() -> None: logger = structlog.get_logger(__name__) - engine = await initialize_database( - TEST_DATABASE_URL, - TEST_DATABASE_PASSWORD, - logger, - schema=Base.metadata, - reset=True, - ) - await engine.dispose() - engine = create_database_engine(TEST_DATABASE_URL, TEST_DATABASE_PASSWORD) + await initialize_database(engine, logger, schema=Base.metadata, reset=True) + session = await create_async_session( engine, logger, statement=select(User) ) @@ -136,13 +111,8 @@ async def test_create_async_session() -> None: @pytest.mark.asyncio async def test_create_sync_session() -> None: logger = structlog.get_logger(__name__) - engine = await initialize_database( - TEST_DATABASE_URL, - TEST_DATABASE_PASSWORD, - logger, - schema=Base.metadata, - reset=True, - ) + engine = create_database_engine(TEST_DATABASE_URL, TEST_DATABASE_PASSWORD) + await initialize_database(engine, logger, schema=Base.metadata, reset=True) await engine.dispose() session = create_sync_session( diff --git a/tests/dependencies/db_session_test.py b/tests/dependencies/db_session_test.py index 160f7e97..48d60da3 100644 --- a/tests/dependencies/db_session_test.py +++ b/tests/dependencies/db_session_test.py @@ -14,7 +14,11 @@ from sqlalchemy.future import select from sqlalchemy.orm import declarative_base -from safir.database import create_async_session, initialize_database +from safir.database import ( + create_async_session, + create_database_engine, + initialize_database, +) from safir.dependencies.db_session import db_session_dependency TEST_DATABASE_URL = os.environ["TEST_DATABASE_URL"] @@ -34,13 +38,8 @@ class User(Base): @pytest.mark.asyncio async def test_session() -> None: logger = structlog.get_logger(__name__) - engine = await initialize_database( - TEST_DATABASE_URL, - TEST_DATABASE_PASSWORD, - logger, - schema=Base.metadata, - reset=True, - ) + engine = create_database_engine(TEST_DATABASE_URL, TEST_DATABASE_PASSWORD) + await initialize_database(engine, logger, schema=Base.metadata, reset=True) session = await create_async_session(engine, logger) await db_session_dependency.initialize( TEST_DATABASE_URL, TEST_DATABASE_PASSWORD @@ -52,14 +51,16 @@ async def test_session() -> None: async def add( session: async_scoped_session = Depends(db_session_dependency), ) -> None: - session.add(User(username="foo")) + async with session.begin(): + session.add(User(username="foo")) @app.get("/list") async def list( session: async_scoped_session = Depends(db_session_dependency), ) -> List[str]: - result = await session.scalars(select(User.username)) - return result.all() + async with session.begin(): + result = await session.scalars(select(User.username)) + return result.all() async with AsyncClient(app=app, base_url="https://example.com") as client: r = await client.get("/list") @@ -82,48 +83,3 @@ async def list( await session.remove() await engine.dispose() await db_session_dependency.aclose() - - -@pytest.mark.asyncio -async def test_unmanaged_transactions() -> None: - logger = structlog.get_logger(__name__) - engine = await initialize_database( - TEST_DATABASE_URL, - TEST_DATABASE_PASSWORD, - logger, - schema=Base.metadata, - reset=True, - ) - await engine.dispose() - await db_session_dependency.initialize( - TEST_DATABASE_URL, TEST_DATABASE_PASSWORD, manage_transactions=False - ) - - app = FastAPI() - - @app.post("/add") - async def add( - session: async_scoped_session = Depends(db_session_dependency), - ) -> None: - # If a transaction was already started, this will fail, so it tests - # that automatic transactions were disabled. - async with session.begin(): - session.add(User(username="foo")) - - @app.get("/list") - async def list( - session: async_scoped_session = Depends(db_session_dependency), - ) -> List[str]: - async with session.begin(): - result = await session.scalars(select(User.username)) - return result.all() - - async with AsyncClient(app=app, base_url="https://example.com") as client: - r = await client.post("/add") - assert r.status_code == 200 - - r = await client.get("/list") - assert r.status_code == 200 - assert r.json() == ["foo"] - - await db_session_dependency.aclose()