From a7aa3ab03f309e8303dc230b1816d14f6d322e2f Mon Sep 17 00:00:00 2001 From: Christophe Bornet Date: Sun, 17 Nov 2024 12:51:24 +0100 Subject: [PATCH] fix: Use AsyncSession in delete_vertex_builds (#4653) Use AsyncSession in delete_vertex_builds --- src/backend/base/langflow/api/v1/monitor.py | 6 +- .../database/models/vertex_builds/crud.py | 7 ++- src/backend/tests/conftest.py | 60 ++++++++++--------- 3 files changed, 39 insertions(+), 34 deletions(-) diff --git a/src/backend/base/langflow/api/v1/monitor.py b/src/backend/base/langflow/api/v1/monitor.py index cf7e28a55bf4..e8a457383dc3 100644 --- a/src/backend/base/langflow/api/v1/monitor.py +++ b/src/backend/base/langflow/api/v1/monitor.py @@ -5,7 +5,7 @@ from sqlalchemy import delete from sqlmodel import col, select -from langflow.api.utils import AsyncDbSession, DbSession +from langflow.api.utils import AsyncDbSession from langflow.schema.message import MessageResponse from langflow.services.auth.utils import get_current_active_user from langflow.services.database.models.message.model import MessageRead, MessageTable, MessageUpdate @@ -30,9 +30,9 @@ async def get_vertex_builds(flow_id: Annotated[UUID, Query()], session: AsyncDbS @router.delete("/builds", status_code=204) -def delete_vertex_builds(flow_id: Annotated[UUID, Query()], session: DbSession) -> None: +async def delete_vertex_builds(flow_id: Annotated[UUID, Query()], session: AsyncDbSession) -> None: try: - delete_vertex_builds_by_flow_id(session, flow_id) + await delete_vertex_builds_by_flow_id(session, flow_id) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) from e diff --git a/src/backend/base/langflow/services/database/models/vertex_builds/crud.py b/src/backend/base/langflow/services/database/models/vertex_builds/crud.py index 286640465401..6bd068e9616e 100644 --- a/src/backend/base/langflow/services/database/models/vertex_builds/crud.py +++ b/src/backend/base/langflow/services/database/models/vertex_builds/crud.py @@ -32,6 +32,7 @@ def log_vertex_build(db: Session, vertex_build: VertexBuildBase) -> VertexBuildT return table -def delete_vertex_builds_by_flow_id(db: Session, flow_id: UUID) -> None: - db.exec(delete(VertexBuildTable).where(VertexBuildTable.flow_id == flow_id)) - db.commit() +async def delete_vertex_builds_by_flow_id(db: AsyncSession, flow_id: UUID) -> None: + stmt = delete(VertexBuildTable).where(VertexBuildTable.flow_id == flow_id) + await db.exec(stmt) + await db.commit() diff --git a/src/backend/tests/conftest.py b/src/backend/tests/conftest.py index b236a733f962..4f33ab1f22be 100644 --- a/src/backend/tests/conftest.py +++ b/src/backend/tests/conftest.py @@ -28,7 +28,9 @@ from langflow.services.database.utils import session_getter from langflow.services.deps import get_db_service from loguru import logger +from sqlalchemy.orm import selectinload from sqlmodel import Session, SQLModel, create_engine, select +from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.pool import StaticPool from typer.testing import CliRunner @@ -85,21 +87,21 @@ def get_text(): assert path.exists(), f"File {path} does not exist. Available files: {list(data_path.iterdir())}" -def delete_transactions_by_flow_id(db: Session, flow_id: UUID): +async def delete_transactions_by_flow_id(db: AsyncSession, flow_id: UUID): stmt = select(TransactionTable).where(TransactionTable.flow_id == flow_id) - transactions = db.exec(stmt) + transactions = await db.exec(stmt) for transaction in transactions: - db.delete(transaction) - db.commit() + await db.delete(transaction) + await db.commit() -def _delete_transactions_and_vertex_builds(session, user: User): - flow_ids = [flow.id for flow in user.flows] +async def _delete_transactions_and_vertex_builds(session, flows: list[Flow]): + flow_ids = [flow.id for flow in flows] for flow_id in flow_ids: if not flow_id: continue - delete_vertex_builds_by_flow_id(session, flow_id) - delete_transactions_by_flow_id(session, flow_id) + await delete_vertex_builds_by_flow_id(session, flow_id) + await delete_transactions_by_flow_id(session, flow_id) @pytest.fixture @@ -361,31 +363,32 @@ async def test_user(client): @pytest.fixture -def active_user(client): # noqa: ARG001 +async def active_user(client): # noqa: ARG001 db_manager = get_db_service() - with db_manager.with_session() as session: + async with db_manager.with_async_session() as session: user = User( username="activeuser", password=get_password_hash("testpassword"), is_active=True, is_superuser=False, ) - if active_user := session.exec(select(User).where(User.username == user.username)).first(): + stmt = select(User).where(User.username == user.username) + if active_user := (await session.exec(stmt)).first(): user = active_user else: session.add(user) - session.commit() - session.refresh(user) + await session.commit() + await session.refresh(user) user = UserRead.model_validate(user, from_attributes=True) yield user # Clean up # Now cleanup transactions, vertex_build - with db_manager.with_session() as session: - user = session.get(User, user.id) - _delete_transactions_and_vertex_builds(session, user) - session.delete(user) + async with db_manager.with_async_session() as session: + user = await session.get(User, user.id, options=[selectinload(User.flows)]) + await _delete_transactions_and_vertex_builds(session, user.flows) + await session.delete(user) - session.commit() + await session.commit() @pytest.fixture @@ -399,31 +402,32 @@ async def logged_in_headers(client, active_user): @pytest.fixture -def active_super_user(client): # noqa: ARG001 +async def active_super_user(client): # noqa: ARG001 db_manager = get_db_service() - with db_manager.with_session() as session: + async with db_manager.with_async_session() as session: user = User( username="activeuser", password=get_password_hash("testpassword"), is_active=True, is_superuser=True, ) - if active_user := session.exec(select(User).where(User.username == user.username)).first(): + stmt = select(User).where(User.username == user.username) + if active_user := (await session.exec(stmt)).first(): user = active_user else: session.add(user) - session.commit() - session.refresh(user) + await session.commit() + await session.refresh(user) user = UserRead.model_validate(user, from_attributes=True) yield user # Clean up # Now cleanup transactions, vertex_build - with db_manager.with_session() as session: - user = session.get(User, user.id) - _delete_transactions_and_vertex_builds(session, user) - session.delete(user) + async with db_manager.with_async_session() as session: + user = await session.get(User, user.id, options=[selectinload(User.flows)]) + await _delete_transactions_and_vertex_builds(session, user.flows) + await session.delete(user) - session.commit() + await session.commit() @pytest.fixture