Skip to content

Commit

Permalink
Pydantic v2
Browse files Browse the repository at this point in the history
  • Loading branch information
levsh committed Apr 12, 2024
1 parent 957286c commit 91b457a
Show file tree
Hide file tree
Showing 36 changed files with 1,272 additions and 876 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.8", "3.9", "3.10", "3.11"]
python-version: ["3.11"]

steps:
- uses: actions/checkout@v2
Expand Down
6 changes: 4 additions & 2 deletions arrlio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,16 @@
import sys

logger = logging.getLogger("arrlio")
logger.setLevel(logging.ERROR)

log_frmt = logging.Formatter("%(asctime)s %(levelname)-8s %(name)-27s lineno:%(lineno)4d -- %(message)s")
log_hndl = logging.StreamHandler(stream=sys.stderr)
log_hndl.setFormatter(log_frmt)
logger.addHandler(log_hndl)

# pylint: disable=wrong-import-position
from arrlio.configs import Config, TaskConfig # noqa
from arrlio.core import App, AsyncResult, registered_tasks, task # noqa
from arrlio.models import Graph, Task, TaskInstance, TaskResult # noqa
from arrlio.settings import Config, TaskConfig # noqa
from arrlio.settings import LOG_LEVEL

logger.setLevel(LOG_LEVEL)
63 changes: 40 additions & 23 deletions arrlio/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,46 +3,53 @@
import logging
from asyncio import create_task, current_task
from collections import defaultdict
from typing import Any, Callable, Dict, List, Set, Type, Union
from typing import Any, Callable, Coroutine, cast
from uuid import uuid4

from pydantic import Field
from pydantic_settings import BaseSettings, SettingsConfigDict

from arrlio.configs import ModuleConfigValidatorMixIn
from arrlio.models import Event, TaskInstance, TaskResult
from arrlio.serializers.base import Serializer
from arrlio.settings import ENV_PREFIX, BaseConfig, ConfigValidatorMixIn
from arrlio.types import AsyncFunction, SerializerModule
from arrlio.settings import ENV_PREFIX
from arrlio.types import ModuleConfig, SerializerModule

logger = logging.getLogger("arrlio.backends.base")


class SerializerConfig(ConfigValidatorMixIn, BaseConfig):
module: SerializerModule
config: Any = Field(default_factory=dict)
SERIALIZER = "arrlio.serializers.nop"

class Config:
env_prefix = [f"{ENV_PREFIX}SERIALIZER_"]

class SerializerConfig(BaseSettings, ModuleConfigValidatorMixIn):
"""Config for backend serializer."""

class Config(BaseConfig):
model_config = SettingsConfigDict(env_prefix=f"{ENV_PREFIX}SERIALIZER_")

module: SerializerModule = SERIALIZER
config: ModuleConfig = Field(default_factory=BaseSettings)


class Config(BaseSettings):
"""Config for backend."""

id: str = Field(default_factory=lambda: f"{uuid4()}")
serializer: SerializerConfig = Field(default_factory=lambda: SerializerConfig(module="arrlio.serializers.nop"))
serializer: SerializerConfig = Field(default_factory=SerializerConfig)


class Backend(abc.ABC):
__slots__ = ("config", "_serializer", "_closed", "_internal_tasks")
__slots__ = ("config", "serializer", "_internal_tasks", "_closed")

def __init__(self, config: Type[Config]):
def __init__(self, config: Config):
"""
Args:
config: Backend config.
"""
self.config: Type[Config] = config
self._serializer: Type[Serializer] = config.serializer.module.Serializer(config.serializer.config)

self.config: Config = config
self.serializer: Serializer = config.serializer.module.Serializer(config.serializer.config)
self._internal_tasks: dict[str, set[asyncio.Task]] = defaultdict(set)
self._closed: asyncio.Future = asyncio.Future()
self._internal_tasks: Dict[str, Set[asyncio.Task]] = defaultdict(set)

def __repr__(self):
return self.__str__()
Expand All @@ -61,7 +68,7 @@ def _create_internal_task(self, key: str, coro_factory: Callable) -> asyncio.Tas
raise Exception(f"{self} closed")

async def fn():
task: asyncio.Task = current_task()
task = cast(asyncio.Task, current_task())
internal_tasks = self._internal_tasks[key]
internal_tasks.add(task)
try:
Expand All @@ -87,7 +94,9 @@ async def close(self):
if self.is_closed:
return
try:
await asyncio.gather(self.stop_consume_tasks(), self.stop_consume_events())
async with asyncio.TaskGroup() as tg:
tg.create_task(self.stop_consume_tasks())
tg.create_task(self.stop_consume_events())
finally:
self._cancel_all_internal_tasks()
self._closed.set_result(None)
Expand All @@ -101,48 +110,56 @@ async def __aexit__(self, exc_type, exc, tb):
@abc.abstractmethod
async def send_task(self, task_instance: TaskInstance, **kwds):
"""Send task to backend."""

return

@abc.abstractmethod
async def close_task(self, task_instance: TaskInstance):
return

@abc.abstractmethod
async def consume_tasks(self, queues: List[str], callback: AsyncFunction):
async def consume_tasks(self, queues: list[str], callback: Callable[[TaskInstance], Coroutine]):
"""Consume tasks from the queues and invoke `callback` on `arrlio.models.TaskInstance` received."""

return

@abc.abstractmethod
async def stop_consume_tasks(self, queues: List[str] = None):
async def stop_consume_tasks(self, queues: list[str] | None = None):
"""Stop consuming tasks."""

return

@abc.abstractmethod
async def push_task_result(self, task_instance: TaskInstance, task_result: TaskResult):
async def push_task_result(self, task_result: TaskResult, task_instance: TaskInstance):
"""Push task result to backend."""

return

@abc.abstractmethod
async def pop_task_result(self, task_instance: TaskInstance) -> TaskResult:
"""Pop task result for `arrlio.models.TaskInstance` from backend."""

return

@abc.abstractmethod
async def send_event(self, event: Event):
"""Send event to backend."""

return

@abc.abstractmethod
async def consume_events(
self,
callback_id: str,
callback: Union[Callable, AsyncFunction],
event_types: List[str] = None,
callback: Callable[[Event], Any],
event_types: list[str] | None = None,
):
"""Consume event from the queues."""

return

@abc.abstractmethod
async def stop_consume_events(self, callback_id: str = None):
async def stop_consume_events(self, callback_id: str | None = None):
"""Stop consuming events."""

return
Loading

0 comments on commit 91b457a

Please sign in to comment.