From ebafcdd140f5f7cf1e10485e8ff06e21298e987e Mon Sep 17 00:00:00 2001 From: sneakers-the-rat Date: Mon, 26 Jan 2026 20:27:53 -0800 Subject: [PATCH 01/13] begin refactoring noderunner to async --- src/noob/logging.py | 3 +- src/noob/runner/zmq.py | 254 +++++++++++++++++++++-------------------- src/noob/scheduler.py | 218 +++++++++++++++++++---------------- 3 files changed, 250 insertions(+), 225 deletions(-) diff --git a/src/noob/logging.py b/src/noob/logging.py index 54a27191..55519d05 100644 --- a/src/noob/logging.py +++ b/src/noob/logging.py @@ -4,6 +4,7 @@ import logging import multiprocessing as mp +import sys from logging.handlers import RotatingFileHandler from pathlib import Path from typing import Any, Literal @@ -193,5 +194,5 @@ def _get_console() -> Console: current_pid = mp.current_process().pid console = _console_by_pid.get(current_pid) if console is None: - _console_by_pid[current_pid] = console = Console() + _console_by_pid[current_pid] = console = Console(file=sys.stdout) return console diff --git a/src/noob/runner/zmq.py b/src/noob/runner/zmq.py index ac85991d..ca734536 100644 --- a/src/noob/runner/zmq.py +++ b/src/noob/runner/zmq.py @@ -22,6 +22,7 @@ """ +import asyncio import math import multiprocessing as mp import os @@ -29,7 +30,7 @@ import threading import traceback from collections import defaultdict -from collections.abc import Callable, Generator +from collections.abc import Callable, Generator, AsyncGenerator from dataclasses import dataclass, field from itertools import count from multiprocessing.synchronize import Event as EventType @@ -46,7 +47,7 @@ "Attempted to import zmq runner, but zmq deps are not installed. install with `noob[zmq]`", ) from e - +from zmq.asyncio import Context, Poller, Socket from zmq.eventloop.zmqstream import ZMQStream from noob.config import config @@ -73,8 +74,8 @@ StopMsg, ) from noob.node import Node, NodeSpecification, Return, Signal -from noob.runner.base import TubeRunner, call_async_from_sync -from noob.scheduler import Scheduler +from noob.runner.base import TubeRunner +from noob.scheduler import AsyncScheduler from noob.store import EventStore from noob.types import NodeID, ReturnNodeType from noob.utils import iscoroutinefunction_partial @@ -311,7 +312,7 @@ def on_status(self, msg: StatusMsg) -> None: self._ready_condition.notify_all() -class NodeRunner(EventloopMixin): +class NodeRunner: """ Runner for a single node @@ -329,7 +330,7 @@ def __init__( input_collection: InputCollection, protocol: str = "ipc", ): - super().__init__() + self.context = Context.instance() self.spec = spec self.runner_id = runner_id self.input_collection = input_collection @@ -337,22 +338,23 @@ def __init__( self.command_router = command_router self.protocol = protocol self.store = EventStore() - self.scheduler: Scheduler = None # type: ignore[assignment] + self.scheduler: AsyncScheduler = None # type: ignore[assignment] self.logger = init_logger(f"runner.node.{runner_id}.{self.spec.id}") - self._dealer: ZMQStream = None # type: ignore[assignment] - self._outbox: zmq.Socket = None # type: ignore[assignment] - self._inbox: ZMQStream = None # type: ignore[assignment] + self._dealer: Socket = None # type: ignore[assignment] + self._outbox: Socket = None # type: ignore[assignment] + self._inbox: Socket = None # type: ignore[assignment] + self._inbox_poller: Poller = Poller() self._node: Node | None = None self._depends: tuple[tuple[str, str], ...] | None = None self._has_input: bool | None = None self._nodes: dict[str, IdentifyValue] = {} self._counter = count() - self._process_quitting = mp.Event() - self._freerun = mp.Event() - self._process_one = mp.Event() + self._process_quitting = asyncio.Event() + self._freerun = asyncio.Event() + self._process_one = asyncio.Event() self._status: NodeStatus = NodeStatus.stopped - self._status_lock = mp.RLock() + self._status_lock = asyncio.Lock() self._to_process = 0 @property @@ -386,13 +388,11 @@ def has_input(self) -> bool: @property def status(self) -> NodeStatus: - with self._status_lock: - return self._status + return self._status @status.setter def status(self, status: NodeStatus) -> None: - with self._status_lock: - self._status = status + self._status = status @classmethod def run(cls, spec: NodeSpecification, **kwargs: Any) -> None: @@ -401,6 +401,10 @@ def run(cls, spec: NodeSpecification, **kwargs: Any) -> None: init the class and start it! """ runner = NodeRunner(spec=spec, **kwargs) + asyncio.run(runner._run()) + + async def _run(self) -> None: + try: def _handler(sig: int, frame: FrameType | None = None) -> None: @@ -408,40 +412,46 @@ def _handler(sig: int, frame: FrameType | None = None) -> None: raise KeyboardInterrupt() signal.signal(signal.SIGTERM, _handler) - runner.init() - runner._node = cast(Node, runner._node) - runner._process_quitting.clear() - runner._freerun.clear() - runner._process_one.clear() - - is_async = iscoroutinefunction_partial(runner._node.process) - - for args, kwargs, epoch in runner.await_inputs(): - runner.logger.debug( - "Running with args: %s, kwargs: %s, epoch: %s", args, kwargs, epoch - ) - if is_async: - # mypy fails here because it can't propagate the type guard above - value = call_async_from_sync(runner._node.process, *args, **kwargs) # type: ignore[arg-type] - else: - value = runner._node.process(*args, **kwargs) - events = runner.store.add_value(runner._node.signals, value, runner._node.id, epoch) - runner.scheduler.add_epoch() - - # node runners should not report epoch endings - events = [e for e in events if e["node_id"] != "meta"] - if events: - runner.update_graph(events) - runner.publish_events(events) - + await self.init() + self._node = cast(Node, self._node) + self._process_quitting.clear() + self._freerun.clear() + self._process_one.clear() + await asyncio.gather(self._poll_inbox(), self._loop()) except KeyboardInterrupt: - runner.logger.debug("Got keyboard interrupt, quitting") + self.logger.debug("Got keyboard interrupt, quitting") except Exception as e: - runner.error(e) + await self.error(e) finally: - runner.deinit() + await self.deinit() + + async def _loop(self) -> None: + is_async = iscoroutinefunction_partial(self._node.process) + + async for args, kwargs, epoch in self.await_inputs(): + self.logger.debug("Running with args: %s, kwargs: %s, epoch: %s", args, kwargs, epoch) + if is_async: + # mypy fails here because it can't propagate the type guard above + value = await self._node.process(*args, **kwargs) # type: ignore[arg-type] + else: + value = self._node.process(*args, **kwargs) + events = self.store.add_value(self._node.signals, value, self._node.id, epoch) + await self.scheduler.add_epoch() + + # nodes should not report epoch endings since they don't know about the full tube + events = [e for e in events if e["node_id"] != "meta"] + if events: + await self.update_graph(events) + await self.publish_events(events) + + async def _poll_inbox(self) -> None: + while not self._process_quitting.is_set(): + events = await self._inbox_poller.poll() + if self._inbox in dict(events): + msg = await self._inbox.recv_multipart() + await self.on_inbox(msg) - def await_inputs(self) -> Generator[tuple[tuple[Any], dict[str, Any], int]]: + async def await_inputs(self) -> AsyncGenerator[tuple[tuple[Any], dict[str, Any], int]]: self._node = cast(Node, self._node) while not self._process_quitting.is_set(): # if we are not freerunning, keep track of how many times we are supposed to run, @@ -449,7 +459,7 @@ def await_inputs(self) -> Generator[tuple[tuple[Any], dict[str, Any], int]]: if not self._freerun.is_set(): if self._to_process <= 0: self._to_process = 0 - self._process_one.wait() + await self._process_one.wait() self._to_process -= 1 if self._to_process <= 0: self._to_process = 0 @@ -457,7 +467,7 @@ def await_inputs(self) -> Generator[tuple[tuple[Any], dict[str, Any], int]]: epoch = next(self._counter) if self._node.stateful else None - ready = self.scheduler.await_node(self.spec.id, epoch=epoch) + ready = await self.scheduler.await_node(self.spec.id, epoch=epoch) edges = self._node.edges inputs = self.store.collect(edges, ready["epoch"]) if inputs is None: @@ -467,35 +477,35 @@ def await_inputs(self) -> Generator[tuple[tuple[Any], dict[str, Any], int]]: self.store.clear(ready["epoch"]) yield args, kwargs, ready["epoch"] - def update_graph(self, events: list[Event]) -> None: - self.scheduler.update(events) + async def update_graph(self, events: list[Event]) -> None: + await self.scheduler.update(events) - def publish_events(self, events: list[Event]) -> None: + async def publish_events(self, events: list[Event]) -> None: msg = EventMsg(node_id=self.spec.id, value=events) - self._outbox.send_multipart([b"event", msg.to_bytes()]) + await self._outbox.send_multipart([b"event", msg.to_bytes()]) - def init(self) -> None: + async def init(self) -> None: self.logger.debug("Initializing") - self.init_node() - self.start_sockets() + await self.init_node() + self._init_sockets() self.status = ( NodeStatus.waiting if self.depends and [d for d in self.depends if d[0] != "input"] else NodeStatus.ready ) - self.identify() + await self.identify() self.logger.debug("Initialization finished") - def deinit(self) -> None: + async def deinit(self) -> None: self.logger.debug("Deinitializing") if self._node is not None: self._node.deinit() - self.update_status(NodeStatus.closed) - self.stop_loop() + await self.update_status(NodeStatus.closed) + self.logger.debug("Deinitialization finished") - def identify(self) -> None: + async def identify(self) -> None: """ Send the command node an announce to say we're alive """ @@ -506,7 +516,7 @@ def identify(self) -> None: ) self.logger.debug("Identifying") - with self._status_lock: + async with self._status_lock: ann = IdentifyMsg( node_id=self.spec.id, value=IdentifyValue( @@ -519,43 +529,37 @@ def identify(self) -> None: ), ), ) - self._dealer.send_multipart([ann.to_bytes()]) + await self._dealer.send_multipart([ann.to_bytes()]) self.logger.debug("Sent identification message: %s", ann) - def update_status(self, status: NodeStatus) -> None: + async def update_status(self, status: NodeStatus) -> None: """Update our internal status and announce it to the command node""" self.logger.debug("Updating status as %s", status) - with self._status_lock: + async with self._status_lock: self.status = status msg = StatusMsg(node_id=self.spec.id, value=status) - self._dealer.send_multipart([msg.to_bytes()]) + await self._dealer.send_multipart([msg.to_bytes()]) self.logger.debug("Updated status") - def start_sockets(self) -> None: - self.start_loop() - self._init_sockets() - - def init_node(self) -> None: + async def init_node(self) -> None: self._node = Node.from_specification(self.spec, self.input_collection) self._node.init() - self.scheduler = Scheduler(nodes={self.spec.id: self.spec}, edges=self._node.edges) - self.scheduler.add_epoch() + self.scheduler = AsyncScheduler(nodes={self.spec.id: self.spec}, edges=self._node.edges) + await self.scheduler.add_epoch() def _init_sockets(self) -> None: self._dealer = self._init_dealer() self._outbox = self._init_outbox() self._inbox = self._init_inbox() - def _init_dealer(self) -> ZMQStream: + def _init_dealer(self) -> Socket: dealer = self.context.socket(zmq.DEALER) dealer.setsockopt_string(zmq.IDENTITY, self.spec.id) dealer.connect(self.command_router) - dealer = ZMQStream(dealer, self.loop) - dealer.on_recv(self.on_dealer) self.logger.debug("Connected to command node at %s", self.command_router) return dealer - def _init_outbox(self) -> zmq.Socket: + def _init_outbox(self) -> Socket: pub = self.context.socket(zmq.PUB) pub.setsockopt_string(zmq.IDENTITY, self.spec.id) if self.protocol == "ipc": @@ -567,7 +571,7 @@ def _init_outbox(self) -> zmq.Socket: return pub - def _init_inbox(self) -> ZMQStream: + def _init_inbox(self) -> Socket: """ Init the subscriber, but don't attempt to subscribe to anything but the command yet! we do that when we get node Announces @@ -576,15 +580,11 @@ def _init_inbox(self) -> ZMQStream: sub.setsockopt_string(zmq.IDENTITY, self.spec.id) sub.setsockopt_string(zmq.SUBSCRIBE, "") sub.connect(self.command_outbox) - sub = ZMQStream(sub, self.loop) - sub.on_recv(self.on_inbox) + self._inbox_poller.register(sub, zmq.POLLIN) self.logger.debug("Subscribed to command outbox %s", self.command_outbox) return sub - def on_dealer(self, msg: list[bytes]) -> None: - self.logger.debug("DEALER received %s", msg) - - def on_inbox(self, msg: list[bytes]) -> None: + async def on_inbox(self, msg: list[bytes]) -> None: try: message = Message.from_bytes(msg) @@ -597,58 +597,58 @@ def on_inbox(self, msg: list[bytes]) -> None: # just have a decorator to register a handler for a given message type if message.type_ == MessageType.announce: message = cast(AnnounceMsg, message) - self.on_announce(message) + await self.on_announce(message) elif message.type_ == MessageType.event: message = cast(EventMsg, message) - self.on_event(message) + await self.on_event(message) elif message.type_ == MessageType.process: message = cast(ProcessMsg, message) - self.on_process(message) + await self.on_process(message) elif message.type_ == MessageType.start: message = cast(StartMsg, message) - self.on_start(message) + await self.on_start(message) elif message.type_ == MessageType.stop: message = cast(StopMsg, message) - self.on_stop(message) + await self.on_stop(message) elif message.type_ == MessageType.deinit: message = cast(DeinitMsg, message) - self.on_deinit(message) + await self.on_deinit(message) elif message.type_ == MessageType.ping: - self.identify() + await self.identify() else: # log but don't throw - other nodes shouldn't be able to crash us self.logger.error(f"{message.type_} not implemented!") self.logger.debug("%s", message) - def on_announce(self, msg: AnnounceMsg) -> None: + async def on_announce(self, msg: AnnounceMsg) -> None: """ Store map, connect to the nodes we depend on """ self._node = cast(Node, self._node) self.logger.debug("Processing announce") - with self._status_lock: - depended_nodes = {edge.source_node for edge in self._node.edges} - if depended_nodes: - self.logger.debug("Should subscribe to %s", depended_nodes) - for node_id in msg.value["nodes"]: - if node_id in depended_nodes and node_id not in self._nodes: - # TODO: a way to check if we're already connected, without storing it locally? - outbox = msg.value["nodes"][node_id]["outbox"] - self.logger.debug("Subscribing to %s at %s", node_id, outbox) - self._inbox.connect(outbox) - self.logger.debug("Subscribed to %s at %s", node_id, outbox) - self._nodes = msg.value["nodes"] - if set(self._nodes) >= depended_nodes - {"input"} and self.status == NodeStatus.waiting: - self.update_status(NodeStatus.ready) - # status and announce messages can be received out of order, - # so if we observe the command node being out of sync, we update it. - elif ( - self._node.id in msg.value["nodes"] - and msg.value["nodes"][self._node.id]["status"] != self.status.value - ): - self.update_status(self.status) - - def on_event(self, msg: EventMsg) -> None: + + depended_nodes = {edge.source_node for edge in self._node.edges} + if depended_nodes: + self.logger.debug("Should subscribe to %s", depended_nodes) + for node_id in msg.value["nodes"]: + if node_id in depended_nodes and node_id not in self._nodes: + # TODO: a way to check if we're already connected, without storing it locally? + outbox = msg.value["nodes"][node_id]["outbox"] + self.logger.debug("Subscribing to %s at %s", node_id, outbox) + self._inbox.connect(outbox) + self.logger.debug("Subscribed to %s at %s", node_id, outbox) + self._nodes = msg.value["nodes"] + if set(self._nodes) >= depended_nodes - {"input"} and self.status == NodeStatus.waiting: + await self.update_status(NodeStatus.ready) + # status and announce messages can be received out of order, + # so if we observe the command node being out of sync, we update it. + elif ( + self._node.id in msg.value["nodes"] + and msg.value["nodes"][self._node.id]["status"] != self.status.value + ): + await self.update_status(self.status) + + async def on_event(self, msg: EventMsg) -> None: events = msg.value if not self.depends: self.logger.debug("No dependencies, not storing events") @@ -658,20 +658,20 @@ def on_event(self, msg: EventMsg) -> None: for event in to_add: self.store.add(event) - self.scheduler.update(events) + await self.scheduler.update(events) - def on_start(self, msg: StartMsg) -> None: + async def on_start(self, msg: StartMsg) -> None: """ Start running in free mode """ - self.update_status(NodeStatus.running) + await self.update_status(NodeStatus.running) if msg.value is None: self._freerun.set() else: self._to_process += msg.value self._process_one.set() - def on_process(self, msg: ProcessMsg) -> None: + async def on_process(self, msg: ProcessMsg) -> None: """ Process a single graph iteration """ @@ -697,20 +697,20 @@ def on_process(self, msg: ProcessMsg) -> None: node_id="input", epoch=msg.value["epoch"], ) - scheduler_events = self.scheduler.update(events) + scheduler_events = await self.scheduler.update(events) self.logger.debug("Updated scheduler with process events: %s", scheduler_events) self._process_one.set() - def on_stop(self, msg: StopMsg) -> None: + async def on_stop(self, msg: StopMsg) -> None: """Stop processing (but stay responsive)""" self._process_one.clear() self._to_process = 0 self._freerun.clear() - self.update_status(NodeStatus.stopped) + await self.update_status(NodeStatus.stopped) self.logger.debug("Stopped") - def on_deinit(self, msg: DeinitMsg) -> None: + async def on_deinit(self, msg: DeinitMsg) -> None: """ Deinitialize the node, close networking thread. @@ -722,8 +722,9 @@ def on_deinit(self, msg: DeinitMsg) -> None: return self.logger.debug("Emitting sigterm to self %s", msg) os.kill(pid, signal.SIGTERM) + raise asyncio.CancelledError() - def error(self, err: Exception) -> None: + async def error(self, err: Exception) -> None: """ Capture the error and traceback context from an exception using :class:`traceback.TracebackException` and send to command node to re-raise @@ -738,7 +739,7 @@ def error(self, err: Exception) -> None: traceback=tbexception, ), ) - self._dealer.send_multipart([msg.to_bytes()]) + await self._dealer.send_multipart([msg.to_bytes()]) @dataclass @@ -936,6 +937,7 @@ def iter(self, n: int | None = None) -> Generator[ReturnNodeType, None, None]: while ret is MetaSignal.NoEvent: self._logger.debug("Awaiting epoch %s", epoch) self.tube.scheduler.await_epoch(epoch) + self._logger.debug("epoch %s completed", epoch) ret = self.collect_return(epoch) epoch += 1 self._current_epoch = epoch diff --git a/src/noob/scheduler.py b/src/noob/scheduler.py index 8c2efffa..d0a3dc2c 100644 --- a/src/noob/scheduler.py +++ b/src/noob/scheduler.py @@ -1,3 +1,4 @@ +import asyncio import logging from collections import deque from collections.abc import MutableSequence @@ -33,7 +34,6 @@ class Scheduler(BaseModel): _clock: count = PrivateAttr(default_factory=count) _epochs: dict[int, TopoSorter] = PrivateAttr(default_factory=dict) - _ready_condition: Condition = PrivateAttr(default_factory=Condition) _epoch_condition: Condition = PrivateAttr(default_factory=Condition) _epoch_log: deque = PrivateAttr(default_factory=lambda: deque(maxlen=100)) @@ -65,23 +65,21 @@ def add_epoch(self, epoch: int | None = None) -> int: """ Add another epoch with a prepared graph to the scheduler. """ - with self._ready_condition: - if epoch is not None: - this_epoch = epoch - # ensure that the next iteration of the clock will return the next number - # if we create epochs out of order - self._clock = count(max([this_epoch, *self._epochs.keys(), *self._epoch_log]) + 1) - else: - this_epoch = next(self._clock) + if epoch is not None: + this_epoch = epoch + # ensure that the next iteration of the clock will return the next number + # if we create epochs out of order + self._clock = count(max([this_epoch, *self._epochs.keys(), *self._epoch_log]) + 1) + else: + this_epoch = next(self._clock) - if this_epoch in self._epochs: - raise EpochExistsError(f"Epoch {this_epoch} is already scheduled") - elif this_epoch in self._epoch_log: - raise EpochCompletedError(f"Epoch {this_epoch} has already been completed!") + if this_epoch in self._epochs: + raise EpochExistsError(f"Epoch {this_epoch} is already scheduled") + elif this_epoch in self._epoch_log: + raise EpochCompletedError(f"Epoch {this_epoch} has already been completed!") - graph = self._init_graph(nodes=self.nodes, edges=self.edges) - self._epochs[this_epoch] = graph - self._ready_condition.notify_all() + graph = self._init_graph(nodes=self.nodes, edges=self.edges) + self._epochs[this_epoch] = graph return this_epoch def is_active(self, epoch: int | None = None) -> bool: @@ -109,20 +107,19 @@ def get_ready(self, epoch: int | None = None) -> list[MetaEvent]: graphs = self._epochs.items() if epoch is None else [(epoch, self._epochs[epoch])] - with self._ready_condition: - ready_nodes = [ - MetaEvent( - id=uuid4().int, - timestamp=datetime.now(), - node_id="meta", - signal=MetaEventType.NodeReady, - epoch=epoch, - value=node_id, - ) - for epoch, graph in graphs - for node_id in graph.get_ready() - if node_id in _VIRTUAL_NODES or self.nodes[node_id].enabled - ] + ready_nodes = [ + MetaEvent( + id=uuid4().int, + timestamp=datetime.now(), + node_id="meta", + signal=MetaEventType.NodeReady, + epoch=epoch, + value=node_id, + ) + for epoch, graph in graphs + for node_id in graph.get_ready() + if node_id in _VIRTUAL_NODES or self.nodes[node_id].enabled + ] return ready_nodes @@ -143,6 +140,7 @@ def node_is_ready(self, node: NodeID, epoch: int | None = None) -> bool: graphs = self._epochs.items() if epoch is None else [(epoch, self[epoch])] is_ready = any(node_id == node for epoch, graph in graphs for node_id in graph.ready_nodes) + self.logger.debug("Node %s ready status in epoch %s: %s", node, epoch, is_ready) return is_ready def __getitem__(self, epoch: int) -> TopoSorter: @@ -177,27 +175,20 @@ def update( return events end_events: MutableSequence[MetaEvent] = [] - with self._ready_condition, self._epoch_condition: - marked_done = set() - for e in events: - if (done_marker := (e["epoch"], e["node_id"])) in marked_done or e[ - "node_id" - ] == "meta": - continue - else: - marked_done.add(done_marker) - - if e["signal"] == META_SIGNAL and e["value"] == MetaSignal.NoEvent: - epoch_ended = self.expire(epoch=e["epoch"], node_id=e["node_id"]) - else: - epoch_ended = self.done(epoch=e["epoch"], node_id=e["node_id"]) - - if epoch_ended: - end_events.append(epoch_ended) - - # condition uses an RLock, so waiters only run here, - # even though `done` also notifies. - self._ready_condition.notify_all() + marked_done = set() + for e in events: + if (done_marker := (e["epoch"], e["node_id"])) in marked_done or e["node_id"] == "meta": + continue + else: + marked_done.add(done_marker) + + if e["signal"] == META_SIGNAL and e["value"] == MetaSignal.NoEvent: + epoch_ended = self.expire(epoch=e["epoch"], node_id=e["node_id"]) + else: + epoch_ended = self.done(epoch=e["epoch"], node_id=e["node_id"]) + + if epoch_ended: + end_events.append(epoch_ended) ret_events = [*events, *end_events] @@ -208,7 +199,7 @@ def done(self, epoch: int, node_id: str) -> MetaEvent | None: Mark a node in a given epoch as done. """ - with self._ready_condition, self._epoch_condition: + with self._epoch_condition: if epoch in self._epoch_log: self.logger.debug( "Marking node %s as done in epoch %s, " @@ -226,7 +217,6 @@ def done(self, epoch: int, node_id: str) -> MetaEvent | None: self[epoch].mark_out(node_id) self[epoch].done(node_id) - self._ready_condition.notify_all() if not self[epoch].is_active(): return self.end_epoch(epoch) return None @@ -236,57 +226,13 @@ def expire(self, epoch: int, node_id: str) -> MetaEvent | None: Mark a node as having been completed without making its dependent nodes ready. i.e. when the node emitted ``NoEvent`` """ - with self._ready_condition, self._epoch_condition: + with self._epoch_condition: self[epoch].mark_expired(node_id) - self._ready_condition.notify_all() if not self[epoch].is_active(): return self.end_epoch(epoch) return None - def await_node(self, node_id: NodeID, epoch: int | None = None) -> MetaEvent: - """ - Block until a node is ready - - Args: - node_id: - epoch (int, None): if `int` , wait until the node is ready in the given epoch, - otherwise wait until the node is ready in any epoch - - Returns: - - """ - with self._ready_condition: - if not self.node_is_ready(node_id, epoch): - self._ready_condition.wait_for(lambda: self.node_is_ready(node_id, epoch)) - - # be FIFO-like and get the earliest epoch the node is ready in - if epoch is None: - for ep in self._epochs: - if self.node_is_ready(node_id, ep): - epoch = ep - break - - if epoch is None: - raise RuntimeError( - "Could not find ready epoch even though node ready condition passed, " - "something is wrong with the way node status checking is " - "locked between threads." - ) - - # mark just one event as "out." - # threadsafe because we are holding the lock that protects graph mutation - self._epochs[epoch].mark_out(node_id) - - return MetaEvent( - id=uuid4().int, - timestamp=datetime.now(), - node_id="meta", - signal=MetaEventType.NodeReady, - epoch=epoch, - value=node_id, - ) - def await_epoch(self, epoch: int | None = None) -> int: """ Block until an epoch is completed. @@ -415,3 +361,79 @@ def generations(self) -> list[tuple[str, ...]]: generations.append(ready) sorter.done(*ready) return generations + + +class AsyncScheduler(Scheduler): + _ready_condition: asyncio.Condition = PrivateAttr(default_factory=asyncio.Condition) + + async def add_epoch(self, epoch: int | None = None) -> int: + async with self._ready_condition: + this_epoch = super().add_epoch(epoch) + self._ready_condition.notify_all() + return this_epoch + + def __getitem__(self, epoch: int) -> TopoSorter: + if epoch == -1: + return self._epochs[max(self._epochs.keys())] + + if epoch not in self._epochs: + asyncio.run_coroutine_threadsafe( + self.add_epoch(epoch), asyncio.get_running_loop() + ).result() + return self._epochs[epoch] + + async def update( + self, events: MutableSequence[Event | MetaEvent] | MutableSequence[Event] + ) -> MutableSequence[Event] | MutableSequence[Event | MetaEvent]: + async with self._ready_condition: + ret_events = super().update(events) + self._ready_condition.notify_all() + return ret_events + + # async def done(self, epoch: int, node_id: str) -> MetaEvent | None: + # async with self._ready_condition: + # ret = super().done(epoch, node_id) + # self._ready_condition.notify_all() + # return ret + + async def await_node(self, node_id: NodeID, epoch: int | None = None) -> MetaEvent: + """ + Block until a node is ready + + Args: + node_id: + epoch (int, None): if `int` , wait until the node is ready in the given epoch, + otherwise wait until the node is ready in any epoch + + Returns: + + """ + async with self._ready_condition: + await self._ready_condition.wait_for(lambda: self.node_is_ready(node_id, epoch)) + + # be FIFO-like and get the earliest epoch the node is ready in + if epoch is None: + for ep in self._epochs: + if self.node_is_ready(node_id, ep): + epoch = ep + break + + if epoch is None: + raise RuntimeError( + "Could not find ready epoch even though node ready condition passed, " + "something is wrong with the way node status checking is " + "locked between threads." + ) + + # mark just one event as "out." + # threadsafe because we are holding the lock that protects graph mutation + self._epochs[epoch].mark_out(node_id) + + return MetaEvent( + id=uuid4().int, + timestamp=datetime.now(), + node_id="meta", + signal=MetaEventType.NodeReady, + epoch=epoch, + value=node_id, + ) From 01441650ead6ea2ad04ef5a507294361fdc40c29 Mon Sep 17 00:00:00 2001 From: sneakers-the-rat Date: Tue, 27 Jan 2026 17:57:13 -0800 Subject: [PATCH 02/13] messy but running refactor to fully async node runners --- src/noob/network/message.py | 9 ++ src/noob/runner/base.py | 4 + src/noob/runner/zmq.py | 204 +++++++++++++++++++++--------- src/noob/scheduler.py | 224 +++++++++++++++------------------ src/noob/testing/nodes.py | 6 +- src/noob/tube.py | 5 + tests/bench.py | 15 +++ tests/test_runners/test_zmq.py | 13 +- 8 files changed, 290 insertions(+), 190 deletions(-) diff --git a/src/noob/network/message.py b/src/noob/network/message.py index 339986ec..590a7d87 100644 --- a/src/noob/network/message.py +++ b/src/noob/network/message.py @@ -172,6 +172,15 @@ class ErrorMsg(Message): model_config = ConfigDict(arbitrary_types_allowed=True) + def to_exception(self) -> Exception: + err = self.value["err_type"](*self.value["err_args"]) + tb_message = "\nError re-raised from node runner process\n\n" + tb_message += "Original traceback:\n" + tb_message += "-" * 20 + "\n" + tb_message += self.value["traceback"] + err.add_note(tb_message) + return err + def _to_json(val: Event, handler: SerializerFunctionWrapHandler) -> Any: if val["signal"] == META_SIGNAL and val["value"] is MetaSignal.NoEvent: diff --git a/src/noob/runner/base.py b/src/noob/runner/base.py index 876dd65c..3ec11ea2 100644 --- a/src/noob/runner/base.py +++ b/src/noob/runner/base.py @@ -526,6 +526,10 @@ def call_async_from_sync( References: * https://github.com/django/asgiref/blob/2b28409ab83b3e4cf6fed9019403b71f8d7d1c51/asgiref/sync.py#L152 * https://stackoverflow.com/questions/79663750/call-async-code-inside-sync-code-inside-async-code + * https://github.com/python/cpython/issues/66435 + * https://github.com/python/cpython/issues/93462 + * https://discuss.python.org/t/support-for-running-async-functions-in-sync-functions/16220/3 + * https://github.com/fsspec/filesystem_spec/blob/2576617e5cbe441bcc53b021bccd85ff3489fde7/fsspec/asyn.py#L63 """ if not iscoroutinefunction_partial(fn): raise RuntimeError( diff --git a/src/noob/runner/zmq.py b/src/noob/runner/zmq.py index ca734536..471cca40 100644 --- a/src/noob/runner/zmq.py +++ b/src/noob/runner/zmq.py @@ -23,12 +23,15 @@ """ import asyncio +import concurrent.futures import math import multiprocessing as mp import os +from datetime import datetime import signal import threading import traceback +from functools import partial from collections import defaultdict from collections.abc import Callable, Generator, AsyncGenerator from dataclasses import dataclass, field @@ -37,6 +40,7 @@ from time import time from types import FrameType from typing import TYPE_CHECKING, Any, Literal, cast, overload +from uuid import uuid4 from noob.network.loop import EventloopMixin @@ -51,7 +55,7 @@ from zmq.eventloop.zmqstream import ZMQStream from noob.config import config -from noob.event import Event, MetaSignal +from noob.event import Event, MetaSignal, MetaEvent, MetaEventType from noob.exceptions import InputMissingError from noob.input import InputCollection, InputScope from noob.logging import init_logger @@ -75,7 +79,7 @@ ) from noob.node import Node, NodeSpecification, Return, Signal from noob.runner.base import TubeRunner -from noob.scheduler import AsyncScheduler +from noob.scheduler import Scheduler from noob.store import EventStore from noob.types import NodeID, ReturnNodeType from noob.utils import iscoroutinefunction_partial @@ -338,7 +342,7 @@ def __init__( self.command_router = command_router self.protocol = protocol self.store = EventStore() - self.scheduler: AsyncScheduler = None # type: ignore[assignment] + self.scheduler: Scheduler = None # type: ignore[assignment] self.logger = init_logger(f"runner.node.{runner_id}.{self.spec.id}") self._dealer: Socket = None # type: ignore[assignment] @@ -355,6 +359,7 @@ def __init__( self._process_one = asyncio.Event() self._status: NodeStatus = NodeStatus.stopped self._status_lock = asyncio.Lock() + self._ready_condition = asyncio.Condition() self._to_process = 0 @property @@ -400,8 +405,13 @@ def run(cls, spec: NodeSpecification, **kwargs: Any) -> None: Target for multiprocessing.run, init the class and start it! """ - runner = NodeRunner(spec=spec, **kwargs) - asyncio.run(runner._run()) + # ensure that events and conditions are bound to the eventloop created in the process + async def _run_inner() -> None: + nonlocal spec, kwargs + runner = NodeRunner(spec=spec, **kwargs) + await runner._run() + + asyncio.run(_run_inner()) async def _run(self) -> None: @@ -427,22 +437,26 @@ def _handler(sig: int, frame: FrameType | None = None) -> None: async def _loop(self) -> None: is_async = iscoroutinefunction_partial(self._node.process) - - async for args, kwargs, epoch in self.await_inputs(): - self.logger.debug("Running with args: %s, kwargs: %s, epoch: %s", args, kwargs, epoch) - if is_async: - # mypy fails here because it can't propagate the type guard above - value = await self._node.process(*args, **kwargs) # type: ignore[arg-type] - else: - value = self._node.process(*args, **kwargs) - events = self.store.add_value(self._node.signals, value, self._node.id, epoch) - await self.scheduler.add_epoch() - - # nodes should not report epoch endings since they don't know about the full tube - events = [e for e in events if e["node_id"] != "meta"] - if events: - await self.update_graph(events) - await self.publish_events(events) + loop = asyncio.get_running_loop() + with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor: + async for args, kwargs, epoch in self.await_inputs(): + self.logger.debug("Running with args: %s, kwargs: %s, epoch: %s", args, kwargs, epoch) + if is_async: + # mypy fails here because it can't propagate the type guard above + value = await self._node.process(*args, **kwargs) # type: ignore[arg-type] + else: + part = partial(self._node.process, *args, **kwargs) + value = await loop.run_in_executor(executor, part) + events = self.store.add_value(self._node.signals, value, self._node.id, epoch) + async with self._ready_condition: + self.scheduler.add_epoch() + self._ready_condition.notify_all() + + # nodes should not report epoch endings since they don't know about the full tube + events = [e for e in events if e["node_id"] != "meta"] + if events: + await self.update_graph(events) + await self.publish_events(events) async def _poll_inbox(self) -> None: while not self._process_quitting.is_set(): @@ -467,7 +481,7 @@ async def await_inputs(self) -> AsyncGenerator[tuple[tuple[Any], dict[str, Any], epoch = next(self._counter) if self._node.stateful else None - ready = await self.scheduler.await_node(self.spec.id, epoch=epoch) + ready = await self.await_node(epoch=epoch) edges = self._node.edges inputs = self.store.collect(edges, ready["epoch"]) if inputs is None: @@ -478,7 +492,9 @@ async def await_inputs(self) -> AsyncGenerator[tuple[tuple[Any], dict[str, Any], yield args, kwargs, ready["epoch"] async def update_graph(self, events: list[Event]) -> None: - await self.scheduler.update(events) + async with self._ready_condition: + self.scheduler.update(events) + self._ready_condition.notify_all() async def publish_events(self, events: list[Event]) -> None: msg = EventMsg(node_id=self.spec.id, value=events) @@ -544,8 +560,10 @@ async def update_status(self, status: NodeStatus) -> None: async def init_node(self) -> None: self._node = Node.from_specification(self.spec, self.input_collection) self._node.init() - self.scheduler = AsyncScheduler(nodes={self.spec.id: self.spec}, edges=self._node.edges) - await self.scheduler.add_epoch() + self.scheduler = Scheduler(nodes={self.spec.id: self.spec}, edges=self._node.edges) + async with self._ready_condition: + self.scheduler.add_epoch() + self._ready_condition.notify_all() def _init_sockets(self) -> None: self._dealer = self._init_dealer() @@ -657,8 +675,11 @@ async def on_event(self, msg: EventMsg) -> None: to_add = [e for e in events if (e["node_id"], e["signal"]) in self.depends] for event in to_add: self.store.add(event) - - await self.scheduler.update(events) + self.logger.debug("scheduler updating") + async with self._ready_condition: + self.scheduler.update(events) + self._ready_condition.notify_all() + self.logger.debug('scheduler updated') async def on_start(self, msg: StartMsg) -> None: """ @@ -697,7 +718,7 @@ async def on_process(self, msg: ProcessMsg) -> None: node_id="input", epoch=msg.value["epoch"], ) - scheduler_events = await self.scheduler.update(events) + scheduler_events = self.scheduler.update(events) self.logger.debug("Updated scheduler with process events: %s", scheduler_events) self._process_one.set() @@ -741,6 +762,48 @@ async def error(self, err: Exception) -> None: ) await self._dealer.send_multipart([msg.to_bytes()]) + async def await_node(self, epoch: int | None = None) -> MetaEvent: + """ + Block until a node is ready + + Args: + node_id: + epoch (int, None): if `int` , wait until the node is ready in the given epoch, + otherwise wait until the node is ready in any epoch + + Returns: + + """ + async with self._ready_condition: + await self._ready_condition.wait_for(lambda: self.scheduler.node_is_ready(self.spec.id, epoch)) + + # be FIFO-like and get the earliest epoch the node is ready in + if epoch is None: + for ep in self.scheduler._epochs: + if self.scheduler.node_is_ready(self.spec.id, ep): + epoch = ep + break + + if epoch is None: + raise RuntimeError( + "Could not find ready epoch even though node ready condition passed, " + "something is wrong with the way node status checking is " + "locked between threads." + ) + + # mark just one event as "out." + # threadsafe because we are holding the lock that protects graph mutation + self.scheduler[epoch].mark_out(self.spec.id) + + return MetaEvent( + id=uuid4().int, + timestamp=datetime.now(), + node_id="meta", + signal=MetaEventType.NodeReady, + epoch=epoch, + value=self.spec.id, + ) + @dataclass class ZMQRunner(TubeRunner): @@ -767,6 +830,7 @@ class ZMQRunner(TubeRunner): _return_node: Return | None = None _to_throw: ErrorValue | None = None _current_epoch: int = 0 + _epoch_futures: dict[int, concurrent.futures.Future] = field(default_factory=dict) @property def running(self) -> bool: @@ -876,9 +940,9 @@ def process(self, **kwargs: Any) -> ReturnNodeType: self.command = cast(CommandNode, self.command) self.command.process(self._current_epoch, input) self._logger.debug("awaiting epoch %s", self._current_epoch) - self.tube.scheduler.await_epoch(self._current_epoch) - if self._to_throw: - self._throw_error() + future = self.await_epoch(self._current_epoch) + # waiting on the result will also raise a result when one is set + future.result() self._logger.debug("collecting return") return self.collect_return(self._current_epoch) @@ -936,8 +1000,7 @@ def iter(self, n: int | None = None) -> Generator[ReturnNodeType, None, None]: loop = 0 while ret is MetaSignal.NoEvent: self._logger.debug("Awaiting epoch %s", epoch) - self.tube.scheduler.await_epoch(epoch) - self._logger.debug("epoch %s completed", epoch) + self.await_epoch(epoch).result() ret = self.collect_return(epoch) epoch += 1 self._current_epoch = epoch @@ -998,12 +1061,20 @@ def run(self, n: int | None = None) -> None | list[ReturnNodeType]: self._running.set() return None - else: + elif self.tube.has_return: + # run until n return values results = [] for res in self.iter(n): results.append(res) return results + else: + # run n epochs + self.command.start(n) + self._running.set() + self._current_epoch = self.await_epoch(self._current_epoch + n).result() + return None + def stop(self) -> None: """ Stop running the tube. @@ -1024,7 +1095,7 @@ def on_event(self, msg: Message) -> None: if not self._ignore_events: for event in msg.value: self.store.add(event) - self.tube.scheduler.update(msg.value) + events = self.tube.scheduler.update(msg.value) if self._return_node is not None: # mark the return node done if we've received the expected events for an epoch # do it here since we don't really run the return node like a real node @@ -1036,7 +1107,13 @@ def on_event(self, msg: Message) -> None: and epoch in self.tube.scheduler._epochs ): self._logger.debug("Marking return node ready in epoch %s", epoch) - self.tube.scheduler.done(epoch, self._return_node.id) + ep_ended = self.tube.scheduler.done(epoch, self._return_node.id) + if ep_ended is not None: + events.append(ep_ended) + for e in events: + if e['node_id'] == "meta" and e['signal'] == MetaEventType.EpochEnded and e['value'] in self._epoch_futures: + self._epoch_futures[e['value']].set_result(e['value']) + del self._epoch_futures[e['value']] def on_router(self, msg: Message) -> None: if isinstance(msg, ErrorMsg): @@ -1061,36 +1138,39 @@ def collect_return(self, epoch: int | None = None) -> Any: def _handle_error(self, msg: ErrorMsg) -> None: """Cancel current epoch, stash error for process method to throw""" self._logger.error("Received error from node: %s", msg) + exception = msg.to_exception() self._to_throw = msg.value if self._current_epoch is not None: # if we're waiting in the process method, # end epoch and raise error there self.tube.scheduler.end_epoch(self._current_epoch) + self.deinit() + if self._current_epoch in self._epoch_futures: + self._epoch_futures[self._current_epoch].set_exception(exception) + del self._epoch_futures[self._current_epoch] + else: + raise exception else: # e.g. errors during init, raise here. - self._throw_error() - - def _throw_error(self) -> None: - errval = self._to_throw - if errval is None: - return - # clear instance object and store locally, we aren't locked here. - self._to_throw = None - self._logger.debug( - "Deinitializing before throwing error", - ) - self.deinit() - - # add the traceback as a note, - # sort of the best we can do without using tblib - err = errval["err_type"](*errval["err_args"]) - tb_message = "\nError re-raised from node runner process\n\n" - tb_message += "Original traceback:\n" - tb_message += "-" * 20 + "\n" - tb_message += errval["traceback"] - err.add_note(tb_message) - - raise err + raise exception + + # + # def _throw_error(self, e) -> None: + # errval = self._to_throw + # if errval is None: + # return + # # clear instance object and store locally, we aren't locked here. + # self._to_throw = None + # self._logger.debug( + # "Deinitializing before throwing error", + # ) + # self.deinit() + # + # # add the traceback as a note, + # # sort of the best we can do without using tblib + # err = self._rehydrate_error(errval) + # + # raise err def _request_more(self, n: int, current_iter: int, n_epochs: int) -> int: """ @@ -1130,3 +1210,9 @@ def enable_node(self, node_id: str) -> None: def disable_node(self, node_id: str) -> None: raise NotImplementedError() + + def await_epoch(self, epoch: int) -> concurrent.futures.Future: + if epoch in self._epoch_futures: + return self._epoch_futures[epoch] + self._epoch_futures[epoch] = concurrent.futures.Future() + return self._epoch_futures[epoch] diff --git a/src/noob/scheduler.py b/src/noob/scheduler.py index d0a3dc2c..73760307 100644 --- a/src/noob/scheduler.py +++ b/src/noob/scheduler.py @@ -34,7 +34,7 @@ class Scheduler(BaseModel): _clock: count = PrivateAttr(default_factory=count) _epochs: dict[int, TopoSorter] = PrivateAttr(default_factory=dict) - _epoch_condition: Condition = PrivateAttr(default_factory=Condition) + # _epoch_condition: Condition = PrivateAttr(default_factory=Condition) _epoch_log: deque = PrivateAttr(default_factory=lambda: deque(maxlen=100)) model_config = ConfigDict(arbitrary_types_allowed=True) @@ -140,7 +140,6 @@ def node_is_ready(self, node: NodeID, epoch: int | None = None) -> bool: graphs = self._epochs.items() if epoch is None else [(epoch, self[epoch])] is_ready = any(node_id == node for epoch, graph in graphs for node_id in graph.ready_nodes) - self.logger.debug("Node %s ready status in epoch %s: %s", node, epoch, is_ready) return is_ready def __getitem__(self, epoch: int) -> TopoSorter: @@ -199,26 +198,25 @@ def done(self, epoch: int, node_id: str) -> MetaEvent | None: Mark a node in a given epoch as done. """ - with self._epoch_condition: - if epoch in self._epoch_log: - self.logger.debug( - "Marking node %s as done in epoch %s, " - "but epoch was already completed. ignoring", - node_id, - epoch, - ) - return None - - try: - self[epoch].done(node_id) - except NotOutYetError: - # in parallel mode, we don't `get_ready` the preceding ready nodes - # so we have to manually mark them as "out" - self[epoch].mark_out(node_id) - self[epoch].done(node_id) - - if not self[epoch].is_active(): - return self.end_epoch(epoch) + if epoch in self._epoch_log: + self.logger.debug( + "Marking node %s as done in epoch %s, " + "but epoch was already completed. ignoring", + node_id, + epoch, + ) + return None + + try: + self[epoch].done(node_id) + except NotOutYetError: + # in parallel mode, we don't `get_ready` the preceding ready nodes + # so we have to manually mark them as "out" + self[epoch].mark_out(node_id) + self[epoch].done(node_id) + + if not self[epoch].is_active(): + return self.end_epoch(epoch) return None def expire(self, epoch: int, node_id: str) -> MetaEvent | None: @@ -226,10 +224,9 @@ def expire(self, epoch: int, node_id: str) -> MetaEvent | None: Mark a node as having been completed without making its dependent nodes ready. i.e. when the node emitted ``NoEvent`` """ - with self._epoch_condition: - self[epoch].mark_expired(node_id) - if not self[epoch].is_active(): - return self.end_epoch(epoch) + self[epoch].mark_expired(node_id) + if not self[epoch].is_active(): + return self.end_epoch(epoch) return None @@ -244,30 +241,29 @@ def await_epoch(self, epoch: int | None = None) -> int: Returns: int: the epoch that was completed. """ - with self._epoch_condition: - # check if we have already completed this epoch - if isinstance(epoch, int) and self.epoch_completed(epoch): - return epoch - if epoch is None: - self._epoch_condition.wait() - return self._epoch_log[-1] - else: - self._epoch_condition.wait_for(lambda: self.epoch_completed(epoch)) - return epoch + # check if we have already completed this epoch + if isinstance(epoch, int) and self.epoch_completed(epoch): + return epoch + + if epoch is None: + self._epoch_condition.wait() + return self._epoch_log[-1] + else: + self._epoch_condition.wait_for(lambda: self.epoch_completed(epoch)) + return epoch def epoch_completed(self, epoch: int) -> bool: """ Check if the epoch has been completed. """ - with self._epoch_condition: - previously_completed = ( - len(self._epoch_log) > 0 - and epoch not in self._epochs - and (epoch in self._epoch_log or epoch < min(self._epoch_log)) - ) - active_completed = epoch in self._epochs and not self._epochs[epoch].is_active() - return previously_completed or active_completed + previously_completed = ( + len(self._epoch_log) > 0 + and epoch not in self._epochs + and (epoch in self._epoch_log or epoch < min(self._epoch_log)) + ) + active_completed = epoch in self._epochs and not self._epochs[epoch].is_active() + return previously_completed or active_completed def end_epoch(self, epoch: int | None = None) -> MetaEvent | None: if epoch is None or epoch == -1: @@ -275,10 +271,8 @@ def end_epoch(self, epoch: int | None = None) -> MetaEvent | None: return None epoch = list(self._epochs)[-1] - with self._epoch_condition: - self._epoch_condition.notify_all() - self._epoch_log.append(epoch) - del self._epochs[epoch] + self._epoch_log.append(epoch) + del self._epochs[epoch] return MetaEvent( id=uuid4().int, @@ -362,78 +356,62 @@ def generations(self) -> list[tuple[str, ...]]: sorter.done(*ready) return generations - -class AsyncScheduler(Scheduler): - _ready_condition: asyncio.Condition = PrivateAttr(default_factory=asyncio.Condition) - - async def add_epoch(self, epoch: int | None = None) -> int: - async with self._ready_condition: - this_epoch = super().add_epoch(epoch) - self._ready_condition.notify_all() - return this_epoch - - def __getitem__(self, epoch: int) -> TopoSorter: - if epoch == -1: - return self._epochs[max(self._epochs.keys())] - - if epoch not in self._epochs: - asyncio.run_coroutine_threadsafe( - self.add_epoch(epoch), asyncio.get_running_loop() - ).result() - return self._epochs[epoch] - - async def update( - self, events: MutableSequence[Event | MetaEvent] | MutableSequence[Event] - ) -> MutableSequence[Event] | MutableSequence[Event | MetaEvent]: - async with self._ready_condition: - ret_events = super().update(events) - self._ready_condition.notify_all() - return ret_events - - # async def done(self, epoch: int, node_id: str) -> MetaEvent | None: - # async with self._ready_condition: - # ret = super().done(epoch, node_id) - # self._ready_condition.notify_all() - # return ret - - async def await_node(self, node_id: NodeID, epoch: int | None = None) -> MetaEvent: - """ - Block until a node is ready - - Args: - node_id: - epoch (int, None): if `int` , wait until the node is ready in the given epoch, - otherwise wait until the node is ready in any epoch - - Returns: - - """ - async with self._ready_condition: - await self._ready_condition.wait_for(lambda: self.node_is_ready(node_id, epoch)) - - # be FIFO-like and get the earliest epoch the node is ready in - if epoch is None: - for ep in self._epochs: - if self.node_is_ready(node_id, ep): - epoch = ep - break - - if epoch is None: - raise RuntimeError( - "Could not find ready epoch even though node ready condition passed, " - "something is wrong with the way node status checking is " - "locked between threads." - ) - - # mark just one event as "out." - # threadsafe because we are holding the lock that protects graph mutation - self._epochs[epoch].mark_out(node_id) - - return MetaEvent( - id=uuid4().int, - timestamp=datetime.now(), - node_id="meta", - signal=MetaEventType.NodeReady, - epoch=epoch, - value=node_id, - ) +# +# class AsyncScheduler(Scheduler): +# +# async def update( +# self, events: MutableSequence[Event | MetaEvent] | MutableSequence[Event] +# ) -> MutableSequence[Event] | MutableSequence[Event | MetaEvent]: +# async with self._ready_condition: +# self.logger.debug("Inside update lock") +# ret_events = super().update(events) +# self._ready_condition.notify_all() +# return ret_events +# +# # async def done(self, epoch: int, node_id: str) -> MetaEvent | None: +# # async with self._ready_condition: +# # ret = super().done(epoch, node_id) +# # self._ready_condition.notify_all() +# # return ret +# +# async def await_node(self, node_id: NodeID, epoch: int | None = None) -> MetaEvent: +# """ +# Block until a node is ready +# +# Args: +# node_id: +# epoch (int, None): if `int` , wait until the node is ready in the given epoch, +# otherwise wait until the node is ready in any epoch +# +# Returns: +# +# """ +# async with self._ready_condition: +# await self._ready_condition.wait_for(lambda: self.node_is_ready(node_id, epoch)) +# +# # be FIFO-like and get the earliest epoch the node is ready in +# if epoch is None: +# for ep in self._epochs: +# if self.node_is_ready(node_id, ep): +# epoch = ep +# break +# +# if epoch is None: +# raise RuntimeError( +# "Could not find ready epoch even though node ready condition passed, " +# "something is wrong with the way node status checking is " +# "locked between threads." +# ) +# +# # mark just one event as "out." +# # threadsafe because we are holding the lock that protects graph mutation +# self._epochs[epoch].mark_out(node_id) +# +# return MetaEvent( +# id=uuid4().int, +# timestamp=datetime.now(), +# node_id="meta", +# signal=MetaEventType.NodeReady, +# epoch=epoch, +# value=node_id, +# ) diff --git a/src/noob/testing/nodes.py b/src/noob/testing/nodes.py index 1996f88f..2754aa54 100644 --- a/src/noob/testing/nodes.py +++ b/src/noob/testing/nodes.py @@ -14,7 +14,7 @@ from noob.node import Node -def count_source(limit: int = 1000, start: int = 0) -> Generator[A[int, Name("index")], None, None]: +def count_source(limit: int = 10000, start: int = 0) -> Generator[A[int, Name("index")], None, None]: counter = count(start=start) if limit == 0: while True: @@ -156,8 +156,8 @@ def input_party( return True -def long_add(value: float) -> float: - sleep(0.25) +def long_add(value: float, sleep_for: float = 0.25) -> float: + sleep(sleep_for) return value + 1 diff --git a/src/noob/tube.py b/src/noob/tube.py index 592fd6f7..c003ead3 100644 --- a/src/noob/tube.py +++ b/src/noob/tube.py @@ -238,6 +238,11 @@ def from_specification( context={"skip_input_presence": True}, ) + @property + def has_return(self) -> bool: + """Whether the tube has a :class:`.Return` node present""" + return any(isinstance(node, Return) for node in self.nodes.values()) + @classmethod def _init_nodes( cls, specs: TubeSpecification, input_collection: InputCollection diff --git a/tests/bench.py b/tests/bench.py index 8e1097a7..02499e8f 100644 --- a/tests/bench.py +++ b/tests/bench.py @@ -1,8 +1,10 @@ import pytest +import asyncio from pytest_codspeed.plugin import BenchmarkFixture from noob import Tube from noob.runner.base import TubeRunner +from noob.runner import SynchronousRunner from noob.runner.zmq import ZMQRunner @@ -30,6 +32,19 @@ def test_long_add(benchmark: BenchmarkFixture, runner: TubeRunner) -> None: """ benchmark(lambda: runner.process()) +@pytest.mark.asyncio +@pytest.mark.parametrize("loaded_tube", ["testing-long-add"], indirect=True) +async def test_long_add_run(benchmark: BenchmarkFixture, runner: TubeRunner) -> None: + """ + ZMQ runner should be faster for tubes where nodes take a long time + and there's lots of concurrency possibilities + """ + if isinstance(runner, SynchronousRunner): + pytest.skip() + runner.run() + await asyncio.sleep(10) + runner.stop() + @pytest.mark.parametrize("loaded_tube", ["testing-kitchen-sink"], indirect=True) def test_topo_sorter(benchmark: BenchmarkFixture, loaded_tube: Tube) -> None: diff --git a/tests/test_runners/test_zmq.py b/tests/test_runners/test_zmq.py index f49f2114..7c4a1d06 100644 --- a/tests/test_runners/test_zmq.py +++ b/tests/test_runners/test_zmq.py @@ -345,8 +345,8 @@ def test_iter_gather(mocker): # ceil((11/2)*9) = 50 assert spy.spy_return == 50 - -def test_noderunner_stores_clear(): +@pytest.mark.asyncio +async def test_noderunner_stores_clear(): """ Stores in the noderunners should clear after they use the events from an epoch """ @@ -362,7 +362,7 @@ def test_noderunner_stores_clear(): command_router="/notreal/unused", input_collection=InputCollection(), ) - runner.init_node() + await runner.init_node() # fake a few events events = [] @@ -389,13 +389,16 @@ def test_noderunner_stores_clear(): ), ], ) - runner.on_event(msg) + await runner.on_event(msg) events.append(msg) runner._freerun.set() assert len(runner.store.events) == 3 - args, kwargs, epoch = next(runner.await_inputs()) + epoch = -1 + async for args, kwargs, epoch in runner.await_inputs(): + break assert len(runner.store.events) == 2 + assert epoch != -1 assert epoch not in runner.store.events From d8f4d863998e7c75380242d8e52e4390e4397f17 Mon Sep 17 00:00:00 2001 From: sneakers-the-rat Date: Tue, 27 Jan 2026 20:21:22 -0800 Subject: [PATCH 03/13] working on refactoring command node, seeing if we can do any consolidation, before any cleanup so this is still messy as all hell and the statefulness test is not working bc the stateless nodes aren't returning anything --- src/noob/runner/zmq.py | 230 +++++++++++++++++++++------------ src/noob/scheduler.py | 6 +- tests/bench.py | 15 --- tests/test_runners/test_zmq.py | 5 +- 4 files changed, 153 insertions(+), 103 deletions(-) diff --git a/src/noob/runner/zmq.py b/src/noob/runner/zmq.py index 471cca40..d776341d 100644 --- a/src/noob/runner/zmq.py +++ b/src/noob/runner/zmq.py @@ -24,17 +24,18 @@ import asyncio import concurrent.futures +import contextlib import math import multiprocessing as mp import os -from datetime import datetime import signal import threading import traceback -from functools import partial from collections import defaultdict -from collections.abc import Callable, Generator, AsyncGenerator +from collections.abc import AsyncGenerator, Callable, Generator from dataclasses import dataclass, field +from datetime import datetime +from functools import partial from itertools import count from multiprocessing.synchronize import Event as EventType from time import time @@ -42,8 +43,6 @@ from typing import TYPE_CHECKING, Any, Literal, cast, overload from uuid import uuid4 -from noob.network.loop import EventloopMixin - try: import zmq except ImportError as e: @@ -55,7 +54,7 @@ from zmq.eventloop.zmqstream import ZMQStream from noob.config import config -from noob.event import Event, MetaSignal, MetaEvent, MetaEventType +from noob.event import Event, MetaEvent, MetaEventType, MetaSignal from noob.exceptions import InputMissingError from noob.input import InputCollection, InputScope from noob.logging import init_logger @@ -88,7 +87,7 @@ pass -class CommandNode(EventloopMixin): +class CommandNode: """ Pub node that controls the state of the other nodes/announces addresses @@ -114,12 +113,38 @@ def __init__(self, runner_id: str, protocol: str = "ipc", port: int | None = Non self.port = port self.protocol = protocol self.logger = init_logger(f"runner.node.{runner_id}.command") - self._outbox: zmq.Socket = None # type: ignore[assignment] - self._inbox: ZMQStream = None # type: ignore[assignment] - self._router: ZMQStream = None # type: ignore[assignment] + self._context: Context = None # type: ignore[assignment] + self._poller: Poller = None # type: ignore[assignment] + self._outbox: Socket = None # type: ignore[assignment] + self._inbox: Socket = None # type: ignore[assignment] + self._router: Socket = None # type: ignore[assignment] self._nodes: dict[str, IdentifyValue] = {} self._ready_condition = threading.Condition() self._callbacks: dict[str, list[Callable[[Message], Any]]] = defaultdict(list) + self._ready_future: concurrent.futures.Future | None = None + self._waiting_for: set[str] = set() + self._quitting: asyncio.Event = None # type: ignore[assignment] + self._tasks = set() + self._init = threading.Event() + self._loop = None + + @property + def context(self) -> Context: + if self._context is None: + self._context = Context.instance() + return self._context + + @property + def poller(self) -> Poller: + if self._poller is None: + self._poller = Poller() + return self._poller + + @property + def loop(self) -> asyncio.AbstractEventLoop: + if self._loop is None: + self._loop = asyncio.get_running_loop() + return self._loop @property def pub_address(self) -> str: @@ -141,9 +166,40 @@ def router_address(self) -> str: else: raise NotImplementedError() + def run(self) -> None: + """ + Target for :class:`threading.Thread` + """ + asyncio.run(self._run()) + + async def _run(self) -> None: + self._init.clear() + self._loop = asyncio.get_running_loop() + self._quitting = asyncio.Event() + self.init() + self._init.set() + with contextlib.suppress(asyncio.CancelledError): + await asyncio.gather(self._poll_input(), self._await_quit()) + + async def _await_quit(self) -> None: + self._quitting.clear() + await self._quitting.wait() + self.logger.debug("QUITTING!!!!") + raise asyncio.CancelledError() + + async def _poll_input(self) -> None: + while not self._quitting.is_set(): + events = await self._poller.poll() + events = dict(events) + if self._inbox in events: + msg = await self._inbox.recv_multipart() + await self.on_inbox(msg) + if self._router in events: + msg = await self._router.recv_multipart() + await self.on_router(msg) + def init(self) -> None: self.logger.debug("Starting command runner") - self.start_loop() self._init_sockets() self.logger.debug("Command runner started") @@ -151,9 +207,10 @@ def deinit(self) -> None: """Close the eventloop, stop processing messages, reset state""" self.logger.debug("Deinitializing") msg = DeinitMsg(node_id="command") - self._outbox.send_multipart([b"deinit", msg.to_bytes()]) - self.stop_loop() - self.logger.debug("Deinitialized") + future = self._outbox.send_multipart([b"deinit", msg.to_bytes()]) + future = cast(asyncio.Future, future) + future.add_done_callback(lambda x: self._quitting.set()) + self.logger.debug("Queued loop for deinitialization") def stop(self) -> None: self.logger.debug("Stopping command runner") @@ -178,9 +235,8 @@ def _init_router(self) -> ZMQStream: router = self.context.socket(zmq.ROUTER) router.bind(self.router_address) router.setsockopt_string(zmq.IDENTITY, "command.router") - router = ZMQStream(router, self.loop) - router.on_recv(self.on_router) - self.logger.debug("Inbox bound to %s", self.router_address) + self.poller.register(router, zmq.POLLIN) + self.logger.debug("Router bound to %s", self.router_address) return router def _init_inbox(self) -> ZMQStream: @@ -188,20 +244,19 @@ def _init_inbox(self) -> ZMQStream: sub = self.context.socket(zmq.SUB) sub.setsockopt_string(zmq.IDENTITY, "command.inbox") sub.setsockopt_string(zmq.SUBSCRIBE, "") - sub = ZMQStream(sub, self.loop) - sub.on_recv(self.on_inbox) + self.poller.register(sub, zmq.POLLIN) return sub - def announce(self) -> None: + async def announce(self) -> None: msg = AnnounceMsg( node_id="command", value=AnnounceValue(inbox=self.router_address, nodes=self._nodes) ) - self._outbox.send_multipart([b"announce", msg.to_bytes()]) + await self._outbox.send_multipart([b"announce", msg.to_bytes()]) - def ping(self) -> None: + async def ping(self) -> None: """Send a ping message asking everyone to identify themselves""" msg = PingMsg(node_id="command") - self._outbox.send_multipart([b"ping", msg.to_bytes()]) + await self._outbox.send_multipart([b"ping", msg.to_bytes()]) def start(self, n: int | None = None) -> None: """ @@ -233,42 +288,42 @@ def add_callback(self, type_: Literal["inbox", "router"], cb: Callable[[Message] def clear_callbacks(self) -> None: self._callbacks = defaultdict(list) - def await_ready(self, node_ids: list[NodeID], timeout: float = 10) -> None: + def await_ready(self, node_ids: list[NodeID], timeout: float = 10) -> concurrent.futures.Future: """ Wait until all the node_ids have announced themselves """ + future = concurrent.futures.Future() + self._waiting_for = set(node_ids) + self._ready_future = future + wait_until = time() + timeout - def _ready_nodes() -> set[str]: - return {node_id for node_id, state in self._nodes.items() if state["status"] == "ready"} + async def _ping() -> None: + if self._ready_future is None: + return + if wait_until < time(): + raise TimeoutError("Nodes were not ready after the timeout. ") + await self.ping() + self.loop.call_later(1, asyncio.create_task, _ping()) - def _is_ready() -> bool: - ready_nodes = _ready_nodes() - waiting_for = set(node_ids) - self.logger.debug( - "Checking if ready, ready nodes are: %s, waiting for %s", - ready_nodes, - waiting_for, - ) - return waiting_for.issubset(ready_nodes) - - with self._ready_condition: - # ping periodically for identifications in case we have slow subscribers - start_time = time() - ready = False - while time() < start_time + timeout and not ready: - ready = self._ready_condition.wait_for(_is_ready, timeout=1) - if not ready: - self.ping() - - # if still not ready, timeout - if not ready: - raise TimeoutError( - f"Nodes were not ready after the timeout. " - f"Waiting for: {set(node_ids)}, " - f"ready: {_ready_nodes()}" - ) + self.loop.call_later(1, asyncio.create_task, _ping()) + return future + + def _check_ready(self) -> None: + if self._ready_future is None: + return + ready_nodes = { + node_id for node_id, state in self._nodes.items() if state["status"] == "ready" + } + self.logger.debug( + "Checking if ready, ready nodes are: %s, waiting for %s", + ready_nodes, + self._waiting_for, + ) + if self._ready_future is not None and self._waiting_for.issubset(ready_nodes): + self._ready_future.set_result(True) + self._ready_future = None - def on_router(self, msg: list[bytes]) -> None: + async def on_router(self, msg: list[bytes]) -> None: try: message = Message.from_bytes(msg) self.logger.debug("Received ROUTER message %s", message) @@ -281,39 +336,41 @@ def on_router(self, msg: list[bytes]) -> None: if message.type_ == MessageType.identify: message = cast(IdentifyMsg, message) - self.on_identify(message) + await self.on_identify(message) elif message.type_ == MessageType.status: message = cast(StatusMsg, message) - self.on_status(message) + await self.on_status(message) - def on_inbox(self, msg: list[bytes]) -> None: + async def on_inbox(self, msg: list[bytes]) -> None: message = Message.from_bytes(msg) self.logger.debug("Received INBOX message: %s", message) for cb in self._callbacks["inbox"]: - cb(message) + if iscoroutinefunction_partial(cb): + await cb(message) + else: + self.loop.run_in_executor(None, cb, message) - def on_identify(self, msg: IdentifyMsg) -> None: - with self._ready_condition: - self._nodes[msg.node_id] = msg.value - self._inbox.connect(msg.value["outbox"]) - self._ready_condition.notify_all() + async def on_identify(self, msg: IdentifyMsg) -> None: + self._nodes[msg.node_id] = msg.value + self._inbox.connect(msg.value["outbox"]) try: - self.announce() + await self.announce() self.logger.debug("Announced") except Exception as e: self.logger.exception("Exception announced: %s", e) - def on_status(self, msg: StatusMsg) -> None: - with self._ready_condition: - if msg.node_id not in self._nodes: - self.logger.warning( - "Node %s sent us a status before sending its full identify message, ignoring", - msg.node_id, - ) - return - self._nodes[msg.node_id]["status"] = msg.value - self._ready_condition.notify_all() + self._check_ready() + + async def on_status(self, msg: StatusMsg) -> None: + if msg.node_id not in self._nodes: + self.logger.warning( + "Node %s sent us a status before sending its full identify message, ignoring", + msg.node_id, + ) + return + self._nodes[msg.node_id]["status"] = msg.value + self._check_ready() class NodeRunner: @@ -405,6 +462,7 @@ def run(cls, spec: NodeSpecification, **kwargs: Any) -> None: Target for multiprocessing.run, init the class and start it! """ + # ensure that events and conditions are bound to the eventloop created in the process async def _run_inner() -> None: nonlocal spec, kwargs @@ -440,7 +498,9 @@ async def _loop(self) -> None: loop = asyncio.get_running_loop() with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor: async for args, kwargs, epoch in self.await_inputs(): - self.logger.debug("Running with args: %s, kwargs: %s, epoch: %s", args, kwargs, epoch) + self.logger.debug( + "Running with args: %s, kwargs: %s, epoch: %s", args, kwargs, epoch + ) if is_async: # mypy fails here because it can't propagate the type guard above value = await self._node.process(*args, **kwargs) # type: ignore[arg-type] @@ -679,7 +739,7 @@ async def on_event(self, msg: EventMsg) -> None: async with self._ready_condition: self.scheduler.update(events) self._ready_condition.notify_all() - self.logger.debug('scheduler updated') + self.logger.debug("scheduler updated") async def on_start(self, msg: StartMsg) -> None: """ @@ -775,7 +835,9 @@ async def await_node(self, epoch: int | None = None) -> MetaEvent: """ async with self._ready_condition: - await self._ready_condition.wait_for(lambda: self.scheduler.node_is_ready(self.spec.id, epoch)) + await self._ready_condition.wait_for( + lambda: self.scheduler.node_is_ready(self.spec.id, epoch) + ) # be FIFO-like and get the earliest epoch the node is ready in if epoch is None: @@ -850,7 +912,8 @@ def init(self) -> None: self.command = CommandNode(runner_id=self.runner_id) self.command.add_callback("inbox", self.on_event) self.command.add_callback("router", self.on_router) - self.command.init() + threading.Thread(target=self.command.run, daemon=True).start() + self.command._init.wait() self._logger.debug("Command node initialized") for node_id, node in self.tube.nodes.items(): @@ -872,9 +935,10 @@ def init(self) -> None: self.node_procs[node_id].start() self._logger.debug("Started node processes, awaiting ready") try: - self.command.await_ready( + future = self.command.await_ready( [k for k, v in self.tube.nodes.items() if not isinstance(v, Return)] ) + future.result() except TimeoutError as e: self._logger.debug("Timeouterror, deinitializing before throwing") self._initialized.set() @@ -1111,9 +1175,13 @@ def on_event(self, msg: Message) -> None: if ep_ended is not None: events.append(ep_ended) for e in events: - if e['node_id'] == "meta" and e['signal'] == MetaEventType.EpochEnded and e['value'] in self._epoch_futures: - self._epoch_futures[e['value']].set_result(e['value']) - del self._epoch_futures[e['value']] + if ( + e["node_id"] == "meta" + and e["signal"] == MetaEventType.EpochEnded + and e["value"] in self._epoch_futures + ): + self._epoch_futures[e["value"]].set_result(e["value"]) + del self._epoch_futures[e["value"]] def on_router(self, msg: Message) -> None: if isinstance(msg, ErrorMsg): diff --git a/src/noob/scheduler.py b/src/noob/scheduler.py index 73760307..3a2c2268 100644 --- a/src/noob/scheduler.py +++ b/src/noob/scheduler.py @@ -1,10 +1,8 @@ -import asyncio import logging from collections import deque from collections.abc import MutableSequence from datetime import UTC, datetime from itertools import count -from threading import Condition from typing import Self from uuid import uuid4 @@ -200,8 +198,7 @@ def done(self, epoch: int, node_id: str) -> MetaEvent | None: """ if epoch in self._epoch_log: self.logger.debug( - "Marking node %s as done in epoch %s, " - "but epoch was already completed. ignoring", + "Marking node %s as done in epoch %s, " "but epoch was already completed. ignoring", node_id, epoch, ) @@ -356,6 +353,7 @@ def generations(self) -> list[tuple[str, ...]]: sorter.done(*ready) return generations + # # class AsyncScheduler(Scheduler): # diff --git a/tests/bench.py b/tests/bench.py index 02499e8f..8e1097a7 100644 --- a/tests/bench.py +++ b/tests/bench.py @@ -1,10 +1,8 @@ import pytest -import asyncio from pytest_codspeed.plugin import BenchmarkFixture from noob import Tube from noob.runner.base import TubeRunner -from noob.runner import SynchronousRunner from noob.runner.zmq import ZMQRunner @@ -32,19 +30,6 @@ def test_long_add(benchmark: BenchmarkFixture, runner: TubeRunner) -> None: """ benchmark(lambda: runner.process()) -@pytest.mark.asyncio -@pytest.mark.parametrize("loaded_tube", ["testing-long-add"], indirect=True) -async def test_long_add_run(benchmark: BenchmarkFixture, runner: TubeRunner) -> None: - """ - ZMQ runner should be faster for tubes where nodes take a long time - and there's lots of concurrency possibilities - """ - if isinstance(runner, SynchronousRunner): - pytest.skip() - runner.run() - await asyncio.sleep(10) - runner.stop() - @pytest.mark.parametrize("loaded_tube", ["testing-kitchen-sink"], indirect=True) def test_topo_sorter(benchmark: BenchmarkFixture, loaded_tube: Tube) -> None: diff --git a/tests/test_runners/test_zmq.py b/tests/test_runners/test_zmq.py index 7c4a1d06..abd3e213 100644 --- a/tests/test_runners/test_zmq.py +++ b/tests/test_runners/test_zmq.py @@ -345,6 +345,7 @@ def test_iter_gather(mocker): # ceil((11/2)*9) = 50 assert spy.spy_return == 50 + @pytest.mark.asyncio async def test_noderunner_stores_clear(): """ @@ -394,9 +395,7 @@ async def test_noderunner_stores_clear(): runner._freerun.set() assert len(runner.store.events) == 3 - epoch = -1 - async for args, kwargs, epoch in runner.await_inputs(): - break + _, _, epoch = await anext(runner.await_inputs()) assert len(runner.store.events) == 2 assert epoch != -1 assert epoch not in runner.store.events From e775dd90e67e8cbd772f2101ba9e4a4e1c0e9785 Mon Sep 17 00:00:00 2001 From: sneakers-the-rat Date: Thu, 29 Jan 2026 20:02:36 -0800 Subject: [PATCH 04/13] notify when inputs ready --- src/noob/runner/zmq.py | 41 +++++++++++++++++++++------------------ src/noob/testing/nodes.py | 4 +++- 2 files changed, 25 insertions(+), 20 deletions(-) diff --git a/src/noob/runner/zmq.py b/src/noob/runner/zmq.py index d776341d..c8a789eb 100644 --- a/src/noob/runner/zmq.py +++ b/src/noob/runner/zmq.py @@ -298,14 +298,15 @@ def await_ready(self, node_ids: list[NodeID], timeout: float = 10) -> concurrent wait_until = time() + timeout async def _ping() -> None: + await asyncio.sleep(1) if self._ready_future is None: return if wait_until < time(): raise TimeoutError("Nodes were not ready after the timeout. ") await self.ping() - self.loop.call_later(1, asyncio.create_task, _ping()) + await _ping() - self.loop.call_later(1, asyncio.create_task, _ping()) + self.loop.create_task(_ping()) return future def _check_ready(self) -> None: @@ -772,15 +773,17 @@ async def on_process(self, msg: ProcessMsg) -> None: value = combined[next(iter(combined.keys()))] else: value = list(combined.values()) - events = self.store.add_value( - [Signal(name=k, type_=None) for k in combined], - value, - node_id="input", - epoch=msg.value["epoch"], - ) - scheduler_events = self.scheduler.update(events) + async with self._ready_condition: + events = self.store.add_value( + [Signal(name=k, type_=None) for k in combined], + value, + node_id="input", + epoch=msg.value["epoch"], + ) + scheduler_events = self.scheduler.update(events) + self._ready_condition.notify_all() + self.logger.debug("Updated scheduler with process events: %s", scheduler_events) - self.logger.debug("Updated scheduler with process events: %s", scheduler_events) self._process_one.set() async def on_stop(self, msg: StopMsg) -> None: @@ -1004,9 +1007,7 @@ def process(self, **kwargs: Any) -> ReturnNodeType: self.command = cast(CommandNode, self.command) self.command.process(self._current_epoch, input) self._logger.debug("awaiting epoch %s", self._current_epoch) - future = self.await_epoch(self._current_epoch) - # waiting on the result will also raise a result when one is set - future.result() + self.await_epoch(self._current_epoch) self._logger.debug("collecting return") return self.collect_return(self._current_epoch) @@ -1064,7 +1065,7 @@ def iter(self, n: int | None = None) -> Generator[ReturnNodeType, None, None]: loop = 0 while ret is MetaSignal.NoEvent: self._logger.debug("Awaiting epoch %s", epoch) - self.await_epoch(epoch).result() + self.await_epoch(epoch) ret = self.collect_return(epoch) epoch += 1 self._current_epoch = epoch @@ -1279,8 +1280,10 @@ def enable_node(self, node_id: str) -> None: def disable_node(self, node_id: str) -> None: raise NotImplementedError() - def await_epoch(self, epoch: int) -> concurrent.futures.Future: - if epoch in self._epoch_futures: - return self._epoch_futures[epoch] - self._epoch_futures[epoch] = concurrent.futures.Future() - return self._epoch_futures[epoch] + def await_epoch(self, epoch: int) -> int: + if self.tube.scheduler.epoch_completed(epoch): + return epoch + + if epoch not in self._epoch_futures: + self._epoch_futures[epoch] = concurrent.futures.Future() + return self._epoch_futures[epoch].result() diff --git a/src/noob/testing/nodes.py b/src/noob/testing/nodes.py index 2754aa54..35b8a081 100644 --- a/src/noob/testing/nodes.py +++ b/src/noob/testing/nodes.py @@ -14,7 +14,9 @@ from noob.node import Node -def count_source(limit: int = 10000, start: int = 0) -> Generator[A[int, Name("index")], None, None]: +def count_source( + limit: int = 10000, start: int = 0 +) -> Generator[A[int, Name("index")], None, None]: counter = count(start=start) if limit == 0: while True: From 93d8f74f18b17e9aa4f49b36d8a7995bc5c48936 Mon Sep 17 00:00:00 2001 From: sneakers-the-rat Date: Thu, 29 Jan 2026 20:25:11 -0800 Subject: [PATCH 05/13] call public methods from main thread --- src/noob/runner/zmq.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/src/noob/runner/zmq.py b/src/noob/runner/zmq.py index c8a789eb..2291511e 100644 --- a/src/noob/runner/zmq.py +++ b/src/noob/runner/zmq.py @@ -206,16 +206,19 @@ def init(self) -> None: def deinit(self) -> None: """Close the eventloop, stop processing messages, reset state""" self.logger.debug("Deinitializing") - msg = DeinitMsg(node_id="command") - future = self._outbox.send_multipart([b"deinit", msg.to_bytes()]) - future = cast(asyncio.Future, future) - future.add_done_callback(lambda x: self._quitting.set()) + + async def _deinit() -> None: + msg = DeinitMsg(node_id="command") + await self._outbox.send_multipart([b"deinit", msg.to_bytes()]) + self._quitting.set() + + self.loop.create_task(_deinit()) self.logger.debug("Queued loop for deinitialization") def stop(self) -> None: self.logger.debug("Stopping command runner") msg = StopMsg(node_id="command") - self._outbox.send_multipart([b"stop", msg.to_bytes()]) + self.loop.call_soon_threadsafe(self._outbox.send_multipart, [b"stop", msg.to_bytes()]) self.logger.debug("Command runner stopped") def _init_sockets(self) -> None: @@ -262,18 +265,21 @@ def start(self, n: int | None = None) -> None: """ Start running in free-run mode """ - self._outbox.send_multipart([b"start", StartMsg(node_id="command", value=n).to_bytes()]) + self.loop.call_soon_threadsafe( + self._outbox.send_multipart, [b"start", StartMsg(node_id="command", value=n).to_bytes()] + ) self.logger.debug("Sent start message") def process(self, epoch: int, input: dict | None = None) -> None: """Emit a ProcessMsg to process a single round through the graph""" # no empty dicts input = input if input else None - self._outbox.send_multipart( + self.loop.call_soon_threadsafe( + self._outbox.send_multipart, [ b"process", ProcessMsg(node_id="command", value={"input": input, "epoch": epoch}).to_bytes(), - ] + ], ) self.logger.debug("Sent process message") From 77ceacd98e638f1bf9bd0b4966fb1734d84461ed Mon Sep 17 00:00:00 2001 From: sneakers-the-rat Date: Wed, 4 Feb 2026 15:33:44 -0800 Subject: [PATCH 06/13] move common loop stuff back into the mixin, refactor command node --- src/noob/network/loop.py | 174 ++++++++++++++++++++---------- src/noob/runner/zmq.py | 222 ++++++++++++++------------------------- 2 files changed, 196 insertions(+), 200 deletions(-) diff --git a/src/noob/network/loop.py b/src/noob/network/loop.py index 74e6e5e3..cfa9afd5 100644 --- a/src/noob/network/loop.py +++ b/src/noob/network/loop.py @@ -1,84 +1,146 @@ import asyncio -import threading +import sys +from collections import defaultdict +from collections.abc import Callable, Coroutine +from typing import Any try: import zmq - from tornado.ioloop import IOLoop + from zmq.asyncio import Context, Poller, Socket except ImportError as e: raise ImportError( "Attempted to import zmq runner, but zmq deps are not installed. install with `noob[zmq]`", ) from e +if sys.version_info < (3, 12): + from typing_extensions import TypedDict +else: + from typing import TypedDict + +from noob.logging import init_logger +from noob.network.message import Message +from noob.utils import iscoroutinefunction_partial + + +class _CallbackDict(TypedDict): + sync: list[Callable[[Message], Any]] + asyncio: list[Callable[[Message], Coroutine]] + class EventloopMixin: """ - Provide an eventloop in a separate thread to an inheriting class. - Any eventloop that is running in the current context is not used - because the inheriting classes are presumed to operate mostly synchronously for now, - pending a refactor to all async networking classes. + Mixin to provide common asyncio zmq scaffolding to networked classes. + + Inheriting classes should, in order + + * call the ``_init_loop`` method to create the eventloop, context, and poller + * populate the private ``_sockets`` and ``_receivers`` dicts + * await the ``_poll_sockets`` method, which polls indefinitely. + + Inheriting classes **must** ensure that ``_init_loop`` + is called in the thread it is intended to run in, + and that thread must already have a running eventloop. + asyncio eventloops (and most of asyncio) are **not** thread safe. + + To help avoid cross-threading issues, the :meth:`.context` , :meth:`.poller` , :meth:`.loop` + properties do *not* automatically create the objects, + raising a :class:`.RuntimeError` if they are accessed before ``_init_loop`` is called. """ def __init__(self): self._context = None self._loop = None - self._quitting = asyncio.Event() - self._thread: threading.Thread | None = None + self._poller = None + self._quitting = None + self._sockets: dict[str, Socket] = {} + """ + All sockets, mapped from some common name to the socket. + The same key used here should be shared between _receivers and _callbacks + """ + self._receivers: dict[str, Socket] = {} + """Sockets that should be polled for incoming messages""" + self._callbacks: dict[str, _CallbackDict] = defaultdict( + lambda: _CallbackDict(sync=[], asyncio=[]) + ) + """Callbacks for each receiver socket""" + if not hasattr(self, "logger") or self.logger is None: + self.logger = init_logger("eventloop") @property - def context(self) -> zmq.Context: + def context(self) -> Context: if self._context is None: - self._context = zmq.Context.instance() + raise RuntimeError("Loop has not been initialized with _init_loop!") return self._context @property - def loop(self) -> IOLoop: - # To ensure that the loop is always created in the spawned thread, - # we don't create it here (since this property could be accessed elsewhere) - # and throw to protect that. + def poller(self) -> Poller: + if self._poller is None: + raise RuntimeError("Loop has not been initialized with _init_loop!") + return self._poller + + @property + def loop(self) -> asyncio.AbstractEventLoop: if self._loop is None: - raise RuntimeError("Loop is not running") + raise RuntimeError("Loop has not been initialized with _init_loop!") return self._loop - def start_loop(self) -> None: - if self._thread is not None: - raise RuntimeWarning("Node already started") - - self._quitting.clear() - - _ready = threading.Event() - - def _signal_ready() -> None: - _ready.set() - - def _run() -> None: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - self._loop = IOLoop.current() - if hasattr(self, "logger"): - self.logger.debug("Starting eventloop") - while not self._quitting.is_set(): - try: - self.loop.add_callback(_signal_ready) - self.loop.start() - - except RuntimeError: - # loop already started - if hasattr(self, "logger"): - self.logger.debug("Eventloop already started, quitting") - break - if hasattr(self, "logger"): - self.logger.debug("Eventloop stopped") - self._thread = None - - self._thread = threading.Thread(target=_run) - self._thread.start() - # wait until the loop has started - _ready.wait(5) - if hasattr(self, "logger"): - self.logger.debug("Event loop started") - - def stop_loop(self) -> None: - if self._thread is None: + @property + def sockets(self) -> dict[str, Socket]: + return self._sockets + + def register_socket(self, name: str, socket: Socket, receiver: bool = False) -> None: + """Register a socket, optionally declaring it as a receiver socket to poll""" + if name in self._sockets: + raise KeyError(f"Socket {name} already declared!") + self._sockets[name] = socket + if receiver: + self._receivers[name] = socket + self.poller.register(socket, zmq.POLLIN) + + def add_callback( + self, socket: str, callback: Callable[[Message], Any] | Callable[[Message], Coroutine] + ) -> None: + """ + Add a callback to be called when the socket receives a message. + Callbacks are called in the order in which they are added. + """ + if socket not in self._receivers: + raise KeyError(f"Socket {socket} does not exist or is not a receiving socket") + if iscoroutinefunction_partial(callback): + self._callbacks[socket]["asyncio"].append(callback) + else: + self._callbacks[socket]["sync"].append(callback) + + def clear_callbacks(self) -> None: + self._callbacks = defaultdict(lambda: _CallbackDict(sync=[], asyncio=[])) + + def _init_loop(self) -> None: + self._loop = asyncio.get_running_loop() + self._poller = Poller() + self._context = Context.instance() + self._quitting = asyncio.Event() + + def _stop_loop(self) -> None: + if self._quitting is None: return self._quitting.set() - self.loop.add_callback(self.loop.stop) + + async def _poll_receivers(self) -> None: + while not self._quitting.is_set(): + # timeout to avoid hanging here when quitting + events = await self.poller.poll(1) + events = dict(events) + for name, socket in self._receivers.items(): + if socket in events: + msg_bytes = await socket.recv_multipart() + try: + msg = Message.from_bytes(msg_bytes) + except Exception as e: + self.logger.exception( + "Exception decoding message for socket %s: %s, %s", name, msg_bytes, e + ) + continue + for acb in self._callbacks[name]["asyncio"]: + await acb(msg) + for cb in self._callbacks[name]["sync"]: + self.loop.run_in_executor(None, cb, msg) diff --git a/src/noob/runner/zmq.py b/src/noob/runner/zmq.py index 2291511e..06248a1a 100644 --- a/src/noob/runner/zmq.py +++ b/src/noob/runner/zmq.py @@ -24,15 +24,13 @@ import asyncio import concurrent.futures -import contextlib import math import multiprocessing as mp import os import signal import threading import traceback -from collections import defaultdict -from collections.abc import AsyncGenerator, Callable, Generator +from collections.abc import AsyncGenerator, Generator from dataclasses import dataclass, field from datetime import datetime from functools import partial @@ -40,7 +38,7 @@ from multiprocessing.synchronize import Event as EventType from time import time from types import FrameType -from typing import TYPE_CHECKING, Any, Literal, cast, overload +from typing import TYPE_CHECKING, Any, cast, overload from uuid import uuid4 try: @@ -51,13 +49,13 @@ ) from e from zmq.asyncio import Context, Poller, Socket -from zmq.eventloop.zmqstream import ZMQStream from noob.config import config from noob.event import Event, MetaEvent, MetaEventType, MetaSignal from noob.exceptions import InputMissingError from noob.input import InputCollection, InputScope from noob.logging import init_logger +from noob.network.loop import EventloopMixin from noob.network.message import ( AnnounceMsg, AnnounceValue, @@ -87,7 +85,7 @@ pass -class CommandNode: +class CommandNode(EventloopMixin): """ Pub node that controls the state of the other nodes/announces addresses @@ -108,43 +106,19 @@ def __init__(self, runner_id: str, protocol: str = "ipc", port: int | None = Non protocol: port: """ - super().__init__() + self.runner_id = runner_id self.port = port self.protocol = protocol self.logger = init_logger(f"runner.node.{runner_id}.command") - self._context: Context = None # type: ignore[assignment] - self._poller: Poller = None # type: ignore[assignment] - self._outbox: Socket = None # type: ignore[assignment] - self._inbox: Socket = None # type: ignore[assignment] - self._router: Socket = None # type: ignore[assignment] self._nodes: dict[str, IdentifyValue] = {} - self._ready_condition = threading.Condition() - self._callbacks: dict[str, list[Callable[[Message], Any]]] = defaultdict(list) - self._ready_future: concurrent.futures.Future | None = None + self._ready_condition: threading.Condition | None = None self._waiting_for: set[str] = set() - self._quitting: asyncio.Event = None # type: ignore[assignment] + self._waiting: threading.Event = threading.Event() self._tasks = set() self._init = threading.Event() - self._loop = None - - @property - def context(self) -> Context: - if self._context is None: - self._context = Context.instance() - return self._context - - @property - def poller(self) -> Poller: - if self._poller is None: - self._poller = Poller() - return self._poller - - @property - def loop(self) -> asyncio.AbstractEventLoop: - if self._loop is None: - self._loop = asyncio.get_running_loop() - return self._loop + self._waiting.set() + super().__init__() @property def pub_address(self) -> str: @@ -173,34 +147,16 @@ def run(self) -> None: asyncio.run(self._run()) async def _run(self) -> None: - self._init.clear() - self._loop = asyncio.get_running_loop() - self._quitting = asyncio.Event() self.init() - self._init.set() - with contextlib.suppress(asyncio.CancelledError): - await asyncio.gather(self._poll_input(), self._await_quit()) - - async def _await_quit(self) -> None: - self._quitting.clear() - await self._quitting.wait() - self.logger.debug("QUITTING!!!!") - raise asyncio.CancelledError() - - async def _poll_input(self) -> None: - while not self._quitting.is_set(): - events = await self._poller.poll() - events = dict(events) - if self._inbox in events: - msg = await self._inbox.recv_multipart() - await self.on_inbox(msg) - if self._router in events: - msg = await self._router.recv_multipart() - await self.on_router(msg) + await self._poll_receivers() def init(self) -> None: self.logger.debug("Starting command runner") + self._init.clear() + self._init_loop() + self._ready_condition = threading.Condition() self._init_sockets() + self._init.set() self.logger.debug("Command runner started") def deinit(self) -> None: @@ -209,7 +165,7 @@ def deinit(self) -> None: async def _deinit() -> None: msg = DeinitMsg(node_id="command") - await self._outbox.send_multipart([b"deinit", msg.to_bytes()]) + await self.sockets["outbox"].send_multipart([b"deinit", msg.to_bytes()]) self._quitting.set() self.loop.create_task(_deinit()) @@ -218,55 +174,57 @@ async def _deinit() -> None: def stop(self) -> None: self.logger.debug("Stopping command runner") msg = StopMsg(node_id="command") - self.loop.call_soon_threadsafe(self._outbox.send_multipart, [b"stop", msg.to_bytes()]) + self.loop.call_soon_threadsafe( + self.sockets["outbox"].send_multipart, [b"stop", msg.to_bytes()] + ) self.logger.debug("Command runner stopped") def _init_sockets(self) -> None: - self._outbox = self._init_outbox() - self._router = self._init_router() - self._inbox = self._init_inbox() + self._init_outbox() + self._init_router() + self._init_inbox() - def _init_outbox(self) -> zmq.Socket: + def _init_outbox(self) -> None: """Create the main control publisher""" pub = self.context.socket(zmq.PUB) pub.bind(self.pub_address) pub.setsockopt_string(zmq.IDENTITY, "command.outbox") - return pub + self.register_socket("outbox", pub) - def _init_router(self) -> ZMQStream: + def _init_router(self) -> None: """Create the inbox router""" router = self.context.socket(zmq.ROUTER) router.bind(self.router_address) router.setsockopt_string(zmq.IDENTITY, "command.router") - self.poller.register(router, zmq.POLLIN) + self.register_socket("router", router, receiver=True) + self.add_callback("router", self.on_router) self.logger.debug("Router bound to %s", self.router_address) - return router - def _init_inbox(self) -> ZMQStream: + def _init_inbox(self) -> None: """Subscriber that receives all events from running nodes""" sub = self.context.socket(zmq.SUB) sub.setsockopt_string(zmq.IDENTITY, "command.inbox") sub.setsockopt_string(zmq.SUBSCRIBE, "") - self.poller.register(sub, zmq.POLLIN) - return sub + self.register_socket("inbox", sub, receiver=True) async def announce(self) -> None: msg = AnnounceMsg( node_id="command", value=AnnounceValue(inbox=self.router_address, nodes=self._nodes) ) - await self._outbox.send_multipart([b"announce", msg.to_bytes()]) + await self.sockets["outbox"].send_multipart([b"announce", msg.to_bytes()]) async def ping(self) -> None: """Send a ping message asking everyone to identify themselves""" msg = PingMsg(node_id="command") - await self._outbox.send_multipart([b"ping", msg.to_bytes()]) + await self.sockets["outbox"].send_multipart([b"ping", msg.to_bytes()]) def start(self, n: int | None = None) -> None: """ Start running in free-run mode """ self.loop.call_soon_threadsafe( - self._outbox.send_multipart, [b"start", StartMsg(node_id="command", value=n).to_bytes()] + self.sockets["outbox"].send_multipart, + [b"start", StartMsg(node_id="command", value=n).to_bytes()], ) self.logger.debug("Sent start message") @@ -275,7 +233,7 @@ def process(self, epoch: int, input: dict | None = None) -> None: # no empty dicts input = input if input else None self.loop.call_soon_threadsafe( - self._outbox.send_multipart, + self.sockets["outbox"].send_multipart, [ b"process", ProcessMsg(node_id="command", value={"input": input, "epoch": epoch}).to_bytes(), @@ -283,63 +241,45 @@ def process(self, epoch: int, input: dict | None = None) -> None: ) self.logger.debug("Sent process message") - def add_callback(self, type_: Literal["inbox", "router"], cb: Callable[[Message], Any]) -> None: - """ - Add a callback called for message received - - by the inbox: the subscriber that receives all events from node runners - - by the router: direct messages sent by node runners to the command node - """ - self._callbacks[type_].append(cb) - - def clear_callbacks(self) -> None: - self._callbacks = defaultdict(list) - - def await_ready(self, node_ids: list[NodeID], timeout: float = 10) -> concurrent.futures.Future: + def await_ready(self, node_ids: list[NodeID], timeout: float = 10) -> None: """ Wait until all the node_ids have announced themselves """ - future = concurrent.futures.Future() - self._waiting_for = set(node_ids) - self._ready_future = future - wait_until = time() + timeout - - async def _ping() -> None: - await asyncio.sleep(1) - if self._ready_future is None: - return - if wait_until < time(): - raise TimeoutError("Nodes were not ready after the timeout. ") - await self.ping() - await _ping() - - self.loop.create_task(_ping()) - return future - - def _check_ready(self) -> None: - if self._ready_future is None: - return - ready_nodes = { - node_id for node_id, state in self._nodes.items() if state["status"] == "ready" - } - self.logger.debug( - "Checking if ready, ready nodes are: %s, waiting for %s", - ready_nodes, - self._waiting_for, - ) - if self._ready_future is not None and self._waiting_for.issubset(ready_nodes): - self._ready_future.set_result(True) - self._ready_future = None - - async def on_router(self, msg: list[bytes]) -> None: - try: - message = Message.from_bytes(msg) - self.logger.debug("Received ROUTER message %s", message) - except Exception as e: - self.logger.exception("Exception decoding: %s, %s", msg, e) - raise e + self._waiting.clear() + + def _ready_nodes() -> set[str]: + return {node_id for node_id, state in self._nodes.items() if state["status"] == "ready"} + + def _is_ready() -> bool: + ready_nodes = _ready_nodes() + waiting_for = set(node_ids) + self.logger.debug( + "Checking if ready, ready nodes are: %s, waiting for %s", + ready_nodes, + waiting_for, + ) + return waiting_for.issubset(ready_nodes) + + with self._ready_condition: + # ping periodically for identifications in case we have slow subscribers + start_time = time() + ready = False + while time() < start_time + timeout and not ready: + ready = self._ready_condition.wait_for(_is_ready, timeout=1) + if not ready: + self.loop.call_soon_threadsafe(self.loop.create_task, self.ping()) + + # if still not ready, timeout + if not ready: + raise TimeoutError( + f"Nodes were not ready after the timeout. " + f"Waiting for: {set(node_ids)}, " + f"ready: {_ready_nodes()}" + ) + self._waiting.set() - for cb in self._callbacks["router"]: - cb(message) + async def on_router(self, message: Message) -> None: + self.logger.debug("Received ROUTER message %s", message) if message.type_ == MessageType.identify: message = cast(IdentifyMsg, message) @@ -348,18 +288,9 @@ async def on_router(self, msg: list[bytes]) -> None: message = cast(StatusMsg, message) await self.on_status(message) - async def on_inbox(self, msg: list[bytes]) -> None: - message = Message.from_bytes(msg) - self.logger.debug("Received INBOX message: %s", message) - for cb in self._callbacks["inbox"]: - if iscoroutinefunction_partial(cb): - await cb(message) - else: - self.loop.run_in_executor(None, cb, message) - async def on_identify(self, msg: IdentifyMsg) -> None: self._nodes[msg.node_id] = msg.value - self._inbox.connect(msg.value["outbox"]) + self.sockets["inbox"].connect(msg.value["outbox"]) try: await self.announce() @@ -367,7 +298,9 @@ async def on_identify(self, msg: IdentifyMsg) -> None: except Exception as e: self.logger.exception("Exception announced: %s", e) - self._check_ready() + if not self._waiting.is_set(): + with self._ready_condition: + self._ready_condition.notify_all() async def on_status(self, msg: StatusMsg) -> None: if msg.node_id not in self._nodes: @@ -377,7 +310,9 @@ async def on_status(self, msg: StatusMsg) -> None: ) return self._nodes[msg.node_id]["status"] = msg.value - self._check_ready() + if not self._waiting.is_set(): + with self._ready_condition: + self._ready_condition.notify_all() class NodeRunner: @@ -919,10 +854,10 @@ def init(self) -> None: with self._init_lock: self._logger.debug("Initializing ZMQ runner") self.command = CommandNode(runner_id=self.runner_id) - self.command.add_callback("inbox", self.on_event) - self.command.add_callback("router", self.on_router) threading.Thread(target=self.command.run, daemon=True).start() self.command._init.wait() + self.command.add_callback("inbox", self.on_event) + self.command.add_callback("router", self.on_router) self._logger.debug("Command node initialized") for node_id, node in self.tube.nodes.items(): @@ -944,10 +879,9 @@ def init(self) -> None: self.node_procs[node_id].start() self._logger.debug("Started node processes, awaiting ready") try: - future = self.command.await_ready( + self.command.await_ready( [k for k, v in self.tube.nodes.items() if not isinstance(v, Return)] ) - future.result() except TimeoutError as e: self._logger.debug("Timeouterror, deinitializing before throwing") self._initialized.set() From f66f37726509f61376aa4d86c290d39b8edad8fb Mon Sep 17 00:00:00 2001 From: sneakers-the-rat Date: Wed, 4 Feb 2026 15:45:04 -0800 Subject: [PATCH 07/13] refactor zmqrunner to eventloopmixin --- src/noob/runner/zmq.py | 65 +++++++++++++++--------------------------- 1 file changed, 23 insertions(+), 42 deletions(-) diff --git a/src/noob/runner/zmq.py b/src/noob/runner/zmq.py index 06248a1a..06266302 100644 --- a/src/noob/runner/zmq.py +++ b/src/noob/runner/zmq.py @@ -48,7 +48,6 @@ "Attempted to import zmq runner, but zmq deps are not installed. install with `noob[zmq]`", ) from e -from zmq.asyncio import Context, Poller, Socket from noob.config import config from noob.event import Event, MetaEvent, MetaEventType, MetaSignal @@ -315,7 +314,7 @@ async def on_status(self, msg: StatusMsg) -> None: self._ready_condition.notify_all() -class NodeRunner: +class NodeRunner(EventloopMixin): """ Runner for a single node @@ -333,7 +332,6 @@ def __init__( input_collection: InputCollection, protocol: str = "ipc", ): - self.context = Context.instance() self.spec = spec self.runner_id = runner_id self.input_collection = input_collection @@ -344,10 +342,6 @@ def __init__( self.scheduler: Scheduler = None # type: ignore[assignment] self.logger = init_logger(f"runner.node.{runner_id}.{self.spec.id}") - self._dealer: Socket = None # type: ignore[assignment] - self._outbox: Socket = None # type: ignore[assignment] - self._inbox: Socket = None # type: ignore[assignment] - self._inbox_poller: Poller = Poller() self._node: Node | None = None self._depends: tuple[tuple[str, str], ...] | None = None self._has_input: bool | None = None @@ -360,6 +354,7 @@ def __init__( self._status_lock = asyncio.Lock() self._ready_condition = asyncio.Condition() self._to_process = 0 + super().__init__() @property def outbox_address(self) -> str: @@ -414,7 +409,6 @@ async def _run_inner() -> None: asyncio.run(_run_inner()) async def _run(self) -> None: - try: def _handler(sig: int, frame: FrameType | None = None) -> None: @@ -427,7 +421,7 @@ def _handler(sig: int, frame: FrameType | None = None) -> None: self._process_quitting.clear() self._freerun.clear() self._process_one.clear() - await asyncio.gather(self._poll_inbox(), self._loop()) + await asyncio.gather(self._poll_receivers(), self._process_loop()) except KeyboardInterrupt: self.logger.debug("Got keyboard interrupt, quitting") except Exception as e: @@ -435,7 +429,7 @@ def _handler(sig: int, frame: FrameType | None = None) -> None: finally: await self.deinit() - async def _loop(self) -> None: + async def _process_loop(self) -> None: is_async = iscoroutinefunction_partial(self._node.process) loop = asyncio.get_running_loop() with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor: @@ -460,13 +454,6 @@ async def _loop(self) -> None: await self.update_graph(events) await self.publish_events(events) - async def _poll_inbox(self) -> None: - while not self._process_quitting.is_set(): - events = await self._inbox_poller.poll() - if self._inbox in dict(events): - msg = await self._inbox.recv_multipart() - await self.on_inbox(msg) - async def await_inputs(self) -> AsyncGenerator[tuple[tuple[Any], dict[str, Any], int]]: self._node = cast(Node, self._node) while not self._process_quitting.is_set(): @@ -500,13 +487,14 @@ async def update_graph(self, events: list[Event]) -> None: async def publish_events(self, events: list[Event]) -> None: msg = EventMsg(node_id=self.spec.id, value=events) - await self._outbox.send_multipart([b"event", msg.to_bytes()]) + await self.sockets["outbox"].send_multipart([b"event", msg.to_bytes()]) async def init(self) -> None: self.logger.debug("Initializing") await self.init_node() self._init_sockets() + self._quitting.clear() self.status = ( NodeStatus.waiting if self.depends and [d for d in self.depends if d[0] != "input"] @@ -520,6 +508,7 @@ async def deinit(self) -> None: if self._node is not None: self._node.deinit() await self.update_status(NodeStatus.closed) + self._quitting.set() self.logger.debug("Deinitialization finished") @@ -547,7 +536,7 @@ async def identify(self) -> None: ), ), ) - await self._dealer.send_multipart([ann.to_bytes()]) + await self.sockets["dealer"].send_multipart([ann.to_bytes()]) self.logger.debug("Sent identification message: %s", ann) async def update_status(self, status: NodeStatus) -> None: @@ -556,7 +545,7 @@ async def update_status(self, status: NodeStatus) -> None: async with self._status_lock: self.status = status msg = StatusMsg(node_id=self.spec.id, value=status) - await self._dealer.send_multipart([msg.to_bytes()]) + await self.sockets["dealer"].send_multipart([msg.to_bytes()]) self.logger.debug("Updated status") async def init_node(self) -> None: @@ -568,18 +557,19 @@ async def init_node(self) -> None: self._ready_condition.notify_all() def _init_sockets(self) -> None: - self._dealer = self._init_dealer() - self._outbox = self._init_outbox() - self._inbox = self._init_inbox() + self._init_loop() + self._init_dealer() + self._init_outbox() + self._init_inbox() - def _init_dealer(self) -> Socket: + def _init_dealer(self) -> None: dealer = self.context.socket(zmq.DEALER) dealer.setsockopt_string(zmq.IDENTITY, self.spec.id) dealer.connect(self.command_router) + self.register_socket("dealer", dealer) self.logger.debug("Connected to command node at %s", self.command_router) - return dealer - def _init_outbox(self) -> Socket: + def _init_outbox(self) -> None: pub = self.context.socket(zmq.PUB) pub.setsockopt_string(zmq.IDENTITY, self.spec.id) if self.protocol == "ipc": @@ -588,10 +578,9 @@ def _init_outbox(self) -> Socket: raise NotImplementedError() # something like: # port = pub.bind_to_random_port(self.protocol) + self.register_socket("outbox", pub) - return pub - - def _init_inbox(self) -> Socket: + def _init_inbox(self) -> None: """ Init the subscriber, but don't attempt to subscribe to anything but the command yet! we do that when we get node Announces @@ -600,19 +589,11 @@ def _init_inbox(self) -> Socket: sub.setsockopt_string(zmq.IDENTITY, self.spec.id) sub.setsockopt_string(zmq.SUBSCRIBE, "") sub.connect(self.command_outbox) - self._inbox_poller.register(sub, zmq.POLLIN) + self.register_socket("inbox", sub, receiver=True) + self.add_callback("inbox", self.on_inbox) self.logger.debug("Subscribed to command outbox %s", self.command_outbox) - return sub - - async def on_inbox(self, msg: list[bytes]) -> None: - try: - message = Message.from_bytes(msg) - - self.logger.debug("INBOX received %s", msg) - except Exception as e: - self.logger.exception("Error decoding message %s %s", msg, e) - return + async def on_inbox(self, message: Message) -> None: # FIXME: all this switching sux, # just have a decorator to register a handler for a given message type if message.type_ == MessageType.announce: @@ -655,7 +636,7 @@ async def on_announce(self, msg: AnnounceMsg) -> None: # TODO: a way to check if we're already connected, without storing it locally? outbox = msg.value["nodes"][node_id]["outbox"] self.logger.debug("Subscribing to %s at %s", node_id, outbox) - self._inbox.connect(outbox) + self.sockets["inbox"].connect(outbox) self.logger.debug("Subscribed to %s at %s", node_id, outbox) self._nodes = msg.value["nodes"] if set(self._nodes) >= depended_nodes - {"input"} and self.status == NodeStatus.waiting: @@ -764,7 +745,7 @@ async def error(self, err: Exception) -> None: traceback=tbexception, ), ) - await self._dealer.send_multipart([msg.to_bytes()]) + await self.sockets["dealer"].send_multipart([msg.to_bytes()]) async def await_node(self, epoch: int | None = None) -> MetaEvent: """ From c9524223dffb64287b24d8ed5cfe2c93a7d6e6d1 Mon Sep 17 00:00:00 2001 From: sneakers-the-rat Date: Wed, 4 Feb 2026 15:55:00 -0800 Subject: [PATCH 08/13] only use one event to signal quitting, clearer signposting about on_deinit vs deinit --- src/noob/runner/zmq.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/src/noob/runner/zmq.py b/src/noob/runner/zmq.py index 06266302..9398be12 100644 --- a/src/noob/runner/zmq.py +++ b/src/noob/runner/zmq.py @@ -347,7 +347,6 @@ def __init__( self._has_input: bool | None = None self._nodes: dict[str, IdentifyValue] = {} self._counter = count() - self._process_quitting = asyncio.Event() self._freerun = asyncio.Event() self._process_one = asyncio.Event() self._status: NodeStatus = NodeStatus.stopped @@ -418,7 +417,6 @@ def _handler(sig: int, frame: FrameType | None = None) -> None: signal.signal(signal.SIGTERM, _handler) await self.init() self._node = cast(Node, self._node) - self._process_quitting.clear() self._freerun.clear() self._process_one.clear() await asyncio.gather(self._poll_receivers(), self._process_loop()) @@ -456,7 +454,7 @@ async def _process_loop(self) -> None: async def await_inputs(self) -> AsyncGenerator[tuple[tuple[Any], dict[str, Any], int]]: self._node = cast(Node, self._node) - while not self._process_quitting.is_set(): + while not self._quitting.is_set(): # if we are not freerunning, keep track of how many times we are supposed to run, # and run until we aren't supposed to anymore! if not self._freerun.is_set(): @@ -491,7 +489,6 @@ async def publish_events(self, events: list[Event]) -> None: async def init(self) -> None: self.logger.debug("Initializing") - await self.init_node() self._init_sockets() self._quitting.clear() @@ -504,10 +501,15 @@ async def init(self) -> None: self.logger.debug("Initialization finished") async def deinit(self) -> None: + """ + Deinitialize the node class after receiving on_deinit message and + draining out the end of the _process_loop. + """ self.logger.debug("Deinitializing") if self._node is not None: self._node.deinit() - await self.update_status(NodeStatus.closed) + + # should have already been called in on_deinit, but just to make sure we're killed dead... self._quitting.set() self.logger.debug("Deinitialization finished") @@ -722,7 +724,9 @@ async def on_deinit(self, msg: DeinitMsg) -> None: Cause the main loop to end, which calls deinit """ - self._process_quitting.set() + await self.update_status(NodeStatus.closed) + self._quitting.set() + pid = mp.current_process().pid if pid is None: return From 74cf737d544a21781393ae6ef9f5e428f7c81e90 Mon Sep 17 00:00:00 2001 From: sneakers-the-rat Date: Wed, 4 Feb 2026 15:57:17 -0800 Subject: [PATCH 09/13] ensure quitting event is created --- src/noob/runner/zmq.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/noob/runner/zmq.py b/src/noob/runner/zmq.py index 9398be12..6c7a0881 100644 --- a/src/noob/runner/zmq.py +++ b/src/noob/runner/zmq.py @@ -354,6 +354,7 @@ def __init__( self._ready_condition = asyncio.Condition() self._to_process = 0 super().__init__() + self._quitting = asyncio.Event() @property def outbox_address(self) -> str: From 5cafd3a7f307f09506f63d2a9c5f0638ce7754b0 Mon Sep 17 00:00:00 2001 From: sneakers-the-rat Date: Fri, 6 Feb 2026 20:01:18 -0800 Subject: [PATCH 10/13] manual polling is faster i guess --- src/noob/network/loop.py | 58 ++++++++++++++++++++-------------------- tests/bench.py | 3 --- 2 files changed, 29 insertions(+), 32 deletions(-) diff --git a/src/noob/network/loop.py b/src/noob/network/loop.py index cfa9afd5..a212f792 100644 --- a/src/noob/network/loop.py +++ b/src/noob/network/loop.py @@ -5,8 +5,7 @@ from typing import Any try: - import zmq - from zmq.asyncio import Context, Poller, Socket + from zmq.asyncio import Context, Socket except ImportError as e: raise ImportError( "Attempted to import zmq runner, but zmq deps are not installed. install with `noob[zmq]`", @@ -42,7 +41,7 @@ class EventloopMixin: and that thread must already have a running eventloop. asyncio eventloops (and most of asyncio) are **not** thread safe. - To help avoid cross-threading issues, the :meth:`.context` , :meth:`.poller` , :meth:`.loop` + To help avoid cross-threading issues, the :meth:`.context` and :meth:`.loop` properties do *not* automatically create the objects, raising a :class:`.RuntimeError` if they are accessed before ``_init_loop`` is called. """ @@ -50,7 +49,6 @@ class EventloopMixin: def __init__(self): self._context = None self._loop = None - self._poller = None self._quitting = None self._sockets: dict[str, Socket] = {} """ @@ -72,12 +70,6 @@ def context(self) -> Context: raise RuntimeError("Loop has not been initialized with _init_loop!") return self._context - @property - def poller(self) -> Poller: - if self._poller is None: - raise RuntimeError("Loop has not been initialized with _init_loop!") - return self._poller - @property def loop(self) -> asyncio.AbstractEventLoop: if self._loop is None: @@ -95,7 +87,6 @@ def register_socket(self, name: str, socket: Socket, receiver: bool = False) -> self._sockets[name] = socket if receiver: self._receivers[name] = socket - self.poller.register(socket, zmq.POLLIN) def add_callback( self, socket: str, callback: Callable[[Message], Any] | Callable[[Message], Coroutine] @@ -116,7 +107,6 @@ def clear_callbacks(self) -> None: def _init_loop(self) -> None: self._loop = asyncio.get_running_loop() - self._poller = Poller() self._context = Context.instance() self._quitting = asyncio.Event() @@ -126,21 +116,31 @@ def _stop_loop(self) -> None: self._quitting.set() async def _poll_receivers(self) -> None: + """ + Rather than using the zmq.asyncio.Poller which wastes a ton of time, + it turns out doing it this way is roughly 4x as fast: + just manually poll the sockets, and if you have multiple sockets, + gather multiple coroutines where you're polling the sockets. + """ + if len(self._receivers) == 1: + await self._poll_receiver(next(iter(self._receivers.keys()))) + else: + await asyncio.gather(*[self._poll_receiver(name) for name in self._receivers]) + + async def _poll_receiver(self, name: str) -> None: + socket = self._receivers[name] while not self._quitting.is_set(): - # timeout to avoid hanging here when quitting - events = await self.poller.poll(1) - events = dict(events) - for name, socket in self._receivers.items(): - if socket in events: - msg_bytes = await socket.recv_multipart() - try: - msg = Message.from_bytes(msg_bytes) - except Exception as e: - self.logger.exception( - "Exception decoding message for socket %s: %s, %s", name, msg_bytes, e - ) - continue - for acb in self._callbacks[name]["asyncio"]: - await acb(msg) - for cb in self._callbacks[name]["sync"]: - self.loop.run_in_executor(None, cb, msg) + msg_bytes = await socket.recv_multipart() + try: + msg = Message.from_bytes(msg_bytes) + except Exception as e: + self.logger.exception( + "Exception decoding message for socket %s: %s, %s", name, msg_bytes, e + ) + continue + + # purposely don't catch errors here because we want them to bubble up into the caller + for acb in self._callbacks[name]["asyncio"]: + await acb(msg) + for cb in self._callbacks[name]["sync"]: + self.loop.run_in_executor(None, cb, msg) diff --git a/tests/bench.py b/tests/bench.py index 8e1097a7..57bc359d 100644 --- a/tests/bench.py +++ b/tests/bench.py @@ -3,7 +3,6 @@ from noob import Tube from noob.runner.base import TubeRunner -from noob.runner.zmq import ZMQRunner def test_load_tube(benchmark: BenchmarkFixture) -> None: @@ -17,8 +16,6 @@ def test_kitchen_sink_process(benchmark: BenchmarkFixture, runner: TubeRunner) - @pytest.mark.parametrize("loaded_tube", ["testing-kitchen-sink"], indirect=True) def test_kitchen_sink_run(benchmark: BenchmarkFixture, runner: TubeRunner) -> None: - if isinstance(runner, ZMQRunner): - pytest.skip("ZMQ runner freerun mode not supported yet") benchmark(lambda: runner.run(n=10)) From dcbd37813c8c8e2c84d77ada8c498dc54cd02a5a Mon Sep 17 00:00:00 2001 From: sneakers-the-rat Date: Fri, 6 Feb 2026 21:33:12 -0800 Subject: [PATCH 11/13] mypy --- src/noob/network/loop.py | 4 ++-- src/noob/runner/zmq.py | 11 ++++++----- src/noob/scheduler.py | 24 ------------------------ 3 files changed, 8 insertions(+), 31 deletions(-) diff --git a/src/noob/network/loop.py b/src/noob/network/loop.py index a212f792..efa49ebf 100644 --- a/src/noob/network/loop.py +++ b/src/noob/network/loop.py @@ -49,7 +49,7 @@ class EventloopMixin: def __init__(self): self._context = None self._loop = None - self._quitting = None + self._quitting: asyncio.Event = None # type: ignore[assignment] self._sockets: dict[str, Socket] = {} """ All sockets, mapped from some common name to the socket. @@ -61,7 +61,7 @@ def __init__(self): lambda: _CallbackDict(sync=[], asyncio=[]) ) """Callbacks for each receiver socket""" - if not hasattr(self, "logger") or self.logger is None: + if not hasattr(self, "logger"): self.logger = init_logger("eventloop") @property diff --git a/src/noob/runner/zmq.py b/src/noob/runner/zmq.py index 6c7a0881..2034d60f 100644 --- a/src/noob/runner/zmq.py +++ b/src/noob/runner/zmq.py @@ -30,7 +30,7 @@ import signal import threading import traceback -from collections.abc import AsyncGenerator, Generator +from collections.abc import AsyncGenerator, Generator, MutableSequence from dataclasses import dataclass, field from datetime import datetime from functools import partial @@ -111,10 +111,9 @@ def __init__(self, runner_id: str, protocol: str = "ipc", port: int | None = Non self.protocol = protocol self.logger = init_logger(f"runner.node.{runner_id}.command") self._nodes: dict[str, IdentifyValue] = {} - self._ready_condition: threading.Condition | None = None + self._ready_condition: threading.Condition = None # type: ignore[assignment] self._waiting_for: set[str] = set() self._waiting: threading.Event = threading.Event() - self._tasks = set() self._init = threading.Event() self._waiting.set() super().__init__() @@ -429,6 +428,7 @@ def _handler(sig: int, frame: FrameType | None = None) -> None: await self.deinit() async def _process_loop(self) -> None: + self._node = cast(Node, self._node) is_async = iscoroutinefunction_partial(self._node.process) loop = asyncio.get_running_loop() with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor: @@ -438,7 +438,7 @@ async def _process_loop(self) -> None: ) if is_async: # mypy fails here because it can't propagate the type guard above - value = await self._node.process(*args, **kwargs) # type: ignore[arg-type] + value = await self._node.process(*args, **kwargs) # type: ignore[misc] else: part = partial(self._node.process, *args, **kwargs) value = await loop.run_in_executor(executor, part) @@ -1063,7 +1063,7 @@ def run(self, n: int | None = None) -> None | list[ReturnNodeType]: # run n epochs self.command.start(n) self._running.set() - self._current_epoch = self.await_epoch(self._current_epoch + n).result() + self._current_epoch = self.await_epoch(self._current_epoch + n) return None def stop(self) -> None: @@ -1087,6 +1087,7 @@ def on_event(self, msg: Message) -> None: for event in msg.value: self.store.add(event) events = self.tube.scheduler.update(msg.value) + events = cast(MutableSequence[Event | MetaEvent], events) if self._return_node is not None: # mark the return node done if we've received the expected events for an epoch # do it here since we don't really run the return node like a real node diff --git a/src/noob/scheduler.py b/src/noob/scheduler.py index 3a2c2268..19aab843 100644 --- a/src/noob/scheduler.py +++ b/src/noob/scheduler.py @@ -32,7 +32,6 @@ class Scheduler(BaseModel): _clock: count = PrivateAttr(default_factory=count) _epochs: dict[int, TopoSorter] = PrivateAttr(default_factory=dict) - # _epoch_condition: Condition = PrivateAttr(default_factory=Condition) _epoch_log: deque = PrivateAttr(default_factory=lambda: deque(maxlen=100)) model_config = ConfigDict(arbitrary_types_allowed=True) @@ -227,29 +226,6 @@ def expire(self, epoch: int, node_id: str) -> MetaEvent | None: return None - def await_epoch(self, epoch: int | None = None) -> int: - """ - Block until an epoch is completed. - - Args: - epoch (int, None): if `int` , wait until the epoch is ready, - otherwise wait until the next epoch is finished, in whatever order. - - Returns: - int: the epoch that was completed. - """ - - # check if we have already completed this epoch - if isinstance(epoch, int) and self.epoch_completed(epoch): - return epoch - - if epoch is None: - self._epoch_condition.wait() - return self._epoch_log[-1] - else: - self._epoch_condition.wait_for(lambda: self.epoch_completed(epoch)) - return epoch - def epoch_completed(self, epoch: int) -> bool: """ Check if the epoch has been completed. From abaffc7be903b523386843eb211b4c89060e6cd6 Mon Sep 17 00:00:00 2001 From: sneakers-the-rat Date: Fri, 6 Feb 2026 21:48:05 -0800 Subject: [PATCH 12/13] cleanup --- src/noob/runner/zmq.py | 32 ++++------------------ src/noob/scheduler.py | 61 ------------------------------------------ 2 files changed, 5 insertions(+), 88 deletions(-) diff --git a/src/noob/runner/zmq.py b/src/noob/runner/zmq.py index 2034d60f..c05dfd93 100644 --- a/src/noob/runner/zmq.py +++ b/src/noob/runner/zmq.py @@ -112,10 +112,7 @@ def __init__(self, runner_id: str, protocol: str = "ipc", port: int | None = Non self.logger = init_logger(f"runner.node.{runner_id}.command") self._nodes: dict[str, IdentifyValue] = {} self._ready_condition: threading.Condition = None # type: ignore[assignment] - self._waiting_for: set[str] = set() - self._waiting: threading.Event = threading.Event() self._init = threading.Event() - self._waiting.set() super().__init__() @property @@ -296,9 +293,8 @@ async def on_identify(self, msg: IdentifyMsg) -> None: except Exception as e: self.logger.exception("Exception announced: %s", e) - if not self._waiting.is_set(): - with self._ready_condition: - self._ready_condition.notify_all() + with self._ready_condition: + self._ready_condition.notify_all() async def on_status(self, msg: StatusMsg) -> None: if msg.node_id not in self._nodes: @@ -308,9 +304,9 @@ async def on_status(self, msg: StatusMsg) -> None: ) return self._nodes[msg.node_id]["status"] = msg.value - if not self._waiting.is_set(): - with self._ready_condition: - self._ready_condition.notify_all() + + with self._ready_condition: + self._ready_condition.notify_all() class NodeRunner(EventloopMixin): @@ -1150,24 +1146,6 @@ def _handle_error(self, msg: ErrorMsg) -> None: # e.g. errors during init, raise here. raise exception - # - # def _throw_error(self, e) -> None: - # errval = self._to_throw - # if errval is None: - # return - # # clear instance object and store locally, we aren't locked here. - # self._to_throw = None - # self._logger.debug( - # "Deinitializing before throwing error", - # ) - # self.deinit() - # - # # add the traceback as a note, - # # sort of the best we can do without using tblib - # err = self._rehydrate_error(errval) - # - # raise err - def _request_more(self, n: int, current_iter: int, n_epochs: int) -> int: """ During iteration with cardinality-reducing nodes, diff --git a/src/noob/scheduler.py b/src/noob/scheduler.py index 19aab843..42e33da1 100644 --- a/src/noob/scheduler.py +++ b/src/noob/scheduler.py @@ -328,64 +328,3 @@ def generations(self) -> list[tuple[str, ...]]: generations.append(ready) sorter.done(*ready) return generations - - -# -# class AsyncScheduler(Scheduler): -# -# async def update( -# self, events: MutableSequence[Event | MetaEvent] | MutableSequence[Event] -# ) -> MutableSequence[Event] | MutableSequence[Event | MetaEvent]: -# async with self._ready_condition: -# self.logger.debug("Inside update lock") -# ret_events = super().update(events) -# self._ready_condition.notify_all() -# return ret_events -# -# # async def done(self, epoch: int, node_id: str) -> MetaEvent | None: -# # async with self._ready_condition: -# # ret = super().done(epoch, node_id) -# # self._ready_condition.notify_all() -# # return ret -# -# async def await_node(self, node_id: NodeID, epoch: int | None = None) -> MetaEvent: -# """ -# Block until a node is ready -# -# Args: -# node_id: -# epoch (int, None): if `int` , wait until the node is ready in the given epoch, -# otherwise wait until the node is ready in any epoch -# -# Returns: -# -# """ -# async with self._ready_condition: -# await self._ready_condition.wait_for(lambda: self.node_is_ready(node_id, epoch)) -# -# # be FIFO-like and get the earliest epoch the node is ready in -# if epoch is None: -# for ep in self._epochs: -# if self.node_is_ready(node_id, ep): -# epoch = ep -# break -# -# if epoch is None: -# raise RuntimeError( -# "Could not find ready epoch even though node ready condition passed, " -# "something is wrong with the way node status checking is " -# "locked between threads." -# ) -# -# # mark just one event as "out." -# # threadsafe because we are holding the lock that protects graph mutation -# self._epochs[epoch].mark_out(node_id) -# -# return MetaEvent( -# id=uuid4().int, -# timestamp=datetime.now(), -# node_id="meta", -# signal=MetaEventType.NodeReady, -# epoch=epoch, -# value=node_id, -# ) From 3ee3578b120c393c358a3c35bf13093b2844ca3a Mon Sep 17 00:00:00 2001 From: sneakers-the-rat Date: Fri, 6 Feb 2026 21:50:51 -0800 Subject: [PATCH 13/13] removing old waiting event --- src/noob/runner/zmq.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/noob/runner/zmq.py b/src/noob/runner/zmq.py index c05dfd93..08380439 100644 --- a/src/noob/runner/zmq.py +++ b/src/noob/runner/zmq.py @@ -240,7 +240,6 @@ def await_ready(self, node_ids: list[NodeID], timeout: float = 10) -> None: """ Wait until all the node_ids have announced themselves """ - self._waiting.clear() def _ready_nodes() -> set[str]: return {node_id for node_id, state in self._nodes.items() if state["status"] == "ready"} @@ -271,7 +270,6 @@ def _is_ready() -> bool: f"Waiting for: {set(node_ids)}, " f"ready: {_ready_nodes()}" ) - self._waiting.set() async def on_router(self, message: Message) -> None: self.logger.debug("Received ROUTER message %s", message)