Skip to content

Commit

Permalink
Merge branch 'release/0.5.5'
Browse files Browse the repository at this point in the history
  • Loading branch information
s3rius committed Dec 12, 2023
2 parents 92ed3ae + e763dcd commit 8601636
Show file tree
Hide file tree
Showing 4 changed files with 235 additions and 4 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "taskiq-redis"
version = "0.5.4"
version = "0.5.5"
description = "Redis integration for taskiq"
authors = ["taskiq-team <taskiq@norely.com>"]
readme = "README.md"
Expand Down
6 changes: 5 additions & 1 deletion taskiq_redis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
)
from taskiq_redis.redis_broker import ListQueueBroker, PubSubBroker
from taskiq_redis.redis_cluster_broker import ListQueueClusterBroker
from taskiq_redis.schedule_source import RedisScheduleSource
from taskiq_redis.schedule_source import (
RedisClusterScheduleSource,
RedisScheduleSource,
)

__all__ = [
"RedisAsyncClusterResultBackend",
Expand All @@ -14,4 +17,5 @@
"PubSubBroker",
"ListQueueClusterBroker",
"RedisScheduleSource",
"RedisClusterScheduleSource",
]
81 changes: 80 additions & 1 deletion taskiq_redis/schedule_source.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any, List, Optional

from redis.asyncio import ConnectionPool, Redis
from redis.asyncio import ConnectionPool, Redis, RedisCluster
from taskiq import ScheduleSource
from taskiq.abc.serializer import TaskiqSerializer
from taskiq.compat import model_dump, model_validate
Expand Down Expand Up @@ -95,3 +95,82 @@ async def post_send(self, task: ScheduledTask) -> None:
async def shutdown(self) -> None:
"""Shut down the schedule source."""
await self.connection_pool.disconnect()


class RedisClusterScheduleSource(ScheduleSource):
"""
Source of schedules for redis cluster.
This class allows you to store schedules in redis.
Also it supports dynamic schedules.
:param url: url to redis cluster.
:param prefix: prefix for redis schedule keys.
:param buffer_size: buffer size for redis scan.
This is how many keys will be fetched at once.
:param max_connection_pool_size: maximum number of connections in pool.
:param serializer: serializer for data.
:param connection_kwargs: additional arguments for RedisCluster.
"""

def __init__(
self,
url: str,
prefix: str = "schedule",
buffer_size: int = 50,
serializer: Optional[TaskiqSerializer] = None,
**connection_kwargs: Any,
) -> None:
self.prefix = prefix
self.redis: RedisCluster[bytes] = RedisCluster.from_url(
url,
**connection_kwargs,
)
self.buffer_size = buffer_size
if serializer is None:
serializer = PickleSerializer()
self.serializer = serializer

async def delete_schedule(self, schedule_id: str) -> None:
"""Remove schedule by id."""
await self.redis.delete(f"{self.prefix}:{schedule_id}") # type: ignore[attr-defined]

async def add_schedule(self, schedule: ScheduledTask) -> None:
"""
Add schedule to redis.
:param schedule: schedule to add.
:param schedule_id: schedule id.
"""
await self.redis.set( # type: ignore[attr-defined]
f"{self.prefix}:{schedule.schedule_id}",
self.serializer.dumpb(model_dump(schedule)),
)

async def get_schedules(self) -> List[ScheduledTask]:
"""
Get all schedules from redis.
This method is used by scheduler to get all schedules.
:return: list of schedules.
"""
schedules = []
buffer = []
async for key in self.redis.scan_iter(f"{self.prefix}:*"): # type: ignore[attr-defined]
buffer.append(key)
if len(buffer) >= self.buffer_size:
schedules.extend(await self.redis.mget(buffer)) # type: ignore[attr-defined]
buffer = []
if buffer:
schedules.extend(await self.redis.mget(buffer)) # type: ignore[attr-defined]
return [
model_validate(ScheduledTask, self.serializer.loadb(schedule))
for schedule in schedules
if schedule
]

async def post_send(self, task: ScheduledTask) -> None:
"""Delete a task after it's completed."""
if task.time is not None:
await self.delete_schedule(task.schedule_id)
150 changes: 149 additions & 1 deletion tests/test_schedule_source.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import datetime as dt
import uuid

import pytest
from taskiq import ScheduledTask

from taskiq_redis import RedisScheduleSource
from taskiq_redis import RedisClusterScheduleSource, RedisScheduleSource


