Skip to content

Commit

Permalink
[DOP-21799] Refactor celery initialization
Browse files Browse the repository at this point in the history
  • Loading branch information
Ilyas Gasanov committed Dec 10, 2024
1 parent 26fce0c commit ca85372
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 50 deletions.
4 changes: 4 additions & 0 deletions syncmaster/backend/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# SPDX-FileCopyrightText: 2023-2024 MTS PJSC
# SPDX-License-Identifier: Apache-2.0
from celery import Celery
from fastapi import FastAPI, HTTPException
from fastapi.exceptions import RequestValidationError
from pydantic import ValidationError
Expand All @@ -18,6 +19,7 @@
from syncmaster.backend.settings import ServerAppSettings as Settings
from syncmaster.db.factory import create_session_factory, get_uow
from syncmaster.exceptions import SyncmasterError
from syncmaster.worker import celery_factory


def application_factory(settings: Settings) -> FastAPI:
Expand All @@ -30,6 +32,7 @@ def application_factory(settings: Settings) -> FastAPI:
redoc_url=None,
)
application.state.settings = settings
application.state.celery = celery_factory(settings)
application.include_router(api_router)
application.exception_handler(RequestValidationError)(validation_exception_handler)
application.exception_handler(ValidationError)(validation_exception_handler)
Expand All @@ -44,6 +47,7 @@ def application_factory(settings: Settings) -> FastAPI:
{
Settings: lambda: settings,
UnitOfWork: get_uow(session_factory, settings=settings),
Celery: lambda: application.state.celery,
},
)

Expand Down
3 changes: 2 additions & 1 deletion syncmaster/backend/api/v1/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
from typing import Annotated

from asgi_correlation_id import correlation_id
from celery import Celery
from fastapi import APIRouter, Depends, Query
from jinja2 import Template
from kombu.exceptions import KombuError

