From aee0a72da4f06c61023d125518a54617e7bc34d1 Mon Sep 17 00:00:00 2001 From: Tai Sakuma Date: Fri, 19 Jan 2024 17:03:21 -0500 Subject: [PATCH] Rewrite the class Monitor as a function relay_queue(). --- nextline/plugin/plugins/session/monitor.py | 94 ++++++++++------------ nextline/plugin/plugins/session/spawn.py | 5 +- 2 files changed, 44 insertions(+), 55 deletions(-) diff --git a/nextline/plugin/plugins/session/monitor.py b/nextline/plugin/plugins/session/monitor.py index f5cdeec6..a5561d4c 100644 --- a/nextline/plugin/plugins/session/monitor.py +++ b/nextline/plugin/plugins/session/monitor.py @@ -1,64 +1,54 @@ import asyncio import time +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager from logging import getLogger from nextline import spawned from nextline.plugin.spec import Context from nextline.spawned import QueueOut -# from rich import print - -class Monitor: - def __init__(self, context: Context, queue: QueueOut): - self._context = context - self._queue = queue - self._logger = getLogger(__name__) - - async def open(self) -> None: - self._task = asyncio.create_task(self._monitor()) - - async def close(self) -> None: +@asynccontextmanager +async def relay_queue(context: Context, queue: QueueOut) -> AsyncIterator[None]: + task = asyncio.create_task(_monitor(context, queue)) + try: + yield + finally: up_to = 0.05 start = time.process_time() - while not self._queue.empty() and time.process_time() - start < up_to: + while not queue.empty() and time.process_time() - start < up_to: await asyncio.sleep(0) - await asyncio.to_thread(self._queue.put, None) # type: ignore - await self._task - - async def __aenter__(self) -> 'Monitor': - await self.open() - return self - - async def __aexit__(self, exc_type, exc_value, traceback): # type: ignore - del exc_type, exc_value, traceback - await self.close() - - async def _monitor(self) -> None: - while (event := await asyncio.to_thread(self._queue.get)) is not None: - await self._on_event(event) - - async def _on_event(self, event: spawned.Event) -> None: - context = self._context - ahook = context.hook.ahook - match event: - case spawned.OnStartTrace(): - await ahook.on_start_trace(context=context, event=event) - case spawned.OnEndTrace(): - await ahook.on_end_trace(context=context, event=event) - case spawned.OnStartTraceCall(): - await ahook.on_start_trace_call(context=context, event=event) - case spawned.OnEndTraceCall(): - await ahook.on_end_trace_call(context=context, event=event) - case spawned.OnStartCmdloop(): - await ahook.on_start_cmdloop(context=context, event=event) - case spawned.OnEndCmdloop(): - await ahook.on_end_cmdloop(context=context, event=event) - case spawned.OnStartPrompt(): - await ahook.on_start_prompt(context=context, event=event) - case spawned.OnEndPrompt(): - await ahook.on_end_prompt(context=context, event=event) - case spawned.OnWriteStdout(): - await ahook.on_write_stdout(context=context, event=event) - case _: - self._logger.warning(f'Unknown event: {event!r}') + await asyncio.to_thread(queue.put, None) # type: ignore + await task + + +async def _monitor(context: Context, queue: QueueOut) -> None: + while (event := await asyncio.to_thread(queue.get)) is not None: + await _on_event(context, event) + + +async def _on_event(context: Context, event: spawned.Event) -> None: + ahook = context.hook.ahook + match event: + case spawned.OnStartTrace(): + await ahook.on_start_trace(context=context, event=event) + case spawned.OnEndTrace(): + await ahook.on_end_trace(context=context, event=event) + case spawned.OnStartTraceCall(): + await ahook.on_start_trace_call(context=context, event=event) + case spawned.OnEndTraceCall(): + await ahook.on_end_trace_call(context=context, event=event) + case spawned.OnStartCmdloop(): + await ahook.on_start_cmdloop(context=context, event=event) + case spawned.OnEndCmdloop(): + await ahook.on_end_cmdloop(context=context, event=event) + case spawned.OnStartPrompt(): + await ahook.on_start_prompt(context=context, event=event) + case spawned.OnEndPrompt(): + await ahook.on_end_prompt(context=context, event=event) + case spawned.OnWriteStdout(): + await ahook.on_write_stdout(context=context, event=event) + case _: + logger = getLogger(__name__) + logger.warning(f'Unknown event: {event!r}') diff --git a/nextline/plugin/plugins/session/spawn.py b/nextline/plugin/plugins/session/spawn.py index 4e5d32f1..d8302780 100644 --- a/nextline/plugin/plugins/session/spawn.py +++ b/nextline/plugin/plugins/session/spawn.py @@ -13,7 +13,7 @@ from nextline.spawned import Command, QueueIn, QueueOut, RunArg, RunResult from nextline.utils import MultiprocessingLogging, RunningProcess, run_in_process -from .monitor import Monitor +from .monitor import relay_queue pickling_support.install() @@ -35,7 +35,6 @@ async def run_session( queue_in = cast(QueueIn, mp_context.Queue()) queue_out = cast(QueueOut, mp_context.Queue()) send_command = SendCommand(queue_in) - monitor = Monitor(context, queue_out) async with MultiprocessingLogging(mp_context=mp_context) as mp_logging: initializer = partial( _call_all, @@ -49,7 +48,7 @@ async def run_session( initializer=initializer, ) func = partial(spawned.main, run_arg) - async with monitor: + async with relay_queue(context, queue_out): running = await run_in_process(func, executor_factory) yield running, send_command