@pytest.mark.anyio
Expand Down Expand Up @@ -56,6 +57,153 @@ async def test_post_run_cron(redis_url: str) -> None:
cron="* * * * *",
)
await source.add_schedule(schedule)
assert await source.get_schedules() == [schedule]
await source.post_send(schedule)
assert await source.get_schedules() == [schedule]
await source.shutdown()


@pytest.mark.anyio
async def test_post_run_time(redis_url: str) -> None:
prefix = uuid.uuid4().hex
source = RedisScheduleSource(redis_url, prefix=prefix)
schedule = ScheduledTask(
task_name="test_task",
labels={},
args=[],
kwargs={},
time=dt.datetime(2000, 1, 1),
)
await source.add_schedule(schedule)
assert await source.get_schedules() == [schedule]
await source.post_send(schedule)
assert await source.get_schedules() == []
await source.shutdown()


@pytest.mark.anyio
async def test_buffer(redis_url: str) -> None:
prefix = uuid.uuid4().hex
source = RedisScheduleSource(redis_url, prefix=prefix, buffer_size=1)
schedule1 = ScheduledTask(
task_name="test_task1",
labels={},
args=[],
kwargs={},
cron="* * * * *",
)
schedule2 = ScheduledTask(
task_name="test_task2",
labels={},
args=[],
kwargs={},
cron="* * * * *",
)
await source.add_schedule(schedule1)
await source.add_schedule(schedule2)
schedules = await source.get_schedules()
assert len(schedules) == 2
assert schedule1 in schedules
assert schedule2 in schedules
await source.shutdown()


@pytest.mark.anyio
async def test_cluster_set_schedule(redis_cluster_url: str) -> None:
prefix = uuid.uuid4().hex
source = RedisClusterScheduleSource(redis_cluster_url, prefix=prefix)
schedule = ScheduledTask(
task_name="test_task",
labels={},
args=[],
kwargs={},
cron="* * * * *",
)
await source.add_schedule(schedule)
schedules = await source.get_schedules()
assert schedules == [schedule]
await source.shutdown()


@pytest.mark.anyio
async def test_cluster_delete_schedule(redis_cluster_url: str) -> None:
prefix = uuid.uuid4().hex
source = RedisClusterScheduleSource(redis_cluster_url, prefix=prefix)
schedule = ScheduledTask(
task_name="test_task",
labels={},
args=[],
kwargs={},
cron="* * * * *",
)
await source.add_schedule(schedule)
schedules = await source.get_schedules()
assert schedules == [schedule]
await source.delete_schedule(schedule.schedule_id)
schedules = await source.get_schedules()
# Schedules are empty.
assert not schedules
await source.shutdown()


@pytest.mark.anyio
async def test_cluster_post_run_cron(redis_cluster_url: str) -> None:
prefix = uuid.uuid4().hex
source = RedisClusterScheduleSource(redis_cluster_url, prefix=prefix)
schedule = ScheduledTask(
task_name="test_task",
labels={},
args=[],
kwargs={},
cron="* * * * *",
)
await source.add_schedule(schedule)
assert await source.get_schedules() == [schedule]
await source.post_send(schedule)
assert await source.get_schedules() == [schedule]
await source.shutdown()


@pytest.mark.anyio
async def test_cluster_post_run_time(redis_cluster_url: str) -> None:
prefix = uuid.uuid4().hex
source = RedisClusterScheduleSource(redis_cluster_url, prefix=prefix)
schedule = ScheduledTask(
task_name="test_task",
labels={},
args=[],
kwargs={},
time=dt.datetime(2000, 1, 1),
)
await source.add_schedule(schedule)
assert await source.get_schedules() == [schedule]
await source.post_send(schedule)
assert await source.get_schedules() == []
await source.shutdown()


@pytest.mark.anyio
async def test_cluster_buffer(redis_cluster_url: str) -> None:
prefix = uuid.uuid4().hex
source = RedisClusterScheduleSource(redis_cluster_url, prefix=prefix, buffer_size=1)
schedule1 = ScheduledTask(
task_name="test_task1",
labels={},
args=[],
kwargs={},
cron="* * * * *",
)
schedule2 = ScheduledTask(
task_name="test_task2",
labels={},
args=[],
kwargs={},
cron="* * * * *",
)
await source.add_schedule(schedule1)
await source.add_schedule(schedule2)
schedules = await source.get_schedules()
assert len(schedules) == 2
assert schedule1 in schedules
assert schedule2 in schedules
await source.shutdown()

0 comments on commit 8601636

Please sign in to comment.