From 6ce65ddc30fe0cd05f65d85cc939112a9ee51d23 Mon Sep 17 00:00:00 2001 From: Roman Koshel Date: Thu, 26 Dec 2024 23:49:16 +0300 Subject: [PATCH] Wip --- arrlio/__init__.py | 7 +- arrlio/abc.py | 12 +-- arrlio/backends/brokers/rabbitmq.py | 16 ++-- arrlio/backends/event_backends/rabbitmq.py | 16 +++- arrlio/backends/rabbitmq.py | 2 +- arrlio/backends/result_backends/rabbitmq.py | 16 ++-- arrlio/core.py | 26 +++++-- arrlio/serializers/base.py | 23 +++--- arrlio/serializers/json.py | 45 +++-------- arrlio/types.py | 71 +++++------------ arrlio/utils.py | 61 ++++++++++++--- pyproject.toml | 4 +- tests/small/backends/brokers/test_rabbitmq.py | 4 +- .../backends/event_backends/test_rabbitmq.py | 4 +- .../backends/result_backends/test_rabbitmq.py | 4 +- tests/small/serializers/test_json.py | 54 ++++++++----- tests/small/test_types.py | 78 +++++++++++-------- tests/small/test_utils.py | 12 +-- 18 files changed, 252 insertions(+), 203 deletions(-) diff --git a/arrlio/__init__.py b/arrlio/__init__.py index a8e64ff..f01e5ae 100644 --- a/arrlio/__init__.py +++ b/arrlio/__init__.py @@ -22,11 +22,12 @@ log_hndl.setFormatter(log_frmt) logger.addHandler(log_hndl) +from arrlio import settings + # ruff: noqa: E402 from arrlio.configs import Config, TaskConfig # noqa -from arrlio.core import App, AsyncResult, registered_tasks, task # noqa +from arrlio.core import App, AsyncResult, get_app, registered_tasks, task # noqa from arrlio.models import Event, Graph, Task, TaskInstance, TaskResult # noqa -from arrlio.settings import LOG_LEVEL -logger.setLevel(LOG_LEVEL) +logger.setLevel(settings.LOG_LEVEL) diff --git a/arrlio/abc.py b/arrlio/abc.py index 5ca361b..e061e10 100644 --- a/arrlio/abc.py +++ b/arrlio/abc.py @@ -165,7 +165,7 @@ async def stop_consume_events(self, callback_id: str | None = None): class AbstractSerializer(ABC): @abstractmethod - def dumps_task_instance(self, task_instance: TaskInstance, **kwds) -> bytes | TaskInstance: + def dumps_task_instance(self, task_instance: TaskInstance, **kwds) -> tuple[bytes | TaskInstance, dict]: """ Dump `TaskInstance`. @@ -174,7 +174,7 @@ def dumps_task_instance(self, task_instance: TaskInstance, **kwds) -> bytes | Ta """ @abstractmethod - def loads_task_instance(self, data: bytes | TaskInstance, **kwds) -> TaskInstance: + def loads_task_instance(self, data: bytes | TaskInstance, headers: dict, **kwds) -> TaskInstance: """ Load `data` into `TaskInstance`. @@ -189,7 +189,7 @@ def dumps_task_result( *, task_instance: TaskInstance | None = None, **kwds, - ) -> bytes | TaskResult: + ) -> tuple[bytes | TaskResult, dict]: """ Dump `TaskResult`. @@ -199,7 +199,7 @@ def dumps_task_result( """ @abstractmethod - def loads_task_result(self, data: bytes | TaskResult, **kwds) -> TaskResult: + def loads_task_result(self, data: bytes | TaskResult, headers: dict, **kwds) -> TaskResult: """ Load data into `TaskResult`. @@ -208,7 +208,7 @@ def loads_task_result(self, data: bytes | TaskResult, **kwds) -> TaskResult: """ @abstractmethod - def dumps_event(self, event: Event, **kwds) -> bytes | Event: + def dumps_event(self, event: Event, **kwds) -> tuple[bytes | Event, dict]: """ Dump `arrlio.models.Event`. @@ -217,7 +217,7 @@ def dumps_event(self, event: Event, **kwds) -> bytes | Event: """ @abstractmethod - def loads_event(self, data: bytes | Event, **kwds) -> Event: + def loads_event(self, data: bytes | Event, headers: dict, **kwds) -> Event: """ Load `data` into `Event`. diff --git a/arrlio/backends/brokers/rabbitmq.py b/arrlio/backends/brokers/rabbitmq.py index 587fedf..e6fe8c4 100644 --- a/arrlio/backends/brokers/rabbitmq.py +++ b/arrlio/backends/brokers/rabbitmq.py @@ -1,6 +1,7 @@ import logging from datetime import datetime, timezone +from itertools import repeat from typing import Callable, Coroutine, Optional from uuid import uuid4 @@ -164,7 +165,11 @@ def __repr__(self): return self.__str__() async def init(self): - await self._conn.open() + await retry( + msg=f"{self} init error", + retry_timeouts=repeat(5), + exc_filter=exc_filter, + )(self._conn.open)() async def close(self): await self._exchange.close() @@ -207,7 +212,7 @@ async def _on_task_message( task_instance = self.serializer.loads_task_instance( message.body, - headers=message.header.properties.headers, + message.header.properties.headers, ) reply_to = message.header.properties.reply_to @@ -229,8 +234,8 @@ async def _on_task_message( if task_instance.ack_late: await channel.basic_ack(message.delivery.delivery_tag) - except Exception as e: - logger.exception(e) + except Exception: + logger.exception(message.header.properties) async def _send_task(self, task_instance: TaskInstance, **kwds): if is_debug_level(): @@ -240,8 +245,7 @@ async def _send_task(self, task_instance: TaskInstance, **kwds): task_instance.pretty_repr(sanitize=settings.LOG_SANITIZE), ) - headers = {} - data = self.serializer.dumps_task_instance(task_instance, headers=headers) + data, headers = self.serializer.dumps_task_instance(task_instance) task_headers = task_instance.headers reply_to = task_headers.get("rabbitmq:reply_to") diff --git a/arrlio/backends/event_backends/rabbitmq.py b/arrlio/backends/event_backends/rabbitmq.py index 3029ed2..e46b276 100644 --- a/arrlio/backends/event_backends/rabbitmq.py +++ b/arrlio/backends/event_backends/rabbitmq.py @@ -5,6 +5,7 @@ from datetime import datetime, timezone from functools import partial from inspect import iscoroutinefunction +from itertools import repeat from typing import Any, Callable, Optional from uuid import uuid4 @@ -177,19 +178,25 @@ def __repr__(self): return self.__str__() async def init(self): - await self._conn.open() + await retry( + msg=f"{self} init error", + retry_timeouts=repeat(5), + exc_filter=exc_filter, + )(self._conn.open)() async def close(self): await super().close() async def _send_event(self, event: Event): + data, headers = self.serializer.dumps_event(event) await self._exchange.publish( - self.serializer.dumps_event(event), + data, routing_key=event.type, properties={ "delivery_mode": 2, "timestamp": datetime.now(tz=timezone.utc), "expiration": f"{int(event.ttl * 1000)}" if event.ttl is not None else None, + "headers": headers, }, ) @@ -223,7 +230,10 @@ async def consume_events( async def on_message(channel: aiormq.Channel, message: aiormq.abc.DeliveredMessage): try: - event: Event = self.serializer.loads_event(message.body) + event: Event = self.serializer.loads_event( + message.body, + message.header.properties.headers, + ) if is_debug_level(): logger.debug(_("%s got event\n%s"), self, event.pretty_repr(sanitize=settings.LOG_SANITIZE)) diff --git a/arrlio/backends/rabbitmq.py b/arrlio/backends/rabbitmq.py index 571be93..159adae 100644 --- a/arrlio/backends/rabbitmq.py +++ b/arrlio/backends/rabbitmq.py @@ -63,4 +63,4 @@ def connection_factory(url: SecretAmqpDsn | list[SecretAmqpDsn]) -> Connection: if not isinstance(url, list): url = [url] - return Connection([u.get_secret_value() for u in url]) + return Connection([f"{u.get_secret_value()}" for u in url]) diff --git a/arrlio/backends/result_backends/rabbitmq.py b/arrlio/backends/result_backends/rabbitmq.py index 8c6e9f1..fe0ff4e 100644 --- a/arrlio/backends/result_backends/rabbitmq.py +++ b/arrlio/backends/result_backends/rabbitmq.py @@ -5,6 +5,7 @@ from datetime import datetime, timezone from functools import partial from inspect import isasyncgenfunction, isgeneratorfunction +from itertools import repeat from typing import AsyncGenerator, Optional from uuid import UUID, uuid4 @@ -58,10 +59,10 @@ QUEUE_DURABLE = False """ResultBackend queue `durable` option.""" -QUEUE_EXCLUSIVE = False +QUEUE_EXCLUSIVE = True """ResultBackend queue `excusive` option.""" -QUEUE_AUTO_DELETE = True +QUEUE_AUTO_DELETE = False """ResultBackend queue `auto-delete` option.""" PREFETCH_COUNT = 10 @@ -205,7 +206,11 @@ def __repr__(self): return self.__str__() async def init(self): - await self._conn.open() + await retry( + msg=f"{self} init error", + retry_timeouts=repeat(5), + exc_filter=exc_filter, + )(self._conn.open)() async def close(self): await super().close() @@ -289,7 +294,7 @@ async def _on_result_message( properties: aiormq.spec.Basic.Properties = message.header.properties task_id: UUID = UUID(properties.message_id) - task_result: TaskResult = self.serializer.loads_task_result(message.body, headers=properties.headers) + task_result: TaskResult = self.serializer.loads_task_result(message.body, properties.headers) if not no_ack: await channel.basic_ack(message.delivery.delivery_tag) @@ -348,8 +353,7 @@ async def _push_task_result( task_result.pretty_repr(sanitize=settings.LOG_SANITIZE), ) - headers = {} - data = self.serializer.dumps_task_result(task_result, task_instance=task_instance, headers=headers) + data, headers = self.serializer.dumps_task_result(task_result, task_instance=task_instance) properties = { "delivery_mode": 2, diff --git a/arrlio/core.py b/arrlio/core.py index e8e8209..dd61001 100644 --- a/arrlio/core.py +++ b/arrlio/core.py @@ -41,11 +41,14 @@ registered_tasks = rodict({}, nested=True) +_curr_app = ContextVar("curr_app", default=None) + + def task( func: FunctionType | MethodType | Type | None = None, name: str | None = None, base: Type[Task] | None = None, - **kwds: dict, + **kwds, ): """Task decorator. @@ -168,7 +171,7 @@ def executor(self) -> Executor: return self._executor @property - def context(self): + def context(self) -> dict: """Application current context.""" return self._context.get() @@ -300,7 +303,7 @@ async def send_task( if is_info_level(): logger.info( - _("%s send task instance\n%s"), + _("%s send task\n%s"), self, task_instance.pretty_repr(sanitize=settings.LOG_SANITIZE), ) @@ -378,12 +381,15 @@ async def cb(task_instance: TaskInstance): idx_0 = uuid4().hex idx_1 = 0 + self._context.set({}) + context = self.context + try: task_result: TaskResult = TaskResult() async with AsyncExitStack() as stack: try: - self.context["task_instance"] = task_instance + context["task_instance"] = task_instance for context_hook in self._hooks["task_context"]: await stack.enter_async_context(context_hook(task_instance)) @@ -461,8 +467,12 @@ async def execute_task(self, task_instance: TaskInstance) -> AsyncGenerator[Task Task result. """ - async for task_result in self._executor(task_instance): - yield task_result + token = _curr_app.set(self) + try: + async for task_result in self._executor(task_instance): + yield task_result + finally: + _curr_app.reset(token) async def consume_events( self, @@ -579,3 +589,7 @@ async def get(self) -> Any: if noresult: raise TaskClosedError(self.task_id) return self._result + + +def get_app() -> App | None: + return _curr_app.get() diff --git a/arrlio/serializers/base.py b/arrlio/serializers/base.py index 7324b1f..9870eb6 100644 --- a/arrlio/serializers/base.py +++ b/arrlio/serializers/base.py @@ -48,7 +48,7 @@ def loads(self, data: bytes | Any) -> Any: data: data to load. """ - def dumps_task_instance(self, task_instance: TaskInstance, **kwds) -> bytes: + def dumps_task_instance(self, task_instance: TaskInstance, **kwds) -> tuple[bytes, dict]: """ Dump `arrlio.models.TaskInstance` object as json encoded string. @@ -60,9 +60,9 @@ def dumps_task_instance(self, task_instance: TaskInstance, **kwds) -> bytes: headers = data["headers"] if graph := headers.get("graph:graph"): headers["graph:graph"] = graph.asdict() - return self.dumps({k: v for k, v in data.items() if v is not None}) + return self.dumps({k: v for k, v in data.items() if v is not None}), {} - def loads_task_instance(self, data: bytes, **kwds) -> TaskInstance: + def loads_task_instance(self, data: bytes, headers: dict, **kwds) -> TaskInstance: """ Load `arrlio.models.TaskInstance` object from json encoded string. @@ -133,7 +133,12 @@ def loads_trb(self, trb: str) -> str: return trb - def dumps_task_result(self, task_result: TaskResult, task_instance: TaskInstance | None = None, **kwds) -> bytes: + def dumps_task_result( + self, + task_result: TaskResult, + task_instance: TaskInstance | None = None, + **kwds, + ) -> tuple[bytes, dict]: """ Dump `arrlio.models.TaskResult` as json encoded string. @@ -148,9 +153,9 @@ def dumps_task_result(self, task_result: TaskResult, task_instance: TaskInstance data["trb"] = self.dumps_trb(data["trb"]) elif task_instance and task_instance.dumps: data["res"] = task_instance.dumps(data["res"]) - return self.dumps(data) + return self.dumps(data), {} - def loads_task_result(self, data: bytes, **kwds) -> TaskResult: + def loads_task_result(self, data: bytes, headers: dict, **kwds) -> TaskResult: """ Load `arrlio.models.TaskResult` from json encoded string. @@ -164,7 +169,7 @@ def loads_task_result(self, data: bytes, **kwds) -> TaskResult: data["trb"] = self.loads_trb(data["trb"]) return TaskResult(**data) - def dumps_event(self, event: Event, **kwds) -> bytes: + def dumps_event(self, event: Event, **kwds) -> tuple[bytes, dict]: """ Dump `arrlio.models.Event` as json encoded string. @@ -184,9 +189,9 @@ def dumps_event(self, event: Event, **kwds) -> bytes: if result["exc"]: result["exc"] = self.dumps_exc(result["exc"]) result["trb"] = self.dumps_trb(result["trb"]) - return self.dumps(data) + return self.dumps(data), {} - def loads_event(self, data: bytes, **kwds) -> Event: + def loads_event(self, data: bytes, headers: dict, **kwds) -> Event: """ Load `arrlio.models.Event` from json encoded string. diff --git a/arrlio/serializers/json.py b/arrlio/serializers/json.py index c12904a..022a5a0 100644 --- a/arrlio/serializers/json.py +++ b/arrlio/serializers/json.py @@ -1,46 +1,25 @@ import logging -from typing import Callable, Optional, Type - -from arrlio.utils import ExtendedJSONEncoder +from importlib.util import find_spec +from typing import Annotated, Any, Callable, Optional, Type +from pydantic import Field, PlainSerializer +from pydantic_settings import SettingsConfigDict -try: - import orjson +from arrlio.serializers import base +from arrlio.settings import ENV_PREFIX +from arrlio.utils import JSONEncoder, json_dumps_bytes, json_loads - _dumps = orjson.dumps - def json_dumps(obj, cls=None): - return _dumps(obj, default=cls) +logger = logging.getLogger("arrlio.serializers.json") - json_loads = orjson.loads +if find_spec("orjson") is not None: JSONEncoderType = Optional[Callable] - JSON_ENCODER = None - -except ImportError: +else: import json - _dumps = json.dumps - - def json_dumps(*args, **kwds): - return _dumps(*args, **kwds).encode() - - json_loads = json.loads - JSONEncoderType = Optional[Type[json.JSONEncoder]] - JSON_ENCODER = ExtendedJSONEncoder - -from typing import Annotated, Any - -from pydantic import Field, PlainSerializer -from pydantic_settings import SettingsConfigDict - -from arrlio.serializers import base -from arrlio.settings import ENV_PREFIX - - -logger = logging.getLogger("arrlio.serializers.json") Encoder = Annotated[ @@ -54,7 +33,7 @@ class Config(base.Config): model_config = SettingsConfigDict(env_prefix=f"{ENV_PREFIX}JSON_SERIALIZER_") - encoder: Encoder = Field(default=JSON_ENCODER) + encoder: Encoder = Field(default=JSONEncoder) """Encoder class.""" @@ -72,7 +51,7 @@ def dumps(self, data: Any, **kwds) -> bytes: data: Data to dump. """ - return json_dumps(data, cls=self.config.encoder) + return json_dumps_bytes(data, encoder=self.config.encoder) def loads(self, data: bytes) -> Any: """Load json encoded data to Python object. diff --git a/arrlio/types.py b/arrlio/types.py index 23f367d..78ceff7 100644 --- a/arrlio/types.py +++ b/arrlio/types.py @@ -1,11 +1,7 @@ -import ipaddress -import re - from dataclasses import dataclass from importlib import import_module from types import ModuleType from typing import Any, Callable, Coroutine, Optional, TypeVar, Union -from urllib.parse import urlparse from uuid import UUID from annotated_types import Ge, Le @@ -111,55 +107,27 @@ def validate_from_str(v): class SecretAnyUrl(AnyUrl): - def __new__(cls, url) -> object: - if hasattr(url, "get_secret_value"): - url = url.get_secret_value() - else: - url = f"{url}" - original = urlparse(url) - if original.username or original.password: - url = original._replace( - netloc=f"***:***@{original.hostname}" + (f":{original.port}" if original.port is not None else "") - ).geturl() - obj = super().__new__(cls, url) - if obj.host is None: - raise ValueError("invalid URL") - cls._validate_host(obj.host) - obj._original_str = str(AnyUrl(original.geturl())) - obj._username = SecretStr(original.username) if original.username else None - obj._password = SecretStr(original.password) if original.password else None - return obj - @property def username(self) -> SecretStr: - return self._username + return SecretStr(self._url.username) if self._url.username is not None else None @property def password(self) -> SecretStr: - return self._password - - @classmethod - def _validate_host(cls, host: str): - if 1 > len(host) > 255: - raise ValueError("invalid URL host length") - splitted = host.split(".") - if splitted[-1] and splitted[-1][0].isdigit(): - ipaddress.ip_address(host) - else: - for x in splitted: - if not re.match(r"(?!-)[a-zA-Z\d-]{1,63}(? str: + url = self._url + return str( + url.build( + scheme=url.scheme, + host=url.host, + username="***" if url.username is not None else None, + password="***" if url.password is not None else None, + port=url.port, + path=(url.path or "").lstrip("/"), + query=url.query, + fragment=url.fragment, + ) ) def __repr__(self) -> str: @@ -168,11 +136,8 @@ def __repr__(self) -> str: def __eq__(self, other: Any) -> bool: return isinstance(other, SecretAnyUrl) and self.get_secret_value() == other.get_secret_value() - def __hash__(self): - return hash(self._original_str) - - def get_secret_value(self) -> str: - return self._original_str + def get_secret_value(self): + return self._url @dataclass diff --git a/arrlio/utils.py b/arrlio/utils.py index 70a9b0e..e316e67 100644 --- a/arrlio/utils.py +++ b/arrlio/utils.py @@ -8,11 +8,10 @@ from functools import wraps from inspect import isasyncgenfunction from itertools import repeat +from types import FunctionType from typing import Callable, Coroutine, Iterable, cast from uuid import UUID -from pydantic import SecretBytes, SecretStr - from arrlio.models import Task from arrlio.types import ExceptionFilter, Timeout @@ -33,21 +32,61 @@ def is_info_level(): return isEnabledFor(INFO) -class ExtendedJSONEncoder(json.JSONEncoder): - """Extended JSONEncoder class.""" +try: + import orjson - def default(self, o): - if isinstance(o, datetime): - return o.isoformat() - if isinstance(o, (UUID, SecretStr, SecretBytes)): + def JSONEncoder(o): + if isinstance(o, UUID): return f"{o}" - if isinstance(o, set): - return list(o) + if get_secret_value := getattr(o, "get_secret_value", None): + return get_secret_value() + if isinstance(o, FunctionType): + return f"{o.__module__}.{o.__name__}" if isinstance(o, Task): o = o.asdict(exclude=["loads", "dumps"]) o["func"] = f"{o['func'].__module__}.{o['func'].__name__}" return o - return super().default(o) + raise TypeError + + def json_dumps_bytes(obj, encoder=None): + return orjson.dumps(obj, default=encoder or JSONEncoder) + + def json_dumps(obj, encoder=None): + return orjson.dumps(obj, default=encoder or JSONEncoder).decode() + + json_loads = orjson.loads + + +except ImportError: + import json + + class JSONEncoder(json.JSONEncoder): + """Extended JSONEncoder class.""" + + def default(self, o): + if isinstance(o, datetime): + return o.isoformat() + if isinstance(o, UUID): + return f"{o}" + if get_secret_value := getattr(o, "get_secret_value", None): + return get_secret_value() + if isinstance(o, FunctionType): + return f"{o.__module__}.{o.__name__}" + if isinstance(o, set): + return list(o) + if isinstance(o, Task): + o = o.asdict(exclude=["loads", "dumps"]) + o["func"] = f"{o['func'].__module__}.{o['func'].__name__}" + return o + return super().default(o) + + def json_dumps_bytes(*args, encoder=None, **kwds): + return json.dumps(*args, cls=encoder or JSONEncoder, **kwds).encode() + + def json_dumps(*args, encoder=None, **kwds): + return json.dumps(*args, encoder or JSONEncoder, **kwds) + + json_loads = json.loads def retry( diff --git a/pyproject.toml b/pyproject.toml index 47fc6cb..5ff77ba 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,9 +22,9 @@ packages = [{include = "arrlio"}] [tool.poetry.dependencies] python = ">=3.11, <4.0" aiormq = ">=6.7.6" -rmqaio = ">=0.9.0" +rmqaio = ">=0.10.0" msgpack = "*" -pydantic = "^2" +pydantic = "^2.10" pydantic_settings = "*" rich = "*" roview = "*" diff --git a/tests/small/backends/brokers/test_rabbitmq.py b/tests/small/backends/brokers/test_rabbitmq.py index 08559e7..5b69fde 100644 --- a/tests/small/backends/brokers/test_rabbitmq.py +++ b/tests/small/backends/brokers/test_rabbitmq.py @@ -23,7 +23,7 @@ def test__init(self, cleanup): config = rabbitmq.Config() assert config.id assert config.serializer.module == serializers.json - assert config.url.get_secret_value() == rabbitmq.URL + assert f"{config.url.get_secret_value()}" == rabbitmq.URL assert config.timeout == rabbitmq.TIMEOUT assert config.push_retry_timeouts assert config.pull_retry_timeout @@ -52,7 +52,7 @@ def test__init_custom(self, cleanup): ) assert config.serializer.module == serializers.json assert config.id == "id" - assert config.url.get_secret_value() == "amqps://admin@example.com" + assert f"{config.url.get_secret_value()}" == "amqps://admin@example.com" assert config.timeout == 123 assert next(iter(config.push_retry_timeouts)) == 2 assert config.pull_retry_timeout == 3 diff --git a/tests/small/backends/event_backends/test_rabbitmq.py b/tests/small/backends/event_backends/test_rabbitmq.py index 292e19e..aab984c 100644 --- a/tests/small/backends/event_backends/test_rabbitmq.py +++ b/tests/small/backends/event_backends/test_rabbitmq.py @@ -23,7 +23,7 @@ def test__init(self, cleanup): config = rabbitmq.Config() assert config.id assert config.serializer.module == serializers.json - assert config.url.get_secret_value() == rabbitmq.URL + assert f"{config.url.get_secret_value()}" == rabbitmq.URL assert config.timeout == rabbitmq.TIMEOUT assert config.push_retry_timeouts assert config.pull_retry_timeout @@ -54,7 +54,7 @@ def test__init_custom(self, cleanup): ) assert config.serializer.module == serializers.json assert config.id == "id" - assert config.url.get_secret_value() == "amqps://admin@example.com" + assert f"{config.url.get_secret_value()}" == "amqps://admin@example.com" assert config.timeout == 123 assert next(iter(config.push_retry_timeouts)) == 2 assert config.pull_retry_timeout == 3 diff --git a/tests/small/backends/result_backends/test_rabbitmq.py b/tests/small/backends/result_backends/test_rabbitmq.py index 95d8f80..01272e1 100644 --- a/tests/small/backends/result_backends/test_rabbitmq.py +++ b/tests/small/backends/result_backends/test_rabbitmq.py @@ -23,7 +23,7 @@ def test__init(self, cleanup): config = rabbitmq.Config() assert config.id assert config.serializer.module == serializers.json - assert config.url.get_secret_value() == rabbitmq.URL + assert f"{config.url.get_secret_value()}" == rabbitmq.URL assert config.timeout == rabbitmq.TIMEOUT assert config.push_retry_timeouts assert config.pull_retry_timeout @@ -49,7 +49,7 @@ def test__init_custom(self, cleanup): ) assert config.serializer.module == serializers.json assert config.id == "id" - assert config.url.get_secret_value() == "amqps://admin@example.com" + assert f"{config.url.get_secret_value()}" == "amqps://admin@example.com" assert config.timeout == 123 assert next(iter(config.push_retry_timeouts)) == 2 assert config.pull_retry_timeout == 3 diff --git a/tests/small/serializers/test_json.py b/tests/small/serializers/test_json.py index b063e93..07d5c74 100644 --- a/tests/small/serializers/test_json.py +++ b/tests/small/serializers/test_json.py @@ -19,9 +19,12 @@ def test_dumps_task_instance(self): serializer = serializers.json.Serializer(serializers.json.Config()) task_instance = Task(None, "test").instantiate(task_id="2d29459b-3245-492e-977b-09043c0f1f27", queue="queue") assert serializer.dumps_task_instance(task_instance) == ( - b'{"name": "test", "queue": "queue", "priority": 1, "timeout": 300, "ttl": 300, ' - b'"ack_late": false, "result_ttl": 300, "result_return": true, "events": false, "event_ttl": 300, ' - b'"headers": {}, "task_id": "2d29459b-3245-492e-977b-09043c0f1f27", "args": [], "kwds": {}, "meta": {}}' + ( + b'{"name": "test", "queue": "queue", "priority": 1, "timeout": 300, "ttl": 300, ' + b'"ack_late": false, "result_ttl": 300, "result_return": true, "events": false, "event_ttl": 300, ' + b'"headers": {}, "task_id": "2d29459b-3245-492e-977b-09043c0f1f27", "args": [], "kwds": {}, "meta": {}}' + ), + {}, ) def test_loads_task_instance(self): @@ -35,11 +38,14 @@ def foo(m: M): serializer = serializers.json.Serializer(serializers.json.Config()) task_instance = serializer.loads_task_instance( ( - b'{"name": "86e68", "queue": "queue", "priority": 1, "timeout": 300, "ttl": 300, ' - b'"ack_late": false, "result_ttl": 300, "result_return": true, "events": false, "event_ttl": 300, ' - b'"headers": {}, "task_id": "2d29459b-3245-492e-977b-09043c0f1f27", "args": [{"x": 1}], "kwds": {}, ' - b'"meta": {}}' - ) + ( + b'{"name": "86e68", "queue": "queue", "priority": 1, "timeout": 300, "ttl": 300, ' + b'"ack_late": false, "result_ttl": 300, "result_return": true, "events": false, "event_ttl": 300, ' + b'"headers": {}, "task_id": "2d29459b-3245-492e-977b-09043c0f1f27", "args": [{"x": 1}], "kwds": {}, ' + b'"meta": {}}' + ) + ), + {}, ) assert task_instance == arrlio.registered_tasks["86e68"].instantiate( task_id="2d29459b-3245-492e-977b-09043c0f1f27", @@ -63,9 +69,9 @@ def foo(): queue="queue", ) task_result = TaskResult(res=foo()) - assert ( - serializer.dumps_task_result(task_result, task_instance) - == b'{"res": {"x": 1}, "exc": null, "trb": null, "idx": null, "routes": null}' + assert serializer.dumps_task_result(task_result, task_instance) == ( + b'{"res": {"x": 1}, "exc": null, "trb": null, "idx": null, "routes": null}', + {}, ) try: @@ -84,17 +90,20 @@ def foo(): ) else: assert serializer.dumps_task_result(task_result, task_instance) == ( - b'{"res": null, "exc": ["builtins", "ZeroDivisionError", "division by zero"], ' - b'"trb": " File \\"%s\\", line 72, ' - b'in test_dumps_task_result\\n 1 / 0\\n ~~^~~\\n", "idx": null, "routes": null}' - % __file__.encode() + ( + b'{"res": null, "exc": ["builtins", "ZeroDivisionError", "division by zero"], ' + b'"trb": " File \\"%s\\", line 78, ' + b'in test_dumps_task_result\\n 1 / 0\\n ~~^~~\\n", "idx": null, "routes": null}' + % __file__.encode() + ), + {}, ) def test_loads_task_result(self): serializer = serializers.json.Serializer(serializers.json.Config()) assert serializer.loads_task_result( - b'{"res": "ABC", "exc": null, "idx": null, "trb": null, "routes": null}' + b'{"res": "ABC", "exc": null, "idx": null, "trb": null, "routes": null}', {} ) == TaskResult(res="ABC") result = serializer.loads_task_result( @@ -102,7 +111,8 @@ def test_loads_task_result(self): b'{"res": null, "exc": ["builtins", "ZeroDivisionError", "division by zero"], ' b'"trb": " File \\"%s\\", line 41, in ' b'test_dumps_task_result\\n 1 / 0\\n", "idx": null, "routes": null}' % __file__.encode() - ) + ), + {}, ) assert isinstance(result, TaskResult) assert result.res is None @@ -119,8 +129,11 @@ def test_dumps_event(self): data={"k": "v"}, ) assert serializer.dumps_event(event) == ( - b'{"type": "TP", "data": {"k": "v"}, "event_id": "f3410fd3-660c-4e26-b433-a6c2f5bdf700", ' - b'"dt": "2022-03-12T00:00:00", "ttl": 300}' + ( + b'{"type": "TP", "data": {"k": "v"}, "event_id": "f3410fd3-660c-4e26-b433-a6c2f5bdf700", ' + b'"dt": "2022-03-12T00:00:00", "ttl": 300}' + ), + {}, ) def test_loads_event(self): @@ -130,7 +143,8 @@ def test_loads_event(self): ( b'{"type": "TP", "data": {"k": "v"}, "event_id": "f3410fd3-660c-4e26-b433-a6c2f5bdf700", ' b'"dt": "2022-03-12T00:00:00", "ttl": 300}' - ) + ), + {}, ) assert event == Event( event_id="f3410fd3-660c-4e26-b433-a6c2f5bdf700", diff --git a/tests/small/test_types.py b/tests/small/test_types.py index eac50d0..0c5ed2b 100644 --- a/tests/small/test_types.py +++ b/tests/small/test_types.py @@ -2,26 +2,34 @@ import pydantic_settings import pytest +from pydantic import TypeAdapter + from arrlio.backends import brokers, event_backends, result_backends from arrlio.types import BrokerModule, EventBackendModule, ResultBackendModule, SecretAmqpDsn, SecretAnyUrl def test_SecretAnyUrl(): - url = SecretAnyUrl("http://example.org") - assert url.scheme == "http" - assert url.username is None - assert url.password is None - assert url.host == "example.org" - assert str(url) == "http://example.org/" - assert repr(url) == "SecretAnyUrl('http://example.org/')" - - url = SecretAnyUrl("http://user:pass@example.org") - assert url.scheme == "http" - assert url.username == pydantic.SecretStr("user") - assert url.password == pydantic.SecretStr("pass") - assert url.host == "example.org" - assert str(url) == "http://***:***@example.org/" - assert repr(url) == "SecretAnyUrl('http://***:***@example.org/')" + for url in [ + SecretAnyUrl("http://example.org"), + TypeAdapter(SecretAnyUrl).validate_python("http://example.org"), + ]: + assert url.scheme == "http" + assert url.username is None + assert url.password is None + assert url.host == "example.org" + assert str(url) == "http://example.org/" + assert repr(url) == "SecretAnyUrl('http://example.org/')" + + for url in [ + SecretAnyUrl("http://user:pass@example.org"), + TypeAdapter(SecretAnyUrl).validate_python("http://user:pass@example.org"), + ]: + assert url.scheme == "http" + assert url.username == pydantic.SecretStr("user") + assert url.password == pydantic.SecretStr("pass") + assert url.host == "example.org" + assert str(url) == "http://***:***@example.org/" + assert repr(url) == "SecretAnyUrl('http://***:***@example.org/')" class S(pydantic_settings.BaseSettings): url: SecretAnyUrl @@ -42,21 +50,27 @@ class S(pydantic_settings.BaseSettings): def test_SecretAmqpDsn(): - url = SecretAmqpDsn("amqp://example.org") - assert url.scheme == "amqp" - assert url.username is None - assert url.password is None - assert url.host == "example.org" - assert str(url) == "amqp://example.org" - assert repr(url) == "SecretAnyUrl('amqp://example.org')" - - url = SecretAnyUrl("amqp://user:pass@example.org") - assert url.scheme == "amqp" - assert url.username == pydantic.SecretStr("user") - assert url.password == pydantic.SecretStr("pass") - assert url.host == "example.org" - assert str(url) == "amqp://***:***@example.org" - assert repr(url) == "SecretAnyUrl('amqp://***:***@example.org')" + for url in [ + SecretAnyUrl("amqp://example.org"), + TypeAdapter(SecretAnyUrl).validate_python("amqp://example.org"), + ]: + assert url.scheme == "amqp" + assert url.username is None + assert url.password is None + assert url.host == "example.org" + assert str(url) == "amqp://example.org/" + assert repr(url) == "SecretAnyUrl('amqp://example.org/')" + + for url in [ + SecretAnyUrl("amqp://user:pass@example.org"), + TypeAdapter(SecretAnyUrl).validate_python("amqp://user:pass@example.org"), + ]: + assert url.scheme == "amqp" + assert url.username == pydantic.SecretStr("user") + assert url.password == pydantic.SecretStr("pass") + assert url.host == "example.org" + assert str(url) == "amqp://***:***@example.org/" + assert repr(url) == "SecretAnyUrl('amqp://***:***@example.org/')" class S(pydantic_settings.BaseSettings): url: SecretAmqpDsn @@ -72,8 +86,8 @@ class S(pydantic_settings.BaseSettings): assert m.url.username == pydantic.SecretStr("user") assert m.url.password == pydantic.SecretStr("pass") assert m.url.host == "example.org" - assert str(m.url) == "amqp://***:***@example.org" - assert repr(m.url) == "SecretAnyUrl('amqp://***:***@example.org')" + assert str(m.url) == "amqp://***:***@example.org/" + assert repr(m.url) == "SecretAnyUrl('amqp://***:***@example.org/')" def test_BrokerModule(): diff --git a/tests/small/test_utils.py b/tests/small/test_utils.py index e738150..04ba867 100644 --- a/tests/small/test_utils.py +++ b/tests/small/test_utils.py @@ -91,25 +91,25 @@ async def foo(): assert counter == 3 -def test_ExtendedJSONEncoder(): +def test_JSONEncoder(): assert ( json.dumps( "a", - cls=utils.ExtendedJSONEncoder, + cls=utils.JSONEncoder, ) == '"a"' ) assert ( json.dumps( datetime.datetime(2021, 1, 1), - cls=utils.ExtendedJSONEncoder, + cls=utils.JSONEncoder, ) == '"2021-01-01T00:00:00"' ) assert ( json.dumps( uuid.UUID("ea47d0af-c6b2-45d0-9a05-6bd1e34aa58c"), - cls=utils.ExtendedJSONEncoder, + cls=utils.JSONEncoder, ) == '"ea47d0af-c6b2-45d0-9a05-6bd1e34aa58c"' ) @@ -118,7 +118,7 @@ def test_ExtendedJSONEncoder(): def foo(): pass - assert json.dumps(foo, cls=utils.ExtendedJSONEncoder) == ( + assert json.dumps(foo, cls=utils.JSONEncoder) == ( """{\"func\": \"test_utils.foo\", \"name\": \"test_utils.foo\", """ """\"queue\": \"arrlio.tasks\", \"priority\": 1, \"timeout\": 300, \"ttl\": 300, """ """\"ack_late\": false, \"result_ttl\": 300, \"result_return\": true, """ @@ -129,4 +129,4 @@ class C: pass with pytest.raises(TypeError): - json.dumps(C(), cls=utils.ExtendedJSONEncoder) + json.dumps(C(), cls=utils.JSONEncoder)