Skip to content

Commit d106203

Browse files
committed
Better cpubound
1 parent d811720 commit d106203

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+1356
-1169
lines changed

.dev/install

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
#!/usr/bin/env bash
22
pip install -U pip poetry
3-
poetry install --all-extras
3+
poetry install --all-extras --with docs

docs/index.md

+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Welcome to MkDocs
2+
3+
For full documentation visit [mkdocs.org](https://www.mkdocs.org).
4+
5+
## Commands
6+
7+
* `mkdocs new [dir-name]` - Create a new project.
8+
* `mkdocs serve` - Start the live-reloading docs server.
9+
* `mkdocs build` - Build the documentation site.
10+
* `mkdocs -h` - Print help message and exit.
11+
12+
## Project layout
13+
14+
mkdocs.yml # The configuration file.
15+
docs/
16+
index.md # The documentation homepage.
17+
... # Other markdown pages, images and other files.

fluid/scheduler/readme.md docs/scheduler.md

+9-9
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,18 @@
1-
# Distrubuted Task Producer/Consumer
1+
# Distributed Task Producer/Consumer
22

33
This module has a lightweight implementation of a distributed task producer (TaskScheduler) and consumer (TaskConsumer).
44
The middleware for distributing tasks can be configured via the Broker interface.
55
A redis broker is provided for convenience.
66

77
## Tasks
88

9-
Tasks are standard python async functions decorated with the `task` or `cpu_task` decorators.
9+
Tasks are standard python async functions decorated with the `task` decorator.
1010

1111
```python
12-
from fluid.scheduler import task, TaskContext
12+
from fluid.scheduler import task, TaskRun
1313

1414
@task
15-
async def say_hi(ctx: TaskContext):
15+
async def say_hi(ctx: TaskRun):
1616
return "Hi!"
1717
```
1818

@@ -21,10 +21,10 @@ There are two types of tasks implemented
2121
* **Simple concurrent tasks** - they run concurrently with the task consumer - thy must be IO type tasks (no heavy CPU bound operations)
2222

2323
```python
24-
from fluid.scheduler import task, TaskContext
24+
from fluid.scheduler import task, TaskRun
2525

2626
@task
27-
async def fecth_data(ctx: TaskContext):
27+
async def fecth_data(ctx: TaskRun):
2828
# fetch data
2929
data = await http_cli.get("https://...")
3030
data_id = await datastore_cli.stote(data)
@@ -35,10 +35,10 @@ There are two types of tasks implemented
3535
* **CPU bound tasks** - they run on a subprocess
3636

3737
```python
38-
from fluid.scheduler import cpu_task, TaskContext
38+
from fluid.scheduler import task, TaskRun
3939

40-
@cpu_task
41-
async def heavy_calculation(ctx: TaskContext):
40+
@task(cpu_bound=True)
41+
async def heavy_calculation(ctx: TaskRun):
4242
# perform some heavy calculation
4343
data = await datastore_cli.get(ctx.params["data_id"])
4444
...

fluid/scheduler/__init__.py

+1-19
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,9 @@
11
from .broker import Broker, QueuedTask
2-
from .constants import TaskPriority, TaskState
32
from .consumer import TaskConsumer, TaskManager
43
from .crontab import Scheduler, crontab
54
from .every import every
5+
from .models import Task, TaskInfo, TaskPriority, TaskRun, TaskState, task
66
from .scheduler import TaskScheduler
7-
from .task import (
8-
Task,
9-
TaskConstructor,
10-
TaskContext,
11-
TaskDecoratorError,
12-
TaskExecutor,
13-
TaskRunError,
14-
create_task_app,
15-
task,
16-
)
17-
from .task_info import TaskInfo
18-
from .task_run import TaskRun
197

208
__all__ = [
219
"Scheduler",
@@ -24,18 +12,12 @@
2412
"TaskScheduler",
2513
"task",
2614
"Task",
27-
"TaskContext",
28-
"TaskRunError",
29-
"TaskExecutor",
30-
"TaskConstructor",
31-
"TaskDecoratorError",
3215
"TaskManager",
3316
"TaskConsumer",
3417
"TaskInfo",
3518
"TaskRun",
3619
"QueuedTask",
3720
"Broker",
38-
"create_task_app",
3921
"crontab",
4022
"every",
4123
]

fluid/scheduler/broker.py

+67-84
Original file line numberDiff line numberDiff line change
@@ -2,53 +2,36 @@
22

33
from abc import ABC, abstractmethod, abstractproperty
44
from functools import cached_property
5-
from typing import Any, Dict, Iterable, List, NamedTuple, Optional
5+
from typing import TYPE_CHECKING, Any, Iterable
66
from uuid import uuid4
77

8+
from redis.asyncio.lock import Lock
89
from yarl import URL
910

10-
from fluid import json
11-
from fluid.tools.redis import FluidRedis
12-
from fluid.tools.timestamp import Timestamp
11+
from fluid import settings
12+
from fluid.utils.redis import Redis, FluidRedis
13+
import json
14+
from .errors import UnknownTaskError
1315

14-
from . import settings
15-
from .constants import TaskPriority, TaskState
16-
from .task import Task
17-
from .task_info import TaskInfo
18-
from .task_run import TaskRun
19-
20-
_brokers: dict[str, type[Broker]] = {}
21-
22-
23-
def broker_url_from_env() -> URL:
24-
return URL(settings.SCHEDULER_BROKER_URL)
16+
from .models import QueuedTask, Task, TaskInfo, TaskPriority, TaskRun
2517

18+
if TYPE_CHECKING: # pragma: no cover
19+
from .consumer import TaskManager
2620

27-
class TaskError(RuntimeError):
28-
pass
2921

30-
31-
class UnknownTask(TaskError):
32-
pass
22+
_brokers: dict[str, type[Broker]] = {}
3323

3424

35-
class DisabledTask(TaskError):
36-
pass
25+
def broker_url_from_env() -> URL:
26+
return URL(settings.BROKER_URL)
3727

3828

39-
class TaskRegistry(Dict[str, Task]):
29+
class TaskRegistry(dict[str, Task]):
4030
def periodic(self) -> Iterable[Task]:
4131
for task in self.values():
4232
yield task
4333

4434

45-
class QueuedTask(NamedTuple):
46-
run_id: str
47-
task: str
48-
params: Dict[str, Any]
49-
priority: Optional[TaskPriority] = None
50-
51-
5235
class Broker(ABC):
5336
def __init__(self, url: URL) -> None:
5437
self.url: URL = url
@@ -59,15 +42,17 @@ def task_queue_names(self) -> tuple[str, ...]:
5942
"""Names of the task queues"""
6043

6144
@abstractmethod
62-
async def queue_task(self, queued_task: QueuedTask) -> TaskRun:
45+
async def queue_task(
46+
self, task_manager: TaskManager, queued_task: QueuedTask
47+
) -> TaskRun:
6348
"""Queue a task"""
6449

6550
@abstractmethod
66-
async def get_task_run(self) -> Optional[TaskRun]:
51+
async def get_task_run(self, task_manager: TaskManager) -> TaskRun | None:
6752
"""Get a Task run from the task queue"""
6853

6954
@abstractmethod
70-
async def queue_length(self) -> Dict[str, int]:
55+
async def queue_length(self) -> dict[str, int]:
7156
"""Length of task queues"""
7257

7358
@abstractmethod
@@ -78,15 +63,22 @@ async def get_tasks_info(self, *task_names: str) -> list[TaskInfo]:
7863
async def update_task(self, task: Task, params: dict[str, Any]) -> TaskInfo:
7964
"""Update a task dynamic parameters"""
8065

66+
@abstractmethod
8167
async def close(self) -> None:
8268
"""Close the broker on shutdown"""
8369

70+
@abstractmethod
71+
def lock(self, name: str, timeout: float | None = None) -> Lock:
72+
"""Create a lock"""
73+
8474
def new_uuid(self) -> str:
8575
return uuid4().hex
8676

8777
async def filter_tasks(
88-
self, scheduled: Optional[bool] = None, enabled: Optional[bool] = None
89-
) -> List[Task]:
78+
self,
79+
scheduled: bool | None = None,
80+
enabled: bool | None = None,
81+
) -> list[Task]:
9082
task_info = await self.get_tasks_info()
9183
task_map = {info.name: info for info in task_info}
9284
tasks = []
@@ -105,7 +97,7 @@ def task_from_registry(self, task: str | Task) -> Task:
10597
else:
10698
if task_ := self.registry.get(task):
10799
return task_
108-
raise UnknownTask(task)
100+
raise UnknownTaskError(task)
109101

