Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 59 additions & 6 deletions hasql/asyncsqlalchemy.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import asyncio
from typing import Sequence
from contextlib import _AsyncGeneratorContextManager, asynccontextmanager
from typing import Any, AsyncIterator, Callable, Dict, Optional, Sequence, Type

import sqlalchemy as sa # type: ignore
from sqlalchemy.ext.asyncio import AsyncConnection # type: ignore
import sqlalchemy as sa
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncSession
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
from sqlalchemy.pool import QueuePool # type: ignore
from sqlalchemy.pool import QueuePool

from hasql.base import BasePoolManager
from hasql.metrics import DriverMetrics
Expand All @@ -13,7 +14,7 @@

class PoolManager(BasePoolManager):
def get_pool_freesize(self, pool: AsyncEngine):
queue_pool: QueuePool = pool.sync_engine.pool
queue_pool: QueuePool = pool.sync_engine.pool # type: ignore
return queue_pool.size() - queue_pool.checkedout()

def acquire_from_pool(self, pool: AsyncEngine, **kwargs):
Expand Down Expand Up @@ -70,4 +71,56 @@ def _driver_metrics(self) -> Sequence[DriverMetrics]:
]


__all__ = ("PoolManager",)
def async_sessionmaker(
pool_manager: PoolManager,
*,
class_: Type[AsyncSession] = AsyncSession,
autoflush: bool = True,
expire_on_commit: bool = True,
info: Optional[Dict[Any, Any]] = None,
acquire_kwargs: Optional[Dict[str, Any]] = None,
**kw: Any,
) -> Callable[..., _AsyncGeneratorContextManager]:
"""Create async session maker with hasql pool support.

This function replaces the default `async_sessionmaker` from
SQLAlchemy to work with the `PoolManager` class. It allows you to
create an async session that is bound to a connection acquired from
the pool. The session will automatically release the connection
back to the pool when the session is closed.

You also can specify the session class to use with the `class_`
parameter, and you can customize the session's behavior with
parameters like `autoflush`, `expire_on_commit`, and `info`.

Use the `acquire_kwargs` to pass additional parameters to the
`pool.acquire()` method. E.g. to create a session with replica
connection:
>>> ReplicaSession = async_sessionmaker(
>>> pool_manager,
>>> acquire_kwargs={"read_only": True}
>>> )

"""
if acquire_kwargs is None:
acquire_kwargs = {}

@asynccontextmanager
async def create_async_session() -> AsyncIterator[AsyncSession]:
"""Create an async session with connection from the pool."""
async with (
pool_manager.acquire(**acquire_kwargs) as connection,
class_(
bind=connection,
autoflush=autoflush,
expire_on_commit=expire_on_commit,
info=info,
**kw,
) as session,
):
yield session

return create_async_session


__all__ = ("PoolManager", "async_sessionmaker")
53 changes: 50 additions & 3 deletions tests/test_asyncsqlalchemy.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from typing import Optional, Type

import mock
import pytest
import sqlalchemy as sa
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine, AsyncSession

from hasql.asyncsqlalchemy import PoolManager
from hasql.asyncsqlalchemy import PoolManager, async_sessionmaker
from hasql.metrics import DriverMetrics


Expand All @@ -17,7 +19,6 @@ async def pool_manager(pg_dsn):
try:
await pg_pool.ready()
yield pg_pool
pass
finally:
await pg_pool.close()

Expand Down Expand Up @@ -73,3 +74,49 @@ async def test_metrics(pool_manager):
assert pool_manager.metrics().drivers == [
DriverMetrics(max=11, min=0, idle=0, used=2, host=mock.ANY)
]


@pytest.mark.parametrize("expire_on_commit", [True, False, None])
@pytest.mark.parametrize("autoflush", [True, False, None])
@pytest.mark.parametrize("read_only", [True, False, None])
@pytest.mark.parametrize("class_", [AsyncSession, None])
async def test_async_sessionmaker(
pool_manager: PoolManager,
expire_on_commit: Optional[bool],
autoflush: Optional[bool],
read_only: Optional[bool],
class_: Optional[Type[AsyncSession]],
):
acquire_kwargs = {}

if read_only is not None:
acquire_kwargs["read_only"] = read_only

kwargs = {}

if expire_on_commit is not None:
kwargs["expire_on_commit"] = expire_on_commit

if autoflush is not None:
kwargs["autoflush"] = autoflush

if class_ is not None:
kwargs["class_"] = class_

if acquire_kwargs:
kwargs["acquire_kwargs"] = acquire_kwargs

session_factory = async_sessionmaker(pool_manager=pool_manager, **kwargs)

async with session_factory() as session: # type: AsyncSession
result = await session.execute(sa.text("SELECT 1"))
assert result.scalar() == 1

if expire_on_commit is not None:
assert session.sync_session.expire_on_commit == expire_on_commit

if autoflush is not None:
assert session.sync_session.autoflush == autoflush

if class_ is not None:
assert isinstance(session, class_)
2 changes: 1 addition & 1 deletion tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ basepython = python3.10
usedevelop = true

deps =
mypy==1.5.1
mypy==1.17.1

commands =
mypy --install-types --non-interactive hasql tests