Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/noob/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import logging
import multiprocessing as mp
import sys
from logging.handlers import RotatingFileHandler
from pathlib import Path
from typing import Any, Literal
Expand Down Expand Up @@ -193,5 +194,5 @@ def _get_console() -> Console:
current_pid = mp.current_process().pid
console = _console_by_pid.get(current_pid)
if console is None:
_console_by_pid[current_pid] = console = Console()
_console_by_pid[current_pid] = console = Console(file=sys.stdout)
return console
176 changes: 119 additions & 57 deletions src/noob/network/loop.py
Original file line number Diff line number Diff line change
@@ -1,84 +1,146 @@
import asyncio
import threading
import sys
from collections import defaultdict
from collections.abc import Callable, Coroutine
from typing import Any

try:
import zmq
from tornado.ioloop import IOLoop
from zmq.asyncio import Context, Socket
except ImportError as e:
raise ImportError(
"Attempted to import zmq runner, but zmq deps are not installed. install with `noob[zmq]`",
) from e

if sys.version_info < (3, 12):
from typing_extensions import TypedDict
else:
from typing import TypedDict

from noob.logging import init_logger
from noob.network.message import Message
from noob.utils import iscoroutinefunction_partial


class _CallbackDict(TypedDict):
sync: list[Callable[[Message], Any]]
asyncio: list[Callable[[Message], Coroutine]]


class EventloopMixin:
"""
Provide an eventloop in a separate thread to an inheriting class.
Any eventloop that is running in the current context is not used
because the inheriting classes are presumed to operate mostly synchronously for now,
pending a refactor to all async networking classes.
Mixin to provide common asyncio zmq scaffolding to networked classes.

Inheriting classes should, in order

* call the ``_init_loop`` method to create the eventloop, context, and poller
* populate the private ``_sockets`` and ``_receivers`` dicts
* await the ``_poll_sockets`` method, which polls indefinitely.

Inheriting classes **must** ensure that ``_init_loop``
is called in the thread it is intended to run in,
and that thread must already have a running eventloop.
asyncio eventloops (and most of asyncio) are **not** thread safe.

To help avoid cross-threading issues, the :meth:`.context` and :meth:`.loop`
properties do *not* automatically create the objects,
raising a :class:`.RuntimeError` if they are accessed before ``_init_loop`` is called.
"""

def __init__(self):
self._context = None
self._loop = None
self._quitting = asyncio.Event()
self._thread: threading.Thread | None = None
self._quitting: asyncio.Event = None # type: ignore[assignment]
self._sockets: dict[str, Socket] = {}
"""
All sockets, mapped from some common name to the socket.
The same key used here should be shared between _receivers and _callbacks
"""
self._receivers: dict[str, Socket] = {}
"""Sockets that should be polled for incoming messages"""
self._callbacks: dict[str, _CallbackDict] = defaultdict(
lambda: _CallbackDict(sync=[], asyncio=[])
)
"""Callbacks for each receiver socket"""
if not hasattr(self, "logger"):
self.logger = init_logger("eventloop")

@property
def context(self) -> zmq.Context:
def context(self) -> Context:
if self._context is None:
self._context = zmq.Context.instance()
raise RuntimeError("Loop has not been initialized with _init_loop!")
return self._context

@property
def loop(self) -> IOLoop:
# To ensure that the loop is always created in the spawned thread,
# we don't create it here (since this property could be accessed elsewhere)
# and throw to protect that.
def loop(self) -> asyncio.AbstractEventLoop:
if self._loop is None:
raise RuntimeError("Loop is not running")
raise RuntimeError("Loop has not been initialized with _init_loop!")
return self._loop

def start_loop(self) -> None:
if self._thread is not None:
raise RuntimeWarning("Node already started")

self._quitting.clear()

_ready = threading.Event()

def _signal_ready() -> None:
_ready.set()

def _run() -> None:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
self._loop = IOLoop.current()
if hasattr(self, "logger"):
self.logger.debug("Starting eventloop")
while not self._quitting.is_set():
try:
self.loop.add_callback(_signal_ready)
self.loop.start()

except RuntimeError:
# loop already started
if hasattr(self, "logger"):
self.logger.debug("Eventloop already started, quitting")
break
if hasattr(self, "logger"):
self.logger.debug("Eventloop stopped")
self._thread = None

