Skip to content

Commit

Permalink
[DOP-21799] Refactor celery initialization
Browse files Browse the repository at this point in the history
  • Loading branch information
Ilyas Gasanov committed Dec 10, 2024
1 parent a093ada commit c93e253
Show file tree
Hide file tree
Showing 11 changed files with 86 additions and 63 deletions.
12 changes: 12 additions & 0 deletions syncmaster/backend/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# SPDX-FileCopyrightText: 2023-2024 MTS PJSC
# SPDX-License-Identifier: Apache-2.0
from celery import Celery
from fastapi import FastAPI, HTTPException
from fastapi.exceptions import RequestValidationError
from pydantic import ValidationError
Expand All @@ -20,6 +21,15 @@
from syncmaster.exceptions import SyncmasterError


def celery_factory(settings: Settings) -> Celery:
app = Celery(
__name__,
broker=settings.broker.url,
backend="db+" + settings.database.sync_url,
)
return app


def application_factory(settings: Settings) -> FastAPI:
application = FastAPI(
title="Syncmaster",
Expand All @@ -30,6 +40,7 @@ def application_factory(settings: Settings) -> FastAPI:
redoc_url=None,
)
application.state.settings = settings
application.state.celery = celery_factory(settings)
application.include_router(api_router)
application.exception_handler(RequestValidationError)(validation_exception_handler)
application.exception_handler(ValidationError)(validation_exception_handler)
Expand All @@ -44,6 +55,7 @@ def application_factory(settings: Settings) -> FastAPI:
{
Settings: lambda: settings,
UnitOfWork: get_uow(session_factory, settings=settings),
Celery: lambda: application.state.celery,
},
)

Expand Down
3 changes: 2 additions & 1 deletion syncmaster/backend/api/v1/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
from typing import Annotated

from asgi_correlation_id import correlation_id
from celery import Celery
from fastapi import APIRouter, Depends, Query
from jinja2 import Template
from kombu.exceptions import KombuError

