From 22edaa230a8cd8936144ed23d5c31f9c5cf576d4 Mon Sep 17 00:00:00 2001 From: Luca Date: Sun, 24 Nov 2024 14:15:54 +0000 Subject: [PATCH] Add endpoint tests --- .vscode/launch.json | 2 +- fluid/scheduler/broker.py | 30 +++++++++-------- fluid/scheduler/endpoints.py | 20 ++++++++++-- fluid/scheduler/models.py | 2 +- fluid/utils/http_client.py | 8 ----- pyproject.toml | 1 + tests/scheduler/conftest.py | 37 +++++++++++++++------ tests/scheduler/tasks.py | 34 +++++++++++++++++++- tests/scheduler/test_endpoints.py | 53 +++++++++++++++++++++++++++++++ tests/scheduler/test_scheduler.py | 4 ++- 10 files changed, 155 insertions(+), 36 deletions(-) create mode 100644 tests/scheduler/test_endpoints.py diff --git a/.vscode/launch.json b/.vscode/launch.json index a649d76..425115e 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -12,7 +12,7 @@ "env": {}, "args": [ "-x", - "tests/scheduler/test_scheduler.py::test_async_handler" + "tests/scheduler/test_endpoints.py" ], "debugOptions": [ "RedirectOutput" diff --git a/fluid/scheduler/broker.py b/fluid/scheduler/broker.py index 2542a33..a8c92b4 100644 --- a/fluid/scheduler/broker.py +++ b/fluid/scheduler/broker.py @@ -102,12 +102,10 @@ def task_from_registry(self, task: str | Task) -> Task: def register_task(self, task: Task) -> None: self.registry[task.name] = task - async def enable_task(self, task_name: str, enable: bool = True) -> TaskInfo: + async def enable_task(self, task: str | Task, enable: bool = True) -> TaskInfo: """Enable or disable a registered task""" - task = self.registry.get(task_name) - if not task: - raise UnknownTaskError(task_name) - return await self.update_task(task, dict(enabled=enable)) + task_ = self.task_from_registry(task) + return await self.update_task(task_, dict(enabled=enable)) @classmethod def from_url(cls, url: str = "") -> TaskBroker: @@ -154,6 +152,13 @@ def task_hash_name(self, name: str) -> str: def task_queue_name(self, priority: TaskPriority) -> str: return f"{self.name}-queue-{priority.name}" + async def clear(self) -> int: + pipe = self.redis_cli.pipeline() + async for key in self.redis_cli.scan_iter(f"{self.name}-*"): + pipe.delete(key) + r = await pipe.execute() + return len(r) + async def get_tasks_info(self, *task_names: str) -> list[TaskInfo]: pipe = self.redis_cli.pipeline() names = task_names or self.registry @@ -170,14 +175,13 @@ async def get_tasks_info(self, *task_names: str) -> list[TaskInfo]: async def update_task(self, task: Task, params: dict[str, Any]) -> TaskInfo: pipe = self.redis_cli.pipeline() - info = json.loads(TaskInfoUpdate(**params).model_dump_json()) - pipe.hset( - self.task_hash_name(task.name), - mapping={name: json.dumps(value) for name, value in info.items()}, - ) - pipe.hgetall(self.task_hash_name(task.name)) - _, info = await pipe.execute() - return self._decode_task(task, info) + info = json.loads(TaskInfoUpdate(**params).model_dump_json(exclude_unset=True)) + update = {name: json.dumps(value) for name, value in info.items()} + key = self.task_hash_name(task.name) + pipe.hset(key, mapping=update) + pipe.hgetall(key) + _, data = await pipe.execute() + return self._decode_task(task, data) async def queue_length(self) -> dict[str, int]: if self.task_queue_names: diff --git a/fluid/scheduler/endpoints.py b/fluid/scheduler/endpoints.py index 70678cc..0c31fb2 100644 --- a/fluid/scheduler/endpoints.py +++ b/fluid/scheduler/endpoints.py @@ -48,10 +48,26 @@ async def get_tasks(task_manager: TaskManagerDep) -> list[TaskInfo]: @router.get( - "/tasks/status", + "/tasks/{task_name}", + response_model=TaskInfo, + summary="Get a Task", + description="Retrieve information about a task", +) +async def get_task( + task_manager: TaskManagerDep, + task_name: str = Path(title="Task name"), +) -> TaskInfo: + data = await task_manager.broker.get_tasks_info(task_name) + if not data: + raise HTTPException(status_code=404, detail="Task not found") + return data[0] + + +@router.get( + "/tasks-status", response_model=dict, summary="Task consumer status", - description="Retrieve a list of tasks runs", + description="Status of the task consumer", ) async def get_task_status(task_manager: TaskManagerDep) -> dict: if isinstance(task_manager, Worker): diff --git a/fluid/scheduler/models.py b/fluid/scheduler/models.py index 7d6ab03..979dfb8 100644 --- a/fluid/scheduler/models.py +++ b/fluid/scheduler/models.py @@ -86,7 +86,7 @@ class TaskInfoUpdate(BaseModel): last_run_duration: timedelta | None = Field( default=None, description="Task last run duration in milliseconds" ) - last_run_state: str | None = Field( + last_run_state: TaskState | None = Field( default=None, description="State of last task run" ) diff --git a/fluid/utils/http_client.py b/fluid/utils/http_client.py index dc92a17..ce495e8 100644 --- a/fluid/utils/http_client.py +++ b/fluid/utils/http_client.py @@ -268,11 +268,3 @@ async def close(self) -> None: if self.session and self.session_owner: await self.session.aclose() self.session = None - - -C = TypeVar("C", bound=HttpClient) - - -class HttpComponent(Generic[C]): - def __init__(self, cli: C) -> None: - self.cli = cli diff --git a/pyproject.toml b/pyproject.toml index d736718..8d06145 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -69,6 +69,7 @@ build-backend = "poetry.core.masonry.api" [tool.pytest.ini_options] asyncio_mode = "auto" +asyncio_default_fixture_loop_scope = "module" testpaths = [ "tests" ] diff --git a/tests/scheduler/conftest.py b/tests/scheduler/conftest.py index e6750fd..44152df 100644 --- a/tests/scheduler/conftest.py +++ b/tests/scheduler/conftest.py @@ -1,17 +1,15 @@ -import os from contextlib import asynccontextmanager from typing import AsyncIterator, cast import pytest from fastapi import FastAPI +from httpx import ASGITransport, AsyncClient from redis.asyncio import Redis from fluid.scheduler import TaskManager, TaskScheduler from fluid.scheduler.broker import RedisTaskBroker from fluid.scheduler.endpoints import get_task_manger, setup_fastapi -from tests.scheduler.tasks import task_application - -os.environ["TASK_MANAGER_APP"] = "tests.scheduler.tasks:task_application" +from tests.scheduler.tasks import TaskClient, task_application @asynccontextmanager @@ -20,13 +18,34 @@ async def start_fastapi(app: FastAPI) -> AsyncIterator: yield app -@pytest.fixture -async def task_scheduler() -> AsyncIterator[TaskManager]: +def redis_broker(task_manager: TaskManager) -> RedisTaskBroker: + return cast(RedisTaskBroker, task_manager.broker) + + +@pytest.fixture(scope="module") +async def task_app(): task_manager = task_application(TaskScheduler()) async with start_fastapi(setup_fastapi(task_manager)) as app: - yield get_task_manger(app) + broker = redis_broker(task_manager) + await broker.clear() + yield app -@pytest.fixture +@pytest.fixture(scope="module") +async def task_scheduler(task_app) -> TaskManager: + return get_task_manger(task_app) + + +@pytest.fixture(scope="module") def redis(task_scheduler: TaskScheduler) -> Redis: # type: ignore - return cast(RedisTaskBroker, task_scheduler.broker).redis.redis_cli + return redis_broker(task_scheduler).redis.redis_cli + + +@pytest.fixture(scope="module") +async def cli(task_app): + base_url = TaskClient().url + async with AsyncClient( + transport=ASGITransport(app=task_app), base_url=base_url + ) as session: + async with TaskClient(url=base_url, session=session) as client: + yield client diff --git a/tests/scheduler/tasks.py b/tests/scheduler/tasks.py index 377aa7c..79e8ecf 100644 --- a/tests/scheduler/tasks.py +++ b/tests/scheduler/tasks.py @@ -1,5 +1,37 @@ +import asyncio +from dataclasses import dataclass +from datetime import datetime + from examples import tasks -from fluid.scheduler import TaskManager +from fluid.scheduler import TaskInfo, TaskManager +from fluid.utils.http_client import HttpxClient + + +@dataclass +class TaskClient(HttpxClient): + url: str = "http://test_api" + + async def get_tasks(self) -> list[TaskInfo]: + data = await self.get(f"{self.url}/tasks") + return [TaskInfo(**task) for task in data] + + async def get_task(self, task_name: str) -> TaskInfo: + data = await self.get(f"{self.url}/tasks/{task_name}") + return TaskInfo(**data) + + async def wait_for_task( + self, + task_name: str, + last_run_end: datetime | None = None, + timeout: float = 1.0, + ) -> TaskInfo: + sleep = min(timeout / 10.0, 0.1) + async with asyncio.timeout(timeout): + while True: + task = await self.get_task(task_name) + if task.last_run_end != last_run_end: + return task + await asyncio.sleep(sleep) def task_application(manager: TaskManager | None = None) -> TaskManager: diff --git a/tests/scheduler/test_endpoints.py b/tests/scheduler/test_endpoints.py new file mode 100644 index 0000000..eeba0fe --- /dev/null +++ b/tests/scheduler/test_endpoints.py @@ -0,0 +1,53 @@ +import pytest + +from fluid.scheduler.models import TaskInfo, TaskState +from fluid.utils.http_client import HttpResponseError +from tests.scheduler.tasks import TaskClient + +pytestmark = pytest.mark.asyncio(loop_scope="module") + + +async def test_get_tasks(cli: TaskClient) -> None: + data = await cli.get(f"{cli.url}/tasks") + assert len(data) == 4 + tasks = {task["name"]: TaskInfo(**task) for task in data} + dummy = tasks["dummy"] + assert dummy.name == "dummy" + + +async def test_get_tasks_status(cli: TaskClient) -> None: + data = await cli.get(f"{cli.url}/tasks-status") + assert data + + +async def test_run_task_404(cli: TaskClient) -> None: + with pytest.raises(HttpResponseError): + await cli.post(f"{cli.url}/tasks", json=dict(name="whatever")) + + +async def test_run_task(cli: TaskClient) -> None: + task = await cli.get_task("dummy") + assert task.last_run_end is None + task = await cli.get_task("dummy") + data = await cli.post(f"{cli.url}/tasks", json=dict(name="dummy")) + assert data["task"] == "dummy" + # wait for task + task = await cli.wait_for_task("dummy") + assert task.last_run_end is not None + + +async def test_patch_task_404(cli: TaskClient) -> None: + with pytest.raises(HttpResponseError): + await cli.patch(f"{cli.url}/tasks/whatever", json=dict(enabled=False)) + + +async def test_patch_task(cli: TaskClient) -> None: + data = await cli.patch(f"{cli.url}/tasks/dummy", json=dict(enabled=False)) + assert data["enabled"] is False + task = await cli.get_task("dummy") + assert task.enabled is False + data = await cli.post(f"{cli.url}/tasks", json=dict(name="dummy")) + task = await cli.wait_for_task("dummy", last_run_end=task.last_run_end) + assert task.enabled is False + assert task.last_run_state == TaskState.aborted + assert task.last_run_end is not None diff --git a/tests/scheduler/test_scheduler.py b/tests/scheduler/test_scheduler.py index 350fc5d..e5a7444 100644 --- a/tests/scheduler/test_scheduler.py +++ b/tests/scheduler/test_scheduler.py @@ -14,6 +14,8 @@ from fluid.scheduler.errors import UnknownTaskError from fluid.utils.waiter import wait_for +pytestmark = pytest.mark.asyncio(loop_scope="module") + @dataclass class WaitFor: @@ -26,7 +28,7 @@ def __call__(self, task_run: TaskRun) -> None: self.runs.append(task_run) -def test_scheduler_manager(task_scheduler: TaskScheduler) -> None: +async def test_scheduler_manager(task_scheduler: TaskScheduler) -> None: assert task_scheduler assert task_scheduler.broker.registry assert "dummy" in task_scheduler.registry