Skip to content

Commit

Permalink
[DOP-14025] Remove asyncio.gather from SQLAlchemy requests
Browse files Browse the repository at this point in the history
  • Loading branch information
dolfinus committed Apr 19, 2024
1 parent c4366f9 commit 45cd542
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 88 deletions.
1 change: 1 addition & 0 deletions docs/changelog/next_release/40.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Do not use ``asyncio.gather`` with SQLAlchemy requests.
67 changes: 28 additions & 39 deletions syncmaster/backend/api/v1/connections.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# SPDX-FileCopyrightText: 2023-2024 MTS (Mobile Telesystems)
# SPDX-License-Identifier: Apache-2.0
import asyncio
from typing import get_args

from fastapi import APIRouter, Depends, Query, status
Expand Down Expand Up @@ -59,19 +58,17 @@ async def read_connections(
items: list[ReadConnectionSchema] = []

if pagination.items:
creds = await asyncio.gather(
*[unit_of_work.credentials.get_for_connection(connection_id=item.id) for item in pagination.items]
)
credentials = await unit_of_work.credentials.get_bulk([item.id for item in pagination.items])
items = [
ReadConnectionSchema(
id=item.id,
group_id=item.group_id,
name=item.name,
description=item.description,
auth_data=creds[n_item],
auth_data=credentials.get(item.id, None),
data=item.data,
)
for n_item, item in enumerate(pagination.items)
for item in pagination.items
]

return ConnectionPageSchema(
Expand Down Expand Up @@ -121,7 +118,7 @@ async def create_connection(
data=data,
)

await unit_of_work.credentials.add_to_connection(
await unit_of_work.credentials.create(
connection_id=connection.id,
data=auth_data,
)
Expand Down Expand Up @@ -155,12 +152,10 @@ async def read_connection(
if resource_role == Permission.NONE:
raise ConnectionNotFoundError

connection = await unit_of_work.connection.read_by_id(connection_id=connection_id)
connection = await unit_of_work.connection.read_by_id(connection_id)

try:
credentials = await unit_of_work.credentials.get_for_connection(
connection_id=connection.id,
)
credentials = await unit_of_work.credentials.get(connection.id)
except AuthDataNotFoundError:
credentials = None

Expand Down Expand Up @@ -206,7 +201,7 @@ async def update_connection(
credential_data=connection_data.auth_data.dict(),
)

auth_data = await unit_of_work.credentials.get_for_connection(connection_id)
auth_data = await unit_of_work.credentials.get(connection_id)
return ReadConnectionSchema(
id=connection.id,
group_id=connection.group_id,
Expand All @@ -227,28 +222,26 @@ async def delete_connection(
user=current_user,
resource_id=connection_id,
)

if resource_role == Permission.NONE:
raise ConnectionNotFoundError

if resource_role < Permission.DELETE:
raise ActionNotAllowedError

connection = await unit_of_work.connection.read_by_id(connection_id=connection_id)
connection = await unit_of_work.connection.read_by_id(connection_id)
transfers = await unit_of_work.transfer.list_by_connection_id(connection.id)
if transfers:
raise ConnectionDeleteError(
f"The connection has an associated transfers. Number of the connected transfers: {len(transfers)}",
)

transfers = await unit_of_work.transfer.list_by_connection_id(conn_id=connection.id)
async with unit_of_work:
if not transfers:
await unit_of_work.connection.delete(connection_id=connection_id)
await unit_of_work.connection.delete(connection_id)

return StatusResponseSchema(
ok=True,
status_code=status.HTTP_200_OK,
message="Connection was deleted",
)

raise ConnectionDeleteError(
f"The connection has an associated transfers. Number of the connected transfers: {len(transfers)}",
return StatusResponseSchema(
ok=True,
status_code=status.HTTP_200_OK,
message="Connection was deleted",
)


Expand All @@ -259,24 +252,20 @@ async def copy_connection(
current_user: User = Depends(get_user(is_active=True)),
unit_of_work: UnitOfWork = Depends(UnitOfWorkMarker),
) -> StatusResponseSchema:
target_source_rules = await asyncio.gather(
unit_of_work.connection.get_resource_permission(
user=current_user,
resource_id=connection_id,
),
unit_of_work.connection.get_group_permission(
user=current_user,
group_id=copy_connection_data.new_group_id,
),
resource_role = await unit_of_work.connection.get_resource_permission(
user=current_user,
resource_id=connection_id,
)
resource_role, target_group_role = target_source_rules
if resource_role == Permission.NONE:
raise ConnectionNotFoundError

if copy_connection_data.remove_source and resource_role < Permission.DELETE:
raise ActionNotAllowedError

if resource_role == Permission.NONE:
raise ConnectionNotFoundError

target_group_role = await unit_of_work.connection.get_group_permission(
user=current_user,
group_id=copy_connection_data.new_group_id,
)
if target_group_role == Permission.NONE:
raise GroupNotFoundError

Expand All @@ -291,7 +280,7 @@ async def copy_connection(
)

if copy_connection_data.remove_source:
await unit_of_work.connection.delete(connection_id=connection_id)
await unit_of_work.connection.delete(connection_id)

return StatusResponseSchema(
ok=True,
Expand Down
57 changes: 23 additions & 34 deletions syncmaster/backend/api/v1/transfers/router.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# SPDX-FileCopyrightText: 2023-2024 MTS (Mobile Telesystems)
# SPDX-License-Identifier: Apache-2.0
import asyncio

from fastapi import APIRouter, Depends, Query, status
from kombu.exceptions import KombuError
Expand Down Expand Up @@ -79,16 +78,11 @@ async def create_transfer(
user=current_user,
group_id=transfer_data.group_id,
)

if group_permission < Permission.WRITE:
raise ActionNotAllowedError

target_connection = await unit_of_work.connection.read_by_id(
connection_id=transfer_data.target_connection_id,
)
source_connection = await unit_of_work.connection.read_by_id(
connection_id=transfer_data.source_connection_id,
)
target_connection = await unit_of_work.connection.read_by_id(transfer_data.target_connection_id)
source_connection = await unit_of_work.connection.read_by_id(transfer_data.source_connection_id)
queue = await unit_of_work.queue.read_by_id(transfer_data.queue_id)

if (
Expand Down Expand Up @@ -158,44 +152,39 @@ async def copy_transfer(
current_user: User = Depends(get_user(is_active=True)),
unit_of_work: UnitOfWork = Depends(UnitOfWorkMarker),
) -> StatusCopyTransferResponseSchema:
# Check: user can copy transfer
target_source_transfer_rules = await asyncio.gather(
unit_of_work.transfer.get_resource_permission(
user=current_user,
resource_id=transfer_id,
),
unit_of_work.transfer.get_group_permission(
user=current_user,
group_id=transfer_data.new_group_id,
),
resource_role = await unit_of_work.transfer.get_resource_permission(
user=current_user,
resource_id=transfer_id,
)
resource_role, target_group_role = target_source_transfer_rules

if resource_role == Permission.NONE:
raise TransferNotFoundError

if target_group_role < Permission.WRITE:
raise ActionNotAllowedError

# Check: user can delete transfer
if transfer_data.remove_source and resource_role < Permission.DELETE:
raise ActionNotAllowedError

target_group_role = await unit_of_work.transfer.get_group_permission(
user=current_user,
group_id=transfer_data.new_group_id,
)
if target_group_role < Permission.WRITE:
raise ActionNotAllowedError

Check warning on line 171 in syncmaster/backend/api/v1/transfers/router.py

View check run for this annotation

Codecov / codecov/patch

syncmaster/backend/api/v1/transfers/router.py#L171

Added line #L171 was not covered by tests

transfer = await unit_of_work.transfer.read_by_id(transfer_id=transfer_id)

# Check: user can copy connection
target_source_connection_rules = await asyncio.gather(
unit_of_work.connection.get_resource_permission(
user=current_user,
resource_id=transfer.source_connection_id,
),
unit_of_work.connection.get_resource_permission(
user=current_user,
resource_id=transfer.target_connection_id,
),
source_connection_role = await unit_of_work.connection.get_resource_permission(
user=current_user,
resource_id=transfer.source_connection_id,
)
source_connection_role, target_connection_role = target_source_connection_rules
if source_connection_role == Permission.NONE:
raise ConnectionNotFoundError

Check warning on line 181 in syncmaster/backend/api/v1/transfers/router.py

View check run for this annotation

Codecov / codecov/patch

syncmaster/backend/api/v1/transfers/router.py#L181

Added line #L181 was not covered by tests

if source_connection_role == Permission.NONE or target_connection_role == Permission.NONE:
target_connection_role = await unit_of_work.connection.get_resource_permission(
user=current_user,
resource_id=transfer.target_connection_id,
)
if target_connection_role == Permission.NONE:
raise ConnectionNotFoundError

# Check: new queue exists
Expand Down
27 changes: 12 additions & 15 deletions syncmaster/db/repositories/credentials_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from typing import NoReturn

from sqlalchemy import ScalarResult, delete, insert, select
from sqlalchemy import ScalarResult, insert, select
from sqlalchemy.exc import DBAPIError, IntegrityError, NoResultFound
from sqlalchemy.ext.asyncio import AsyncSession

Expand All @@ -26,7 +26,7 @@ def __init__(
super().__init__(model=model, session=session)
self._settings = settings

async def get_for_connection(
async def get(
self,
connection_id: int,
) -> dict:
Expand All @@ -37,7 +37,15 @@ async def get_for_connection(
except NoResultFound as e:
raise AuthDataNotFoundError(f"Connection id = {connection_id}") from e

async def add_to_connection(self, connection_id: int, data: dict) -> AuthData:
async def get_bulk(
self,
connection_ids: list[int],
) -> dict[int, dict]:
query = select(AuthData).where(AuthData.connection_id.in_(connection_ids))
result: ScalarResult[AuthData] = await self._session.scalars(query)
return {item.connection_id: decrypt_auth_data(item.value, settings=self._settings) for item in result}

async def create(self, connection_id: int, data: dict) -> AuthData:
query = (
insert(AuthData)
.values(
Expand All @@ -54,23 +62,12 @@ async def add_to_connection(self, connection_id: int, data: dict) -> AuthData:
await self._session.flush()
return result.one()

async def delete_from_connection(self, connection_id: int) -> AuthData:
query = delete(AuthData).where(AuthData.connection_id == connection_id).returning(AuthData)

try:
result: ScalarResult[AuthData] = await self._session.scalars(query)
except IntegrityError as e:
self._raise_error(e)
else:
await self._session.flush()
return result.one()

async def update(
self,
connection_id: int,
credential_data: dict,
) -> AuthData:
creds = await self.get_for_connection(connection_id)
creds = await self.get(connection_id)
try:
for key in creds:
if key not in credential_data or credential_data[key] is None:
Expand Down

0 comments on commit 45cd542

Please sign in to comment.