Skip to content

Commit

Permalink
Remove RetryTimeout and fix json serializer
Browse files Browse the repository at this point in the history
  • Loading branch information
levsh committed Apr 16, 2024
1 parent 20333b2 commit 2d9156b
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 19 deletions.
22 changes: 13 additions & 9 deletions arrlio/backends/rabbitmq.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class ResultQueueMode(StrEnum):
CONNECT_TIMEOUT = 15

PUSH_RETRY_TIMEOUTS = [5, 5, 5, 5] # pylint: disable=invalid-name
PULL_RETRY_TIMEOUTS = itertools.repeat(5) # pylint: disable=invalid-name
PULL_RETRY_TIMEOUT = 5 # pylint: disable=invalid-name

TASKS_EXCHANGE = "arrlio"
TASKS_EXCHANGE_DURABLE = False
Expand Down Expand Up @@ -592,6 +592,7 @@ async def consume(
callback: Callable[[aiormq.Channel, aiormq.abc.DeliveredMessage], Coroutine],
prefetch_count: int | None = None,
timeout: int | None = None,
retry_timeout: int = 5,
):
if self.consumer is None:
channel = await self.conn.new_channel()
Expand Down Expand Up @@ -620,7 +621,7 @@ async def consume(
"on_lost",
f"on_lost_queue_{self.name}_consume",
partial(
retry(retry_timeouts=itertools.repeat(5), exc_filter=lambda e: True)(self.consume),
retry(retry_timeouts=itertools.repeat(retry_timeout), exc_filter=lambda e: True)(self.consume),
callback,
prefetch_count=prefetch_count,
timeout=timeout,
Expand Down Expand Up @@ -652,10 +653,8 @@ class Config(base.Config):
"""See amqp [spec](https://www.rabbitmq.com/uri-spec.html)."""
timeout: Optional[Timeout] = Field(default_factory=lambda: TIMEOUT)
verify_ssl: Optional[bool] = Field(default_factory=lambda: True)
# push_retry_timeouts: Optional[RetryTimeout] = Field(default_factory=lambda: PUSH_RETRY_TIMEOUTS)
# pull_retry_timeouts: Optional[RetryTimeout] = Field(default_factory=lambda: PULL_RETRY_TIMEOUTS)
push_retry_timeouts: Optional[Any] = Field(default_factory=lambda: PUSH_RETRY_TIMEOUTS)
pull_retry_timeouts: Optional[Any] = Field(default_factory=lambda: PULL_RETRY_TIMEOUTS)
push_retry_timeouts: Optional[list[Timeout]] = Field(default_factory=lambda: PUSH_RETRY_TIMEOUTS)
pull_retry_timeout: Optional[Timeout] = Field(default_factory=lambda: PULL_RETRY_TIMEOUT)
tasks_exchange: str = Field(default_factory=lambda: TASKS_EXCHANGE)
tasks_exchange_durable: bool = Field(default_factory=lambda: TASKS_EXCHANGE_DURABLE)
tasks_queue_type: QueueType = Field(default_factory=lambda: TASKS_QUEUE_TYPE)
Expand Down Expand Up @@ -834,7 +833,8 @@ async def _on_conn_open_first_time(self):
"on_result_message",
lambda: self._on_result_message(*args, **kwds),
)
and None
and None,
retry_timeout=self.config.pull_retry_timeout,
)

