Skip to content

Commit

Permalink
Wip
Browse files Browse the repository at this point in the history
  • Loading branch information
levsh committed Apr 17, 2024
1 parent e19cd17 commit 4cad63c
Show file tree
Hide file tree
Showing 6 changed files with 64 additions and 34 deletions.
2 changes: 1 addition & 1 deletion arrlio/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def _cancel_internal_tasks(self, key: str):
for task in self._internal_tasks[key]:
task.cancel()

def _create_internal_task(self, key: str, coro_factory: Callable) -> asyncio.Task:
def _create_internal_task(self, key: str, coro_factory: Callable[[], Coroutine]) -> asyncio.Task:
if self._closed.done():
raise Exception(f"{self} closed")

Expand Down
67 changes: 48 additions & 19 deletions arrlio/backends/rabbitmq.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@
from functools import partial
from inspect import isasyncgenfunction, iscoroutine, iscoroutinefunction, isgeneratorfunction
from ssl import SSLContext
from typing import Annotated, Any, AsyncGenerator, Callable, Coroutine, Optional
from typing import Any, AsyncGenerator, Callable, Coroutine, Optional
from uuid import UUID

import aiormq
import aiormq.exceptions
import yarl
from pydantic import Field, PlainSerializer, PositiveInt
from pydantic import Field, PositiveInt
from pydantic_settings import SettingsConfigDict

from arrlio import settings
Expand Down Expand Up @@ -313,7 +313,40 @@ async def channel(self) -> aiormq.Channel:
return self._channel


@dataclass(frozen=True)
@dataclass(slots=True, frozen=True)
class SimpleExchange:
conn: Connection
name: str = ""
timeout: int | None = None

async def publish(
self,
data: bytes,
routing_key: str,
properties: dict | None = None,
timeout: int | None = None,
):
channel = await self.conn.channel()

if is_debug_level():
logger.debug(
"Exchange[name='%s'] channel[%s] publish[routing_key='%s'] %s",
self.name,
channel,
routing_key,
data if not settings.LOG_SANITIZE else "<hiden>",
)

await channel.basic_publish(
data,
exchange=self.name,
routing_key=routing_key,
properties=BasicProperties(**(properties or {})),
timeout=timeout or self.timeout,
)


@dataclass(slots=True, frozen=True)
class Exchange:
name: str = ""
type: str = "direct"
Expand Down Expand Up @@ -417,7 +450,7 @@ async def publish(
)


@dataclass(frozen=True)
@dataclass(slots=True, frozen=True)
class Consumer:
channel: aiormq.Channel
consumer_tag: int
Expand All @@ -427,7 +460,7 @@ async def close(self):
await self.channel.close()


@dataclass(frozen=True)
@dataclass(slots=True, frozen=True)
class Queue:
name: str
type: QueueType = QueueType.CLASSIC
Expand Down Expand Up @@ -644,19 +677,11 @@ class SerializerConfig(base.SerializerConfig):
module: SerializerModule = SERIALIZER


HeadersExtender = Annotated[
Callable[[TaskInstance], dict],
PlainSerializer(lambda x: f"{x}", return_type=str, when_used="json"),
# PlainSerializer(lambda x: "<HeadersExtender>", return_type=str, when_used="json"),
]


class Config(base.Config):
"""RabbitMQ backend config."""

model_config = SettingsConfigDict(env_prefix=f"{ENV_PREFIX}RABBITMQ_")

headers_extenders: list[HeadersExtender] = Field(default_factory=list)
serializer: SerializerConfig = Field(default_factory=SerializerConfig)
url: SecretAmqpDsn | list[SecretAmqpDsn] = Field(default_factory=lambda: URL)
"""See amqp [spec](https://www.rabbitmq.com/uri-spec.html)."""
Expand Down Expand Up @@ -759,8 +784,6 @@ def __init__(self, config: Config):
)
self._task_queues: dict[str, Queue] = {}

# self._semaphore = asyncio.Semaphore(value=config.pool_size)

