Skip to content

Commit

Permalink
Wip
Browse files Browse the repository at this point in the history
  • Loading branch information
levsh committed Apr 17, 2024
1 parent aaedd91 commit 44ab29c
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 40 deletions.
2 changes: 1 addition & 1 deletion arrlio/backends/rabbitmq.py
Original file line number Diff line number Diff line change
Expand Up @@ -1029,7 +1029,7 @@ async def _send_task(self, task_instance: TaskInstance, **kwds): # pylint: disa
headers = {}
data: bytes = self.serializer.dumps_task_instance(task_instance, headers=headers)

await self._ensure_task_queue(task_instance.queue)
# await self._ensure_task_queue(task_instance.queue)

properties = {
"delivery_mode": 2,
Expand Down
111 changes: 73 additions & 38 deletions arrlio/core.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
import dataclasses
import logging
from asyncio import current_task, gather
from asyncio import TaskGroup, current_task, gather
from contextlib import AsyncExitStack
from contextvars import ContextVar
from inspect import isasyncgenfunction, isgeneratorfunction
Expand All @@ -15,7 +15,15 @@
from arrlio import settings
from arrlio.backends.base import Backend
from arrlio.configs import Config
from arrlio.exceptions import GraphError, NotFoundError, TaskClosedError, TaskError
from arrlio.exceptions import (
GraphError,
HooksError,
InternalError,
NotFoundError,
TaskCancelledError,
TaskClosedError,
TaskError,
)
from arrlio.executor import Executor
from arrlio.models import Event, Task, TaskInstance, TaskResult
from arrlio.plugins.base import Plugin
Expand Down Expand Up @@ -155,11 +163,11 @@ async def init(self):
if self.is_closed:
return

logger.info("%s: initializing with config\n%s", self, pretty_repr(self.config.model_dump()))
logger.info("%s initializing with config\n%s", self, pretty_repr(self.config.model_dump()))

await self._execute_hooks("on_init")

logger.info("%s: initialization done", self)
logger.info("%s initialization done", self)

async def close(self):
"""Close application."""
Expand All @@ -181,7 +189,7 @@ async def close(self):
await self._backend.close()

for task_id, aio_task in tuple(self._running_tasks.items()):
logger.warning("%s: cancel processing task '%s'", self, task_id)
logger.warning("%s cancel processing task '%s'", self, task_id)
aio_task.cancel()
try:
await aio_task
Expand All @@ -195,13 +203,19 @@ async def close(self):
async def _execute_hook(self, hook_fn, *args, **kwds):
try:
if is_debug_level():
logger.debug("%s: execute hook %s", self, hook_fn)
logger.debug("%s execute hook %s", self, hook_fn)
await hook_fn(*args, **kwds)
except Exception:
logger.exception("%s: hook %s error", self, hook_fn)
except Exception as e:
logger.exception("%s hook %s error", self, hook_fn)
raise e

async def _execute_hooks(self, hook: str, *args, **kwds):
await gather(*(self._execute_hook(hook_fn, *args, **kwds) for hook_fn in self._hooks[hook]))
try:
async with TaskGroup() as tg:
for hook_fn in self._hooks[hook]:
tg.create_task(self._execute_hook(hook_fn, *args, **kwds))
except ExceptionGroup as eg:
raise HooksError(exceptions=eg.exceptions)

async def send_task(
self,
Expand Down Expand Up @@ -250,7 +264,7 @@ async def send_task(

if is_info_level():
logger.info(
"%s: send task instance\n%s",
"%s send task instance\n%s",
self,
task_instance.pretty_repr(sanitize=settings.LOG_SANITIZE),
)
Expand All @@ -263,7 +277,7 @@ async def send_task(

async def send_event(self, event: Event):
if is_info_level():
logger.info("%s: send event\n%s", self, event.pretty_repr(sanitize=settings.LOG_SANITIZE))
logger.info("%s send event\n%s", self, event.pretty_repr(sanitize=settings.LOG_SANITIZE))

await self._backend.send_event(event)

Expand Down Expand Up @@ -301,57 +315,78 @@ async def consume_tasks(self, queues: list[str] | None = None):

async def cb(task_instance: TaskInstance):
task_id: UUID = task_instance.task_id

self._running_tasks[task_id] = current_task()
try:
async with AsyncExitStack() as stack:
self.context["task_instance"] = task_instance
for context_hook in self._hooks["task_context"]:
await stack.enter_async_context(context_hook(task_instance))

await self._execute_hooks("on_task_received", task_instance)
idx_0 = uuid4().hex
idx_1 = 0

task_result: TaskResult = TaskResult()
try:
task_result: TaskResult = TaskResult()

idx_0 = uuid4().hex
idx_1 = 0
async with AsyncExitStack() as stack:
try:
self.context["task_instance"] = task_instance

async for task_result in self.execute_task(task_instance):
idx_1 += 1
task_result.set_idx([idx_0, idx_1])
for context_hook in self._hooks["task_context"]:
await stack.enter_async_context(context_hook(task_instance))

if task_instance.result_return:
await self._backend.push_task_result(task_result, task_instance)
await self._execute_hooks("on_task_received", task_instance)

async for task_result in self.execute_task(task_instance):
task_result.set_idx((idx_0, idx_1 + 1))

await self._execute_hooks("on_task_result", task_instance, task_result)
if task_instance.result_return:
await self._backend.push_task_result(task_result, task_instance)

if task_instance.result_return and not task_instance.extra.get("graph:graph"):
func = task_instance.func
if isasyncgenfunction(func) or isgeneratorfunction(func):
await self._execute_hooks("on_task_result", task_instance, task_result)
idx_1 += 1
await self._backend.close_task(task_instance, idx=(idx_0, idx_1))

await self._execute_hooks("on_task_done", task_instance, task_result)
except (asyncio.CancelledError, Exception) as e:
if isinstance(e, asyncio.CancelledError):
logger.error("%s task %s[%s] cancelled", self, task_instance.name, task_id)
task_result = TaskResult(exc=TaskCancelledError(task_id))
raise e
if isinstance(e, HooksError):
# pylint: disable=no-member
if len(e.exceptions) == 1:
e = e.exceptions[0]
else:
e = TaskError(exceptions=e.exceptions)
logger.error("%s task %s[%s] %s: %s", self, task_instance.name, task_id, e.__class__, e)
task_result = TaskResult(exc=e)
else:
logger.exception(e)
task_result = TaskResult(exc=InternalError())
task_result.set_idx((idx_0, idx_1 + 1))
if task_instance.result_return:
await self._backend.push_task_result(task_result, task_instance)
idx_1 += 1
finally:
try:
if task_instance.result_return and not task_instance.extra.get("graph:graph"):
func = task_instance.func
if isasyncgenfunction(func) or isgeneratorfunction(func):
await self._backend.close_task(task_instance, idx=(idx_0, idx_1 + 1))
idx_1 += 1
finally:
await self._execute_hooks("on_task_done", task_instance, task_result)

except asyncio.CancelledError:
logger.error("%s: task %s(%s) cancelled", self, task_id, task_instance.name)
raise
except Exception as e:
logger.exception(e)
finally:
self._running_tasks.pop(task_id, None)

await self._backend.consume_tasks(queues, cb)
logger.info("%s: consuming task queues %s", self, queues)
logger.info("%s consuming task queues %s", self, queues)

async def stop_consume_tasks(self, queues: list[str] | None = None):
"""Stop consuming tasks."""

await self._backend.stop_consume_tasks(queues=queues)
if queues is not None:
logger.info("%s: stop consuming task queues %s", self, queues)
logger.info("%s stop consuming task queues %s", self, queues)
else:
logger.info("%s: stop consuming task queues", self)
logger.info("%s stop consuming task queues", self)

async def execute_task(self, task_instance: TaskInstance) -> AsyncGenerator[TaskResult, None]:
"""Execute the task instance locally by the application executor."""
Expand Down
38 changes: 37 additions & 1 deletion arrlio/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,42 @@ class ArrlioError(Exception):
"""Base exception."""


class TaskError(ArrlioError):
class InternalError(ArrlioError):
pass


class HooksError(ArrlioError):
def __init__(self, *args, exceptions: list[Exception] | tuple[Exception] | None = None):
super().__init__(*args)
self.exceptions = exceptions

def __str__(self):
if self.exceptions is not None:
return f"{self.exceptions}"
return super().__str__()

def __repr__(self):
if self.exceptions is not None:
return f"{self.exceptions}"
return super().__repr__()


class TaskError(ArrlioError):
def __init__(self, *args, exceptions: list[Exception] | None = None):
super().__init__(*args)
self.exceptions = exceptions

def __str__(self):
if self.exceptions is not None:
return f"{self.exceptions}"
return super().__str__()

def __repr__(self):
if self.exceptions is not None:
return f"{self.exceptions}"
return super().__repr__()


class TaskClosedError(ArrlioError):
pass

Expand All @@ -14,6 +46,10 @@ class TaskTimeoutError(ArrlioError):
pass


class TaskCancelledError(ArrlioError):
pass


class TaskResultError(ArrlioError):
pass

Expand Down

0 comments on commit 44ab29c

Please sign in to comment.