Skip to content

Commit

Permalink
Wip
Browse files Browse the repository at this point in the history
  • Loading branch information
levsh committed Dec 26, 2024
1 parent 6c3083b commit 6ce65dd
Show file tree
Hide file tree
Showing 18 changed files with 252 additions and 203 deletions.
7 changes: 4 additions & 3 deletions arrlio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,12 @@
log_hndl.setFormatter(log_frmt)
logger.addHandler(log_hndl)

from arrlio import settings

# ruff: noqa: E402
from arrlio.configs import Config, TaskConfig # noqa
from arrlio.core import App, AsyncResult, registered_tasks, task # noqa
from arrlio.core import App, AsyncResult, get_app, registered_tasks, task # noqa
from arrlio.models import Event, Graph, Task, TaskInstance, TaskResult # noqa
from arrlio.settings import LOG_LEVEL


logger.setLevel(LOG_LEVEL)
logger.setLevel(settings.LOG_LEVEL)
12 changes: 6 additions & 6 deletions arrlio/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ async def stop_consume_events(self, callback_id: str | None = None):

class AbstractSerializer(ABC):
@abstractmethod
def dumps_task_instance(self, task_instance: TaskInstance, **kwds) -> bytes | TaskInstance:
def dumps_task_instance(self, task_instance: TaskInstance, **kwds) -> tuple[bytes | TaskInstance, dict]:
"""
Dump `TaskInstance`.
Expand All @@ -174,7 +174,7 @@ def dumps_task_instance(self, task_instance: TaskInstance, **kwds) -> bytes | Ta
"""