self._results_queue: Queue = Queue(
f"{config.results_queue_prefix}results.{config.id}",
conn=self._conn,
Expand Down Expand Up @@ -936,7 +959,7 @@ async def _on_task_message(self, callback, channel: aiormq.Channel, message: aio
if is_debug_level():
logger.debug("%s got raw message %s", self, message.body if not settings.LOG_SANITIZE else "<hiden>")

task_instance = self.serializer.loads_task_instance(message.body)
task_instance = self.serializer.loads_task_instance(message.body, headers=message.header.properties.headers)

task_instance.extra["rabbitmq:reply_to"] = message.header.properties.reply_to

Expand Down Expand Up @@ -995,14 +1018,16 @@ def _reply_to(self, task_instance: TaskInstance) -> str:
async def _send_task(self, task_instance: TaskInstance, **kwds): # pylint: disable=method-hidden
reply_to = self._reply_to(task_instance)
task_instance.extra["rabbitmq:reply_to"] = reply_to
data: bytes = self.serializer.dumps_task_instance(task_instance)

headers = {}
data: bytes = self.serializer.dumps_task_instance(task_instance, headers=headers)

await self._ensure_task_queue(task_instance.queue)

properties = {
"delivery_mode": 2,
"message_type": "arrlio:task",
"headers": {k: v for extender in self.config.headers_extenders for k, v in extender(task_instance).items()},
"headers": headers,
"message_id": f"{task_instance.task_id}",
"correlation_id": f"{task_instance.task_id}",
"reply_to": reply_to,
Expand Down Expand Up @@ -1060,7 +1085,11 @@ async def stop_consume_tasks(self, queues: list[str] | None = None):
self._task_queues.pop(queue_name)

async def _result_routing(self, task_instance: TaskInstance) -> tuple[Exchange, str]:
exchange = self._tasks_exchange
exchange_name = task_instance.extra.get("rabbitmq:reply_to.exchange", self._tasks_exchange.name)
if exchange_name == self._tasks_exchange.name:
exchange = self._tasks_exchange
else:
exchange = SimpleExchange(self._conn, exchange_name)
routing_key = task_instance.extra["rabbitmq:reply_to"]
if routing_key.startswith("amq.rabbitmq.reply-to."):
exchange = self._default_exchange
Expand Down
11 changes: 6 additions & 5 deletions arrlio/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from arrlio.types import Args, AsyncCallable, Kwds, TaskId, TaskPriority, Timeout


@dataclass(frozen=True)
@dataclass(slots=True, frozen=True)
class Task:
"""Task `dataclass`.
Expand Down Expand Up @@ -112,7 +112,7 @@ def instantiate(
)


@dataclass(frozen=True)
@dataclass(slots=True, frozen=True)
class TaskInstance(Task):
"""Task instance `dataclass`.
Expand All @@ -139,7 +139,8 @@ def __post_init__(self):
object.__setattr__(self, "args", tuple(self.args))

def dict(self, exclude: list[str] | None = None, sanitize: bool | None = None):
data = super().dict(exclude=exclude, sanitize=sanitize)
# pylint: disable=super-with-arguments
data = super(TaskInstance, self).dict(exclude=exclude, sanitize=sanitize)
if sanitize:
if self.sanitizer:
data = self.sanitizer(data) # pylint: disable=not-callable
Expand Down Expand Up @@ -171,7 +172,7 @@ def instantiate(self, *args, **kwds):
raise NotImplementedError


@dataclass(frozen=True)
@dataclass(slots=True, frozen=True)
class TaskResult:
"""Task result `dataclass`."""

Expand Down Expand Up @@ -203,7 +204,7 @@ def pretty_repr(self, sanitize: bool | None = None):
return pretty_repr(self.dict(sanitize=sanitize))


@dataclass(frozen=True)
@dataclass(slots=True, frozen=True)
class Event:
"""Event `dataclass`.
Expand Down
6 changes: 3 additions & 3 deletions arrlio/serializers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def dumps_task_instance(self, task_instance: TaskInstance, **kwds) -> bytes | Ta
pass

@abc.abstractmethod
def loads_task_instance(self, data: bytes | TaskInstance) -> TaskInstance:
def loads_task_instance(self, data: bytes | TaskInstance, **kwds) -> TaskInstance:
pass

@abc.abstractmethod
Expand All @@ -46,13 +46,13 @@ def dumps_task_result(
pass

@abc.abstractmethod
def loads_task_result(self, data: bytes | TaskResult) -> TaskResult:
def loads_task_result(self, data: bytes | TaskResult, **kwds) -> TaskResult:
pass

@abc.abstractmethod
def dumps_event(self, event: Event, **kwds) -> bytes | Event:
pass

@abc.abstractmethod
def loads_event(self, data: bytes | Event) -> Event:
def loads_event(self, data: bytes | Event, **kwds) -> Event:
pass
6 changes: 3 additions & 3 deletions arrlio/serializers/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def dumps_task_instance(self, task_instance: TaskInstance, **kwds) -> bytes:
extra["graph:graph"] = graph.dict()
return self.dumps({k: v for k, v in data.items() if v is not None})

def loads_task_instance(self, data: bytes) -> TaskInstance:
def loads_task_instance(self, data: bytes, **kwds) -> TaskInstance:
"""Loads `arrlio.models.TaskInstance` object from json encoded string."""

data: dict = self.loads(data)
Expand Down Expand Up @@ -123,7 +123,7 @@ def dumps_task_result(self, task_result: TaskResult, task_instance: TaskInstance
data["res"] = task_instance.dumps(data["res"])
return self.dumps(data)

def loads_task_result(self, data: bytes) -> TaskResult:
def loads_task_result(self, data: bytes, **kwds) -> TaskResult:
"""Loads `arrlio.models.TaskResult` from json encoded string."""

data = self.loads(data)
Expand All @@ -149,7 +149,7 @@ def dumps_event(self, event: Event, **kwds) -> bytes:
result["trb"] = self.dumps_trb(result["trb"])
return self.dumps(data)

def loads_event(self, data: bytes) -> Event:
def loads_event(self, data: bytes, **kwds) -> Event:
"""Loads `arrlio.models.Event` from json encoded string."""

event: Event = Event(**self.loads(data))
Expand Down
6 changes: 3 additions & 3 deletions arrlio/serializers/nop.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def loads(self, data: Any) -> Any:
def dumps_task_instance(self, task_instance: TaskInstance, **kwds) -> TaskInstance:
return task_instance

def loads_task_instance(self, data: TaskInstance) -> TaskInstance:
def loads_task_instance(self, data: TaskInstance, **kwds) -> TaskInstance:
return data

def dumps_task_result(
Expand All @@ -29,11 +29,11 @@ def dumps_task_result(
) -> TaskResult:
return task_result

def loads_task_result(self, data: TaskResult) -> TaskResult:
def loads_task_result(self, data: TaskResult, **kwds) -> TaskResult:
return data

def dumps_event(self, event: Event, **kwds) -> Event:
return event

def loads_event(self, data: Event) -> Event:
def loads_event(self, data: Event, **kwds) -> Event:
return data

0 comments on commit 4cad63c

Please sign in to comment.