110102
def register_task(self, task: Task) -> None:
111103
self.registry[task.name] = task
@@ -114,41 +106,15 @@ async def enable_task(self, task_name: str, enable: bool = True) -> TaskInfo:
114106
"""Enable or disable a registered task"""
115107
task = self.registry.get(task_name)
116108
if not task:
117-
raise UnknownTask(task_name)
109+
raise UnknownTaskError(task_name)
118110
return await self.update_task(task, dict(enabled=enable))
119111

120-
def task_run_from_data(self, data: Dict[str, Any]) -> TaskRun:
121-
"""Build a TaskRun object from its metadata"""
122-
data = data.copy()
123-
name = data.pop("name")
124-
data["task"] = self.task_from_registry(name)
125-
return TaskRun(**data)
126-
127-
def task_run_data(
128-
self, queued_task: QueuedTask, state: TaskState
129-
) -> Dict[str, Any]:
130-
"""Create a dictionary of metadata required by a task run
131-
132-
This dictionary must be serializable by the broker
133-
"""
134-
task = self.task_from_registry(queued_task.task)
135-
priority = queued_task.priority or task.priority
136-
return dict(
137-
id=queued_task.run_id,
138-
name=task.name,
139-
priority=priority.name,
140-
state=state.name,
141-
params=queued_task.params,
142-
queued=Timestamp.utcnow(),
143-
)
144-
145112
@classmethod
146113
def from_url(cls, url: str = "") -> Broker:
147114
p = URL(url or broker_url_from_env())
148-
Factory = _brokers.get(p.scheme)
149-
if not Factory:
150-
raise RuntimeError(f"Invalid broker {p}")
151-
return Factory(p)
115+
if factory := _brokers.get(p.scheme):
116+
return factory(p)
117+
raise RuntimeError(f"Invalid broker {p}")
152118

