Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
243 changes: 135 additions & 108 deletions pylock.toml

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ tests = [
"pytest-timeout>=2.4.0",
"pytest-asyncio>=1.3.0",
"pytest-codspeed>=4.2.0",
"pytest-mock>=3.15.1",
]
docs = [
"sphinx>=8.2.3,<9.0.0", # until myst parser catches up and stops emitting invalid references
Expand Down
71 changes: 41 additions & 30 deletions src/noob/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from pathlib import Path
from typing import Any, Literal

from rich import get_console
from rich.console import Console
from rich.logging import RichHandler

from noob.config import LOG_LEVELS, config
Expand All @@ -29,6 +29,13 @@ def init_logger(
Log to a set of rotating files in the ``log_dir`` according to ``name`` ,
as well as using the :class:`~rich.RichHandler` for pretty-formatted stdout logs.

If this method is called from a process that isn't the root process,
it will create new rich and file handlers in the root noob logger to avoid
deadlocks from threading locks that are copied on forked processes.
Since the handlers will be different across processes,
to avoid file access conflicts, logging files will have the process's ``pid``
appended (e.g. ``noob_12345.log`` )

Args:
name (str): Name of this logger. Ideally names are hierarchical
and indicate what they are logging for, eg. ``noob.api.auth``
Expand Down Expand Up @@ -85,30 +92,6 @@ def init_logger(
logger = logging.getLogger(name)
logger.setLevel(min_level)

# if run from a forked process, need to add different handlers to not collide
if mp.parent_process() is not None:
handler_name = f"{name}_{mp.current_process().pid}"
if log_dir is not False and not any([h.name == handler_name for h in logger.handlers]):
logger.addHandler(
_file_handler(
name=f"{name}_{mp.current_process().pid}",
file_level=file_level,
log_dir=log_dir,
log_file_n=log_file_n,
log_file_size=log_file_size,
)
)

if not any(
[
handler_name in h.keywords
for h in logger.handlers
if isinstance(h, RichHandler) and h.keywords is not None
]
):
logger.addHandler(_rich_handler(level, keywords=[handler_name], width=width))
logger.propagate = False

return logger


Expand All @@ -121,6 +104,18 @@ def _init_root(
width: int | None = None,
) -> None:
root_logger = logging.getLogger("noob")

# ensure each root logger has fresh handlers in subprocesses
if mp.parent_process() is not None:
current_pid = mp.current_process().pid
file_name = f"noob_{current_pid}"
rich_name = f"{file_name}_rich"
else:
file_name = "noob"
rich_name = "noob_rich"

root_logger.handlers = [h for h in root_logger.handlers if h.name in (rich_name, file_name)]

file_handlers = [
handler for handler in root_logger.handlers if isinstance(handler, RotatingFileHandler)
]
Expand All @@ -131,7 +126,7 @@ def _init_root(
if log_dir is not False and not file_handlers:
root_logger.addHandler(
_file_handler(
"noob",
file_name,
file_level,
log_dir,
log_file_n,
Expand All @@ -143,7 +138,7 @@ def _init_root(
file_handler.setLevel(file_level)

if not stream_handlers:
root_logger.addHandler(_rich_handler(stdout_level, width=width))
root_logger.addHandler(_rich_handler(stdout_level, name=rich_name, width=width))
else:
for stream_handler in stream_handlers:
stream_handler.setLevel(stdout_level)
Expand Down Expand Up @@ -171,16 +166,32 @@ def _file_handler(
return file_handler


def _rich_handler(level: LOG_LEVELS, width: int | None = None, **kwargs: Any) -> RichHandler:
console = get_console()
def _rich_handler(
level: LOG_LEVELS, name: str, width: int | None = None, **kwargs: Any
) -> RichHandler:
console = _get_console()
if width:
console.width = width

rich_handler = RichHandler(rich_tracebacks=True, markup=True, **kwargs)
rich_handler = RichHandler(console=console, rich_tracebacks=True, markup=True, **kwargs)
rich_handler.name = name
rich_formatter = logging.Formatter(
r"[bold green]\[%(name)s][/bold green] %(message)s",
datefmt="[%y-%m-%dT%H:%M:%S]",
)
rich_handler.setFormatter(rich_formatter)
rich_handler.setLevel(level)
return rich_handler


_console_by_pid: dict[int | None, Console] = {}


def _get_console() -> Console:
"""get a console that was spawned in this process"""
global _console_by_pid
current_pid = mp.current_process().pid
console = _console_by_pid.get(current_pid)
if console is None:
_console_by_pid[current_pid] = console = Console()
return console
31 changes: 29 additions & 2 deletions src/noob/network/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ class MessageType(StrEnum):
announce = "announce"
identify = "identify"
process = "process"
init = "init"
deinit = "deinit"
ping = "ping"
start = "start"
status = "status"
stop = "stop"
Expand Down Expand Up @@ -111,6 +114,13 @@ class IdentifyMsg(Message):
value: IdentifyValue


class PingMsg(Message):
"""Request other nodes to identify themselves and report their status"""

type_: Literal[MessageType.ping] = Field(MessageType.ping, alias="type")
value: None = None


class ProcessMsg(Message):
"""Process a single iteration of the graph"""

Expand All @@ -119,11 +129,25 @@ class ProcessMsg(Message):
"""Any process-scoped input passed to the `process` call"""


class InitMsg(Message):
"""Initialize nodes within node runners"""

type_: Literal[MessageType.init] = Field(MessageType.init, alias="type")
value: None = None


class DeinitMsg(Message):
"""Deinitializes nodes within node runners"""

type_: Literal[MessageType.deinit] = Field(MessageType.deinit, alias="type")
value: None = None


class StartMsg(Message):
"""Start free running nodes"""
"""Start free-running nodes"""

type_: Literal[MessageType.start] = Field(MessageType.start, alias="type")
value: None = None
value: int | None = None


class StatusMsg(Message):
Expand Down Expand Up @@ -196,6 +220,9 @@ def _type_discriminator(v: dict | Message) -> str:
A[AnnounceMsg, Tag("announce")]
| A[IdentifyMsg, Tag("identify")]
| A[ProcessMsg, Tag("process")]
| A[InitMsg, Tag("init")]
| A[DeinitMsg, Tag("deinit")]
| A[PingMsg, Tag("ping")]
| A[StartMsg, Tag("start")]
| A[StatusMsg, Tag("status")]
| A[StopMsg, Tag("stop")]
Expand Down
10 changes: 1 addition & 9 deletions src/noob/node/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
cast,
get_args,
get_origin,
overload,
)

from pydantic import (
Expand All @@ -26,7 +25,6 @@

from noob.introspection import is_optional, is_union
from noob.node.spec import NodeSpecification
from noob.types import RunnerContext
from noob.utils import resolve_python_identifier

if TYPE_CHECKING:
Expand Down Expand Up @@ -195,14 +193,8 @@ def model_post_init(self, __context: Any) -> None:
if inspect.isgeneratorfunction(self.process):
self._wrap_generator(self.process)

@overload
def init(self) -> None: ...

@overload
def init(self, context: RunnerContext) -> None: ...

# TODO: Support dependency injection in mypy plugin
def init(self) -> None: # type: ignore[misc]
def init(self) -> None:
"""
Start producing, processing, or receiving data.

Expand Down
27 changes: 24 additions & 3 deletions src/noob/runner/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from datetime import UTC, datetime
from functools import partial
from logging import Logger
from typing import TYPE_CHECKING, Any, ParamSpec, Self, TypeVar
from typing import TYPE_CHECKING, Any, ParamSpec, Self, TypeVar, overload

from noob import Tube, init_logger
from noob.asset import AssetScope
Expand Down Expand Up @@ -180,7 +180,20 @@ def iter(self, n: int | None = None) -> Generator[ReturnNodeType, None, None]:
finally:
self.deinit()

@overload
def run(self, n: int) -> list[ReturnNodeType]: ...

@overload
def run(self, n: None) -> None: ...

def run(self, n: int | None = None) -> None | list[ReturnNodeType]:
"""
Run the tube infinitely or for a fixed number of iterations in a row.

Returns results if ``n`` is not ``None`` -
If ``n`` is ``None`` , we assume we are going to be running for a very long time,
and don't want to have an infinitely-growing collection in memory.
"""
try:
_ = self.tube.input_collection.validate_input(InputScope.process, {})
except InputMissingError as e:
Expand Down Expand Up @@ -540,25 +553,33 @@ def call_async_from_sync(

result_future: asyncio.Future[_TReturn] = asyncio.Future()
work_ready = threading.Condition()
finished = False

# Closures because this code should never escape the containment tomb of this crime against god
async def _wrap(call_result: asyncio.Future[_TReturn], fn: Coroutine) -> None:
nonlocal finished
try:
result = await fn
call_result.set_result(result)
except Exception as e:
call_result.set_exception(e)
finally:
finished = True

def _done(_: ConcurrentFuture) -> None:
nonlocal finished

finished = True
with work_ready:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this ok to be not nonlocal? clearly it's working since tests are passing but why

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure what you mean? this is nonlocal because it's referring to finished which is defined in the outer function scope, and we are modifying it, so it needs to be nonlocal (otherwise assigning to it would just create finished in the inner _done scope.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh sorry i was asking about work_ready

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aha my b. No thats fine since we dont rebind work_ready

work_ready.notify_all()

future_inner = executor.submit(asyncio.run, _wrap(result_future, coro))
future_inner.add_done_callback(_done)

with work_ready:
work_ready.wait()
try:
while not finished and not future_inner.done():
with work_ready:
work_ready.wait(timeout=1)
res = result_future.result()
return res
finally:
Expand Down
Loading