diff --git a/hasql/asyncsqlalchemy.py b/hasql/asyncsqlalchemy.py index b409828..6263335 100644 --- a/hasql/asyncsqlalchemy.py +++ b/hasql/asyncsqlalchemy.py @@ -1,10 +1,11 @@ import asyncio -from typing import Sequence +from contextlib import _AsyncGeneratorContextManager, asynccontextmanager +from typing import Any, AsyncIterator, Callable, Dict, Optional, Sequence, Type -import sqlalchemy as sa # type: ignore -from sqlalchemy.ext.asyncio import AsyncConnection # type: ignore +import sqlalchemy as sa +from sqlalchemy.ext.asyncio import AsyncConnection, AsyncSession from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine -from sqlalchemy.pool import QueuePool # type: ignore +from sqlalchemy.pool import QueuePool from hasql.base import BasePoolManager from hasql.metrics import DriverMetrics @@ -13,7 +14,7 @@ class PoolManager(BasePoolManager): def get_pool_freesize(self, pool: AsyncEngine): - queue_pool: QueuePool = pool.sync_engine.pool + queue_pool: QueuePool = pool.sync_engine.pool # type: ignore return queue_pool.size() - queue_pool.checkedout() def acquire_from_pool(self, pool: AsyncEngine, **kwargs): @@ -70,4 +71,56 @@ def _driver_metrics(self) -> Sequence[DriverMetrics]: ] -__all__ = ("PoolManager",) +def async_sessionmaker( + pool_manager: PoolManager, + *, + class_: Type[AsyncSession] = AsyncSession, + autoflush: bool = True, + expire_on_commit: bool = True, + info: Optional[Dict[Any, Any]] = None, + acquire_kwargs: Optional[Dict[str, Any]] = None, + **kw: Any, +) -> Callable[..., _AsyncGeneratorContextManager]: + """Create async session maker with hasql pool support. + + This function replaces the default `async_sessionmaker` from + SQLAlchemy to work with the `PoolManager` class. It allows you to + create an async session that is bound to a connection acquired from + the pool. The session will automatically release the connection + back to the pool when the session is closed. + + You also can specify the session class to use with the `class_` + parameter, and you can customize the session's behavior with + parameters like `autoflush`, `expire_on_commit`, and `info`. + + Use the `acquire_kwargs` to pass additional parameters to the + `pool.acquire()` method. E.g. to create a session with replica + connection: + >>> ReplicaSession = async_sessionmaker( + >>> pool_manager, + >>> acquire_kwargs={"read_only": True} + >>> ) + + """ + if acquire_kwargs is None: + acquire_kwargs = {} + + @asynccontextmanager + async def create_async_session() -> AsyncIterator[AsyncSession]: + """Create an async session with connection from the pool.""" + async with ( + pool_manager.acquire(**acquire_kwargs) as connection, + class_( + bind=connection, + autoflush=autoflush, + expire_on_commit=expire_on_commit, + info=info, + **kw, + ) as session, + ): + yield session + + return create_async_session + + +__all__ = ("PoolManager", "async_sessionmaker") diff --git a/tests/test_asyncsqlalchemy.py b/tests/test_asyncsqlalchemy.py index a1e1acd..c0488c5 100644 --- a/tests/test_asyncsqlalchemy.py +++ b/tests/test_asyncsqlalchemy.py @@ -1,9 +1,11 @@ +from typing import Optional, Type + import mock import pytest import sqlalchemy as sa -from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine +from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine, AsyncSession -from hasql.asyncsqlalchemy import PoolManager +from hasql.asyncsqlalchemy import PoolManager, async_sessionmaker from hasql.metrics import DriverMetrics @@ -17,7 +19,6 @@ async def pool_manager(pg_dsn): try: await pg_pool.ready() yield pg_pool - pass finally: await pg_pool.close() @@ -73,3 +74,49 @@ async def test_metrics(pool_manager): assert pool_manager.metrics().drivers == [ DriverMetrics(max=11, min=0, idle=0, used=2, host=mock.ANY) ] + + +@pytest.mark.parametrize("expire_on_commit", [True, False, None]) +@pytest.mark.parametrize("autoflush", [True, False, None]) +@pytest.mark.parametrize("read_only", [True, False, None]) +@pytest.mark.parametrize("class_", [AsyncSession, None]) +async def test_async_sessionmaker( + pool_manager: PoolManager, + expire_on_commit: Optional[bool], + autoflush: Optional[bool], + read_only: Optional[bool], + class_: Optional[Type[AsyncSession]], +): + acquire_kwargs = {} + + if read_only is not None: + acquire_kwargs["read_only"] = read_only + + kwargs = {} + + if expire_on_commit is not None: + kwargs["expire_on_commit"] = expire_on_commit + + if autoflush is not None: + kwargs["autoflush"] = autoflush + + if class_ is not None: + kwargs["class_"] = class_ + + if acquire_kwargs: + kwargs["acquire_kwargs"] = acquire_kwargs + + session_factory = async_sessionmaker(pool_manager=pool_manager, **kwargs) + + async with session_factory() as session: # type: AsyncSession + result = await session.execute(sa.text("SELECT 1")) + assert result.scalar() == 1 + + if expire_on_commit is not None: + assert session.sync_session.expire_on_commit == expire_on_commit + + if autoflush is not None: + assert session.sync_session.autoflush == autoflush + + if class_ is not None: + assert isinstance(session, class_) diff --git a/tox.ini b/tox.ini index c887bc8..9ee1f7a 100644 --- a/tox.ini +++ b/tox.ini @@ -39,7 +39,7 @@ basepython = python3.10 usedevelop = true deps = - mypy==1.5.1 + mypy==1.17.1 commands = mypy --install-types --non-interactive hasql tests