self._conn.remove_callback("on_open", "on_conn_open_first_time")
Expand Down Expand Up @@ -1028,7 +1028,8 @@ async def consume_tasks(self, queues: list[str], callback: Callable[[TaskInstanc
"on_task_message",
lambda: self._on_task_message(callback, *args, **kwds),
)
and None
and None,
retry_timeout=self.config.pull_retry_timeout,
)

async def stop_consume_tasks(self, queues: list[str] | None = None):
Expand Down Expand Up @@ -1234,7 +1235,10 @@ async def on_message(channel: aiormq.Channel, message: aiormq.abc.DeliveredMessa
await self._events_exchange.declare(restore=True, force=True)
await self._events_queue.declare(restore=True, force=True)
await self._events_queue.bind(self._events_exchange, routing_key="events", restore=True)
await self._events_queue.consume(lambda *args, **kwds: create_task(on_message(*args, **kwds)) and None)
await self._events_queue.consume(
lambda *args, **kwds: create_task(on_message(*args, **kwds)) and None,
retry_timeout=self.config.pull_retry_timeout,
)

async def stop_consume_events(self, callback_id: str | None = None):
if callback_id:
Expand Down
12 changes: 9 additions & 3 deletions arrlio/serializers/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
from json import loads as json_loads
from traceback import format_tb
from types import TracebackType
from typing import Any, Type
from typing import Annotated, Any, Type

from pydantic import Field
from pydantic import Field, PlainSerializer
from pydantic_settings import SettingsConfigDict

from arrlio import registered_tasks
Expand All @@ -20,10 +20,16 @@
logger = logging.getLogger("arrlio.serializers.json")


Encoder = Annotated[
Type[json.JSONEncoder],
PlainSerializer(lambda x: f"{x}", return_type=str, when_used="json"),
]


class Config(base.Config):
model_config = SettingsConfigDict(env_prefix=f"{ENV_PREFIX}JSON_SERIALIZER_")

encoder: Type[json.JSONEncoder] = Field(default=ExtendedJSONEncoder)
encoder: Encoder = Field(default=ExtendedJSONEncoder)


class Serializer(base.Serializer):
Expand Down
3 changes: 0 additions & 3 deletions arrlio/types.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from collections.abc import Generator
from dataclasses import dataclass
from importlib import import_module
from types import ModuleType
Expand All @@ -20,8 +19,6 @@

Timeout = Annotated[int, Ge(0)]

RetryTimeout = list[Timeout] | Generator[Timeout]

Ttl = Annotated[int, Ge(1)]

TaskPriority = Annotated[int, Ge(TASK_MIN_PRIORITY), Le(TASK_MAX_PRIORITY)]
Expand Down
6 changes: 3 additions & 3 deletions tests/small/backends/test_rabbitmq.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def test__init(self, cleanup):
assert config.timeout == rabbitmq.TIMEOUT
assert config.verify_ssl is True
assert config.push_retry_timeouts
assert config.pull_retry_timeouts
assert config.pull_retry_timeout
assert config.tasks_exchange == rabbitmq.TASKS_EXCHANGE
assert config.tasks_exchange_durable == rabbitmq.TASKS_EXCHANGE_DURABLE
assert config.tasks_queue_type == rabbitmq.TASKS_QUEUE_TYPE
Expand Down Expand Up @@ -46,7 +46,7 @@ def test__init_custom(self, cleanup):
timeout=123,
verify_ssl=False,
push_retry_timeouts=[2],
pull_retry_timeouts=[3],
pull_retry_timeout=3,
tasks_exchange="tasks_exchange",
tasks_exchange_durable=True,
tasks_queue_type="quorum",
Expand Down Expand Up @@ -74,7 +74,7 @@ def test__init_custom(self, cleanup):
assert config.timeout == 123
assert config.verify_ssl is False
assert next(iter(config.push_retry_timeouts)) == 2
assert next(iter(config.pull_retry_timeouts)) == 3
assert config.pull_retry_timeout == 3
assert config.tasks_exchange == "tasks_exchange"
assert config.tasks_exchange_durable is True
assert config.tasks_queue_type == "quorum"
Expand Down
5 changes: 4 additions & 1 deletion tests/small/test_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@ def test_backend_config():
assert isinstance(config.module, ModuleType)
assert config.module.__name__ == "arrlio.backends.rabbitmq"
assert isinstance(config.config, config.module.Config)
# assert config.config.push_retry_timeouts == [5, 5, 5, 5]
assert config.config.push_retry_timeouts == [5, 5, 5, 5]
assert isinstance(config.model_dump()["module"], ModuleType)
assert isinstance(config.model_dump()["config"]["serializer"]["module"], ModuleType)
assert config.model_dump_json()

with pytest.raises(ValidationError):
config = configs.BackendConfig(module="arrlio.backends.invalid")
Expand Down

0 comments on commit 2d9156b

Please sign in to comment.