self._thread = threading.Thread(target=_run)
self._thread.start()
# wait until the loop has started
_ready.wait(5)
if hasattr(self, "logger"):
self.logger.debug("Event loop started")

def stop_loop(self) -> None:
if self._thread is None:
@property
def sockets(self) -> dict[str, Socket]:
return self._sockets

def register_socket(self, name: str, socket: Socket, receiver: bool = False) -> None:
"""Register a socket, optionally declaring it as a receiver socket to poll"""
if name in self._sockets:
raise KeyError(f"Socket {name} already declared!")
self._sockets[name] = socket
if receiver:
self._receivers[name] = socket

def add_callback(
self, socket: str, callback: Callable[[Message], Any] | Callable[[Message], Coroutine]
) -> None:
"""
Add a callback to be called when the socket receives a message.
Callbacks are called in the order in which they are added.
"""
if socket not in self._receivers:
raise KeyError(f"Socket {socket} does not exist or is not a receiving socket")
if iscoroutinefunction_partial(callback):
self._callbacks[socket]["asyncio"].append(callback)
else:
self._callbacks[socket]["sync"].append(callback)

def clear_callbacks(self) -> None:
self._callbacks = defaultdict(lambda: _CallbackDict(sync=[], asyncio=[]))

def _init_loop(self) -> None:
self._loop = asyncio.get_running_loop()
self._context = Context.instance()
self._quitting = asyncio.Event()

def _stop_loop(self) -> None:
if self._quitting is None:
return
self._quitting.set()
self.loop.add_callback(self.loop.stop)

async def _poll_receivers(self) -> None:
"""
Rather than using the zmq.asyncio.Poller which wastes a ton of time,
it turns out doing it this way is roughly 4x as fast:
just manually poll the sockets, and if you have multiple sockets,
gather multiple coroutines where you're polling the sockets.
"""
if len(self._receivers) == 1:
await self._poll_receiver(next(iter(self._receivers.keys())))
else:
await asyncio.gather(*[self._poll_receiver(name) for name in self._receivers])

async def _poll_receiver(self, name: str) -> None:
socket = self._receivers[name]
while not self._quitting.is_set():
msg_bytes = await socket.recv_multipart()
try:
msg = Message.from_bytes(msg_bytes)
except Exception as e:
self.logger.exception(
"Exception decoding message for socket %s: %s, %s", name, msg_bytes, e
)
continue

# purposely don't catch errors here because we want them to bubble up into the caller
for acb in self._callbacks[name]["asyncio"]:
await acb(msg)
for cb in self._callbacks[name]["sync"]:
self.loop.run_in_executor(None, cb, msg)
9 changes: 9 additions & 0 deletions src/noob/network/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,15 @@ class ErrorMsg(Message):

model_config = ConfigDict(arbitrary_types_allowed=True)

def to_exception(self) -> Exception:
err = self.value["err_type"](*self.value["err_args"])
tb_message = "\nError re-raised from node runner process\n\n"
tb_message += "Original traceback:\n"
tb_message += "-" * 20 + "\n"
tb_message += self.value["traceback"]
err.add_note(tb_message)
return err


def _to_json(val: Event, handler: SerializerFunctionWrapHandler) -> Any:
if val["signal"] == META_SIGNAL and val["value"] is MetaSignal.NoEvent:
Expand Down
4 changes: 4 additions & 0 deletions src/noob/runner/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,10 @@ def call_async_from_sync(
References:
* https://github.com/django/asgiref/blob/2b28409ab83b3e4cf6fed9019403b71f8d7d1c51/asgiref/sync.py#L152
* https://stackoverflow.com/questions/79663750/call-async-code-inside-sync-code-inside-async-code
* https://github.com/python/cpython/issues/66435
* https://github.com/python/cpython/issues/93462
* https://discuss.python.org/t/support-for-running-async-functions-in-sync-functions/16220/3
* https://github.com/fsspec/filesystem_spec/blob/2576617e5cbe441bcc53b021bccd85ff3489fde7/fsspec/asyn.py#L63
"""
if not iscoroutinefunction_partial(fn):
raise RuntimeError(
Expand Down
Loading