Skip to content

Commit

Permalink
Add endpoint tests
Browse files Browse the repository at this point in the history
  • Loading branch information
lsbardel committed Nov 24, 2024
1 parent 098901f commit 22edaa2
Show file tree
Hide file tree
Showing 10 changed files with 155 additions and 36 deletions.
2 changes: 1 addition & 1 deletion .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
"env": {},
"args": [
"-x",
"tests/scheduler/test_scheduler.py::test_async_handler"
"tests/scheduler/test_endpoints.py"
],
"debugOptions": [
"RedirectOutput"
Expand Down
30 changes: 17 additions & 13 deletions fluid/scheduler/broker.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,12 +102,10 @@ def task_from_registry(self, task: str | Task) -> Task:
def register_task(self, task: Task) -> None:
self.registry[task.name] = task

async def enable_task(self, task_name: str, enable: bool = True) -> TaskInfo:
async def enable_task(self, task: str | Task, enable: bool = True) -> TaskInfo:
"""Enable or disable a registered task"""
task = self.registry.get(task_name)
if not task:
raise UnknownTaskError(task_name)
return await self.update_task(task, dict(enabled=enable))
task_ = self.task_from_registry(task)
return await self.update_task(task_, dict(enabled=enable))

@classmethod
def from_url(cls, url: str = "") -> TaskBroker:
Expand Down Expand Up @@ -154,6 +152,13 @@ def task_hash_name(self, name: str) -> str:
def task_queue_name(self, priority: TaskPriority) -> str:
return f"{self.name}-queue-{priority.name}"

async def clear(self) -> int:
pipe = self.redis_cli.pipeline()
async for key in self.redis_cli.scan_iter(f"{self.name}-*"):
pipe.delete(key)
r = await pipe.execute()
return len(r)

async def get_tasks_info(self, *task_names: str) -> list[TaskInfo]:
pipe = self.redis_cli.pipeline()
names = task_names or self.registry
Expand All @@ -170,14 +175,13 @@ async def get_tasks_info(self, *task_names: str) -> list[TaskInfo]:

async def update_task(self, task: Task, params: dict[str, Any]) -> TaskInfo:
pipe = self.redis_cli.pipeline()
info = json.loads(TaskInfoUpdate(**params).model_dump_json())
pipe.hset(
self.task_hash_name(task.name),
mapping={name: json.dumps(value) for name, value in info.items()},
)
pipe.hgetall(self.task_hash_name(task.name))
_, info = await pipe.execute()
return self._decode_task(task, info)
info = json.loads(TaskInfoUpdate(**params).model_dump_json(exclude_unset=True))
update = {name: json.dumps(value) for name, value in info.items()}
key = self.task_hash_name(task.name)
pipe.hset(key, mapping=update)
pipe.hgetall(key)
_, data = await pipe.execute()
return self._decode_task(task, data)

async def queue_length(self) -> dict[str, int]:
if self.task_queue_names:
Expand Down
20 changes: 18 additions & 2 deletions fluid/scheduler/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,26 @@ async def get_tasks(task_manager: TaskManagerDep) -> list[TaskInfo]:


@router.get(
"/tasks/status",
"/tasks/{task_name}",
response_model=TaskInfo,
summary="Get a Task",
description="Retrieve information about a task",
)
async def get_task(
task_manager: TaskManagerDep,
task_name: str = Path(title="Task name"),
) -> TaskInfo:
data = await task_manager.broker.get_tasks_info(task_name)
if not data:
raise HTTPException(status_code=404, detail="Task not found")
return data[0]


@router.get(
"/tasks-status",
response_model=dict,
summary="Task consumer status",
description="Retrieve a list of tasks runs",
description="Status of the task consumer",
)
async def get_task_status(task_manager: TaskManagerDep) -> dict:
if isinstance(task_manager, Worker):
Expand Down
2 changes: 1 addition & 1 deletion fluid/scheduler/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ class TaskInfoUpdate(BaseModel):
last_run_duration: timedelta | None = Field(
default=None, description="Task last run duration in milliseconds"
)
last_run_state: str | None = Field(
last_run_state: TaskState | None = Field(
default=None, description="State of last task run"
)

Expand Down
8 changes: 0 additions & 8 deletions fluid/utils/http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,11 +268,3 @@ async def close(self) -> None:
if self.session and self.session_owner:
await self.session.aclose()
self.session = None


C = TypeVar("C", bound=HttpClient)


class HttpComponent(Generic[C]):
def __init__(self, cli: C) -> None:
self.cli = cli
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ build-backend = "poetry.core.masonry.api"

[tool.pytest.ini_options]
asyncio_mode = "auto"
asyncio_default_fixture_loop_scope = "module"
testpaths = [
"tests"
]
Expand Down
37 changes: 28 additions & 9 deletions tests/scheduler/conftest.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
import os
from contextlib import asynccontextmanager
from typing import AsyncIterator, cast

import pytest
from fastapi import FastAPI
from httpx import ASGITransport, AsyncClient
from redis.asyncio import Redis

from fluid.scheduler import TaskManager, TaskScheduler
from fluid.scheduler.broker import RedisTaskBroker
from fluid.scheduler.endpoints import get_task_manger, setup_fastapi
from tests.scheduler.tasks import task_application

os.environ["TASK_MANAGER_APP"] = "tests.scheduler.tasks:task_application"
from tests.scheduler.tasks import TaskClient, task_application


@asynccontextmanager
Expand All @@ -20,13 +18,34 @@ async def start_fastapi(app: FastAPI) -> AsyncIterator:
yield app


@pytest.fixture
async def task_scheduler() -> AsyncIterator[TaskManager]:
def redis_broker(task_manager: TaskManager) -> RedisTaskBroker:
return cast(RedisTaskBroker, task_manager.broker)


@pytest.fixture(scope="module")
async def task_app():
task_manager = task_application(TaskScheduler())
async with start_fastapi(setup_fastapi(task_manager)) as app:
yield get_task_manger(app)
broker = redis_broker(task_manager)
await broker.clear()
yield app


@pytest.fixture
@pytest.fixture(scope="module")
async def task_scheduler(task_app) -> TaskManager:
return get_task_manger(task_app)


@pytest.fixture(scope="module")
def redis(task_scheduler: TaskScheduler) -> Redis: # type: ignore
return cast(RedisTaskBroker, task_scheduler.broker).redis.redis_cli
return redis_broker(task_scheduler).redis.redis_cli


@pytest.fixture(scope="module")
async def cli(task_app):
base_url = TaskClient().url
async with AsyncClient(
transport=ASGITransport(app=task_app), base_url=base_url
) as session:
async with TaskClient(url=base_url, session=session) as client:
yield client
34 changes: 33 additions & 1 deletion tests/scheduler/tasks.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,37 @@
import asyncio
from dataclasses import dataclass
from datetime import datetime

from examples import tasks
from fluid.scheduler import TaskManager
from fluid.scheduler import TaskInfo, TaskManager
from fluid.utils.http_client import HttpxClient


@dataclass
class TaskClient(HttpxClient):
url: str = "http://test_api"

async def get_tasks(self) -> list[TaskInfo]:
data = await self.get(f"{self.url}/tasks")
return [TaskInfo(**task) for task in data]

async def get_task(self, task_name: str) -> TaskInfo:
data = await self.get(f"{self.url}/tasks/{task_name}")
return TaskInfo(**data)

async def wait_for_task(
self,
task_name: str,
last_run_end: datetime | None = None,
timeout: float = 1.0,
) -> TaskInfo:
sleep = min(timeout / 10.0, 0.1)
async with asyncio.timeout(timeout):
while True:
task = await self.get_task(task_name)
if task.last_run_end != last_run_end:
return task
await asyncio.sleep(sleep)


def task_application(manager: TaskManager | None = None) -> TaskManager:
Expand Down
53 changes: 53 additions & 0 deletions tests/scheduler/test_endpoints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import pytest

from fluid.scheduler.models import TaskInfo, TaskState
from fluid.utils.http_client import HttpResponseError
from tests.scheduler.tasks import TaskClient

pytestmark = pytest.mark.asyncio(loop_scope="module")


async def test_get_tasks(cli: TaskClient) -> None:
data = await cli.get(f"{cli.url}/tasks")
assert len(data) == 4
tasks = {task["name"]: TaskInfo(**task) for task in data}
dummy = tasks["dummy"]
assert dummy.name == "dummy"


async def test_get_tasks_status(cli: TaskClient) -> None:
data = await cli.get(f"{cli.url}/tasks-status")
assert data


async def test_run_task_404(cli: TaskClient) -> None:
with pytest.raises(HttpResponseError):
await cli.post(f"{cli.url}/tasks", json=dict(name="whatever"))


async def test_run_task(cli: TaskClient) -> None:
task = await cli.get_task("dummy")
assert task.last_run_end is None
task = await cli.get_task("dummy")
data = await cli.post(f"{cli.url}/tasks", json=dict(name="dummy"))
assert data["task"] == "dummy"
# wait for task
task = await cli.wait_for_task("dummy")
assert task.last_run_end is not None


async def test_patch_task_404(cli: TaskClient) -> None:
with pytest.raises(HttpResponseError):
await cli.patch(f"{cli.url}/tasks/whatever", json=dict(enabled=False))


async def test_patch_task(cli: TaskClient) -> None:
data = await cli.patch(f"{cli.url}/tasks/dummy", json=dict(enabled=False))
assert data["enabled"] is False
task = await cli.get_task("dummy")
assert task.enabled is False
data = await cli.post(f"{cli.url}/tasks", json=dict(name="dummy"))
task = await cli.wait_for_task("dummy", last_run_end=task.last_run_end)
assert task.enabled is False
assert task.last_run_state == TaskState.aborted
assert task.last_run_end is not None
4 changes: 3 additions & 1 deletion tests/scheduler/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from fluid.scheduler.errors import UnknownTaskError
from fluid.utils.waiter import wait_for

pytestmark = pytest.mark.asyncio(loop_scope="module")


@dataclass
class WaitFor:
Expand All @@ -26,7 +28,7 @@ def __call__(self, task_run: TaskRun) -> None:
self.runs.append(task_run)


def test_scheduler_manager(task_scheduler: TaskScheduler) -> None:
async def test_scheduler_manager(task_scheduler: TaskScheduler) -> None:
assert task_scheduler
assert task_scheduler.broker.registry
assert "dummy" in task_scheduler.registry
Expand Down

0 comments on commit 22edaa2

Please sign in to comment.