@abstractmethod
def loads_task_instance(self, data: bytes | TaskInstance, **kwds) -> TaskInstance:
def loads_task_instance(self, data: bytes | TaskInstance, headers: dict, **kwds) -> TaskInstance:
"""
Load `data` into `TaskInstance`.
Expand All @@ -189,7 +189,7 @@ def dumps_task_result(
*,
task_instance: TaskInstance | None = None,
**kwds,
) -> bytes | TaskResult:
) -> tuple[bytes | TaskResult, dict]:
"""
Dump `TaskResult`.
Expand All @@ -199,7 +199,7 @@ def dumps_task_result(
"""

@abstractmethod
def loads_task_result(self, data: bytes | TaskResult, **kwds) -> TaskResult:
def loads_task_result(self, data: bytes | TaskResult, headers: dict, **kwds) -> TaskResult:
"""
Load data into `TaskResult`.
Expand All @@ -208,7 +208,7 @@ def loads_task_result(self, data: bytes | TaskResult, **kwds) -> TaskResult:
"""

@abstractmethod
def dumps_event(self, event: Event, **kwds) -> bytes | Event:
def dumps_event(self, event: Event, **kwds) -> tuple[bytes | Event, dict]:
"""
Dump `arrlio.models.Event`.
Expand All @@ -217,7 +217,7 @@ def dumps_event(self, event: Event, **kwds) -> bytes | Event:
"""

@abstractmethod
def loads_event(self, data: bytes | Event, **kwds) -> Event:
def loads_event(self, data: bytes | Event, headers: dict, **kwds) -> Event:
"""
Load `data` into `Event`.
Expand Down
16 changes: 10 additions & 6 deletions arrlio/backends/brokers/rabbitmq.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging

from datetime import datetime, timezone
from itertools import repeat
from typing import Callable, Coroutine, Optional
from uuid import uuid4

Expand Down Expand Up @@ -164,7 +165,11 @@ def __repr__(self):
return self.__str__()

async def init(self):
await self._conn.open()
await retry(
msg=f"{self} init error",
retry_timeouts=repeat(5),
exc_filter=exc_filter,
)(self._conn.open)()

async def close(self):
await self._exchange.close()
Expand Down Expand Up @@ -207,7 +212,7 @@ async def _on_task_message(

task_instance = self.serializer.loads_task_instance(
message.body,
headers=message.header.properties.headers,
message.header.properties.headers,
)

reply_to = message.header.properties.reply_to
Expand All @@ -229,8 +234,8 @@ async def _on_task_message(
if task_instance.ack_late:
await channel.basic_ack(message.delivery.delivery_tag)

except Exception as e:
logger.exception(e)
except Exception:
logger.exception(message.header.properties)

async def _send_task(self, task_instance: TaskInstance, **kwds):
if is_debug_level():
Expand All @@ -240,8 +245,7 @@ async def _send_task(self, task_instance: TaskInstance, **kwds):
task_instance.pretty_repr(sanitize=settings.LOG_SANITIZE),
)

headers = {}
data = self.serializer.dumps_task_instance(task_instance, headers=headers)
data, headers = self.serializer.dumps_task_instance(task_instance)
task_headers = task_instance.headers

reply_to = task_headers.get("rabbitmq:reply_to")
Expand Down
16 changes: 13 additions & 3 deletions arrlio/backends/event_backends/rabbitmq.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from datetime import datetime, timezone
from functools import partial
from inspect import iscoroutinefunction
from itertools import repeat
from typing import Any, Callable, Optional
from uuid import uuid4

Expand Down Expand Up @@ -177,19 +178,25 @@ def __repr__(self):
return self.__str__()

async def init(self):
await self._conn.open()
await retry(
msg=f"{self} init error",
retry_timeouts=repeat(5),
exc_filter=exc_filter,
)(self._conn.open)()

async def close(self):
await super().close()

async def _send_event(self, event: Event):
data, headers = self.serializer.dumps_event(event)
await self._exchange.publish(
self.serializer.dumps_event(event),
data,
routing_key=event.type,
properties={
"delivery_mode": 2,
"timestamp": datetime.now(tz=timezone.utc),
"expiration": f"{int(event.ttl * 1000)}" if event.ttl is not None else None,
"headers": headers,
},
)

Expand Down Expand Up @@ -223,7 +230,10 @@ async def consume_events(

async def on_message(channel: aiormq.Channel, message: aiormq.abc.DeliveredMessage):
try:
event: Event = self.serializer.loads_event(message.body)
event: Event = self.serializer.loads_event(
message.body,
message.header.properties.headers,
)

if is_debug_level():
logger.debug(_("%s got event\n%s"), self, event.pretty_repr(sanitize=settings.LOG_SANITIZE))
Expand Down
2 changes: 1 addition & 1 deletion arrlio/backends/rabbitmq.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,4 +63,4 @@ def connection_factory(url: SecretAmqpDsn | list[SecretAmqpDsn]) -> Connection:

if not isinstance(url, list):
url = [url]
return Connection([u.get_secret_value() for u in url])
return Connection([f"{u.get_secret_value()}" for u in url])
16 changes: 10 additions & 6 deletions arrlio/backends/result_backends/rabbitmq.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from datetime import datetime, timezone
from functools import partial
from inspect import isasyncgenfunction, isgeneratorfunction
from itertools import repeat
from typing import AsyncGenerator, Optional
from uuid import UUID, uuid4

Expand Down Expand Up @@ -58,10 +59,10 @@
QUEUE_DURABLE = False
"""ResultBackend queue `durable` option."""

QUEUE_EXCLUSIVE = False
QUEUE_EXCLUSIVE = True
"""ResultBackend queue `excusive` option."""

QUEUE_AUTO_DELETE = True
QUEUE_AUTO_DELETE = False
"""ResultBackend queue `auto-delete` option."""

PREFETCH_COUNT = 10
Expand Down Expand Up @@ -205,7 +206,11 @@ def __repr__(self):
return self.__str__()

async def init(self):
await self._conn.open()
await retry(
msg=f"{self} init error",
retry_timeouts=repeat(5),
exc_filter=exc_filter,
)(self._conn.open)()

async def close(self):
await super().close()
Expand Down Expand Up @@ -289,7 +294,7 @@ async def _on_result_message(
properties: aiormq.spec.Basic.Properties = message.header.properties
task_id: UUID = UUID(properties.message_id)

task_result: TaskResult = self.serializer.loads_task_result(message.body, headers=properties.headers)
task_result: TaskResult = self.serializer.loads_task_result(message.body, properties.headers)

if not no_ack:
await channel.basic_ack(message.delivery.delivery_tag)
Expand Down Expand Up @@ -348,8 +353,7 @@ async def _push_task_result(
task_result.pretty_repr(sanitize=settings.LOG_SANITIZE),
)

headers = {}
data = self.serializer.dumps_task_result(task_result, task_instance=task_instance, headers=headers)
data, headers = self.serializer.dumps_task_result(task_result, task_instance=task_instance)

properties = {
"delivery_mode": 2,
Expand Down
26 changes: 20 additions & 6 deletions arrlio/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,14 @@
registered_tasks = rodict({}, nested=True)


_curr_app = ContextVar("curr_app", default=None)


def task(
func: FunctionType | MethodType | Type | None = None,
name: str | None = None,
base: Type[Task] | None = None,
**kwds: dict,
**kwds,
):
"""Task decorator.
Expand Down Expand Up @@ -168,7 +171,7 @@ def executor(self) -> Executor:
return self._executor

@property
def context(self):
def context(self) -> dict:
"""Application current context."""

return self._context.get()
Expand Down Expand Up @@ -300,7 +303,7 @@ async def send_task(

if is_info_level():
logger.info(
_("%s send task instance\n%s"),
_("%s send task\n%s"),
self,
task_instance.pretty_repr(sanitize=settings.LOG_SANITIZE),
)
Expand Down Expand Up @@ -378,12 +381,15 @@ async def cb(task_instance: TaskInstance):
idx_0 = uuid4().hex
idx_1 = 0

self._context.set({})
context = self.context

try:
task_result: TaskResult = TaskResult()

async with AsyncExitStack() as stack:
try:
self.context["task_instance"] = task_instance
context["task_instance"] = task_instance

for context_hook in self._hooks["task_context"]:
await stack.enter_async_context(context_hook(task_instance))
Expand Down Expand Up @@ -461,8 +467,12 @@ async def execute_task(self, task_instance: TaskInstance) -> AsyncGenerator[Task
Task result.
"""

async for task_result in self._executor(task_instance):
yield task_result
token = _curr_app.set(self)
try:
async for task_result in self._executor(task_instance):
yield task_result
finally:
_curr_app.reset(token)

async def consume_events(
self,
Expand Down Expand Up @@ -579,3 +589,7 @@ async def get(self) -> Any:
if noresult:
raise TaskClosedError(self.task_id)
return self._result


def get_app() -> App | None:
return _curr_app.get()
23 changes: 14 additions & 9 deletions arrlio/serializers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def loads(self, data: bytes | Any) -> Any:
data: data to load.
"""

def dumps_task_instance(self, task_instance: TaskInstance, **kwds) -> bytes:
def dumps_task_instance(self, task_instance: TaskInstance, **kwds) -> tuple[bytes, dict]:
"""
Dump `arrlio.models.TaskInstance` object as json encoded string.
Expand All @@ -60,9 +60,9 @@ def dumps_task_instance(self, task_instance: TaskInstance, **kwds) -> bytes:
headers = data["headers"]
if graph := headers.get("graph:graph"):
headers["graph:graph"] = graph.asdict()
return self.dumps({k: v for k, v in data.items() if v is not None})
return self.dumps({k: v for k, v in data.items() if v is not None}), {}

def loads_task_instance(self, data: bytes, **kwds) -> TaskInstance:
def loads_task_instance(self, data: bytes, headers: dict, **kwds) -> TaskInstance:
"""
Load `arrlio.models.TaskInstance` object from json encoded string.
Expand Down Expand Up @@ -133,7 +133,12 @@ def loads_trb(self, trb: str) -> str:

return trb

def dumps_task_result(self, task_result: TaskResult, task_instance: TaskInstance | None = None, **kwds) -> bytes:
def dumps_task_result(
self,
task_result: TaskResult,
task_instance: TaskInstance | None = None,
**kwds,
) -> tuple[bytes, dict]:
"""
Dump `arrlio.models.TaskResult` as json encoded string.
Expand All @@ -148,9 +153,9 @@ def dumps_task_result(self, task_result: TaskResult, task_instance: TaskInstance
data["trb"] = self.dumps_trb(data["trb"])
elif task_instance and task_instance.dumps:
data["res"] = task_instance.dumps(data["res"])
return self.dumps(data)
return self.dumps(data), {}

def loads_task_result(self, data: bytes, **kwds) -> TaskResult:
def loads_task_result(self, data: bytes, headers: dict, **kwds) -> TaskResult:
"""
Load `arrlio.models.TaskResult` from json encoded string.
Expand All @@ -164,7 +169,7 @@ def loads_task_result(self, data: bytes, **kwds) -> TaskResult:
data["trb"] = self.loads_trb(data["trb"])
return TaskResult(**data)

def dumps_event(self, event: Event, **kwds) -> bytes:
def dumps_event(self, event: Event, **kwds) -> tuple[bytes, dict]:
"""
Dump `arrlio.models.Event` as json encoded string.
Expand All @@ -184,9 +189,9 @@ def dumps_event(self, event: Event, **kwds) -> bytes:
if result["exc"]:
result["exc"] = self.dumps_exc(result["exc"])
result["trb"] = self.dumps_trb(result["trb"])
return self.dumps(data)
return self.dumps(data), {}

def loads_event(self, data: bytes, **kwds) -> Event:
def loads_event(self, data: bytes, headers: dict, **kwds) -> Event:
"""
Load `arrlio.models.Event` from json encoded string.
Expand Down
Loading

0 comments on commit 6ce65dd

Please sign in to comment.