2
2
3
3
from abc import ABC , abstractmethod , abstractproperty
4
4
from functools import cached_property
5
- from typing import Any , Dict , Iterable , List , NamedTuple , Optional
5
+ from typing import TYPE_CHECKING , Any , Iterable
6
6
from uuid import uuid4
7
7
8
+ from redis .asyncio .lock import Lock
8
9
from yarl import URL
9
10
10
- from fluid import json
11
- from fluid .tools .redis import FluidRedis
12
- from fluid .tools .timestamp import Timestamp
11
+ from fluid import settings
12
+ from fluid .utils .redis import Redis , FluidRedis
13
+ import json
14
+ from .errors import UnknownTaskError
13
15
14
- from . import settings
15
- from .constants import TaskPriority , TaskState
16
- from .task import Task
17
- from .task_info import TaskInfo
18
- from .task_run import TaskRun
19
-
20
- _brokers : dict [str , type [Broker ]] = {}
21
-
22
-
23
- def broker_url_from_env () -> URL :
24
- return URL (settings .SCHEDULER_BROKER_URL )
16
+ from .models import QueuedTask , Task , TaskInfo , TaskPriority , TaskRun
25
17
18
+ if TYPE_CHECKING : # pragma: no cover
19
+ from .consumer import TaskManager
26
20
27
- class TaskError (RuntimeError ):
28
- pass
29
21
30
-
31
- class UnknownTask (TaskError ):
32
- pass
22
+ _brokers : dict [str , type [Broker ]] = {}
33
23
34
24
35
- class DisabledTask ( TaskError ) :
36
- pass
25
+ def broker_url_from_env () -> URL :
26
+ return URL ( settings . BROKER_URL )
37
27
38
28
39
- class TaskRegistry (Dict [str , Task ]):
29
+ class TaskRegistry (dict [str , Task ]):
40
30
def periodic (self ) -> Iterable [Task ]:
41
31
for task in self .values ():
42
32
yield task
43
33
44
34
45
- class QueuedTask (NamedTuple ):
46
- run_id : str
47
- task : str
48
- params : Dict [str , Any ]
49
- priority : Optional [TaskPriority ] = None
50
-
51
-
52
35
class Broker (ABC ):
53
36
def __init__ (self , url : URL ) -> None :
54
37
self .url : URL = url
@@ -59,15 +42,17 @@ def task_queue_names(self) -> tuple[str, ...]:
59
42
"""Names of the task queues"""
60
43
61
44
@abstractmethod
62
- async def queue_task (self , queued_task : QueuedTask ) -> TaskRun :
45
+ async def queue_task (
46
+ self , task_manager : TaskManager , queued_task : QueuedTask
47
+ ) -> TaskRun :
63
48
"""Queue a task"""
64
49
65
50
@abstractmethod
66
- async def get_task_run (self ) -> Optional [ TaskRun ] :
51
+ async def get_task_run (self , task_manager : TaskManager ) -> TaskRun | None :
67
52
"""Get a Task run from the task queue"""
68
53
69
54
@abstractmethod
70
- async def queue_length (self ) -> Dict [str , int ]:
55
+ async def queue_length (self ) -> dict [str , int ]:
71
56
"""Length of task queues"""
72
57
73
58
@abstractmethod
@@ -78,15 +63,22 @@ async def get_tasks_info(self, *task_names: str) -> list[TaskInfo]:
78
63
async def update_task (self , task : Task , params : dict [str , Any ]) -> TaskInfo :
79
64
"""Update a task dynamic parameters"""
80
65
66
+ @abstractmethod
81
67
async def close (self ) -> None :
82
68
"""Close the broker on shutdown"""
83
69
70
+ @abstractmethod
71
+ def lock (self , name : str , timeout : float | None = None ) -> Lock :
72
+ """Create a lock"""
73
+
84
74
def new_uuid (self ) -> str :
85
75
return uuid4 ().hex
86
76
87
77
async def filter_tasks (
88
- self , scheduled : Optional [bool ] = None , enabled : Optional [bool ] = None
89
- ) -> List [Task ]:
78
+ self ,
79
+ scheduled : bool | None = None ,
80
+ enabled : bool | None = None ,
81
+ ) -> list [Task ]:
90
82
task_info = await self .get_tasks_info ()
91
83
task_map = {info .name : info for info in task_info }
92
84
tasks = []
@@ -105,7 +97,7 @@ def task_from_registry(self, task: str | Task) -> Task:
105
97
else :
106
98
if task_ := self .registry .get (task ):
107
99
return task_
108
- raise UnknownTask (task )
100
+ raise UnknownTaskError (task )
109
101
110
102
def register_task (self , task : Task ) -> None :
111
103
self .registry [task .name ] = task
@@ -114,41 +106,15 @@ async def enable_task(self, task_name: str, enable: bool = True) -> TaskInfo:
114
106
"""Enable or disable a registered task"""
115
107
task = self .registry .get (task_name )
116
108
if not task :
117
- raise UnknownTask (task_name )
109
+ raise UnknownTaskError (task_name )
118
110
return await self .update_task (task , dict (enabled = enable ))
119
111
120
- def task_run_from_data (self , data : Dict [str , Any ]) -> TaskRun :
121
- """Build a TaskRun object from its metadata"""
122
- data = data .copy ()
123
- name = data .pop ("name" )
124
- data ["task" ] = self .task_from_registry (name )
125
- return TaskRun (** data )
126
-
127
- def task_run_data (
128
- self , queued_task : QueuedTask , state : TaskState
129
- ) -> Dict [str , Any ]:
130
- """Create a dictionary of metadata required by a task run
131
-
132
- This dictionary must be serializable by the broker
133
- """
134
- task = self .task_from_registry (queued_task .task )
135
- priority = queued_task .priority or task .priority
136
- return dict (
137
- id = queued_task .run_id ,
138
- name = task .name ,
139
- priority = priority .name ,
140
- state = state .name ,
141
- params = queued_task .params ,
142
- queued = Timestamp .utcnow (),
143
- )
144
-
145
112
@classmethod
146
113
def from_url (cls , url : str = "" ) -> Broker :
147
114
p = URL (url or broker_url_from_env ())
148
- Factory = _brokers .get (p .scheme )
149
- if not Factory :
150
- raise RuntimeError (f"Invalid broker { p } " )
151
- return Factory (p )
115
+ if factory := _brokers .get (p .scheme ):
116
+ return factory (p )
117
+ raise RuntimeError (f"Invalid broker { p } " )
152
118
153
119
@classmethod
154
120
def register_broker (cls , name : str , factory : type [Broker ]) -> None :
@@ -160,7 +126,11 @@ class RedisBroker(Broker):
160
126
161
127
@cached_property
162
128
def redis (self ) -> FluidRedis :
163
- return FluidRedis (str (self .url .with_query ({})), name = self .name )
129
+ return FluidRedis .create (str (self .url .with_query ({})), name = self .name )
130
+
131
+ @property
132
+ def redis_cli (self ) -> Redis :
133
+ return self .redis .redis_cli
164
134
165
135
@property
166
136
def name (self ) -> str :
@@ -185,7 +155,7 @@ def task_queue_name(self, priority: TaskPriority) -> str:
185
155
return f"{ self .name } -queue-{ priority .name } "
186
156
187
157
async def get_tasks_info (self , * task_names : str ) -> list [TaskInfo ]:
188
- pipe = self .redis . cli .pipeline ()
158
+ pipe = self .redis_cli .pipeline ()
189
159
names = task_names or self .registry
190
160
requested_task_names = []
191
161
for name in names :
@@ -199,7 +169,7 @@ async def get_tasks_info(self, *task_names: str) -> list[TaskInfo]:
199
169
]
200
170
201
171
async def update_task (self , task : Task , params : dict [str , Any ]) -> TaskInfo :
202
- pipe = self .redis . cli .pipeline ()
172
+ pipe = self .redis_cli .pipeline ()
203
173
pipe .hset (
204
174
self .task_hash_name (task .name ),
205
175
mapping = {name : json .dumps (value ) for name , value in params .items ()},
@@ -208,41 +178,53 @@ async def update_task(self, task: Task, params: dict[str, Any]) -> TaskInfo:
208
178
_ , info = await pipe .execute ()
209
179
return self ._decode_task (task , info )
210
180
211
- async def queue_length (self ) -> Dict [str , int ]:
181
+ async def queue_length (self ) -> dict [str , int ]:
212
182
if self .task_queue_names :
213
- pipe = self .redis . cli .pipeline ()
183
+ pipe = self .redis_cli .pipeline ()
214
184
for name in self .task_queue_names :
215
185
pipe .llen (name )
216
186
result = await pipe .execute ()
217
- return { p . name : r for p , r in zip (TaskPriority , result )}
187
+ return dict ( zip (TaskPriority , result ))
218
188
return {}
219
189
220
190
async def close (self ) -> None :
221
191
"""Close the broker on shutdown"""
222
192
await self .redis .close ()
223
193
224
- async def get_task_run (self ) -> Optional [ TaskRun ] :
194
+ async def get_task_run (self , task_manager : TaskManager ) -> TaskRun | None :
225
195
if self .task_queue_names :
226
- data = await self .redis .cli .brpop (self .task_queue_names , timeout = 1 )
227
- if data :
228
- data_str = data [1 ].decode ("utf-8" )
229
- return self .task_run_from_data (json .loads (data_str ))
196
+ if redis_data := await self .redis_cli .brpop ( # type: ignore [misc]
197
+ self .task_queue_names , # type: ignore [arg-type]
198
+ timeout = 1 ,
199
+ ):
200
+ data = json .loads (redis_data [1 ])
201
+ data .update (
202
+ task = self .task_from_registry (data ["task" ]),
203
+ task_manager = task_manager ,
204
+ )
205
+ return TaskRun (** data )
230
206
return None
231
207
232
- async def queue_task (self , queued_task : QueuedTask ) -> TaskRun :
233
- task = self .task_from_registry (queued_task .task )
234
- priority = queued_task .priority or task .priority
235
- data = self .task_run_data (queued_task , TaskState .queued )
236
- await self .redis .cli .lpush (self .task_queue_name (priority ), json .dumps (data ))
237
- return self .task_run_from_data (data )
208
+ async def queue_task (
209
+ self , task_manager : TaskManager , queued_task : QueuedTask
210
+ ) -> TaskRun :
211
+ task_run = self .create_task_run (task_manager , queued_task )
212
+ await self .redis_cli .lpush ( # type: ignore [misc]
213
+ self .task_queue_name (task_run .priority ),
214
+ task_run .model_dump_json (),
215
+ )
216
+ return task_run
217
+
218
+ def lock (self , name : str , timeout : float | None = None ) -> Lock :
219
+ return self .redis_cli .lock (name , timeout = timeout )
238
220
239
221
def _decode_task (self , task : Task , data : dict [bytes , Any ]) -> TaskInfo :
240
222
info = {name .decode (): json .loads (value ) for name , value in data .items ()}
241
223
return TaskInfo (
242
224
name = task .name ,
243
225
description = task .description ,
244
226
schedule = str (task .schedule ) if task .schedule else None ,
245
- priority = task .priority . name ,
227
+ priority = task .priority ,
246
228
enabled = info .get ("enabled" , True ),
247
229
last_run_duration = info .get ("last_run_duration" ),
248
230
last_run_end = info .get ("last_run_end" ),
@@ -251,3 +233,4 @@ def _decode_task(self, task: Task, data: dict[bytes, Any]) -> TaskInfo:
251
233
252
234
253
235
Broker .register_broker ("redis" , RedisBroker )
236
+ Broker .register_broker ("rediss" , RedisBroker )
0 commit comments