From f0ba198f81ae7b3651c1ca0a21fa331c112374b1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=D0=9C=D0=B0=D1=80=D1=82=D1=8B=D0=BD=D0=BE=D0=B2=20=D0=9C?= =?UTF-8?q?=D0=B0=D0=BA=D1=81=D0=B8=D0=BC=20=D0=A1=D0=B5=D1=80=D0=B3=D0=B5?= =?UTF-8?q?=D0=B5=D0=B2=D0=B8=D1=87?= Date: Fri, 19 Apr 2024 14:33:14 +0300 Subject: [PATCH] [DOP-14025] Remove asyncio.gather from SQLAlchemy requests --- docs/changelog/next_release/40.bugfix.rst | 1 + syncmaster/backend/api/v1/connections.py | 68 ++++++++----------- syncmaster/backend/api/v1/transfers/router.py | 57 +++++++--------- .../db/repositories/credentials_repository.py | 27 ++++---- 4 files changed, 64 insertions(+), 89 deletions(-) create mode 100644 docs/changelog/next_release/40.bugfix.rst diff --git a/docs/changelog/next_release/40.bugfix.rst b/docs/changelog/next_release/40.bugfix.rst new file mode 100644 index 00000000..f62ecd7c --- /dev/null +++ b/docs/changelog/next_release/40.bugfix.rst @@ -0,0 +1 @@ +Do not use ``asyncio.gather`` with SQLAlchemy requests. diff --git a/syncmaster/backend/api/v1/connections.py b/syncmaster/backend/api/v1/connections.py index acfe9a2a..b6965e70 100644 --- a/syncmaster/backend/api/v1/connections.py +++ b/syncmaster/backend/api/v1/connections.py @@ -1,6 +1,5 @@ # SPDX-FileCopyrightText: 2023-2024 MTS (Mobile Telesystems) # SPDX-License-Identifier: Apache-2.0 -import asyncio from typing import get_args from fastapi import APIRouter, Depends, Query, status @@ -59,19 +58,17 @@ async def read_connections( items: list[ReadConnectionSchema] = [] if pagination.items: - creds = await asyncio.gather( - *[unit_of_work.credentials.get_for_connection(connection_id=item.id) for item in pagination.items] - ) + credentials = await unit_of_work.credentials.read_bulk([item.id for item in pagination.items]) items = [ ReadConnectionSchema( id=item.id, group_id=item.group_id, name=item.name, description=item.description, - auth_data=creds[n_item], + auth_data=credentials.get(item.id, None), data=item.data, ) - for n_item, item in enumerate(pagination.items) + for item in pagination.items ] return ConnectionPageSchema( @@ -121,7 +118,7 @@ async def create_connection( data=data, ) - await unit_of_work.credentials.add_to_connection( + await unit_of_work.credentials.create( connection_id=connection.id, data=auth_data, ) @@ -155,12 +152,9 @@ async def read_connection( if resource_role == Permission.NONE: raise ConnectionNotFoundError - connection = await unit_of_work.connection.read_by_id(connection_id=connection_id) - + connection = await unit_of_work.connection.read_by_id(connection_id) try: - credentials = await unit_of_work.credentials.get_for_connection( - connection_id=connection.id, - ) + credentials = await unit_of_work.credentials.read(connection.id) except AuthDataNotFoundError: credentials = None @@ -206,7 +200,7 @@ async def update_connection( credential_data=connection_data.auth_data.dict(), ) - auth_data = await unit_of_work.credentials.get_for_connection(connection_id) + auth_data = await unit_of_work.credentials.read(connection_id) return ReadConnectionSchema( id=connection.id, group_id=connection.group_id, @@ -227,28 +221,26 @@ async def delete_connection( user=current_user, resource_id=connection_id, ) - if resource_role == Permission.NONE: raise ConnectionNotFoundError if resource_role < Permission.DELETE: raise ActionNotAllowedError - connection = await unit_of_work.connection.read_by_id(connection_id=connection_id) + connection = await unit_of_work.connection.read_by_id(connection_id) + transfers = await unit_of_work.transfer.list_by_connection_id(connection.id) + if transfers: + raise ConnectionDeleteError( + f"The connection has an associated transfers. Number of the connected transfers: {len(transfers)}", + ) - transfers = await unit_of_work.transfer.list_by_connection_id(conn_id=connection.id) async with unit_of_work: - if not transfers: - await unit_of_work.connection.delete(connection_id=connection_id) + await unit_of_work.connection.delete(connection_id) - return StatusResponseSchema( - ok=True, - status_code=status.HTTP_200_OK, - message="Connection was deleted", - ) - - raise ConnectionDeleteError( - f"The connection has an associated transfers. Number of the connected transfers: {len(transfers)}", + return StatusResponseSchema( + ok=True, + status_code=status.HTTP_200_OK, + message="Connection was deleted", ) @@ -259,24 +251,20 @@ async def copy_connection( current_user: User = Depends(get_user(is_active=True)), unit_of_work: UnitOfWork = Depends(UnitOfWorkMarker), ) -> StatusResponseSchema: - target_source_rules = await asyncio.gather( - unit_of_work.connection.get_resource_permission( - user=current_user, - resource_id=connection_id, - ), - unit_of_work.connection.get_group_permission( - user=current_user, - group_id=copy_connection_data.new_group_id, - ), + resource_role = await unit_of_work.connection.get_resource_permission( + user=current_user, + resource_id=connection_id, ) - resource_role, target_group_role = target_source_rules + if resource_role == Permission.NONE: + raise ConnectionNotFoundError if copy_connection_data.remove_source and resource_role < Permission.DELETE: raise ActionNotAllowedError - if resource_role == Permission.NONE: - raise ConnectionNotFoundError - + target_group_role = await unit_of_work.connection.get_group_permission( + user=current_user, + group_id=copy_connection_data.new_group_id, + ) if target_group_role == Permission.NONE: raise GroupNotFoundError @@ -291,7 +279,7 @@ async def copy_connection( ) if copy_connection_data.remove_source: - await unit_of_work.connection.delete(connection_id=connection_id) + await unit_of_work.connection.delete(connection_id) return StatusResponseSchema( ok=True, diff --git a/syncmaster/backend/api/v1/transfers/router.py b/syncmaster/backend/api/v1/transfers/router.py index 4c795d07..fd486767 100644 --- a/syncmaster/backend/api/v1/transfers/router.py +++ b/syncmaster/backend/api/v1/transfers/router.py @@ -1,6 +1,5 @@ # SPDX-FileCopyrightText: 2023-2024 MTS (Mobile Telesystems) # SPDX-License-Identifier: Apache-2.0 -import asyncio from fastapi import APIRouter, Depends, Query, status from kombu.exceptions import KombuError @@ -79,16 +78,11 @@ async def create_transfer( user=current_user, group_id=transfer_data.group_id, ) - if group_permission < Permission.WRITE: raise ActionNotAllowedError - target_connection = await unit_of_work.connection.read_by_id( - connection_id=transfer_data.target_connection_id, - ) - source_connection = await unit_of_work.connection.read_by_id( - connection_id=transfer_data.source_connection_id, - ) + target_connection = await unit_of_work.connection.read_by_id(transfer_data.target_connection_id) + source_connection = await unit_of_work.connection.read_by_id(transfer_data.source_connection_id) queue = await unit_of_work.queue.read_by_id(transfer_data.queue_id) if ( @@ -158,44 +152,39 @@ async def copy_transfer( current_user: User = Depends(get_user(is_active=True)), unit_of_work: UnitOfWork = Depends(UnitOfWorkMarker), ) -> StatusCopyTransferResponseSchema: - # Check: user can copy transfer - target_source_transfer_rules = await asyncio.gather( - unit_of_work.transfer.get_resource_permission( - user=current_user, - resource_id=transfer_id, - ), - unit_of_work.transfer.get_group_permission( - user=current_user, - group_id=transfer_data.new_group_id, - ), + resource_role = await unit_of_work.transfer.get_resource_permission( + user=current_user, + resource_id=transfer_id, ) - resource_role, target_group_role = target_source_transfer_rules - if resource_role == Permission.NONE: raise TransferNotFoundError - if target_group_role < Permission.WRITE: - raise ActionNotAllowedError - # Check: user can delete transfer if transfer_data.remove_source and resource_role < Permission.DELETE: raise ActionNotAllowedError + target_group_role = await unit_of_work.transfer.get_group_permission( + user=current_user, + group_id=transfer_data.new_group_id, + ) + if target_group_role < Permission.WRITE: + raise ActionNotAllowedError + transfer = await unit_of_work.transfer.read_by_id(transfer_id=transfer_id) + # Check: user can copy connection - target_source_connection_rules = await asyncio.gather( - unit_of_work.connection.get_resource_permission( - user=current_user, - resource_id=transfer.source_connection_id, - ), - unit_of_work.connection.get_resource_permission( - user=current_user, - resource_id=transfer.target_connection_id, - ), + source_connection_role = await unit_of_work.connection.get_resource_permission( + user=current_user, + resource_id=transfer.source_connection_id, ) - source_connection_role, target_connection_role = target_source_connection_rules + if source_connection_role == Permission.NONE: + raise ConnectionNotFoundError - if source_connection_role == Permission.NONE or target_connection_role == Permission.NONE: + target_connection_role = await unit_of_work.connection.get_resource_permission( + user=current_user, + resource_id=transfer.target_connection_id, + ) + if target_connection_role == Permission.NONE: raise ConnectionNotFoundError # Check: new queue exists diff --git a/syncmaster/db/repositories/credentials_repository.py b/syncmaster/db/repositories/credentials_repository.py index b16e94d5..e96b651a 100644 --- a/syncmaster/db/repositories/credentials_repository.py +++ b/syncmaster/db/repositories/credentials_repository.py @@ -4,7 +4,7 @@ from typing import NoReturn -from sqlalchemy import ScalarResult, delete, insert, select +from sqlalchemy import ScalarResult, insert, select from sqlalchemy.exc import DBAPIError, IntegrityError, NoResultFound from sqlalchemy.ext.asyncio import AsyncSession @@ -26,7 +26,7 @@ def __init__( super().__init__(model=model, session=session) self._settings = settings - async def get_for_connection( + async def read( self, connection_id: int, ) -> dict: @@ -37,7 +37,15 @@ async def get_for_connection( except NoResultFound as e: raise AuthDataNotFoundError(f"Connection id = {connection_id}") from e - async def add_to_connection(self, connection_id: int, data: dict) -> AuthData: + async def read_bulk( + self, + connection_ids: list[int], + ) -> dict[int, dict]: + query = select(AuthData).where(AuthData.connection_id.in_(connection_ids)) + result: ScalarResult[AuthData] = await self._session.scalars(query) + return {item.connection_id: decrypt_auth_data(item.value, settings=self._settings) for item in result} + + async def create(self, connection_id: int, data: dict) -> AuthData: query = ( insert(AuthData) .values( @@ -54,23 +62,12 @@ async def add_to_connection(self, connection_id: int, data: dict) -> AuthData: await self._session.flush() return result.one() - async def delete_from_connection(self, connection_id: int) -> AuthData: - query = delete(AuthData).where(AuthData.connection_id == connection_id).returning(AuthData) - - try: - result: ScalarResult[AuthData] = await self._session.scalars(query) - except IntegrityError as e: - self._raise_error(e) - else: - await self._session.flush() - return result.one() - async def update( self, connection_id: int, credential_data: dict, ) -> AuthData: - creds = await self.get_for_connection(connection_id) + creds = await self.read(connection_id) try: for key in creds: if key not in credential_data or credential_data[key] is None: