Skip to content

Commit

Permalink
release
Browse files Browse the repository at this point in the history
  • Loading branch information
lsbardel committed Oct 20, 2024
1 parent c1add4b commit bd75a48
Show file tree
Hide file tree
Showing 15 changed files with 161 additions and 33 deletions.
2 changes: 1 addition & 1 deletion .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
"env": {},
"args": [
"-x",
"tests/scheduler/test_scheduler.py::test_disabled_execution"
"tests/scheduler/test_scheduler.py::test_async_handler"
],
"debugOptions": [
"RedirectOutput"
Expand Down
13 changes: 13 additions & 0 deletions docs/reference/dispatchers.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Event Dispatchers

A set of classes for dispatching events, they can be imported from `fluid.utils.dispatcher`:

```python
from fluid.utils.dispatcher import Dispatcher
```

::: fluid.utils.dispatcher.BaseDispatcher

::: fluid.utils.dispatcher.Dispatcher

::: fluid.utils.dispatcher.AsyncDispatcher
10 changes: 10 additions & 0 deletions docs/reference/workers.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,18 @@ from fastapi.utils.worker import StoppingWorker

::: fluid.utils.worker.Worker

::: fluid.utils.worker.RunningWorker

::: fluid.utils.worker.StoppingWorker

::: fluid.utils.worker.WorkerFunction

::: fluid.utils.worker.QueueConsumer

::: fluid.utils.worker.QueueConsumerWorker

::: fluid.utils.worker.AsyncConsumer

::: fluid.utils.worker.Workers

::: fluid.utils.worker.DynamicWorkers
11 changes: 11 additions & 0 deletions docs/tutorials/dispatchers.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Event Dispatchers

Event dispatchers are a way to decouple the event source from the event handler. This is useful when you want to have multiple handlers for the same event, or when you want to have a single handler for multiple events.

```python
from fluid.utils.dispatcher import SimpleDispatcher

simple = SimpleDispatcher[Any]()

simple.dispatch("you can dispatch anything to this generic dispatcher")
```
58 changes: 49 additions & 9 deletions fluid/scheduler/consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@
from collections import defaultdict, deque
from contextlib import AsyncExitStack
from functools import partial
from typing import Any, Callable, Coroutine, Self
from typing import Any, Awaitable, Callable, Self

import async_timeout
from inflection import underscore
from typing_extensions import Annotated, Doc

from fluid.utils import log
from fluid.utils.dispatcher import Dispatcher
from fluid.utils.worker import WorkerFunction, Workers
from fluid.utils.dispatcher import AsyncDispatcher, Dispatcher, Event
from fluid.utils.worker import AsyncConsumer, WorkerFunction, Workers

from .broker import TaskBroker, TaskRegistry
from .errors import TaskAbortedError, TaskRunError, UnknownTaskError
Expand All @@ -31,13 +32,19 @@
TaskManagerCLI = None # type: ignore[assignment,misc]


AsyncExecutor = Callable[..., Coroutine[Any, Any, None]]
AsyncMessage = tuple[AsyncExecutor, tuple[Any, ...]]
AsyncHandler = Callable[[TaskRun], Awaitable[None]]

logger = log.get_logger(__name__)


class TaskDispatcher(Dispatcher[TaskRun]):
"""The task dispatcher is responsible for dispatching task run messages"""

def message_type(self, message: TaskRun) -> str:
return message.state


class AsyncTaskDispatcher(AsyncDispatcher[TaskRun]):

def message_type(self, message: TaskRun) -> str:
return message.state
Expand All @@ -49,7 +56,16 @@ class TaskManager:
def __init__(self, **kwargs: Any) -> None:
self.state: dict[str, Any] = {}
self.config: TaskManagerConfig = TaskManagerConfig(**kwargs)
self.dispatcher = TaskDispatcher()
self.dispatcher: Annotated[
TaskDispatcher,
Doc(
"""
A dispatcher of task run events.
Register handlers to listen for task run events.
"""
),
] = TaskDispatcher()
self.broker = TaskBroker.from_url(self.config.broker_url)
self._stack = AsyncExitStack()

Expand Down Expand Up @@ -83,9 +99,7 @@ async def on_shutdown(self) -> None:
await self.broker.close()

def execute_sync(self, task: Task | str, **params: Any) -> TaskRun:
return asyncio.get_event_loop().run_until_complete(
self._execute_and_exit(task, **params)
)
return asyncio.run(self._execute_and_exit(task, **params))

def register_task(self, task: Task) -> None:
"""Register a task with the task manager
Expand Down Expand Up @@ -139,6 +153,19 @@ def register_from_module(self, module: Any) -> None:
if isinstance(obj := getattr(module, name), Task):
self.register_task(obj)

def register_async_handler(self, event: str, handler: AsyncHandler) -> None:
"""Register an async handler for a given event
This method is a no op for a TaskManager that is not a worker
"""

def unregister_async_handler(self, event: Event | str) -> AsyncHandler | None:
"""Unregister an async handler for a given event
This method is a no op for a TaskManager that is not a worker
"""
return None

def cli(self, **kwargs: Any) -> Any:
"""Create the task manager command line interface"""
try:
Expand All @@ -163,13 +190,15 @@ class TaskConsumer(TaskManager, Workers):
def __init__(self, **config: Any) -> None:
super().__init__(**config)
Workers.__init__(self)
self._async_dispatcher_worker = AsyncConsumer(AsyncTaskDispatcher())
self._concurrent_tasks: dict[str, dict[str, TaskRun]] = defaultdict(dict)
self._task_to_queue: deque[str | Task] = deque()
self._priority_task_run_queue: deque[TaskRun] = deque()
self._queue_tasks_worker = WorkerFunction(
self._queue_task, name="queue-task-worker"
)
self.add_workers(self._queue_tasks_worker)
self.add_workers(self._async_dispatcher_worker)
for i in range(self.config.max_concurrent_tasks):
worker_name = f"task-worker-{i+1}"
self.add_workers(
Expand Down Expand Up @@ -200,6 +229,17 @@ async def queue_and_wait(
with TaskRunWaiter(self) as waiter:
return await waiter.wait(await self.queue(task, **params), timeout=timeout)

def register_async_handler(self, event: Event | str, handler: AsyncHandler) -> None:
event = Event.from_string_or_event(event)
self.dispatcher.register_handler(
f"{event.type}.async_dispatch",
self._async_dispatcher_worker.send,
)
self._async_dispatcher_worker.dispatcher.register_handler(event, handler)

def unregister_async_handler(self, event: Event | str) -> AsyncHandler | None:
return self._async_dispatcher_worker.dispatcher.unregister_handler(event)

# Internals

# process tasks from the internal queue
Expand Down
6 changes: 3 additions & 3 deletions fluid/scheduler/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from fluid import settings
from fluid.utils import kernel, log
from fluid.utils.data import compact_dict
from fluid.utils.dates import utcnow
from fluid.utils.dates import as_utc
from fluid.utils.text import create_uid, trim_docstring

from .crontab import Scheduler
Expand Down Expand Up @@ -223,7 +223,7 @@ def set_state(
) -> None:
if self.state == state:
return
state_time = state_time or utcnow()
state_time = as_utc(state_time)
match (self.state, state):
case (TaskState.init, TaskState.queued):
self.queued = state_time
Expand Down Expand Up @@ -262,7 +262,7 @@ def lock(self, timeout: float | None) -> Lock:
return self.task_manager.broker.lock(self.name, timeout=timeout)

def _dispatch(self) -> None:
self.task_manager.dispatcher.dispatch(self)
self.task_manager.dispatcher.dispatch(self.model_copy())


@dataclass
Expand Down
11 changes: 8 additions & 3 deletions fluid/utils/dates.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from datetime import datetime, timezone
from datetime import date, datetime, timezone
from typing import Any

from zoneinfo import ZoneInfo
Expand All @@ -10,8 +10,13 @@ def utcnow() -> datetime:
return datetime.now(tz=UTC)


def as_utc(dt: datetime) -> datetime:
return dt.replace(tzinfo=UTC)
def as_utc(dt: date | None) -> datetime:
if dt is None:
return utcnow()
elif isinstance(dt, datetime):
return dt.replace(tzinfo=UTC)
else:
return datetime(dt.year, dt.month, dt.day, tzinfo=UTC)


def isoformat(dt: datetime, **kwargs: Any) -> str:
Expand Down
24 changes: 18 additions & 6 deletions fluid/utils/dispatcher.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import asyncio
from abc import ABC, abstractmethod
from collections import defaultdict
Expand All @@ -11,6 +13,12 @@ class Event(NamedTuple):
type: str
tag: str

@classmethod
def from_string_or_event(cls, event: str | Self) -> Self:
if isinstance(event, str):
return cls.from_string(event)
return event

@classmethod
def from_string(cls, event: str) -> Self:
bits = event.split(".")
Expand All @@ -27,23 +35,23 @@ def __init__(self) -> None:

def register_handler(
self,
message_type: str,
event: Event | str,
handler: MessageHandlerType,
) -> MessageHandlerType | None:
event = Event.from_string(message_type)
event = Event.from_string_or_event(event)
previous = self._msg_handlers[event.type].get(event.tag)
self._msg_handlers[event.type][event.tag] = handler
return previous

def unregister_handler(self, message_type: str) -> MessageHandlerType | None:
event = Event.from_string(message_type)
def unregister_handler(self, event: Event | str) -> MessageHandlerType | None:
event = Event.from_string_or_event(event)
return self._msg_handlers[event.type].pop(event.tag, None)

def get_handlers(
self,
message: MessageType,
) -> dict[str, MessageHandlerType] | None:
message_type = self.message_type(message)
message_type = str(self.message_type(message))
return self._msg_handlers.get(message_type)

@abstractmethod
Expand All @@ -52,6 +60,8 @@ def message_type(self, message: MessageType) -> str:


class Dispatcher(BaseDispatcher[MessageType, Callable[[MessageType], None]]):
"""Dispatcher for sync handlers"""

def dispatch(self, message: MessageType) -> int:
"""dispatch the message"""
handlers = self.get_handlers(message)
Expand All @@ -64,8 +74,10 @@ def dispatch(self, message: MessageType) -> int:
class AsyncDispatcher(
BaseDispatcher[MessageType, Callable[[MessageType], Awaitable[None]]],
):
"""Dispatcher for async handlers"""

async def dispatch(self, message: MessageType) -> int:
"""Dispatch the message"""
"""Dispatch the message and wait for all handlers to complete"""
handlers = self.get_handlers(message)
if handlers:
await asyncio.gather(*[handler(message) for handler in handlers.values()])
Expand Down
4 changes: 1 addition & 3 deletions fluid/utils/waiter.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
import asyncio
from typing import Callable

import async_timeout


async def wait_for(assertion: Callable[[], bool], timeout: float = 1.0) -> None:
async with async_timeout.timeout(timeout):
async with asyncio.timeout(timeout):
while True:
if assertion():
return
Expand Down
14 changes: 12 additions & 2 deletions fluid/utils/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,12 @@ def __init__(self, name: str = "") -> None:

@property
def worker_name(self) -> str:
"""The name of the worker"""
return self._name

@property
def num_workers(self) -> int:
"""The number of workers in this worker"""
return 1

@abstractmethod
Expand All @@ -71,6 +73,7 @@ async def run(self) -> None:


class RunningWorker(Worker):
"""A Worker that can be started"""

def __init__(self, name: str = "") -> None:
super().__init__(name)
Expand Down Expand Up @@ -110,6 +113,8 @@ async def status(self) -> dict:


class WorkerFunction(StoppingWorker):
"""A Worker that runs a coroutine function"""

def __init__(
self,
run_function: Callable[[], Awaitable[None]],
Expand All @@ -135,15 +140,18 @@ def send(self, message: T | None) -> None: ...


class QueueConsumer(StoppingWorker, MessageProducer[MessageType]):
"""A Worker that can receive messages"""
"""A Worker that can receive messages
This worker can receive messages but not consume them.
"""

def __init__(self, name: str = "") -> None:
super().__init__(name=name)
self._queue: asyncio.Queue[MessageType | None] = asyncio.Queue()

async def get_message(self, timeout: float = 0.5) -> MessageType | None:
try:
async with async_timeout.timeout(timeout):
async with asyncio.timeout(timeout):
return await self._queue.get()
except asyncio.TimeoutError:
return None
Expand All @@ -166,6 +174,8 @@ def send(self, message: MessageType | None) -> None:


class QueueConsumerWorker(QueueConsumer[MessageType]):
"""A Worker that can receive and consume messages"""

def __init__(
self,
on_message: Callable[[MessageType], Awaitable[None]],
Expand Down
15 changes: 11 additions & 4 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ theme:
toggle:
icon: material/lightbulb-outline
name: Switch to light mode
features:
- content.code.copy
plugins:
search: null
mkdocstrings:
Expand All @@ -41,7 +43,12 @@ plugins:
show_symbol_type_heading: true
show_symbol_type_toc: true
markdown_extensions:
toc:
permalink: true
markdown.extensions.codehilite:
guess_lang: false
- pymdownx.highlight:
anchor_linenums: true
line_spans: __span
pygments_lang_class: true
- pymdownx.inlinehilite
- pymdownx.snippets
- pymdownx.superfences
- toc:
permalink: true
Loading

0 comments on commit bd75a48

Please sign in to comment.