Skip to content

Commit

Permalink
Merge pull request #57 from taskiq-python/feature/backend-serializer
Browse files Browse the repository at this point in the history
  • Loading branch information
s3rius authored Jun 11, 2024
2 parents 386b5bb + f536b4f commit f66a2c2
Show file tree
Hide file tree
Showing 7 changed files with 456 additions and 437 deletions.
772 changes: 389 additions & 383 deletions poetry.lock

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ keywords = [

[tool.poetry.dependencies]
python = "^3.8.1"
taskiq = ">=0.10.3,<1"
taskiq = ">=0.11.1,<1"
redis = "^5"

[tool.poetry.group.dev.dependencies]
Expand All @@ -40,7 +40,7 @@ fakeredis = "^2"
pre-commit = "^2.20.0"
pytest-xdist = { version = "^2.5.0", extras = ["psutil"] }
ruff = "^0.1.0"
types-redis = "^4.6.0.7"
types-redis = "^4.6.0.20240425"

[tool.mypy]
strict = true
Expand Down
45 changes: 29 additions & 16 deletions taskiq_redis/redis_backend.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import pickle
import sys
from contextlib import asynccontextmanager
from typing import (
Expand All @@ -15,16 +14,18 @@

from redis.asyncio import BlockingConnectionPool, Redis, Sentinel
from redis.asyncio.cluster import RedisCluster
from redis.asyncio.connection import Connection
from taskiq import AsyncResultBackend
from taskiq.abc.result_backend import TaskiqResult
from taskiq.abc.serializer import TaskiqSerializer
from taskiq.compat import model_dump, model_validate
from taskiq.serializers import PickleSerializer

from taskiq_redis.exceptions import (
DuplicateExpireTimeSelectedError,
ExpireTimeMustBeMoreThanZeroError,
ResultIsMissingError,
)
from taskiq_redis.serializer import PickleSerializer

if sys.version_info >= (3, 10):
from typing import TypeAlias
Expand All @@ -33,8 +34,10 @@

if TYPE_CHECKING:
_Redis: TypeAlias = Redis[bytes]
_BlockingConnectionPool: TypeAlias = BlockingConnectionPool[Connection]
else:
_Redis: TypeAlias = Redis
_BlockingConnectionPool: TypeAlias = BlockingConnectionPool

_ReturnType = TypeVar("_ReturnType")

Expand All @@ -49,6 +52,7 @@ def __init__(
result_ex_time: Optional[int] = None,
result_px_time: Optional[int] = None,
max_connection_pool_size: Optional[int] = None,
serializer: Optional[TaskiqSerializer] = None,
**connection_kwargs: Any,
) -> None:
"""
Expand All @@ -66,11 +70,12 @@ def __init__(
:raises ExpireTimeMustBeMoreThanZeroError: if result_ex_time
and result_px_time are equal zero.
"""
self.redis_pool = BlockingConnectionPool.from_url(
self.redis_pool: _BlockingConnectionPool = BlockingConnectionPool.from_url(
url=redis_url,
max_connections=max_connection_pool_size,
**connection_kwargs,
)
self.serializer = serializer or PickleSerializer()
self.keep_results = keep_results
self.result_ex_time = result_ex_time
self.result_px_time = result_px_time
Expand Down Expand Up @@ -110,9 +115,9 @@ async def set_result(
:param task_id: ID of the task.
:param result: TaskiqResult instance.
"""
redis_set_params: Dict[str, Union[str, bytes, int]] = {
redis_set_params: Dict[str, Union[str, int, bytes]] = {
"name": task_id,
"value": pickle.dumps(result),
"value": self.serializer.dumpb(model_dump(result)),
}
if self.result_ex_time:
redis_set_params["ex"] = self.result_ex_time
Expand Down Expand Up @@ -159,8 +164,9 @@ async def get_result(
if result_value is None:
raise ResultIsMissingError

taskiq_result: TaskiqResult[_ReturnType] = pickle.loads( # noqa: S301
result_value,
taskiq_result = model_validate(
TaskiqResult[_ReturnType],
self.serializer.loadb(result_value),
)

if not with_logs:
Expand All @@ -178,6 +184,7 @@ def __init__(
keep_results: bool = True,
result_ex_time: Optional[int] = None,
result_px_time: Optional[int] = None,
serializer: Optional[TaskiqSerializer] = None,
**connection_kwargs: Any,
) -> None:
"""
Expand All @@ -198,6 +205,7 @@ def __init__(
redis_url,
**connection_kwargs,
)
self.serializer = serializer or PickleSerializer()
self.keep_results = keep_results
self.result_ex_time = result_ex_time
self.result_px_time = result_px_time
Expand Down Expand Up @@ -239,7 +247,7 @@ async def set_result(
"""
redis_set_params: Dict[str, Union[str, bytes, int]] = {
"name": task_id,
"value": pickle.dumps(result),
"value": self.serializer.dumpb(model_dump(result)),
}
if self.result_ex_time:
redis_set_params["ex"] = self.result_ex_time
Expand Down Expand Up @@ -283,8 +291,9 @@ async def get_result(
if result_value is None:
raise ResultIsMissingError

taskiq_result: TaskiqResult[_ReturnType] = pickle.loads( # noqa: S301
result_value,
taskiq_result: TaskiqResult[_ReturnType] = model_validate(
TaskiqResult[_ReturnType],
self.serializer.loadb(result_value),
)

if not with_logs:
Expand Down Expand Up @@ -331,9 +340,7 @@ def __init__(
**connection_kwargs,
)
self.master_name = master_name
if serializer is None:
serializer = PickleSerializer()
self.serializer = serializer
self.serializer = serializer or PickleSerializer()
self.keep_results = keep_results
self.result_ex_time = result_ex_time
self.result_px_time = result_px_time
Expand Down Expand Up @@ -375,7 +382,7 @@ async def set_result(
"""
redis_set_params: Dict[str, Union[str, bytes, int]] = {
"name": task_id,
"value": self.serializer.dumpb(result),
"value": self.serializer.dumpb(model_dump(result)),
}
if self.result_ex_time:
redis_set_params["ex"] = self.result_ex_time
Expand Down Expand Up @@ -422,11 +429,17 @@ async def get_result(
if result_value is None:
raise ResultIsMissingError

taskiq_result: TaskiqResult[_ReturnType] = pickle.loads( # noqa: S301
result_value,
taskiq_result = model_validate(
TaskiqResult[_ReturnType],
self.serializer.loadb(result_value),
)

if not with_logs:
taskiq_result.log = None

return taskiq_result

async def shutdown(self) -> None:
"""Shutdown sentinel connections."""
for sentinel in self.sentinel.sentinels:
await sentinel.aclose() # type: ignore[attr-defined]
17 changes: 14 additions & 3 deletions taskiq_redis/redis_broker.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import sys
from logging import getLogger
from typing import Any, AsyncGenerator, Callable, Optional, TypeVar
from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Optional, TypeVar

from redis.asyncio import BlockingConnectionPool, ConnectionPool, Redis
from redis.asyncio import BlockingConnectionPool, Connection, Redis
from taskiq.abc.broker import AsyncBroker
from taskiq.abc.result_backend import AsyncResultBackend
from taskiq.message import BrokerMessage
Expand All @@ -10,6 +11,16 @@

logger = getLogger("taskiq.redis_broker")

if sys.version_info >= (3, 10):
from typing import TypeAlias
else:
from typing_extensions import TypeAlias

if TYPE_CHECKING:
_BlockingConnectionPool: TypeAlias = BlockingConnectionPool[Connection]
else:
_BlockingConnectionPool: TypeAlias = BlockingConnectionPool


class BaseRedisBroker(AsyncBroker):
"""Base broker that works with Redis."""
Expand Down Expand Up @@ -40,7 +51,7 @@ def __init__(
task_id_generator=task_id_generator,
)

self.connection_pool: ConnectionPool = BlockingConnectionPool.from_url(
self.connection_pool: _BlockingConnectionPool = BlockingConnectionPool.from_url(
url=url,
max_connections=max_connection_pool_size,
**connection_kwargs,
Expand Down
18 changes: 14 additions & 4 deletions taskiq_redis/schedule_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from redis.asyncio import (
BlockingConnectionPool,
ConnectionPool,
Connection,
Redis,
RedisCluster,
Sentinel,
Expand All @@ -13,8 +13,7 @@
from taskiq.abc.serializer import TaskiqSerializer
from taskiq.compat import model_dump, model_validate
from taskiq.scheduler.scheduled_task import ScheduledTask

from taskiq_redis.serializer import PickleSerializer
from taskiq.serializers import PickleSerializer

if sys.version_info >= (3, 10):
from typing import TypeAlias
Expand All @@ -23,8 +22,10 @@

if TYPE_CHECKING:
_Redis: TypeAlias = Redis[bytes]
_BlockingConnectionPool: TypeAlias = BlockingConnectionPool[Connection]
else:
_Redis: TypeAlias = Redis
_BlockingConnectionPool: TypeAlias = BlockingConnectionPool


class RedisScheduleSource(ScheduleSource):
Expand Down Expand Up @@ -53,7 +54,7 @@ def __init__(
**connection_kwargs: Any,
) -> None:
self.prefix = prefix
self.connection_pool: ConnectionPool = BlockingConnectionPool.from_url(
self.connection_pool: _BlockingConnectionPool = BlockingConnectionPool.from_url(
url=url,
max_connections=max_connection_pool_size,
**connection_kwargs,
Expand Down Expand Up @@ -186,6 +187,10 @@ async def post_send(self, task: ScheduledTask) -> None:
if task.time is not None:
await self.delete_schedule(task.schedule_id)

async def shutdown(self) -> None:
"""Shut down the schedule source."""
await self.redis.aclose() # type: ignore[attr-defined]


class RedisSentinelScheduleSource(ScheduleSource):
"""
Expand Down Expand Up @@ -279,3 +284,8 @@ async def post_send(self, task: ScheduledTask) -> None:
"""Delete a task after it's completed."""
if task.time is not None:
await self.delete_schedule(task.schedule_id)

async def shutdown(self) -> None:
"""Shut down the schedule source."""
for sentinel in self.sentinel.sentinels:
await sentinel.aclose() # type: ignore[attr-defined]
16 changes: 0 additions & 16 deletions taskiq_redis/serializer.py

This file was deleted.

21 changes: 8 additions & 13 deletions tests/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,34 +90,29 @@ async def test_success_backend_default_result(


@pytest.mark.anyio
async def test_success_backend_custom_result(
async def test_error_backend_custom_result(
custom_taskiq_result: TaskiqResult[_ReturnType],
task_id: str,
redis_url: str,
) -> None:
"""
Tests normal behavior with custom result in TaskiqResult.
Setting custom class as a result should raise an error.
:param custom_taskiq_result: TaskiqResult with custom result.
:param task_id: ID for task.
:param redis_url: url to redis.
"""
backend: RedisAsyncResultBackend[_ReturnType] = RedisAsyncResultBackend(
redis_url,
)
await backend.set_result(
task_id=task_id,
result=custom_taskiq_result,
)
result = await backend.get_result(task_id=task_id)
with pytest.raises(ValueError):
await backend.set_result(
task_id=task_id,
result=custom_taskiq_result,
)

assert (
result.return_value.test_arg # type: ignore
== custom_taskiq_result.return_value.test_arg # type: ignore
)
assert result.is_err == custom_taskiq_result.is_err
assert result.execution_time == custom_taskiq_result.execution_time
assert result.log == custom_taskiq_result.log
await backend.shutdown()


Expand Down

0 comments on commit f66a2c2

Please sign in to comment.