From a77c855d56e21d2c8a414380f507e8a7d7818ada Mon Sep 17 00:00:00 2001 From: Oleg A Date: Wed, 30 Jul 2025 14:26:39 +0300 Subject: [PATCH 1/6] Add async_sessionmaker helper for sqlalchemy --- hasql/asyncsqlalchemy.py | 58 +++++++++++++++++++++++++++++++++-- tests/test_asyncsqlalchemy.py | 48 +++++++++++++++++++++++++++-- 2 files changed, 101 insertions(+), 5 deletions(-) diff --git a/hasql/asyncsqlalchemy.py b/hasql/asyncsqlalchemy.py index b409828..ed485e6 100644 --- a/hasql/asyncsqlalchemy.py +++ b/hasql/asyncsqlalchemy.py @@ -1,8 +1,9 @@ import asyncio -from typing import Sequence +from contextlib import _AsyncGeneratorContextManager, asynccontextmanager +from typing import Any, AsyncIterator, Callable, Dict, Optional, Sequence, Type import sqlalchemy as sa # type: ignore -from sqlalchemy.ext.asyncio import AsyncConnection # type: ignore +from sqlalchemy.ext.asyncio import AsyncConnection, AsyncSession # type: ignore from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine from sqlalchemy.pool import QueuePool # type: ignore @@ -70,4 +71,55 @@ def _driver_metrics(self) -> Sequence[DriverMetrics]: ] -__all__ = ("PoolManager",) +def async_sessionmaker( + pool_manager: PoolManager, + *, + class_: Type[AsyncSession] = AsyncSession, + autoflush: bool = True, + expire_on_commit: bool = True, + info: Optional[Dict[Any, Any]] = None, + acquire_kwargs: Optional[Dict[str, Any]] = None, + **kw: Any, +) -> Callable[..., _AsyncGeneratorContextManager]: + """Create async session maker with hasql pool support. + + This function replaces the default `async_sessionmaker` from + SQLAlchemy to work with the `PoolManager` class. It allows you to + create an async session that is bound to a connection acquired from + the pool. The session will automatically release the connection + back to the pool when the session is closed. + + You also can specify the session class to use with the `class_` + parameter, and you can customize the session's behavior with + parameters like `autoflush`, `expire_on_commit`, and `info`. + + Use the `acquire_kwargs` to pass additional parameters to the + `pool.acquire()` method. E.g. to create a session with replica + connection: + >>> ReplicaSession = async_sessionmaker( + >>> pool_manager, + >>> acquire_kwargs={"read_only": True} + >>> ) + + """ + if acquire_kwargs is None: + acquire_kwargs = {} + + @asynccontextmanager + async def create_async_session() -> AsyncIterator[AsyncSession]: + """Create an async session with connection from the pool.""" + # TODO(PY310): Use parentheses to break the statement in multiple lines + async with pool_manager.acquire(**acquire_kwargs) as connection, \ + class_( + bind=connection, + autoflush=autoflush, + expire_on_commit=expire_on_commit, + info=info, + **kw, + ) as session: + yield session + + return create_async_session + + +__all__ = ("PoolManager", "async_sessionmaker") diff --git a/tests/test_asyncsqlalchemy.py b/tests/test_asyncsqlalchemy.py index a1e1acd..b1040bf 100644 --- a/tests/test_asyncsqlalchemy.py +++ b/tests/test_asyncsqlalchemy.py @@ -1,9 +1,11 @@ +from typing import Optional + import mock import pytest import sqlalchemy as sa -from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine +from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine, AsyncSession -from hasql.asyncsqlalchemy import PoolManager +from hasql.asyncsqlalchemy import PoolManager, async_sessionmaker from hasql.metrics import DriverMetrics @@ -73,3 +75,45 @@ async def test_metrics(pool_manager): assert pool_manager.metrics().drivers == [ DriverMetrics(max=11, min=0, idle=0, used=2, host=mock.ANY) ] + + +@pytest.mark.parametrize("expire_on_commit", [True, False, None]) +@pytest.mark.parametrize("autoflush", [True, False, None]) +@pytest.mark.parametrize("read_only", [True, False, None]) +async def test_async_sessionmaker( + pool_manager: PoolManager, + expire_on_commit: Optional[bool], + autoflush: Optional[bool], + read_only: Optional[bool], +): + acquire_kwargs = {} + + if read_only is not None: + acquire_kwargs["read_only"] = read_only + + kwargs = {} + + if expire_on_commit is not None: + kwargs["expire_on_commit"] = expire_on_commit + + if autoflush is not None: + kwargs["autoflush"] = autoflush + + if acquire_kwargs: + kwargs["acquire_kwargs"] = acquire_kwargs + + session_factory = async_sessionmaker( + pool_manager=pool_manager, + class_=AsyncSession, + **kwargs, + ) + + async with session_factory() as session: + result = await session.execute(sa.text("SELECT 1")) + assert result.scalar() == 1 + + if expire_on_commit is not None: + assert session.expire_on_commit == expire_on_commit + + if autoflush is not None: + assert session.autoflush == autoflush From 307aabb5b03d6555ea5c1af12bc49f017961273c Mon Sep 17 00:00:00 2001 From: Oleg A Date: Wed, 30 Jul 2025 15:53:54 +0300 Subject: [PATCH 2/6] tests: check session class instance --- tests/test_asyncsqlalchemy.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/tests/test_asyncsqlalchemy.py b/tests/test_asyncsqlalchemy.py index b1040bf..c0488c5 100644 --- a/tests/test_asyncsqlalchemy.py +++ b/tests/test_asyncsqlalchemy.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, Type import mock import pytest @@ -19,7 +19,6 @@ async def pool_manager(pg_dsn): try: await pg_pool.ready() yield pg_pool - pass finally: await pg_pool.close() @@ -80,11 +79,13 @@ async def test_metrics(pool_manager): @pytest.mark.parametrize("expire_on_commit", [True, False, None]) @pytest.mark.parametrize("autoflush", [True, False, None]) @pytest.mark.parametrize("read_only", [True, False, None]) +@pytest.mark.parametrize("class_", [AsyncSession, None]) async def test_async_sessionmaker( pool_manager: PoolManager, expire_on_commit: Optional[bool], autoflush: Optional[bool], read_only: Optional[bool], + class_: Optional[Type[AsyncSession]], ): acquire_kwargs = {} @@ -99,21 +100,23 @@ async def test_async_sessionmaker( if autoflush is not None: kwargs["autoflush"] = autoflush + if class_ is not None: + kwargs["class_"] = class_ + if acquire_kwargs: kwargs["acquire_kwargs"] = acquire_kwargs - session_factory = async_sessionmaker( - pool_manager=pool_manager, - class_=AsyncSession, - **kwargs, - ) + session_factory = async_sessionmaker(pool_manager=pool_manager, **kwargs) - async with session_factory() as session: + async with session_factory() as session: # type: AsyncSession result = await session.execute(sa.text("SELECT 1")) assert result.scalar() == 1 if expire_on_commit is not None: - assert session.expire_on_commit == expire_on_commit + assert session.sync_session.expire_on_commit == expire_on_commit if autoflush is not None: - assert session.autoflush == autoflush + assert session.sync_session.autoflush == autoflush + + if class_ is not None: + assert isinstance(session, class_) From 2124642bb87c0a916dcf2cdebb7411191eae274a Mon Sep 17 00:00:00 2001 From: Oleg A Date: Wed, 30 Jul 2025 17:35:01 +0300 Subject: [PATCH 3/6] chore: use with+parenthesis --- hasql/asyncsqlalchemy.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/hasql/asyncsqlalchemy.py b/hasql/asyncsqlalchemy.py index ed485e6..dbdf941 100644 --- a/hasql/asyncsqlalchemy.py +++ b/hasql/asyncsqlalchemy.py @@ -108,15 +108,16 @@ def async_sessionmaker( @asynccontextmanager async def create_async_session() -> AsyncIterator[AsyncSession]: """Create an async session with connection from the pool.""" - # TODO(PY310): Use parentheses to break the statement in multiple lines - async with pool_manager.acquire(**acquire_kwargs) as connection, \ - class_( - bind=connection, - autoflush=autoflush, - expire_on_commit=expire_on_commit, - info=info, - **kw, - ) as session: + async with ( + pool_manager.acquire(**acquire_kwargs) as connection, + class_( + bind=connection, + autoflush=autoflush, + expire_on_commit=expire_on_commit, + info=info, + **kw, + ) as session, + ): yield session return create_async_session From d9e6902bfbe25eaa7a7739603829410bc9d55d0d Mon Sep 17 00:00:00 2001 From: Oleg A Date: Wed, 30 Jul 2025 17:55:35 +0300 Subject: [PATCH 4/6] fix: cleanup mypy errors --- hasql/asyncsqlalchemy.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/hasql/asyncsqlalchemy.py b/hasql/asyncsqlalchemy.py index dbdf941..301dd66 100644 --- a/hasql/asyncsqlalchemy.py +++ b/hasql/asyncsqlalchemy.py @@ -2,10 +2,10 @@ from contextlib import _AsyncGeneratorContextManager, asynccontextmanager from typing import Any, AsyncIterator, Callable, Dict, Optional, Sequence, Type -import sqlalchemy as sa # type: ignore -from sqlalchemy.ext.asyncio import AsyncConnection, AsyncSession # type: ignore +import sqlalchemy as sa +from sqlalchemy.ext.asyncio import AsyncConnection, AsyncSession from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine -from sqlalchemy.pool import QueuePool # type: ignore +from sqlalchemy.pool import QueuePool from hasql.base import BasePoolManager from hasql.metrics import DriverMetrics @@ -14,7 +14,7 @@ class PoolManager(BasePoolManager): def get_pool_freesize(self, pool: AsyncEngine): - queue_pool: QueuePool = pool.sync_engine.pool + queue_pool: QueuePool = pool.sync_engine.pool # type: ignore[assignment] return queue_pool.size() - queue_pool.checkedout() def acquire_from_pool(self, pool: AsyncEngine, **kwargs): From 8d7effed5c8c73ad26fbd0397c5a7c952ba0ff3f Mon Sep 17 00:00:00 2001 From: Oleg A Date: Wed, 30 Jul 2025 18:01:04 +0300 Subject: [PATCH 5/6] fix: cleanup lint errors --- hasql/asyncsqlalchemy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hasql/asyncsqlalchemy.py b/hasql/asyncsqlalchemy.py index 301dd66..6263335 100644 --- a/hasql/asyncsqlalchemy.py +++ b/hasql/asyncsqlalchemy.py @@ -14,7 +14,7 @@ class PoolManager(BasePoolManager): def get_pool_freesize(self, pool: AsyncEngine): - queue_pool: QueuePool = pool.sync_engine.pool # type: ignore[assignment] + queue_pool: QueuePool = pool.sync_engine.pool # type: ignore return queue_pool.size() - queue_pool.checkedout() def acquire_from_pool(self, pool: AsyncEngine, **kwargs): From 1f973e32424465bcc3c98c40ed05144550d82762 Mon Sep 17 00:00:00 2001 From: Oleg A Date: Thu, 31 Jul 2025 21:35:23 +0300 Subject: [PATCH 6/6] fix: update mypy --- tox.ini | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tox.ini b/tox.ini index c887bc8..9ee1f7a 100644 --- a/tox.ini +++ b/tox.ini @@ -39,7 +39,7 @@ basepython = python3.10 usedevelop = true deps = - mypy==1.5.1 + mypy==1.17.1 commands = mypy --install-types --non-interactive hasql tests