From c93e2539b8728050c54b464f37b5936626af3d17 Mon Sep 17 00:00:00 2001 From: Ilyas Gasanov Date: Tue, 10 Dec 2024 10:13:40 +0300 Subject: [PATCH] [DOP-21799] Refactor celery initialization --- syncmaster/backend/__init__.py | 12 +++ syncmaster/backend/api/v1/runs.py | 3 +- syncmaster/backend/celery.py | 6 -- syncmaster/scheduler/__init__.py | 14 +++- syncmaster/scheduler/celery.py | 3 +- syncmaster/scheduler/transfer_job_manager.py | 5 +- syncmaster/worker/__init__.py | 3 +- tests/conftest.py | 22 ++++++ tests/test_integration/celery_test.py | 2 +- .../test_scheduler/test_scheduler.py | 3 +- tests/test_unit/test_runs/test_create_run.py | 76 +++++++------------ 11 files changed, 86 insertions(+), 63 deletions(-) delete mode 100644 syncmaster/backend/celery.py diff --git a/syncmaster/backend/__init__.py b/syncmaster/backend/__init__.py index 5262ed34..c00158b0 100644 --- a/syncmaster/backend/__init__.py +++ b/syncmaster/backend/__init__.py @@ -1,5 +1,6 @@ # SPDX-FileCopyrightText: 2023-2024 MTS PJSC # SPDX-License-Identifier: Apache-2.0 +from celery import Celery from fastapi import FastAPI, HTTPException from fastapi.exceptions import RequestValidationError from pydantic import ValidationError @@ -20,6 +21,15 @@ from syncmaster.exceptions import SyncmasterError +def celery_factory(settings: Settings) -> Celery: + app = Celery( + __name__, + broker=settings.broker.url, + backend="db+" + settings.database.sync_url, + ) + return app + + def application_factory(settings: Settings) -> FastAPI: application = FastAPI( title="Syncmaster", @@ -30,6 +40,7 @@ def application_factory(settings: Settings) -> FastAPI: redoc_url=None, ) application.state.settings = settings + application.state.celery = celery_factory(settings) application.include_router(api_router) application.exception_handler(RequestValidationError)(validation_exception_handler) application.exception_handler(ValidationError)(validation_exception_handler) @@ -44,6 +55,7 @@ def application_factory(settings: Settings) -> FastAPI: { Settings: lambda: settings, UnitOfWork: get_uow(session_factory, settings=settings), + Celery: lambda: application.state.celery, }, ) diff --git a/syncmaster/backend/api/v1/runs.py b/syncmaster/backend/api/v1/runs.py index 5d069955..d6bf1b16 100644 --- a/syncmaster/backend/api/v1/runs.py +++ b/syncmaster/backend/api/v1/runs.py @@ -5,11 +5,11 @@ from typing import Annotated from asgi_correlation_id import correlation_id +from celery import Celery from fastapi import APIRouter, Depends, Query from jinja2 import Template from kombu.exceptions import KombuError -from syncmaster.backend.celery import app as celery from syncmaster.backend.dependencies import Stub from syncmaster.backend.services import UnitOfWork, get_user from syncmaster.backend.settings import ServerAppSettings as Settings @@ -84,6 +84,7 @@ async def read_run( async def start_run( create_run_data: CreateRunSchema, settings: Annotated[Settings, Depends(Stub(Settings))], + celery: Annotated[Celery, Depends(Stub(Celery))], unit_of_work: UnitOfWork = Depends(UnitOfWork), current_user: User = Depends(get_user(is_active=True)), ) -> ReadRunSchema: diff --git a/syncmaster/backend/celery.py b/syncmaster/backend/celery.py deleted file mode 100644 index e85268f9..00000000 --- a/syncmaster/backend/celery.py +++ /dev/null @@ -1,6 +0,0 @@ -# SPDX-FileCopyrightText: 2023-2024 MTS PJSC -# SPDX-License-Identifier: Apache-2.0 -from syncmaster.backend.settings import ServerAppSettings -from syncmaster.worker import celery_factory - -app = celery_factory(ServerAppSettings()) diff --git a/syncmaster/scheduler/__init__.py b/syncmaster/scheduler/__init__.py index aa0c378f..3cd219cc 100644 --- a/syncmaster/scheduler/__init__.py +++ b/syncmaster/scheduler/__init__.py @@ -1,4 +1,14 @@ # SPDX-FileCopyrightText: 2023-2024 MTS PJSC # SPDX-License-Identifier: Apache-2.0 -from syncmaster.scheduler.transfer_fetcher import TransferFetcher -from syncmaster.scheduler.transfer_job_manager import TransferJobManager +from celery import Celery + +from syncmaster.scheduler.settings import SchedulerAppSettings + + +def celery_factory(settings: SchedulerAppSettings) -> Celery: + app = Celery( + __name__, + broker=settings.broker.url, + backend="db+" + settings.database.sync_url, + ) + return app diff --git a/syncmaster/scheduler/celery.py b/syncmaster/scheduler/celery.py index c6bba44e..da11f406 100644 --- a/syncmaster/scheduler/celery.py +++ b/syncmaster/scheduler/celery.py @@ -1,6 +1,7 @@ # SPDX-FileCopyrightText: 2023-2024 MTS PJSC # SPDX-License-Identifier: Apache-2.0 +from syncmaster.scheduler import celery_factory from syncmaster.scheduler.settings import SchedulerAppSettings -from syncmaster.worker import celery_factory +# Global object, since the TransferJobManager.send_job_to_celery method is static app = celery_factory(SchedulerAppSettings()) diff --git a/syncmaster/scheduler/transfer_job_manager.py b/syncmaster/scheduler/transfer_job_manager.py index 83c03718..36469972 100644 --- a/syncmaster/scheduler/transfer_job_manager.py +++ b/syncmaster/scheduler/transfer_job_manager.py @@ -50,8 +50,11 @@ def update_jobs(self, transfers: list[Transfer]) -> None: @staticmethod async def send_job_to_celery(transfer_id: int) -> None: """ - Do not pass additional arguments like settings, + 1. Do not pass additional arguments like settings, otherwise they will be serialized in jobs table. + 2. Instance methods are bound to specific objects and cannot be reliably serialized + due to the weak reference problem. Use a static method instead, as it is not + object-specific and can be serialized. """ settings = Settings() diff --git a/syncmaster/worker/__init__.py b/syncmaster/worker/__init__.py index 09df658a..1dd8d3a8 100644 --- a/syncmaster/worker/__init__.py +++ b/syncmaster/worker/__init__.py @@ -3,9 +3,10 @@ from celery import Celery from syncmaster.worker.base import WorkerTask +from syncmaster.worker.settings import WorkerAppSettings -def celery_factory(settings) -> Celery: +def celery_factory(settings: WorkerAppSettings) -> Celery: app = Celery( __name__, broker=settings.broker.url, diff --git a/tests/conftest.py b/tests/conftest.py index 1d3c779d..5bafb3d3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,11 +4,13 @@ import time from collections.abc import AsyncGenerator, Callable from pathlib import Path +from unittest.mock import AsyncMock, Mock import pytest import pytest_asyncio from alembic.config import Config as AlembicConfig from celery import Celery +from fastapi import FastAPI from httpx import AsyncClient from sqlalchemy.ext.asyncio import ( AsyncEngine, @@ -129,6 +131,26 @@ async def session(sessionmaker: async_sessionmaker[AsyncSession]): await session.close() +@pytest.fixture(scope="session") +def mocked_celery() -> Celery: + celery_app = Mock(Celery) + celery_app.send_task = AsyncMock() + return celery_app + + +@pytest_asyncio.fixture(scope="session") +async def app(settings: Settings, mocked_celery: Celery) -> FastAPI: + app = application_factory(settings=settings) + app.dependency_overrides[Celery] = lambda: mocked_celery + return app + + +@pytest_asyncio.fixture(scope="session") +async def client_with_mocked_celery(app: FastAPI) -> AsyncGenerator: + async with AsyncClient(app=app, base_url="http://testserver") as client: + yield client + + @pytest_asyncio.fixture(scope="session") async def client(settings: Settings) -> AsyncGenerator: logger.info("START CLIENT FIXTURE") diff --git a/tests/test_integration/celery_test.py b/tests/test_integration/celery_test.py index 389383d0..5fee7307 100644 --- a/tests/test_integration/celery_test.py +++ b/tests/test_integration/celery_test.py @@ -1,3 +1,3 @@ -from syncmaster.scheduler.celery import app as celery +from syncmaster.worker.celery import app as celery celery.conf.update(imports=list(celery.conf.imports) + ["tests.test_integration.test_scheduler.test_task"]) diff --git a/tests/test_integration/test_scheduler/test_scheduler.py b/tests/test_integration/test_scheduler/test_scheduler.py index 00e6ebfc..fbf8ebe0 100644 --- a/tests/test_integration/test_scheduler/test_scheduler.py +++ b/tests/test_integration/test_scheduler/test_scheduler.py @@ -7,7 +7,8 @@ from syncmaster.backend.settings import ServerAppSettings as Settings from syncmaster.db.models import Run, Status -from syncmaster.scheduler import TransferFetcher, TransferJobManager +from syncmaster.scheduler.transfer_fetcher import TransferFetcher +from syncmaster.scheduler.transfer_job_manager import TransferJobManager from tests.mocks import MockTransfer pytestmark = [pytest.mark.asyncio, pytest.mark.worker, pytest.mark.scheduler_integration] diff --git a/tests/test_unit/test_runs/test_create_run.py b/tests/test_unit/test_runs/test_create_run.py index 20c3c379..426b4b12 100644 --- a/tests/test_unit/test_runs/test_create_run.py +++ b/tests/test_unit/test_runs/test_create_run.py @@ -1,7 +1,9 @@ from unittest.mock import AsyncMock import pytest +from celery import Celery from httpx import AsyncClient +from pytest_mock import MockerFixture from sqlalchemy import desc, select from sqlalchemy.ext.asyncio import AsyncSession @@ -12,15 +14,15 @@ async def test_developer_plus_can_create_run_of_transfer_his_group( - client: AsyncClient, + client_with_mocked_celery: AsyncClient, + mocked_celery: Celery, group_transfer: MockTransfer, session: AsyncSession, - mocker, + mocker: MockerFixture, role_developer_plus: UserTestRoles, ) -> None: - # Arrange user = group_transfer.owner_group.get_member_of_role(role_developer_plus) - mock_send_task = mocker.patch("syncmaster.backend.celery.app.send_task") + mock_send_task = mocked_celery.send_task mock_to_thread = mocker.patch("asyncio.to_thread", new_callable=AsyncMock) run = ( @@ -31,14 +33,12 @@ async def test_developer_plus_can_create_run_of_transfer_his_group( assert not run - # Act - result = await client.post( + result = await client_with_mocked_celery.post( "v1/runs", headers={"Authorization": f"Bearer {user.token}"}, json={"transfer_id": group_transfer.id}, ) - # Assert run = ( await session.scalars( select(Run).filter_by(transfer_id=group_transfer.id, status=Status.CREATED).order_by(desc(Run.created_at)), @@ -66,24 +66,20 @@ async def test_developer_plus_can_create_run_of_transfer_his_group( async def test_groupless_user_cannot_create_run( - client: AsyncClient, + client_with_mocked_celery: AsyncClient, simple_user: MockUser, group_transfer: MockTransfer, session: AsyncSession, - mocker, + mocker: MockerFixture, ) -> None: - # Arrange - mocker.patch("syncmaster.backend.celery.app.send_task") mocker.patch("asyncio.to_thread", new_callable=AsyncMock) - # Act - result = await client.post( + result = await client_with_mocked_celery.post( "v1/runs", headers={"Authorization": f"Bearer {simple_user.token}"}, json={"transfer_id": group_transfer.id}, ) - # Assert assert result.json() == { "error": { "code": "not_found", @@ -95,26 +91,22 @@ async def test_groupless_user_cannot_create_run( async def test_group_member_cannot_create_run_of_other_group_transfer( - client: AsyncClient, + client_with_mocked_celery: AsyncClient, group_transfer: MockTransfer, group: MockGroup, session: AsyncSession, - mocker, + mocker: MockerFixture, role_guest_plus: UserTestRoles, ): - # Arrange - mocker.patch("syncmaster.backend.celery.app.send_task") mocker.patch("asyncio.to_thread", new_callable=AsyncMock) user = group.get_member_of_role(role_guest_plus) - # Act - result = await client.post( + result = await client_with_mocked_celery.post( "v1/runs", headers={"Authorization": f"Bearer {user.token}"}, json={"transfer_id": group_transfer.id}, ) - # Assert assert result.json() == { "error": { "code": "not_found", @@ -132,18 +124,17 @@ async def test_group_member_cannot_create_run_of_other_group_transfer( async def test_superuser_can_create_run( - client: AsyncClient, + client_with_mocked_celery: AsyncClient, + mocked_celery: Celery, superuser: MockUser, group_transfer: MockTransfer, session: AsyncSession, - mocker, + mocker: MockerFixture, ) -> None: - # Arrange - mock_send_task = mocker.patch("syncmaster.backend.celery.app.send_task") + mock_send_task = mocked_celery.send_task mock_to_thread = mocker.patch("asyncio.to_thread", new_callable=AsyncMock) - # Act - result = await client.post( + result = await client_with_mocked_celery.post( "v1/runs", headers={"Authorization": f"Bearer {superuser.token}"}, json={"transfer_id": group_transfer.id}, @@ -154,7 +145,6 @@ async def test_superuser_can_create_run( ) ).first() - # Assert response = result.json() assert response == { "id": run.id, @@ -178,21 +168,17 @@ async def test_superuser_can_create_run( async def test_unauthorized_user_cannot_create_run( - client: AsyncClient, + client_with_mocked_celery: AsyncClient, group_transfer: MockTransfer, - mocker, + mocker: MockerFixture, ) -> None: - # Arrange - mocker.patch("syncmaster.backend.celery.app.send_task") mocker.patch("asyncio.to_thread", new_callable=AsyncMock) - # Act - result = await client.post( + result = await client_with_mocked_celery.post( "v1/runs", json={"transfer_id": group_transfer.id}, ) - # Assert assert result.json() == { "error": { "code": "unauthorized", @@ -204,25 +190,21 @@ async def test_unauthorized_user_cannot_create_run( async def test_group_member_cannot_create_run_of_unknown_transfer_error( - client: AsyncClient, + client_with_mocked_celery: AsyncClient, group_transfer: MockTransfer, session: AsyncSession, - mocker, + mocker: MockerFixture, role_guest_plus: UserTestRoles, ) -> None: - # Arrange user = group_transfer.owner_group.get_member_of_role(role_guest_plus) - mocker.patch("syncmaster.backend.celery.app.send_task") mocker.patch("asyncio.to_thread", new_callable=AsyncMock) - # Act - result = await client.post( + result = await client_with_mocked_celery.post( "v1/runs", headers={"Authorization": f"Bearer {user.token}"}, json={"transfer_id": -1}, ) - # Assert assert result.json() == { "error": { "code": "not_found", @@ -233,24 +215,20 @@ async def test_group_member_cannot_create_run_of_unknown_transfer_error( async def test_superuser_cannot_create_run_of_unknown_transfer_error( - client: AsyncClient, + client_with_mocked_celery: AsyncClient, superuser: MockUser, group_transfer: MockTransfer, session: AsyncSession, - mocker, + mocker: MockerFixture, ) -> None: - # Arrange - mocker.patch("syncmaster.backend.celery.app.send_task") mocker.patch("asyncio.to_thread", new_callable=AsyncMock) - # Act - result = await client.post( + result = await client_with_mocked_celery.post( "v1/runs", headers={"Authorization": f"Bearer {superuser.token}"}, json={"transfer_id": -1}, ) - # Assert assert result.json() == { "error": { "code": "not_found",