153119
@classmethod
154120
def register_broker(cls, name: str, factory: type[Broker]) -> None:
@@ -160,7 +126,11 @@ class RedisBroker(Broker):
160126

161127
@cached_property
162128
def redis(self) -> FluidRedis:
163-
return FluidRedis(str(self.url.with_query({})), name=self.name)
129+
return FluidRedis.create(str(self.url.with_query({})), name=self.name)
130+
131+
@property
132+
def redis_cli(self) -> Redis:
133+
return self.redis.redis_cli
164134

165135
@property
166136
def name(self) -> str:
@@ -185,7 +155,7 @@ def task_queue_name(self, priority: TaskPriority) -> str:
185155
return f"{self.name}-queue-{priority.name}"
186156

187157
async def get_tasks_info(self, *task_names: str) -> list[TaskInfo]:
188-
pipe = self.redis.cli.pipeline()
158+
pipe = self.redis_cli.pipeline()
189159
names = task_names or self.registry
190160
requested_task_names = []
191161
for name in names:
@@ -199,7 +169,7 @@ async def get_tasks_info(self, *task_names: str) -> list[TaskInfo]:
199169
]
200170

201171
async def update_task(self, task: Task, params: dict[str, Any]) -> TaskInfo:
202-
pipe = self.redis.cli.pipeline()
172+
pipe = self.redis_cli.pipeline()
203173
pipe.hset(
204174
self.task_hash_name(task.name),
205175
mapping={name: json.dumps(value) for name, value in params.items()},
@@ -208,41 +178,53 @@ async def update_task(self, task: Task, params: dict[str, Any]) -> TaskInfo:
208178
_, info = await pipe.execute()
209179
return self._decode_task(task, info)
210180

211-
async def queue_length(self) -> Dict[str, int]:
181+
async def queue_length(self) -> dict[str, int]:
212182
if self.task_queue_names:
213-
pipe = self.redis.cli.pipeline()
183+
pipe = self.redis_cli.pipeline()
214184
for name in self.task_queue_names:
215185
pipe.llen(name)
216186
result = await pipe.execute()
217-
return {p.name: r for p, r in zip(TaskPriority, result)}
187+
return dict(zip(TaskPriority, result))
218188
return {}
219189

220190
async def close(self) -> None:
221191
"""Close the broker on shutdown"""
222192
await self.redis.close()
223193

224-
async def get_task_run(self) -> Optional[TaskRun]:
194+
async def get_task_run(self, task_manager: TaskManager) -> TaskRun | None:
225195
if self.task_queue_names:
226-
data = await self.redis.cli.brpop(self.task_queue_names, timeout=1)
227-
if data:
228-
data_str = data[1].decode("utf-8")
229-
return self.task_run_from_data(json.loads(data_str))
196+
if redis_data := await self.redis_cli.brpop( # type: ignore [misc]
197+
self.task_queue_names, # type: ignore [arg-type]
198+
timeout=1,
199+
):
200+
data = json.loads(redis_data[1])
201+
data.update(
202+
task=self.task_from_registry(data["task"]),
203+
task_manager=task_manager,
204+
)
205+
return TaskRun(**data)
230206
return None
231207

232-
async def queue_task(self, queued_task: QueuedTask) -> TaskRun:
233-
task = self.task_from_registry(queued_task.task)
234-
priority = queued_task.priority or task.priority
235-
data = self.task_run_data(queued_task, TaskState.queued)
236-
await self.redis.cli.lpush(self.task_queue_name(priority), json.dumps(data))
237-
return self.task_run_from_data(data)
208+
async def queue_task(
209+
self, task_manager: TaskManager, queued_task: QueuedTask
210+
) -> TaskRun:
211+
task_run = self.create_task_run(task_manager, queued_task)
212+
await self.redis_cli.lpush( # type: ignore [misc]
213+
self.task_queue_name(task_run.priority),
214+
task_run.model_dump_json(),
215+
)
216+
return task_run
217+
218+
def lock(self, name: str, timeout: float | None = None) -> Lock:
219+
return self.redis_cli.lock(name, timeout=timeout)
238220

239221
def _decode_task(self, task: Task, data: dict[bytes, Any]) -> TaskInfo:
240222
info = {name.decode(): json.loads(value) for name, value in data.items()}
241223
return TaskInfo(
242224
name=task.name,
243225
description=task.description,
244226
schedule=str(task.schedule) if task.schedule else None,
245-
priority=task.priority.name,
227+
priority=task.priority,
246228
enabled=info.get("enabled", True),
247229
last_run_duration=info.get("last_run_duration"),
248230
last_run_end=info.get("last_run_end"),
@@ -251,3 +233,4 @@ def _decode_task(self, task: Task, data: dict[bytes, Any]) -> TaskInfo:
251233

252234

253235
Broker.register_broker("redis", RedisBroker)
236+
Broker.register_broker("rediss", RedisBroker)

0 commit comments

Comments
 (0)