diff --git a/src/noob/logging.py b/src/noob/logging.py index 54a2719..55519d0 100644 --- a/src/noob/logging.py +++ b/src/noob/logging.py @@ -4,6 +4,7 @@ import logging import multiprocessing as mp +import sys from logging.handlers import RotatingFileHandler from pathlib import Path from typing import Any, Literal @@ -193,5 +194,5 @@ def _get_console() -> Console: current_pid = mp.current_process().pid console = _console_by_pid.get(current_pid) if console is None: - _console_by_pid[current_pid] = console = Console() + _console_by_pid[current_pid] = console = Console(file=sys.stdout) return console diff --git a/src/noob/network/loop.py b/src/noob/network/loop.py index 74e6e5e..efa49eb 100644 --- a/src/noob/network/loop.py +++ b/src/noob/network/loop.py @@ -1,84 +1,146 @@ import asyncio -import threading +import sys +from collections import defaultdict +from collections.abc import Callable, Coroutine +from typing import Any try: - import zmq - from tornado.ioloop import IOLoop + from zmq.asyncio import Context, Socket except ImportError as e: raise ImportError( "Attempted to import zmq runner, but zmq deps are not installed. install with `noob[zmq]`", ) from e +if sys.version_info < (3, 12): + from typing_extensions import TypedDict +else: + from typing import TypedDict + +from noob.logging import init_logger +from noob.network.message import Message +from noob.utils import iscoroutinefunction_partial + + +class _CallbackDict(TypedDict): + sync: list[Callable[[Message], Any]] + asyncio: list[Callable[[Message], Coroutine]] + class EventloopMixin: """ - Provide an eventloop in a separate thread to an inheriting class. - Any eventloop that is running in the current context is not used - because the inheriting classes are presumed to operate mostly synchronously for now, - pending a refactor to all async networking classes. + Mixin to provide common asyncio zmq scaffolding to networked classes. + + Inheriting classes should, in order + + * call the ``_init_loop`` method to create the eventloop, context, and poller + * populate the private ``_sockets`` and ``_receivers`` dicts + * await the ``_poll_sockets`` method, which polls indefinitely. + + Inheriting classes **must** ensure that ``_init_loop`` + is called in the thread it is intended to run in, + and that thread must already have a running eventloop. + asyncio eventloops (and most of asyncio) are **not** thread safe. + + To help avoid cross-threading issues, the :meth:`.context` and :meth:`.loop` + properties do *not* automatically create the objects, + raising a :class:`.RuntimeError` if they are accessed before ``_init_loop`` is called. """ def __init__(self): self._context = None self._loop = None - self._quitting = asyncio.Event() - self._thread: threading.Thread | None = None + self._quitting: asyncio.Event = None # type: ignore[assignment] + self._sockets: dict[str, Socket] = {} + """ + All sockets, mapped from some common name to the socket. + The same key used here should be shared between _receivers and _callbacks + """ + self._receivers: dict[str, Socket] = {} + """Sockets that should be polled for incoming messages""" + self._callbacks: dict[str, _CallbackDict] = defaultdict( + lambda: _CallbackDict(sync=[], asyncio=[]) + ) + """Callbacks for each receiver socket""" + if not hasattr(self, "logger"): + self.logger = init_logger("eventloop") @property - def context(self) -> zmq.Context: + def context(self) -> Context: if self._context is None: - self._context = zmq.Context.instance() + raise RuntimeError("Loop has not been initialized with _init_loop!") return self._context @property - def loop(self) -> IOLoop: - # To ensure that the loop is always created in the spawned thread, - # we don't create it here (since this property could be accessed elsewhere) - # and throw to protect that. + def loop(self) -> asyncio.AbstractEventLoop: if self._loop is None: - raise RuntimeError("Loop is not running") + raise RuntimeError("Loop has not been initialized with _init_loop!") return self._loop - def start_loop(self) -> None: - if self._thread is not None: - raise RuntimeWarning("Node already started") - - self._quitting.clear() - - _ready = threading.Event() - - def _signal_ready() -> None: - _ready.set() - - def _run() -> None: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - self._loop = IOLoop.current() - if hasattr(self, "logger"): - self.logger.debug("Starting eventloop") - while not self._quitting.is_set(): - try: - self.loop.add_callback(_signal_ready) - self.loop.start() - - except RuntimeError: - # loop already started - if hasattr(self, "logger"): - self.logger.debug("Eventloop already started, quitting") - break - if hasattr(self, "logger"): - self.logger.debug("Eventloop stopped") - self._thread = None - - self._thread = threading.Thread(target=_run) - self._thread.start() - # wait until the loop has started - _ready.wait(5) - if hasattr(self, "logger"): - self.logger.debug("Event loop started") - - def stop_loop(self) -> None: - if self._thread is None: + @property + def sockets(self) -> dict[str, Socket]: + return self._sockets + + def register_socket(self, name: str, socket: Socket, receiver: bool = False) -> None: + """Register a socket, optionally declaring it as a receiver socket to poll""" + if name in self._sockets: + raise KeyError(f"Socket {name} already declared!") + self._sockets[name] = socket + if receiver: + self._receivers[name] = socket + + def add_callback( + self, socket: str, callback: Callable[[Message], Any] | Callable[[Message], Coroutine] + ) -> None: + """ + Add a callback to be called when the socket receives a message. + Callbacks are called in the order in which they are added. + """ + if socket not in self._receivers: + raise KeyError(f"Socket {socket} does not exist or is not a receiving socket") + if iscoroutinefunction_partial(callback): + self._callbacks[socket]["asyncio"].append(callback) + else: + self._callbacks[socket]["sync"].append(callback) + + def clear_callbacks(self) -> None: + self._callbacks = defaultdict(lambda: _CallbackDict(sync=[], asyncio=[])) + + def _init_loop(self) -> None: + self._loop = asyncio.get_running_loop() + self._context = Context.instance() + self._quitting = asyncio.Event() + + def _stop_loop(self) -> None: + if self._quitting is None: return self._quitting.set() - self.loop.add_callback(self.loop.stop) + + async def _poll_receivers(self) -> None: + """ + Rather than using the zmq.asyncio.Poller which wastes a ton of time, + it turns out doing it this way is roughly 4x as fast: + just manually poll the sockets, and if you have multiple sockets, + gather multiple coroutines where you're polling the sockets. + """ + if len(self._receivers) == 1: + await self._poll_receiver(next(iter(self._receivers.keys()))) + else: + await asyncio.gather(*[self._poll_receiver(name) for name in self._receivers]) + + async def _poll_receiver(self, name: str) -> None: + socket = self._receivers[name] + while not self._quitting.is_set(): + msg_bytes = await socket.recv_multipart() + try: + msg = Message.from_bytes(msg_bytes) + except Exception as e: + self.logger.exception( + "Exception decoding message for socket %s: %s, %s", name, msg_bytes, e + ) + continue + + # purposely don't catch errors here because we want them to bubble up into the caller + for acb in self._callbacks[name]["asyncio"]: + await acb(msg) + for cb in self._callbacks[name]["sync"]: + self.loop.run_in_executor(None, cb, msg) diff --git a/src/noob/network/message.py b/src/noob/network/message.py index 339986e..590a7d8 100644 --- a/src/noob/network/message.py +++ b/src/noob/network/message.py @@ -172,6 +172,15 @@ class ErrorMsg(Message): model_config = ConfigDict(arbitrary_types_allowed=True) + def to_exception(self) -> Exception: + err = self.value["err_type"](*self.value["err_args"]) + tb_message = "\nError re-raised from node runner process\n\n" + tb_message += "Original traceback:\n" + tb_message += "-" * 20 + "\n" + tb_message += self.value["traceback"] + err.add_note(tb_message) + return err + def _to_json(val: Event, handler: SerializerFunctionWrapHandler) -> Any: if val["signal"] == META_SIGNAL and val["value"] is MetaSignal.NoEvent: diff --git a/src/noob/runner/base.py b/src/noob/runner/base.py index 876dd65..3ec11ea 100644 --- a/src/noob/runner/base.py +++ b/src/noob/runner/base.py @@ -526,6 +526,10 @@ def call_async_from_sync( References: * https://github.com/django/asgiref/blob/2b28409ab83b3e4cf6fed9019403b71f8d7d1c51/asgiref/sync.py#L152 * https://stackoverflow.com/questions/79663750/call-async-code-inside-sync-code-inside-async-code + * https://github.com/python/cpython/issues/66435 + * https://github.com/python/cpython/issues/93462 + * https://discuss.python.org/t/support-for-running-async-functions-in-sync-functions/16220/3 + * https://github.com/fsspec/filesystem_spec/blob/2576617e5cbe441bcc53b021bccd85ff3489fde7/fsspec/asyn.py#L63 """ if not iscoroutinefunction_partial(fn): raise RuntimeError( diff --git a/src/noob/runner/zmq.py b/src/noob/runner/zmq.py index ac85991..0838043 100644 --- a/src/noob/runner/zmq.py +++ b/src/noob/runner/zmq.py @@ -22,22 +22,24 @@ """ +import asyncio +import concurrent.futures import math import multiprocessing as mp import os import signal import threading import traceback -from collections import defaultdict -from collections.abc import Callable, Generator +from collections.abc import AsyncGenerator, Generator, MutableSequence from dataclasses import dataclass, field +from datetime import datetime +from functools import partial from itertools import count from multiprocessing.synchronize import Event as EventType from time import time from types import FrameType -from typing import TYPE_CHECKING, Any, Literal, cast, overload - -from noob.network.loop import EventloopMixin +from typing import TYPE_CHECKING, Any, cast, overload +from uuid import uuid4 try: import zmq @@ -47,13 +49,12 @@ ) from e -from zmq.eventloop.zmqstream import ZMQStream - from noob.config import config -from noob.event import Event, MetaSignal +from noob.event import Event, MetaEvent, MetaEventType, MetaSignal from noob.exceptions import InputMissingError from noob.input import InputCollection, InputScope from noob.logging import init_logger +from noob.network.loop import EventloopMixin from noob.network.message import ( AnnounceMsg, AnnounceValue, @@ -73,7 +74,7 @@ StopMsg, ) from noob.node import Node, NodeSpecification, Return, Signal -from noob.runner.base import TubeRunner, call_async_from_sync +from noob.runner.base import TubeRunner from noob.scheduler import Scheduler from noob.store import EventStore from noob.types import NodeID, ReturnNodeType @@ -104,17 +105,15 @@ def __init__(self, runner_id: str, protocol: str = "ipc", port: int | None = Non protocol: port: """ - super().__init__() + self.runner_id = runner_id self.port = port self.protocol = protocol self.logger = init_logger(f"runner.node.{runner_id}.command") - self._outbox: zmq.Socket = None # type: ignore[assignment] - self._inbox: ZMQStream = None # type: ignore[assignment] - self._router: ZMQStream = None # type: ignore[assignment] self._nodes: dict[str, IdentifyValue] = {} - self._ready_condition = threading.Condition() - self._callbacks: dict[str, list[Callable[[Message], Any]]] = defaultdict(list) + self._ready_condition: threading.Condition = None # type: ignore[assignment] + self._init = threading.Event() + super().__init__() @property def pub_address(self) -> str: @@ -136,98 +135,107 @@ def router_address(self) -> str: else: raise NotImplementedError() + def run(self) -> None: + """ + Target for :class:`threading.Thread` + """ + asyncio.run(self._run()) + + async def _run(self) -> None: + self.init() + await self._poll_receivers() + def init(self) -> None: self.logger.debug("Starting command runner") - self.start_loop() + self._init.clear() + self._init_loop() + self._ready_condition = threading.Condition() self._init_sockets() + self._init.set() self.logger.debug("Command runner started") def deinit(self) -> None: """Close the eventloop, stop processing messages, reset state""" self.logger.debug("Deinitializing") - msg = DeinitMsg(node_id="command") - self._outbox.send_multipart([b"deinit", msg.to_bytes()]) - self.stop_loop() - self.logger.debug("Deinitialized") + + async def _deinit() -> None: + msg = DeinitMsg(node_id="command") + await self.sockets["outbox"].send_multipart([b"deinit", msg.to_bytes()]) + self._quitting.set() + + self.loop.create_task(_deinit()) + self.logger.debug("Queued loop for deinitialization") def stop(self) -> None: self.logger.debug("Stopping command runner") msg = StopMsg(node_id="command") - self._outbox.send_multipart([b"stop", msg.to_bytes()]) + self.loop.call_soon_threadsafe( + self.sockets["outbox"].send_multipart, [b"stop", msg.to_bytes()] + ) self.logger.debug("Command runner stopped") def _init_sockets(self) -> None: - self._outbox = self._init_outbox() - self._router = self._init_router() - self._inbox = self._init_inbox() + self._init_outbox() + self._init_router() + self._init_inbox() - def _init_outbox(self) -> zmq.Socket: + def _init_outbox(self) -> None: """Create the main control publisher""" pub = self.context.socket(zmq.PUB) pub.bind(self.pub_address) pub.setsockopt_string(zmq.IDENTITY, "command.outbox") - return pub + self.register_socket("outbox", pub) - def _init_router(self) -> ZMQStream: + def _init_router(self) -> None: """Create the inbox router""" router = self.context.socket(zmq.ROUTER) router.bind(self.router_address) router.setsockopt_string(zmq.IDENTITY, "command.router") - router = ZMQStream(router, self.loop) - router.on_recv(self.on_router) - self.logger.debug("Inbox bound to %s", self.router_address) - return router + self.register_socket("router", router, receiver=True) + self.add_callback("router", self.on_router) + self.logger.debug("Router bound to %s", self.router_address) - def _init_inbox(self) -> ZMQStream: + def _init_inbox(self) -> None: """Subscriber that receives all events from running nodes""" sub = self.context.socket(zmq.SUB) sub.setsockopt_string(zmq.IDENTITY, "command.inbox") sub.setsockopt_string(zmq.SUBSCRIBE, "") - sub = ZMQStream(sub, self.loop) - sub.on_recv(self.on_inbox) - return sub + self.register_socket("inbox", sub, receiver=True) - def announce(self) -> None: + async def announce(self) -> None: msg = AnnounceMsg( node_id="command", value=AnnounceValue(inbox=self.router_address, nodes=self._nodes) ) - self._outbox.send_multipart([b"announce", msg.to_bytes()]) + await self.sockets["outbox"].send_multipart([b"announce", msg.to_bytes()]) - def ping(self) -> None: + async def ping(self) -> None: """Send a ping message asking everyone to identify themselves""" msg = PingMsg(node_id="command") - self._outbox.send_multipart([b"ping", msg.to_bytes()]) + await self.sockets["outbox"].send_multipart([b"ping", msg.to_bytes()]) def start(self, n: int | None = None) -> None: """ Start running in free-run mode """ - self._outbox.send_multipart([b"start", StartMsg(node_id="command", value=n).to_bytes()]) + self.loop.call_soon_threadsafe( + self.sockets["outbox"].send_multipart, + [b"start", StartMsg(node_id="command", value=n).to_bytes()], + ) self.logger.debug("Sent start message") def process(self, epoch: int, input: dict | None = None) -> None: """Emit a ProcessMsg to process a single round through the graph""" # no empty dicts input = input if input else None - self._outbox.send_multipart( + self.loop.call_soon_threadsafe( + self.sockets["outbox"].send_multipart, [ b"process", ProcessMsg(node_id="command", value={"input": input, "epoch": epoch}).to_bytes(), - ] + ], ) self.logger.debug("Sent process message") - def add_callback(self, type_: Literal["inbox", "router"], cb: Callable[[Message], Any]) -> None: - """ - Add a callback called for message received - - by the inbox: the subscriber that receives all events from node runners - - by the router: direct messages sent by node runners to the command node - """ - self._callbacks[type_].append(cb) - - def clear_callbacks(self) -> None: - self._callbacks = defaultdict(list) - def await_ready(self, node_ids: list[NodeID], timeout: float = 10) -> None: """ Wait until all the node_ids have announced themselves @@ -253,7 +261,7 @@ def _is_ready() -> bool: while time() < start_time + timeout and not ready: ready = self._ready_condition.wait_for(_is_ready, timeout=1) if not ready: - self.ping() + self.loop.call_soon_threadsafe(self.loop.create_task, self.ping()) # if still not ready, timeout if not ready: @@ -263,51 +271,39 @@ def _is_ready() -> bool: f"ready: {_ready_nodes()}" ) - def on_router(self, msg: list[bytes]) -> None: - try: - message = Message.from_bytes(msg) - self.logger.debug("Received ROUTER message %s", message) - except Exception as e: - self.logger.exception("Exception decoding: %s, %s", msg, e) - raise e - - for cb in self._callbacks["router"]: - cb(message) + async def on_router(self, message: Message) -> None: + self.logger.debug("Received ROUTER message %s", message) if message.type_ == MessageType.identify: message = cast(IdentifyMsg, message) - self.on_identify(message) + await self.on_identify(message) elif message.type_ == MessageType.status: message = cast(StatusMsg, message) - self.on_status(message) - - def on_inbox(self, msg: list[bytes]) -> None: - message = Message.from_bytes(msg) - self.logger.debug("Received INBOX message: %s", message) - for cb in self._callbacks["inbox"]: - cb(message) + await self.on_status(message) - def on_identify(self, msg: IdentifyMsg) -> None: - with self._ready_condition: - self._nodes[msg.node_id] = msg.value - self._inbox.connect(msg.value["outbox"]) - self._ready_condition.notify_all() + async def on_identify(self, msg: IdentifyMsg) -> None: + self._nodes[msg.node_id] = msg.value + self.sockets["inbox"].connect(msg.value["outbox"]) try: - self.announce() + await self.announce() self.logger.debug("Announced") except Exception as e: self.logger.exception("Exception announced: %s", e) - def on_status(self, msg: StatusMsg) -> None: with self._ready_condition: - if msg.node_id not in self._nodes: - self.logger.warning( - "Node %s sent us a status before sending its full identify message, ignoring", - msg.node_id, - ) - return - self._nodes[msg.node_id]["status"] = msg.value + self._ready_condition.notify_all() + + async def on_status(self, msg: StatusMsg) -> None: + if msg.node_id not in self._nodes: + self.logger.warning( + "Node %s sent us a status before sending its full identify message, ignoring", + msg.node_id, + ) + return + self._nodes[msg.node_id]["status"] = msg.value + + with self._ready_condition: self._ready_condition.notify_all() @@ -329,7 +325,6 @@ def __init__( input_collection: InputCollection, protocol: str = "ipc", ): - super().__init__() self.spec = spec self.runner_id = runner_id self.input_collection = input_collection @@ -340,20 +335,19 @@ def __init__( self.scheduler: Scheduler = None # type: ignore[assignment] self.logger = init_logger(f"runner.node.{runner_id}.{self.spec.id}") - self._dealer: ZMQStream = None # type: ignore[assignment] - self._outbox: zmq.Socket = None # type: ignore[assignment] - self._inbox: ZMQStream = None # type: ignore[assignment] self._node: Node | None = None self._depends: tuple[tuple[str, str], ...] | None = None self._has_input: bool | None = None self._nodes: dict[str, IdentifyValue] = {} self._counter = count() - self._process_quitting = mp.Event() - self._freerun = mp.Event() - self._process_one = mp.Event() + self._freerun = asyncio.Event() + self._process_one = asyncio.Event() self._status: NodeStatus = NodeStatus.stopped - self._status_lock = mp.RLock() + self._status_lock = asyncio.Lock() + self._ready_condition = asyncio.Condition() self._to_process = 0 + super().__init__() + self._quitting = asyncio.Event() @property def outbox_address(self) -> str: @@ -386,13 +380,11 @@ def has_input(self) -> bool: @property def status(self) -> NodeStatus: - with self._status_lock: - return self._status + return self._status @status.setter def status(self, status: NodeStatus) -> None: - with self._status_lock: - self._status = status + self._status = status @classmethod def run(cls, spec: NodeSpecification, **kwargs: Any) -> None: @@ -400,7 +392,16 @@ def run(cls, spec: NodeSpecification, **kwargs: Any) -> None: Target for multiprocessing.run, init the class and start it! """ - runner = NodeRunner(spec=spec, **kwargs) + + # ensure that events and conditions are bound to the eventloop created in the process + async def _run_inner() -> None: + nonlocal spec, kwargs + runner = NodeRunner(spec=spec, **kwargs) + await runner._run() + + asyncio.run(_run_inner()) + + async def _run(self) -> None: try: def _handler(sig: int, frame: FrameType | None = None) -> None: @@ -408,48 +409,53 @@ def _handler(sig: int, frame: FrameType | None = None) -> None: raise KeyboardInterrupt() signal.signal(signal.SIGTERM, _handler) - runner.init() - runner._node = cast(Node, runner._node) - runner._process_quitting.clear() - runner._freerun.clear() - runner._process_one.clear() - - is_async = iscoroutinefunction_partial(runner._node.process) + await self.init() + self._node = cast(Node, self._node) + self._freerun.clear() + self._process_one.clear() + await asyncio.gather(self._poll_receivers(), self._process_loop()) + except KeyboardInterrupt: + self.logger.debug("Got keyboard interrupt, quitting") + except Exception as e: + await self.error(e) + finally: + await self.deinit() - for args, kwargs, epoch in runner.await_inputs(): - runner.logger.debug( + async def _process_loop(self) -> None: + self._node = cast(Node, self._node) + is_async = iscoroutinefunction_partial(self._node.process) + loop = asyncio.get_running_loop() + with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor: + async for args, kwargs, epoch in self.await_inputs(): + self.logger.debug( "Running with args: %s, kwargs: %s, epoch: %s", args, kwargs, epoch ) if is_async: # mypy fails here because it can't propagate the type guard above - value = call_async_from_sync(runner._node.process, *args, **kwargs) # type: ignore[arg-type] + value = await self._node.process(*args, **kwargs) # type: ignore[misc] else: - value = runner._node.process(*args, **kwargs) - events = runner.store.add_value(runner._node.signals, value, runner._node.id, epoch) - runner.scheduler.add_epoch() - - # node runners should not report epoch endings + part = partial(self._node.process, *args, **kwargs) + value = await loop.run_in_executor(executor, part) + events = self.store.add_value(self._node.signals, value, self._node.id, epoch) + async with self._ready_condition: + self.scheduler.add_epoch() + self._ready_condition.notify_all() + + # nodes should not report epoch endings since they don't know about the full tube events = [e for e in events if e["node_id"] != "meta"] if events: - runner.update_graph(events) - runner.publish_events(events) - - except KeyboardInterrupt: - runner.logger.debug("Got keyboard interrupt, quitting") - except Exception as e: - runner.error(e) - finally: - runner.deinit() + await self.update_graph(events) + await self.publish_events(events) - def await_inputs(self) -> Generator[tuple[tuple[Any], dict[str, Any], int]]: + async def await_inputs(self) -> AsyncGenerator[tuple[tuple[Any], dict[str, Any], int]]: self._node = cast(Node, self._node) - while not self._process_quitting.is_set(): + while not self._quitting.is_set(): # if we are not freerunning, keep track of how many times we are supposed to run, # and run until we aren't supposed to anymore! if not self._freerun.is_set(): if self._to_process <= 0: self._to_process = 0 - self._process_one.wait() + await self._process_one.wait() self._to_process -= 1 if self._to_process <= 0: self._to_process = 0 @@ -457,7 +463,7 @@ def await_inputs(self) -> Generator[tuple[tuple[Any], dict[str, Any], int]]: epoch = next(self._counter) if self._node.stateful else None - ready = self.scheduler.await_node(self.spec.id, epoch=epoch) + ready = await self.await_node(epoch=epoch) edges = self._node.edges inputs = self.store.collect(edges, ready["epoch"]) if inputs is None: @@ -467,35 +473,43 @@ def await_inputs(self) -> Generator[tuple[tuple[Any], dict[str, Any], int]]: self.store.clear(ready["epoch"]) yield args, kwargs, ready["epoch"] - def update_graph(self, events: list[Event]) -> None: - self.scheduler.update(events) + async def update_graph(self, events: list[Event]) -> None: + async with self._ready_condition: + self.scheduler.update(events) + self._ready_condition.notify_all() - def publish_events(self, events: list[Event]) -> None: + async def publish_events(self, events: list[Event]) -> None: msg = EventMsg(node_id=self.spec.id, value=events) - self._outbox.send_multipart([b"event", msg.to_bytes()]) + await self.sockets["outbox"].send_multipart([b"event", msg.to_bytes()]) - def init(self) -> None: + async def init(self) -> None: self.logger.debug("Initializing") - - self.init_node() - self.start_sockets() + await self.init_node() + self._init_sockets() + self._quitting.clear() self.status = ( NodeStatus.waiting if self.depends and [d for d in self.depends if d[0] != "input"] else NodeStatus.ready ) - self.identify() + await self.identify() self.logger.debug("Initialization finished") - def deinit(self) -> None: + async def deinit(self) -> None: + """ + Deinitialize the node class after receiving on_deinit message and + draining out the end of the _process_loop. + """ self.logger.debug("Deinitializing") if self._node is not None: self._node.deinit() - self.update_status(NodeStatus.closed) - self.stop_loop() + + # should have already been called in on_deinit, but just to make sure we're killed dead... + self._quitting.set() + self.logger.debug("Deinitialization finished") - def identify(self) -> None: + async def identify(self) -> None: """ Send the command node an announce to say we're alive """ @@ -506,7 +520,7 @@ def identify(self) -> None: ) self.logger.debug("Identifying") - with self._status_lock: + async with self._status_lock: ann = IdentifyMsg( node_id=self.spec.id, value=IdentifyValue( @@ -519,43 +533,40 @@ def identify(self) -> None: ), ), ) - self._dealer.send_multipart([ann.to_bytes()]) + await self.sockets["dealer"].send_multipart([ann.to_bytes()]) self.logger.debug("Sent identification message: %s", ann) - def update_status(self, status: NodeStatus) -> None: + async def update_status(self, status: NodeStatus) -> None: """Update our internal status and announce it to the command node""" self.logger.debug("Updating status as %s", status) - with self._status_lock: + async with self._status_lock: self.status = status msg = StatusMsg(node_id=self.spec.id, value=status) - self._dealer.send_multipart([msg.to_bytes()]) + await self.sockets["dealer"].send_multipart([msg.to_bytes()]) self.logger.debug("Updated status") - def start_sockets(self) -> None: - self.start_loop() - self._init_sockets() - - def init_node(self) -> None: + async def init_node(self) -> None: self._node = Node.from_specification(self.spec, self.input_collection) self._node.init() self.scheduler = Scheduler(nodes={self.spec.id: self.spec}, edges=self._node.edges) - self.scheduler.add_epoch() + async with self._ready_condition: + self.scheduler.add_epoch() + self._ready_condition.notify_all() def _init_sockets(self) -> None: - self._dealer = self._init_dealer() - self._outbox = self._init_outbox() - self._inbox = self._init_inbox() + self._init_loop() + self._init_dealer() + self._init_outbox() + self._init_inbox() - def _init_dealer(self) -> ZMQStream: + def _init_dealer(self) -> None: dealer = self.context.socket(zmq.DEALER) dealer.setsockopt_string(zmq.IDENTITY, self.spec.id) dealer.connect(self.command_router) - dealer = ZMQStream(dealer, self.loop) - dealer.on_recv(self.on_dealer) + self.register_socket("dealer", dealer) self.logger.debug("Connected to command node at %s", self.command_router) - return dealer - def _init_outbox(self) -> zmq.Socket: + def _init_outbox(self) -> None: pub = self.context.socket(zmq.PUB) pub.setsockopt_string(zmq.IDENTITY, self.spec.id) if self.protocol == "ipc": @@ -564,10 +575,9 @@ def _init_outbox(self) -> zmq.Socket: raise NotImplementedError() # something like: # port = pub.bind_to_random_port(self.protocol) + self.register_socket("outbox", pub) - return pub - - def _init_inbox(self) -> ZMQStream: + def _init_inbox(self) -> None: """ Init the subscriber, but don't attempt to subscribe to anything but the command yet! we do that when we get node Announces @@ -576,79 +586,67 @@ def _init_inbox(self) -> ZMQStream: sub.setsockopt_string(zmq.IDENTITY, self.spec.id) sub.setsockopt_string(zmq.SUBSCRIBE, "") sub.connect(self.command_outbox) - sub = ZMQStream(sub, self.loop) - sub.on_recv(self.on_inbox) + self.register_socket("inbox", sub, receiver=True) + self.add_callback("inbox", self.on_inbox) self.logger.debug("Subscribed to command outbox %s", self.command_outbox) - return sub - - def on_dealer(self, msg: list[bytes]) -> None: - self.logger.debug("DEALER received %s", msg) - - def on_inbox(self, msg: list[bytes]) -> None: - try: - message = Message.from_bytes(msg) - - self.logger.debug("INBOX received %s", msg) - except Exception as e: - self.logger.exception("Error decoding message %s %s", msg, e) - return + async def on_inbox(self, message: Message) -> None: # FIXME: all this switching sux, # just have a decorator to register a handler for a given message type if message.type_ == MessageType.announce: message = cast(AnnounceMsg, message) - self.on_announce(message) + await self.on_announce(message) elif message.type_ == MessageType.event: message = cast(EventMsg, message) - self.on_event(message) + await self.on_event(message) elif message.type_ == MessageType.process: message = cast(ProcessMsg, message) - self.on_process(message) + await self.on_process(message) elif message.type_ == MessageType.start: message = cast(StartMsg, message) - self.on_start(message) + await self.on_start(message) elif message.type_ == MessageType.stop: message = cast(StopMsg, message) - self.on_stop(message) + await self.on_stop(message) elif message.type_ == MessageType.deinit: message = cast(DeinitMsg, message) - self.on_deinit(message) + await self.on_deinit(message) elif message.type_ == MessageType.ping: - self.identify() + await self.identify() else: # log but don't throw - other nodes shouldn't be able to crash us self.logger.error(f"{message.type_} not implemented!") self.logger.debug("%s", message) - def on_announce(self, msg: AnnounceMsg) -> None: + async def on_announce(self, msg: AnnounceMsg) -> None: """ Store map, connect to the nodes we depend on """ self._node = cast(Node, self._node) self.logger.debug("Processing announce") - with self._status_lock: - depended_nodes = {edge.source_node for edge in self._node.edges} - if depended_nodes: - self.logger.debug("Should subscribe to %s", depended_nodes) - for node_id in msg.value["nodes"]: - if node_id in depended_nodes and node_id not in self._nodes: - # TODO: a way to check if we're already connected, without storing it locally? - outbox = msg.value["nodes"][node_id]["outbox"] - self.logger.debug("Subscribing to %s at %s", node_id, outbox) - self._inbox.connect(outbox) - self.logger.debug("Subscribed to %s at %s", node_id, outbox) - self._nodes = msg.value["nodes"] - if set(self._nodes) >= depended_nodes - {"input"} and self.status == NodeStatus.waiting: - self.update_status(NodeStatus.ready) - # status and announce messages can be received out of order, - # so if we observe the command node being out of sync, we update it. - elif ( - self._node.id in msg.value["nodes"] - and msg.value["nodes"][self._node.id]["status"] != self.status.value - ): - self.update_status(self.status) - def on_event(self, msg: EventMsg) -> None: + depended_nodes = {edge.source_node for edge in self._node.edges} + if depended_nodes: + self.logger.debug("Should subscribe to %s", depended_nodes) + for node_id in msg.value["nodes"]: + if node_id in depended_nodes and node_id not in self._nodes: + # TODO: a way to check if we're already connected, without storing it locally? + outbox = msg.value["nodes"][node_id]["outbox"] + self.logger.debug("Subscribing to %s at %s", node_id, outbox) + self.sockets["inbox"].connect(outbox) + self.logger.debug("Subscribed to %s at %s", node_id, outbox) + self._nodes = msg.value["nodes"] + if set(self._nodes) >= depended_nodes - {"input"} and self.status == NodeStatus.waiting: + await self.update_status(NodeStatus.ready) + # status and announce messages can be received out of order, + # so if we observe the command node being out of sync, we update it. + elif ( + self._node.id in msg.value["nodes"] + and msg.value["nodes"][self._node.id]["status"] != self.status.value + ): + await self.update_status(self.status) + + async def on_event(self, msg: EventMsg) -> None: events = msg.value if not self.depends: self.logger.debug("No dependencies, not storing events") @@ -657,21 +655,24 @@ def on_event(self, msg: EventMsg) -> None: to_add = [e for e in events if (e["node_id"], e["signal"]) in self.depends] for event in to_add: self.store.add(event) + self.logger.debug("scheduler updating") + async with self._ready_condition: + self.scheduler.update(events) + self._ready_condition.notify_all() + self.logger.debug("scheduler updated") - self.scheduler.update(events) - - def on_start(self, msg: StartMsg) -> None: + async def on_start(self, msg: StartMsg) -> None: """ Start running in free mode """ - self.update_status(NodeStatus.running) + await self.update_status(NodeStatus.running) if msg.value is None: self._freerun.set() else: self._to_process += msg.value self._process_one.set() - def on_process(self, msg: ProcessMsg) -> None: + async def on_process(self, msg: ProcessMsg) -> None: """ Process a single graph iteration """ @@ -691,39 +692,44 @@ def on_process(self, msg: ProcessMsg) -> None: value = combined[next(iter(combined.keys()))] else: value = list(combined.values()) - events = self.store.add_value( - [Signal(name=k, type_=None) for k in combined], - value, - node_id="input", - epoch=msg.value["epoch"], - ) - scheduler_events = self.scheduler.update(events) + async with self._ready_condition: + events = self.store.add_value( + [Signal(name=k, type_=None) for k in combined], + value, + node_id="input", + epoch=msg.value["epoch"], + ) + scheduler_events = self.scheduler.update(events) + self._ready_condition.notify_all() + self.logger.debug("Updated scheduler with process events: %s", scheduler_events) - self.logger.debug("Updated scheduler with process events: %s", scheduler_events) self._process_one.set() - def on_stop(self, msg: StopMsg) -> None: + async def on_stop(self, msg: StopMsg) -> None: """Stop processing (but stay responsive)""" self._process_one.clear() self._to_process = 0 self._freerun.clear() - self.update_status(NodeStatus.stopped) + await self.update_status(NodeStatus.stopped) self.logger.debug("Stopped") - def on_deinit(self, msg: DeinitMsg) -> None: + async def on_deinit(self, msg: DeinitMsg) -> None: """ Deinitialize the node, close networking thread. Cause the main loop to end, which calls deinit """ - self._process_quitting.set() + await self.update_status(NodeStatus.closed) + self._quitting.set() + pid = mp.current_process().pid if pid is None: return self.logger.debug("Emitting sigterm to self %s", msg) os.kill(pid, signal.SIGTERM) + raise asyncio.CancelledError() - def error(self, err: Exception) -> None: + async def error(self, err: Exception) -> None: """ Capture the error and traceback context from an exception using :class:`traceback.TracebackException` and send to command node to re-raise @@ -738,7 +744,51 @@ def error(self, err: Exception) -> None: traceback=tbexception, ), ) - self._dealer.send_multipart([msg.to_bytes()]) + await self.sockets["dealer"].send_multipart([msg.to_bytes()]) + + async def await_node(self, epoch: int | None = None) -> MetaEvent: + """ + Block until a node is ready + + Args: + node_id: + epoch (int, None): if `int` , wait until the node is ready in the given epoch, + otherwise wait until the node is ready in any epoch + + Returns: + + """ + async with self._ready_condition: + await self._ready_condition.wait_for( + lambda: self.scheduler.node_is_ready(self.spec.id, epoch) + ) + + # be FIFO-like and get the earliest epoch the node is ready in + if epoch is None: + for ep in self.scheduler._epochs: + if self.scheduler.node_is_ready(self.spec.id, ep): + epoch = ep + break + + if epoch is None: + raise RuntimeError( + "Could not find ready epoch even though node ready condition passed, " + "something is wrong with the way node status checking is " + "locked between threads." + ) + + # mark just one event as "out." + # threadsafe because we are holding the lock that protects graph mutation + self.scheduler[epoch].mark_out(self.spec.id) + + return MetaEvent( + id=uuid4().int, + timestamp=datetime.now(), + node_id="meta", + signal=MetaEventType.NodeReady, + epoch=epoch, + value=self.spec.id, + ) @dataclass @@ -766,6 +816,7 @@ class ZMQRunner(TubeRunner): _return_node: Return | None = None _to_throw: ErrorValue | None = None _current_epoch: int = 0 + _epoch_futures: dict[int, concurrent.futures.Future] = field(default_factory=dict) @property def running(self) -> bool: @@ -783,9 +834,10 @@ def init(self) -> None: with self._init_lock: self._logger.debug("Initializing ZMQ runner") self.command = CommandNode(runner_id=self.runner_id) + threading.Thread(target=self.command.run, daemon=True).start() + self.command._init.wait() self.command.add_callback("inbox", self.on_event) self.command.add_callback("router", self.on_router) - self.command.init() self._logger.debug("Command node initialized") for node_id, node in self.tube.nodes.items(): @@ -875,9 +927,7 @@ def process(self, **kwargs: Any) -> ReturnNodeType: self.command = cast(CommandNode, self.command) self.command.process(self._current_epoch, input) self._logger.debug("awaiting epoch %s", self._current_epoch) - self.tube.scheduler.await_epoch(self._current_epoch) - if self._to_throw: - self._throw_error() + self.await_epoch(self._current_epoch) self._logger.debug("collecting return") return self.collect_return(self._current_epoch) @@ -935,7 +985,7 @@ def iter(self, n: int | None = None) -> Generator[ReturnNodeType, None, None]: loop = 0 while ret is MetaSignal.NoEvent: self._logger.debug("Awaiting epoch %s", epoch) - self.tube.scheduler.await_epoch(epoch) + self.await_epoch(epoch) ret = self.collect_return(epoch) epoch += 1 self._current_epoch = epoch @@ -996,12 +1046,20 @@ def run(self, n: int | None = None) -> None | list[ReturnNodeType]: self._running.set() return None - else: + elif self.tube.has_return: + # run until n return values results = [] for res in self.iter(n): results.append(res) return results + else: + # run n epochs + self.command.start(n) + self._running.set() + self._current_epoch = self.await_epoch(self._current_epoch + n) + return None + def stop(self) -> None: """ Stop running the tube. @@ -1022,7 +1080,8 @@ def on_event(self, msg: Message) -> None: if not self._ignore_events: for event in msg.value: self.store.add(event) - self.tube.scheduler.update(msg.value) + events = self.tube.scheduler.update(msg.value) + events = cast(MutableSequence[Event | MetaEvent], events) if self._return_node is not None: # mark the return node done if we've received the expected events for an epoch # do it here since we don't really run the return node like a real node @@ -1034,7 +1093,17 @@ def on_event(self, msg: Message) -> None: and epoch in self.tube.scheduler._epochs ): self._logger.debug("Marking return node ready in epoch %s", epoch) - self.tube.scheduler.done(epoch, self._return_node.id) + ep_ended = self.tube.scheduler.done(epoch, self._return_node.id) + if ep_ended is not None: + events.append(ep_ended) + for e in events: + if ( + e["node_id"] == "meta" + and e["signal"] == MetaEventType.EpochEnded + and e["value"] in self._epoch_futures + ): + self._epoch_futures[e["value"]].set_result(e["value"]) + del self._epoch_futures[e["value"]] def on_router(self, msg: Message) -> None: if isinstance(msg, ErrorMsg): @@ -1059,36 +1128,21 @@ def collect_return(self, epoch: int | None = None) -> Any: def _handle_error(self, msg: ErrorMsg) -> None: """Cancel current epoch, stash error for process method to throw""" self._logger.error("Received error from node: %s", msg) + exception = msg.to_exception() self._to_throw = msg.value if self._current_epoch is not None: # if we're waiting in the process method, # end epoch and raise error there self.tube.scheduler.end_epoch(self._current_epoch) + self.deinit() + if self._current_epoch in self._epoch_futures: + self._epoch_futures[self._current_epoch].set_exception(exception) + del self._epoch_futures[self._current_epoch] + else: + raise exception else: # e.g. errors during init, raise here. - self._throw_error() - - def _throw_error(self) -> None: - errval = self._to_throw - if errval is None: - return - # clear instance object and store locally, we aren't locked here. - self._to_throw = None - self._logger.debug( - "Deinitializing before throwing error", - ) - self.deinit() - - # add the traceback as a note, - # sort of the best we can do without using tblib - err = errval["err_type"](*errval["err_args"]) - tb_message = "\nError re-raised from node runner process\n\n" - tb_message += "Original traceback:\n" - tb_message += "-" * 20 + "\n" - tb_message += errval["traceback"] - err.add_note(tb_message) - - raise err + raise exception def _request_more(self, n: int, current_iter: int, n_epochs: int) -> int: """ @@ -1128,3 +1182,11 @@ def enable_node(self, node_id: str) -> None: def disable_node(self, node_id: str) -> None: raise NotImplementedError() + + def await_epoch(self, epoch: int) -> int: + if self.tube.scheduler.epoch_completed(epoch): + return epoch + + if epoch not in self._epoch_futures: + self._epoch_futures[epoch] = concurrent.futures.Future() + return self._epoch_futures[epoch].result() diff --git a/src/noob/scheduler.py b/src/noob/scheduler.py index 8c2efff..42e33da 100644 --- a/src/noob/scheduler.py +++ b/src/noob/scheduler.py @@ -3,7 +3,6 @@ from collections.abc import MutableSequence from datetime import UTC, datetime from itertools import count -from threading import Condition from typing import Self from uuid import uuid4 @@ -33,8 +32,6 @@ class Scheduler(BaseModel): _clock: count = PrivateAttr(default_factory=count) _epochs: dict[int, TopoSorter] = PrivateAttr(default_factory=dict) - _ready_condition: Condition = PrivateAttr(default_factory=Condition) - _epoch_condition: Condition = PrivateAttr(default_factory=Condition) _epoch_log: deque = PrivateAttr(default_factory=lambda: deque(maxlen=100)) model_config = ConfigDict(arbitrary_types_allowed=True) @@ -65,23 +62,21 @@ def add_epoch(self, epoch: int | None = None) -> int: """ Add another epoch with a prepared graph to the scheduler. """ - with self._ready_condition: - if epoch is not None: - this_epoch = epoch - # ensure that the next iteration of the clock will return the next number - # if we create epochs out of order - self._clock = count(max([this_epoch, *self._epochs.keys(), *self._epoch_log]) + 1) - else: - this_epoch = next(self._clock) + if epoch is not None: + this_epoch = epoch + # ensure that the next iteration of the clock will return the next number + # if we create epochs out of order + self._clock = count(max([this_epoch, *self._epochs.keys(), *self._epoch_log]) + 1) + else: + this_epoch = next(self._clock) - if this_epoch in self._epochs: - raise EpochExistsError(f"Epoch {this_epoch} is already scheduled") - elif this_epoch in self._epoch_log: - raise EpochCompletedError(f"Epoch {this_epoch} has already been completed!") + if this_epoch in self._epochs: + raise EpochExistsError(f"Epoch {this_epoch} is already scheduled") + elif this_epoch in self._epoch_log: + raise EpochCompletedError(f"Epoch {this_epoch} has already been completed!") - graph = self._init_graph(nodes=self.nodes, edges=self.edges) - self._epochs[this_epoch] = graph - self._ready_condition.notify_all() + graph = self._init_graph(nodes=self.nodes, edges=self.edges) + self._epochs[this_epoch] = graph return this_epoch def is_active(self, epoch: int | None = None) -> bool: @@ -109,20 +104,19 @@ def get_ready(self, epoch: int | None = None) -> list[MetaEvent]: graphs = self._epochs.items() if epoch is None else [(epoch, self._epochs[epoch])] - with self._ready_condition: - ready_nodes = [ - MetaEvent( - id=uuid4().int, - timestamp=datetime.now(), - node_id="meta", - signal=MetaEventType.NodeReady, - epoch=epoch, - value=node_id, - ) - for epoch, graph in graphs - for node_id in graph.get_ready() - if node_id in _VIRTUAL_NODES or self.nodes[node_id].enabled - ] + ready_nodes = [ + MetaEvent( + id=uuid4().int, + timestamp=datetime.now(), + node_id="meta", + signal=MetaEventType.NodeReady, + epoch=epoch, + value=node_id, + ) + for epoch, graph in graphs + for node_id in graph.get_ready() + if node_id in _VIRTUAL_NODES or self.nodes[node_id].enabled + ] return ready_nodes @@ -177,27 +171,20 @@ def update( return events end_events: MutableSequence[MetaEvent] = [] - with self._ready_condition, self._epoch_condition: - marked_done = set() - for e in events: - if (done_marker := (e["epoch"], e["node_id"])) in marked_done or e[ - "node_id" - ] == "meta": - continue - else: - marked_done.add(done_marker) - - if e["signal"] == META_SIGNAL and e["value"] == MetaSignal.NoEvent: - epoch_ended = self.expire(epoch=e["epoch"], node_id=e["node_id"]) - else: - epoch_ended = self.done(epoch=e["epoch"], node_id=e["node_id"]) - - if epoch_ended: - end_events.append(epoch_ended) - - # condition uses an RLock, so waiters only run here, - # even though `done` also notifies. - self._ready_condition.notify_all() + marked_done = set() + for e in events: + if (done_marker := (e["epoch"], e["node_id"])) in marked_done or e["node_id"] == "meta": + continue + else: + marked_done.add(done_marker) + + if e["signal"] == META_SIGNAL and e["value"] == MetaSignal.NoEvent: + epoch_ended = self.expire(epoch=e["epoch"], node_id=e["node_id"]) + else: + epoch_ended = self.done(epoch=e["epoch"], node_id=e["node_id"]) + + if epoch_ended: + end_events.append(epoch_ended) ret_events = [*events, *end_events] @@ -208,27 +195,24 @@ def done(self, epoch: int, node_id: str) -> MetaEvent | None: Mark a node in a given epoch as done. """ - with self._ready_condition, self._epoch_condition: - if epoch in self._epoch_log: - self.logger.debug( - "Marking node %s as done in epoch %s, " - "but epoch was already completed. ignoring", - node_id, - epoch, - ) - return None - - try: - self[epoch].done(node_id) - except NotOutYetError: - # in parallel mode, we don't `get_ready` the preceding ready nodes - # so we have to manually mark them as "out" - self[epoch].mark_out(node_id) - self[epoch].done(node_id) - - self._ready_condition.notify_all() - if not self[epoch].is_active(): - return self.end_epoch(epoch) + if epoch in self._epoch_log: + self.logger.debug( + "Marking node %s as done in epoch %s, " "but epoch was already completed. ignoring", + node_id, + epoch, + ) + return None + + try: + self[epoch].done(node_id) + except NotOutYetError: + # in parallel mode, we don't `get_ready` the preceding ready nodes + # so we have to manually mark them as "out" + self[epoch].mark_out(node_id) + self[epoch].done(node_id) + + if not self[epoch].is_active(): + return self.end_epoch(epoch) return None def expire(self, epoch: int, node_id: str) -> MetaEvent | None: @@ -236,92 +220,23 @@ def expire(self, epoch: int, node_id: str) -> MetaEvent | None: Mark a node as having been completed without making its dependent nodes ready. i.e. when the node emitted ``NoEvent`` """ - with self._ready_condition, self._epoch_condition: - self[epoch].mark_expired(node_id) - self._ready_condition.notify_all() - if not self[epoch].is_active(): - return self.end_epoch(epoch) + self[epoch].mark_expired(node_id) + if not self[epoch].is_active(): + return self.end_epoch(epoch) return None - def await_node(self, node_id: NodeID, epoch: int | None = None) -> MetaEvent: - """ - Block until a node is ready - - Args: - node_id: - epoch (int, None): if `int` , wait until the node is ready in the given epoch, - otherwise wait until the node is ready in any epoch - - Returns: - - """ - with self._ready_condition: - if not self.node_is_ready(node_id, epoch): - self._ready_condition.wait_for(lambda: self.node_is_ready(node_id, epoch)) - - # be FIFO-like and get the earliest epoch the node is ready in - if epoch is None: - for ep in self._epochs: - if self.node_is_ready(node_id, ep): - epoch = ep - break - - if epoch is None: - raise RuntimeError( - "Could not find ready epoch even though node ready condition passed, " - "something is wrong with the way node status checking is " - "locked between threads." - ) - - # mark just one event as "out." - # threadsafe because we are holding the lock that protects graph mutation - self._epochs[epoch].mark_out(node_id) - - return MetaEvent( - id=uuid4().int, - timestamp=datetime.now(), - node_id="meta", - signal=MetaEventType.NodeReady, - epoch=epoch, - value=node_id, - ) - - def await_epoch(self, epoch: int | None = None) -> int: - """ - Block until an epoch is completed. - - Args: - epoch (int, None): if `int` , wait until the epoch is ready, - otherwise wait until the next epoch is finished, in whatever order. - - Returns: - int: the epoch that was completed. - """ - with self._epoch_condition: - # check if we have already completed this epoch - if isinstance(epoch, int) and self.epoch_completed(epoch): - return epoch - - if epoch is None: - self._epoch_condition.wait() - return self._epoch_log[-1] - else: - self._epoch_condition.wait_for(lambda: self.epoch_completed(epoch)) - return epoch - def epoch_completed(self, epoch: int) -> bool: """ Check if the epoch has been completed. """ - with self._epoch_condition: - previously_completed = ( - len(self._epoch_log) > 0 - and epoch not in self._epochs - and (epoch in self._epoch_log or epoch < min(self._epoch_log)) - ) - active_completed = epoch in self._epochs and not self._epochs[epoch].is_active() - return previously_completed or active_completed + previously_completed = ( + len(self._epoch_log) > 0 + and epoch not in self._epochs + and (epoch in self._epoch_log or epoch < min(self._epoch_log)) + ) + active_completed = epoch in self._epochs and not self._epochs[epoch].is_active() + return previously_completed or active_completed def end_epoch(self, epoch: int | None = None) -> MetaEvent | None: if epoch is None or epoch == -1: @@ -329,10 +244,8 @@ def end_epoch(self, epoch: int | None = None) -> MetaEvent | None: return None epoch = list(self._epochs)[-1] - with self._epoch_condition: - self._epoch_condition.notify_all() - self._epoch_log.append(epoch) - del self._epochs[epoch] + self._epoch_log.append(epoch) + del self._epochs[epoch] return MetaEvent( id=uuid4().int, diff --git a/src/noob/testing/nodes.py b/src/noob/testing/nodes.py index 1996f88..35b8a08 100644 --- a/src/noob/testing/nodes.py +++ b/src/noob/testing/nodes.py @@ -14,7 +14,9 @@ from noob.node import Node -def count_source(limit: int = 1000, start: int = 0) -> Generator[A[int, Name("index")], None, None]: +def count_source( + limit: int = 10000, start: int = 0 +) -> Generator[A[int, Name("index")], None, None]: counter = count(start=start) if limit == 0: while True: @@ -156,8 +158,8 @@ def input_party( return True -def long_add(value: float) -> float: - sleep(0.25) +def long_add(value: float, sleep_for: float = 0.25) -> float: + sleep(sleep_for) return value + 1 diff --git a/src/noob/tube.py b/src/noob/tube.py index 592fd6f..c003ead 100644 --- a/src/noob/tube.py +++ b/src/noob/tube.py @@ -238,6 +238,11 @@ def from_specification( context={"skip_input_presence": True}, ) + @property + def has_return(self) -> bool: + """Whether the tube has a :class:`.Return` node present""" + return any(isinstance(node, Return) for node in self.nodes.values()) + @classmethod def _init_nodes( cls, specs: TubeSpecification, input_collection: InputCollection diff --git a/tests/bench.py b/tests/bench.py index 8e1097a..57bc359 100644 --- a/tests/bench.py +++ b/tests/bench.py @@ -3,7 +3,6 @@ from noob import Tube from noob.runner.base import TubeRunner -from noob.runner.zmq import ZMQRunner def test_load_tube(benchmark: BenchmarkFixture) -> None: @@ -17,8 +16,6 @@ def test_kitchen_sink_process(benchmark: BenchmarkFixture, runner: TubeRunner) - @pytest.mark.parametrize("loaded_tube", ["testing-kitchen-sink"], indirect=True) def test_kitchen_sink_run(benchmark: BenchmarkFixture, runner: TubeRunner) -> None: - if isinstance(runner, ZMQRunner): - pytest.skip("ZMQ runner freerun mode not supported yet") benchmark(lambda: runner.run(n=10)) diff --git a/tests/test_runners/test_zmq.py b/tests/test_runners/test_zmq.py index f49f211..abd3e21 100644 --- a/tests/test_runners/test_zmq.py +++ b/tests/test_runners/test_zmq.py @@ -346,7 +346,8 @@ def test_iter_gather(mocker): assert spy.spy_return == 50 -def test_noderunner_stores_clear(): +@pytest.mark.asyncio +async def test_noderunner_stores_clear(): """ Stores in the noderunners should clear after they use the events from an epoch """ @@ -362,7 +363,7 @@ def test_noderunner_stores_clear(): command_router="/notreal/unused", input_collection=InputCollection(), ) - runner.init_node() + await runner.init_node() # fake a few events events = [] @@ -389,13 +390,14 @@ def test_noderunner_stores_clear(): ), ], ) - runner.on_event(msg) + await runner.on_event(msg) events.append(msg) runner._freerun.set() assert len(runner.store.events) == 3 - args, kwargs, epoch = next(runner.await_inputs()) + _, _, epoch = await anext(runner.await_inputs()) assert len(runner.store.events) == 2 + assert epoch != -1 assert epoch not in runner.store.events