Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added serializer option to redis backend. #57

Merged
merged 2 commits into from
Jun 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
772 changes: 389 additions & 383 deletions poetry.lock

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ keywords = [

[tool.poetry.dependencies]
python = "^3.8.1"
taskiq = ">=0.10.3,<1"
taskiq = ">=0.11.1,<1"
redis = "^5"

[tool.poetry.group.dev.dependencies]
Expand All @@ -40,7 +40,7 @@ fakeredis = "^2"
pre-commit = "^2.20.0"
pytest-xdist = { version = "^2.5.0", extras = ["psutil"] }
ruff = "^0.1.0"
types-redis = "^4.6.0.7"
types-redis = "^4.6.0.20240425"

[tool.mypy]
strict = true
Expand Down
45 changes: 29 additions & 16 deletions taskiq_redis/redis_backend.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import pickle
import sys
from contextlib import asynccontextmanager
from typing import (
Expand All @@ -15,16 +14,18 @@

from redis.asyncio import BlockingConnectionPool, Redis, Sentinel
from redis.asyncio.cluster import RedisCluster
from redis.asyncio.connection import Connection
from taskiq import AsyncResultBackend
from taskiq.abc.result_backend import TaskiqResult
from taskiq.abc.serializer import TaskiqSerializer
from taskiq.compat import model_dump, model_validate
from taskiq.serializers import PickleSerializer

from taskiq_redis.exceptions import (
DuplicateExpireTimeSelectedError,
ExpireTimeMustBeMoreThanZeroError,
ResultIsMissingError,
)
from taskiq_redis.serializer import PickleSerializer

if sys.version_info >= (3, 10):
from typing import TypeAlias
Expand All @@ -33,8 +34,10 @@

if TYPE_CHECKING:
_Redis: TypeAlias = Redis[bytes]
_BlockingConnectionPool: TypeAlias = BlockingConnectionPool[Connection]
else:
_Redis: TypeAlias = Redis
_BlockingConnectionPool: TypeAlias = BlockingConnectionPool

_ReturnType = TypeVar("_ReturnType")

Expand All @@ -49,6 +52,7 @@ def __init__(
result_ex_time: Optional[int] = None,
result_px_time: Optional[int] = None,
max_connection_pool_size: Optional[int] = None,
serializer: Optional[TaskiqSerializer] = None,
**connection_kwargs: Any,
) -> None:
"""
Expand All @@ -66,11 +70,12 @@ def __init__(
:raises ExpireTimeMustBeMoreThanZeroError: if result_ex_time
and result_px_time are equal zero.
"""
self.redis_pool = BlockingConnectionPool.from_url(
self.redis_pool: _BlockingConnectionPool = BlockingConnectionPool.from_url(
url=redis_url,
max_connections=max_connection_pool_size,
**connection_kwargs,
)
self.serializer = serializer or PickleSerializer()
self.keep_results = keep_results
self.result_ex_time = result_ex_time
self.result_px_time = result_px_time
Expand Down Expand Up @@ -110,9 +115,9 @@ async def set_result(
:param task_id: ID of the task.
:param result: TaskiqResult instance.
"""
redis_set_params: Dict[str, Union[str, bytes, int]] = {
redis_set_params: Dict[str, Union[str, int, bytes]] = {
"name": task_id,
"value": pickle.dumps(result),
"value": self.serializer.dumpb(model_dump(result)),
}
if self.result_ex_time:
redis_set_params["ex"] = self.result_ex_time
Expand Down Expand Up @@ -159,8 +164,9 @@ async def get_result(
if result_value is None:
raise ResultIsMissingError

taskiq_result: TaskiqResult[_ReturnType] = pickle.loads( # noqa: S301
result_value,
taskiq_result = model_validate(
TaskiqResult[_ReturnType],
self.serializer.loadb(result_value),
)

if not with_logs:
Expand All @@ -178,6 +184,7 @@ def __init__(
keep_results: bool = True,
result_ex_time: Optional[int] = None,
result_px_time: Optional[int] = None,
serializer: Optional[TaskiqSerializer] = None,
**connection_kwargs: Any,
) -> None:
"""
Expand All @@ -198,6 +205,7 @@ def __init__(
redis_url,
**connection_kwargs,
)
self.serializer = serializer or PickleSerializer()
self.keep_results = keep_results
self.result_ex_time = result_ex_time
self.result_px_time = result_px_time
Expand Down Expand Up @@ -239,7 +247,7 @@ async def set_result(
"""
redis_set_params: Dict[str, Union[str, bytes, int]] = {
"name": task_id,
"value": pickle.dumps(result),
"value": self.serializer.dumpb(model_dump(result)),
}
if self.result_ex_time:
redis_set_params["ex"] = self.result_ex_time
Expand Down Expand Up @@ -283,8 +291,9 @@ async def get_result(
if result_value is None:
raise ResultIsMissingError

taskiq_result: TaskiqResult[_ReturnType] = pickle.loads( # noqa: S301
result_value,
taskiq_result: TaskiqResult[_ReturnType] = model_validate(
TaskiqResult[_ReturnType],
self.serializer.loadb(result_value),
)

if not with_logs:
Expand Down Expand Up @@ -331,9 +340,7 @@ def __init__(
**connection_kwargs,
)
self.master_name = master_name
if serializer is None:
serializer = PickleSerializer()
self.serializer = serializer
self.serializer = serializer or PickleSerializer()
self.keep_results = keep_results
self.result_ex_time = result_ex_time
self.result_px_time = result_px_time
Expand Down Expand Up @@ -375,7 +382,7 @@ async def set_result(
"""
redis_set_params: Dict[str, Union[str, bytes, int]] = {
"name": task_id,
"value": self.serializer.dumpb(result),
"value": self.serializer.dumpb(model_dump(result)),
}
if self.result_ex_time:
redis_set_params["ex"] = self.result_ex_time
Expand Down Expand Up @@ -422,11 +429,17 @@ async def get_result(
if result_value is None:
raise ResultIsMissingError

taskiq_result: TaskiqResult[_ReturnType] = pickle.loads( # noqa: S301
result_value,
taskiq_result = model_validate(
TaskiqResult[_ReturnType],
self.serializer.loadb(result_value),
)

if not with_logs:
taskiq_result.log = None

return taskiq_result

async def shutdown(self) -> None:
"""Shutdown sentinel connections."""
for sentinel in self.sentinel.sentinels:
await sentinel.aclose() # type: ignore[attr-defined]
17 changes: 14 additions & 3 deletions taskiq_redis/redis_broker.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import sys
from logging import getLogger
from typing import Any, AsyncGenerator, Callable, Optional, TypeVar
from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Optional, TypeVar

from redis.asyncio import BlockingConnectionPool, ConnectionPool, Redis
from redis.asyncio import BlockingConnectionPool, Connection, Redis
from taskiq.abc.broker import AsyncBroker
from taskiq.abc.result_backend import AsyncResultBackend
from taskiq.message import BrokerMessage
Expand All @@ -10,6 +11,16 @@

logger = getLogger("taskiq.redis_broker")

if sys.version_info >= (3, 10):
from typing import TypeAlias
else:
from typing_extensions import TypeAlias

if TYPE_CHECKING:
_BlockingConnectionPool: TypeAlias = BlockingConnectionPool[Connection]
else:
_BlockingConnectionPool: TypeAlias = BlockingConnectionPool


class BaseRedisBroker(AsyncBroker):
"""Base broker that works with Redis."""
Expand Down Expand Up @@ -40,7 +51,7 @@ def __init__(
task_id_generator=task_id_generator,
)

self.connection_pool: ConnectionPool = BlockingConnectionPool.from_url(
self.connection_pool: _BlockingConnectionPool = BlockingConnectionPool.from_url(
url=url,
max_connections=max_connection_pool_size,
**connection_kwargs,
Expand Down
18 changes: 14 additions & 4 deletions taskiq_redis/schedule_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from redis.asyncio import (
BlockingConnectionPool,
ConnectionPool,
Connection,
Redis,
RedisCluster,
Sentinel,
Expand All @@ -13,8 +13,7 @@
from taskiq.abc.serializer import TaskiqSerializer
from taskiq.compat import model_dump, model_validate
from taskiq.scheduler.scheduled_task import ScheduledTask

from taskiq_redis.serializer import PickleSerializer
from taskiq.serializers import PickleSerializer

if sys.version_info >= (3, 10):
from typing import TypeAlias
Expand All @@ -23,8 +22,10 @@

if TYPE_CHECKING:
_Redis: TypeAlias = Redis[bytes]
_BlockingConnectionPool: TypeAlias = BlockingConnectionPool[Connection]
else:
_Redis: TypeAlias = Redis
_BlockingConnectionPool: TypeAlias = BlockingConnectionPool


class RedisScheduleSource(ScheduleSource):
Expand Down Expand Up @@ -53,7 +54,7 @@ def __init__(
**connection_kwargs: Any,
) -> None:
self.prefix = prefix
self.connection_pool: ConnectionPool = BlockingConnectionPool.from_url(
self.connection_pool: _BlockingConnectionPool = BlockingConnectionPool.from_url(
url=url,
max_connections=max_connection_pool_size,
**connection_kwargs,
Expand Down Expand Up @@ -186,6 +187,10 @@ async def post_send(self, task: ScheduledTask) -> None:
if task.time is not None:
await self.delete_schedule(task.schedule_id)

async def shutdown(self) -> None:
"""Shut down the schedule source."""
await self.redis.aclose() # type: ignore[attr-defined]


class RedisSentinelScheduleSource(ScheduleSource):
"""
Expand Down Expand Up @@ -279,3 +284,8 @@ async def post_send(self, task: ScheduledTask) -> None:
"""Delete a task after it's completed."""
if task.time is not None:
await self.delete_schedule(task.schedule_id)

async def shutdown(self) -> None:
"""Shut down the schedule source."""
for sentinel in self.sentinel.sentinels:
await sentinel.aclose() # type: ignore[attr-defined]
16 changes: 0 additions & 16 deletions taskiq_redis/serializer.py

This file was deleted.

21 changes: 8 additions & 13 deletions tests/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,34 +90,29 @@ async def test_success_backend_default_result(


@pytest.mark.anyio
async def test_success_backend_custom_result(
async def test_error_backend_custom_result(
custom_taskiq_result: TaskiqResult[_ReturnType],
task_id: str,
redis_url: str,
) -> None:
"""
Tests normal behavior with custom result in TaskiqResult.

Setting custom class as a result should raise an error.

:param custom_taskiq_result: TaskiqResult with custom result.
:param task_id: ID for task.
:param redis_url: url to redis.
"""
backend: RedisAsyncResultBackend[_ReturnType] = RedisAsyncResultBackend(
redis_url,
)
await backend.set_result(
task_id=task_id,
result=custom_taskiq_result,
)
result = await backend.get_result(task_id=task_id)
with pytest.raises(ValueError):
await backend.set_result(
task_id=task_id,
result=custom_taskiq_result,
)

assert (
result.return_value.test_arg # type: ignore
== custom_taskiq_result.return_value.test_arg # type: ignore
)
assert result.is_err == custom_taskiq_result.is_err
assert result.execution_time == custom_taskiq_result.execution_time
assert result.log == custom_taskiq_result.log
await backend.shutdown()


Expand Down
Loading