from syncmaster.backend.celery import app as celery
from syncmaster.backend.dependencies import Stub
from syncmaster.backend.services import UnitOfWork, get_user
from syncmaster.backend.settings import ServerAppSettings as Settings
Expand Down Expand Up @@ -84,6 +84,7 @@ async def read_run(
async def start_run(
create_run_data: CreateRunSchema,
settings: Annotated[Settings, Depends(Stub(Settings))],
celery: Annotated[Celery, Depends(Stub(Celery))],
unit_of_work: UnitOfWork = Depends(UnitOfWork),
current_user: User = Depends(get_user(is_active=True)),
) -> ReadRunSchema:
Expand Down
22 changes: 22 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
import time
from collections.abc import AsyncGenerator, Callable
from pathlib import Path
from unittest.mock import AsyncMock, Mock

import pytest
import pytest_asyncio
from alembic.config import Config as AlembicConfig
from celery import Celery
from fastapi import FastAPI
from httpx import AsyncClient
from sqlalchemy.ext.asyncio import (
AsyncEngine,
Expand Down Expand Up @@ -129,6 +131,26 @@ async def session(sessionmaker: async_sessionmaker[AsyncSession]):
await session.close()


@pytest.fixture(scope="session")
def mocked_celery() -> Celery:
celery_app = Mock(Celery)
celery_app.send_task = AsyncMock()
return celery_app


@pytest_asyncio.fixture(scope="session")
async def app(settings: Settings, mocked_celery: Celery) -> FastAPI:
app = application_factory(settings=settings)
app.dependency_overrides[Celery] = lambda: mocked_celery
return app


@pytest_asyncio.fixture(scope="session")
async def client_with_mocked_celery(app: FastAPI) -> AsyncGenerator:
async with AsyncClient(app=app, base_url="http://testserver") as client:
yield client


@pytest_asyncio.fixture(scope="session")
async def client(settings: Settings) -> AsyncGenerator:
logger.info("START CLIENT FIXTURE")
Expand Down
76 changes: 27 additions & 49 deletions tests/test_unit/test_runs/test_create_run.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from unittest.mock import AsyncMock

import pytest
from celery import Celery
from httpx import AsyncClient
from pytest_mock import MockerFixture
from sqlalchemy import desc, select
from sqlalchemy.ext.asyncio import AsyncSession

Expand All @@ -12,15 +14,15 @@


async def test_developer_plus_can_create_run_of_transfer_his_group(
client: AsyncClient,
client_with_mocked_celery: AsyncClient,
mocked_celery: Celery,
group_transfer: MockTransfer,
session: AsyncSession,
mocker,
mocker: MockerFixture,
role_developer_plus: UserTestRoles,
) -> None:
# Arrange
user = group_transfer.owner_group.get_member_of_role(role_developer_plus)
mock_send_task = mocker.patch("syncmaster.backend.celery.app.send_task")
mock_send_task = mocked_celery.send_task
mock_to_thread = mocker.patch("asyncio.to_thread", new_callable=AsyncMock)

run = (
Expand All @@ -31,14 +33,12 @@ async def test_developer_plus_can_create_run_of_transfer_his_group(

assert not run

# Act
result = await client.post(
result = await client_with_mocked_celery.post(
"v1/runs",
headers={"Authorization": f"Bearer {user.token}"},
json={"transfer_id": group_transfer.id},
)

# Assert
run = (
await session.scalars(
select(Run).filter_by(transfer_id=group_transfer.id, status=Status.CREATED).order_by(desc(Run.created_at)),
Expand Down Expand Up @@ -66,24 +66,20 @@ async def test_developer_plus_can_create_run_of_transfer_his_group(


async def test_groupless_user_cannot_create_run(
client: AsyncClient,
client_with_mocked_celery: AsyncClient,
simple_user: MockUser,
group_transfer: MockTransfer,
session: AsyncSession,
mocker,
mocker: MockerFixture,
) -> None:
# Arrange
mocker.patch("syncmaster.backend.celery.app.send_task")
mocker.patch("asyncio.to_thread", new_callable=AsyncMock)

# Act
result = await client.post(
result = await client_with_mocked_celery.post(
"v1/runs",
headers={"Authorization": f"Bearer {simple_user.token}"},
json={"transfer_id": group_transfer.id},
)

# Assert
assert result.json() == {
"error": {
"code": "not_found",
Expand All @@ -95,26 +91,22 @@ async def test_groupless_user_cannot_create_run(


async def test_group_member_cannot_create_run_of_other_group_transfer(
client: AsyncClient,
client_with_mocked_celery: AsyncClient,
group_transfer: MockTransfer,
group: MockGroup,
session: AsyncSession,
mocker,
mocker: MockerFixture,
role_guest_plus: UserTestRoles,
):
# Arrange
mocker.patch("syncmaster.backend.celery.app.send_task")
mocker.patch("asyncio.to_thread", new_callable=AsyncMock)
user = group.get_member_of_role(role_guest_plus)

# Act
result = await client.post(
result = await client_with_mocked_celery.post(
"v1/runs",
headers={"Authorization": f"Bearer {user.token}"},
json={"transfer_id": group_transfer.id},
)

# Assert
assert result.json() == {
"error": {
"code": "not_found",
Expand All @@ -132,18 +124,17 @@ async def test_group_member_cannot_create_run_of_other_group_transfer(


async def test_superuser_can_create_run(
client: AsyncClient,
client_with_mocked_celery: AsyncClient,
mocked_celery: Celery,
superuser: MockUser,
group_transfer: MockTransfer,
session: AsyncSession,
mocker,
mocker: MockerFixture,
) -> None:
# Arrange
mock_send_task = mocker.patch("syncmaster.backend.celery.app.send_task")
mock_send_task = mocked_celery.send_task
mock_to_thread = mocker.patch("asyncio.to_thread", new_callable=AsyncMock)

# Act
result = await client.post(
result = await client_with_mocked_celery.post(
"v1/runs",
headers={"Authorization": f"Bearer {superuser.token}"},
json={"transfer_id": group_transfer.id},
Expand All @@ -154,7 +145,6 @@ async def test_superuser_can_create_run(
)
).first()

# Assert
response = result.json()
assert response == {
"id": run.id,
Expand All @@ -178,21 +168,17 @@ async def test_superuser_can_create_run(


async def test_unauthorized_user_cannot_create_run(
client: AsyncClient,
client_with_mocked_celery: AsyncClient,
group_transfer: MockTransfer,
mocker,
mocker: MockerFixture,
) -> None:
# Arrange
mocker.patch("syncmaster.backend.celery.app.send_task")
mocker.patch("asyncio.to_thread", new_callable=AsyncMock)

# Act
result = await client.post(
result = await client_with_mocked_celery.post(
"v1/runs",
json={"transfer_id": group_transfer.id},
)

# Assert
assert result.json() == {
"error": {
"code": "unauthorized",
Expand All @@ -204,25 +190,21 @@ async def test_unauthorized_user_cannot_create_run(


async def test_group_member_cannot_create_run_of_unknown_transfer_error(
client: AsyncClient,
client_with_mocked_celery: AsyncClient,
group_transfer: MockTransfer,
session: AsyncSession,
mocker,
mocker: MockerFixture,
role_guest_plus: UserTestRoles,
) -> None:
# Arrange
user = group_transfer.owner_group.get_member_of_role(role_guest_plus)
mocker.patch("syncmaster.backend.celery.app.send_task")
mocker.patch("asyncio.to_thread", new_callable=AsyncMock)

# Act
result = await client.post(
result = await client_with_mocked_celery.post(
"v1/runs",
headers={"Authorization": f"Bearer {user.token}"},
json={"transfer_id": -1},
)

# Assert
assert result.json() == {
"error": {
"code": "not_found",
Expand All @@ -233,24 +215,20 @@ async def test_group_member_cannot_create_run_of_unknown_transfer_error(


async def test_superuser_cannot_create_run_of_unknown_transfer_error(
client: AsyncClient,
client_with_mocked_celery: AsyncClient,
superuser: MockUser,
group_transfer: MockTransfer,
session: AsyncSession,
mocker,
mocker: MockerFixture,
) -> None:
# Arrange
mocker.patch("syncmaster.backend.celery.app.send_task")
mocker.patch("asyncio.to_thread", new_callable=AsyncMock)

# Act
result = await client.post(
result = await client_with_mocked_celery.post(
"v1/runs",
headers={"Authorization": f"Bearer {superuser.token}"},
json={"transfer_id": -1},
)

# Assert
assert result.json() == {
"error": {
"code": "not_found",
Expand Down

0 comments on commit ca85372

Please sign in to comment.