from syncmaster.backend.celery import app as celery
from syncmaster.backend.dependencies import Stub
from syncmaster.backend.services import UnitOfWork, get_user
from syncmaster.backend.settings import ServerAppSettings as Settings
Expand Down Expand Up @@ -84,6 +84,7 @@ async def read_run(
async def start_run(
create_run_data: CreateRunSchema,
settings: Annotated[Settings, Depends(Stub(Settings))],
celery: Annotated[Celery, Depends(Stub(Celery))],
unit_of_work: UnitOfWork = Depends(UnitOfWork),
current_user: User = Depends(get_user(is_active=True)),
) -> ReadRunSchema:
Expand Down
6 changes: 0 additions & 6 deletions syncmaster/backend/celery.py

This file was deleted.

14 changes: 12 additions & 2 deletions syncmaster/scheduler/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,14 @@
# SPDX-FileCopyrightText: 2023-2024 MTS PJSC
# SPDX-License-Identifier: Apache-2.0
from syncmaster.scheduler.transfer_fetcher import TransferFetcher
from syncmaster.scheduler.transfer_job_manager import TransferJobManager
from celery import Celery

from syncmaster.scheduler.settings import SchedulerAppSettings


def celery_factory(settings: SchedulerAppSettings) -> Celery:
app = Celery(
__name__,
broker=settings.broker.url,
backend="db+" + settings.database.sync_url,
)
return app
3 changes: 2 additions & 1 deletion syncmaster/scheduler/celery.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-FileCopyrightText: 2023-2024 MTS PJSC
# SPDX-License-Identifier: Apache-2.0
from syncmaster.scheduler import celery_factory
from syncmaster.scheduler.settings import SchedulerAppSettings
from syncmaster.worker import celery_factory

# Global object, since the TransferJobManager.send_job_to_celery method is static
app = celery_factory(SchedulerAppSettings())
5 changes: 4 additions & 1 deletion syncmaster/scheduler/transfer_job_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,11 @@ def update_jobs(self, transfers: list[Transfer]) -> None:
@staticmethod
async def send_job_to_celery(transfer_id: int) -> None:
"""
Do not pass additional arguments like settings,
1. Do not pass additional arguments like settings,
otherwise they will be serialized in jobs table.
2. Instance methods are bound to specific objects and cannot be reliably serialized
due to the weak reference problem. Use a static method instead, as it is not
object-specific and can be serialized.
"""
settings = Settings()

Expand Down
3 changes: 2 additions & 1 deletion syncmaster/worker/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
from celery import Celery

from syncmaster.worker.base import WorkerTask
from syncmaster.worker.settings import WorkerAppSettings


def celery_factory(settings) -> Celery:
def celery_factory(settings: WorkerAppSettings) -> Celery:
app = Celery(
__name__,
broker=settings.broker.url,
Expand Down
22 changes: 22 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
import time
from collections.abc import AsyncGenerator, Callable
from pathlib import Path
from unittest.mock import AsyncMock, Mock

import pytest
import pytest_asyncio
from alembic.config import Config as AlembicConfig
from celery import Celery
from fastapi import FastAPI
from httpx import AsyncClient
from sqlalchemy.ext.asyncio import (
AsyncEngine,
Expand Down Expand Up @@ -129,6 +131,26 @@ async def session(sessionmaker: async_sessionmaker[AsyncSession]):
await session.close()


@pytest.fixture(scope="session")
def mocked_celery() -> Celery:
celery_app = Mock(Celery)
celery_app.send_task = AsyncMock()
return celery_app


@pytest_asyncio.fixture(scope="session")
async def app(settings: Settings, mocked_celery: Celery) -> FastAPI:
app = application_factory(settings=settings)
app.dependency_overrides[Celery] = lambda: mocked_celery
return app


@pytest_asyncio.fixture(scope="session")
async def client_with_mocked_celery(app: FastAPI) -> AsyncGenerator:
async with AsyncClient(app=app, base_url="http://testserver") as client:
yield client


@pytest_asyncio.fixture(scope="session")
async def client(settings: Settings) -> AsyncGenerator:
logger.info("START CLIENT FIXTURE")
Expand Down
2 changes: 1 addition & 1 deletion tests/test_integration/celery_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from syncmaster.scheduler.celery import app as celery
from syncmaster.worker.celery import app as celery

celery.conf.update(imports=list(celery.conf.imports) + ["tests.test_integration.test_scheduler.test_task"])
3 changes: 2 additions & 1 deletion tests/test_integration/test_scheduler/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@

from syncmaster.backend.settings import ServerAppSettings as Settings
from syncmaster.db.models import Run, Status
from syncmaster.scheduler import TransferFetcher, TransferJobManager
from syncmaster.scheduler.transfer_fetcher import TransferFetcher
from syncmaster.scheduler.transfer_job_manager import TransferJobManager
from tests.mocks import MockTransfer

pytestmark = [pytest.mark.asyncio, pytest.mark.worker, pytest.mark.scheduler_integration]
Expand Down
76 changes: 27 additions & 49 deletions tests/test_unit/test_runs/test_create_run.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from unittest.mock import AsyncMock

import pytest
from celery import Celery
from httpx import AsyncClient
from pytest_mock import MockerFixture
from sqlalchemy import desc, select
from sqlalchemy.ext.asyncio import AsyncSession

Expand All @@ -12,15 +14,15 @@


async def test_developer_plus_can_create_run_of_transfer_his_group(
client: AsyncClient,
client_with_mocked_celery: AsyncClient,
mocked_celery: Celery,
group_transfer: MockTransfer,
session: AsyncSession,
mocker,
mocker: MockerFixture,
role_developer_plus: UserTestRoles,
) -> None:
# Arrange
user = group_transfer.owner_group.get_member_of_role(role_developer_plus)
mock_send_task = mocker.patch("syncmaster.backend.celery.app.send_task")
mock_send_task = mocked_celery.send_task
mock_to_thread = mocker.patch("asyncio.to_thread", new_callable=AsyncMock)

run = (
Expand All @@ -31,14 +33,12 @@ async def test_developer_plus_can_create_run_of_transfer_his_group(

assert not run

# Act
result = await client.post(
result = await client_with_mocked_celery.post(
"v1/runs",
headers={"Authorization": f"Bearer {user.token}"},
json={"transfer_id": group_transfer.id},
)

# Assert
run = (
await session.scalars(
select(Run).filter_by(transfer_id=group_transfer.id, status=Status.CREATED).order_by(desc(Run.created_at)),
Expand Down Expand Up @@ -66,24 +66,20 @@ async def test_developer_plus_can_create_run_of_transfer_his_group(


async def test_groupless_user_cannot_create_run(
client: AsyncClient,
client_with_mocked_celery: AsyncClient,
simple_user: MockUser,
group_transfer: MockTransfer,
session: AsyncSession,
mocker,
mocker: MockerFixture,
) -> None:
# Arrange
mocker.patch("syncmaster.backend.celery.app.send_task")
mocker.patch("asyncio.to_thread", new_callable=AsyncMock)

# Act
result = await client.post(
result = await client_with_mocked_celery.post(
"v1/runs",
headers={"Authorization": f"Bearer {simple_user.token}"},
json={"transfer_id": group_transfer.id},
)

# Assert
assert result.json() == {
"error": {
"code": "not_found",
Expand All @@ -95,26 +91,22 @@ async def test_groupless_user_cannot_create_run(


async def test_group_member_cannot_create_run_of_other_group_transfer(
client: AsyncClient,
client_with_mocked_celery: AsyncClient,
group_transfer: MockTransfer,
group: MockGroup,
session: AsyncSession,
mocker,
mocker: MockerFixture,
role_guest_plus: UserTestRoles,
):
# Arrange
mocker.patch("syncmaster.backend.celery.app.send_task")
mocker.patch("asyncio.to_thread", new_callable=AsyncMock)
user = group.get_member_of_role(role_guest_plus)

# Act
result = await client.post(
result = await client_with_mocked_celery.post(
"v1/runs",
headers={"Authorization": f"Bearer {user.token}"},
json={"transfer_id": group_transfer.id},
)

# Assert
assert result.json() == {
"error": {
"code": "not_found",
Expand All @@ -132,18 +124,17 @@ async def test_group_member_cannot_create_run_of_other_group_transfer(


async def test_superuser_can_create_run(
client: AsyncClient,
client_with_mocked_celery: AsyncClient,
mocked_celery: Celery,
superuser: MockUser,
group_transfer: MockTransfer,
session: AsyncSession,
mocker,
mocker: MockerFixture,
) -> None:
# Arrange
mock_send_task = mocker.patch("syncmaster.backend.celery.app.send_task")
mock_send_task = mocked_celery.send_task
mock_to_thread = mocker.patch("asyncio.to_thread", new_callable=AsyncMock)

# Act
result = await client.post(
result = await client_with_mocked_celery.post(
"v1/runs",
headers={"Authorization": f"Bearer {superuser.token}"},
json={"transfer_id": group_transfer.id},
Expand All @@ -154,7 +145,6 @@ async def test_superuser_can_create_run(
)
).first()

# Assert
response = result.json()
assert response == {
"id": run.id,
Expand All @@ -178,21 +168,17 @@ async def test_superuser_can_create_run(


async def test_unauthorized_user_cannot_create_run(
client: AsyncClient,
client_with_mocked_celery: AsyncClient,
group_transfer: MockTransfer,
mocker,
mocker: MockerFixture,
) -> None:
# Arrange
mocker.patch("syncmaster.backend.celery.app.send_task")
mocker.patch("asyncio.to_thread", new_callable=AsyncMock)

# Act
result = await client.post(
result = await client_with_mocked_celery.post(
"v1/runs",
json={"transfer_id": group_transfer.id},
)

# Assert
assert result.json() == {
"error": {
"code": "unauthorized",
Expand All @@ -204,25 +190,21 @@ async def test_unauthorized_user_cannot_create_run(


async def test_group_member_cannot_create_run_of_unknown_transfer_error(
client: AsyncClient,
client_with_mocked_celery: AsyncClient,
group_transfer: MockTransfer,
session: AsyncSession,
mocker,
mocker: MockerFixture,
role_guest_plus: UserTestRoles,
) -> None:
# Arrange
user = group_transfer.owner_group.get_member_of_role(role_guest_plus)
mocker.patch("syncmaster.backend.celery.app.send_task")
mocker.patch("asyncio.to_thread", new_callable=AsyncMock)

# Act
result = await client.post(
result = await client_with_mocked_celery.post(
"v1/runs",
headers={"Authorization": f"Bearer {user.token}"},
json={"transfer_id": -1},
)

# Assert
assert result.json() == {
"error": {
"code": "not_found",
Expand All @@ -233,24 +215,20 @@ async def test_group_member_cannot_create_run_of_unknown_transfer_error(


async def test_superuser_cannot_create_run_of_unknown_transfer_error(
client: AsyncClient,
client_with_mocked_celery: AsyncClient,
superuser: MockUser,
group_transfer: MockTransfer,
session: AsyncSession,
mocker,
mocker: MockerFixture,
) -> None:
# Arrange
mocker.patch("syncmaster.backend.celery.app.send_task")
mocker.patch("asyncio.to_thread", new_callable=AsyncMock)

# Act
result = await client.post(
result = await client_with_mocked_celery.post(
"v1/runs",
headers={"Authorization": f"Bearer {superuser.token}"},
json={"transfer_id": -1},
)

# Assert
assert result.json() == {
"error": {
"code": "not_found",
Expand Down

0 comments on commit c93e253

Please sign in to comment.