Skip to content

Commit

Permalink
Rewrite the class Monitor as a function relay_queue().
Browse files Browse the repository at this point in the history
  • Loading branch information
TaiSakuma committed Jan 19, 2024
1 parent 2971755 commit aee0a72
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 55 deletions.
94 changes: 42 additions & 52 deletions nextline/plugin/plugins/session/monitor.py
Original file line number Diff line number Diff line change
@@ -1,64 +1,54 @@
import asyncio
import time
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from logging import getLogger

from nextline import spawned
from nextline.plugin.spec import Context
from nextline.spawned import QueueOut

# from rich import print


class Monitor:
def __init__(self, context: Context, queue: QueueOut):
self._context = context
self._queue = queue
self._logger = getLogger(__name__)

async def open(self) -> None:
self._task = asyncio.create_task(self._monitor())

async def close(self) -> None:
@asynccontextmanager
async def relay_queue(context: Context, queue: QueueOut) -> AsyncIterator[None]:
task = asyncio.create_task(_monitor(context, queue))
try:
yield
finally:
up_to = 0.05
start = time.process_time()
while not self._queue.empty() and time.process_time() - start < up_to:
while not queue.empty() and time.process_time() - start < up_to:
await asyncio.sleep(0)
await asyncio.to_thread(self._queue.put, None) # type: ignore
await self._task

async def __aenter__(self) -> 'Monitor':
await self.open()
return self

async def __aexit__(self, exc_type, exc_value, traceback): # type: ignore
del exc_type, exc_value, traceback
await self.close()

async def _monitor(self) -> None:
while (event := await asyncio.to_thread(self._queue.get)) is not None:
await self._on_event(event)

async def _on_event(self, event: spawned.Event) -> None:
context = self._context
ahook = context.hook.ahook
match event:
case spawned.OnStartTrace():
await ahook.on_start_trace(context=context, event=event)
case spawned.OnEndTrace():
await ahook.on_end_trace(context=context, event=event)
case spawned.OnStartTraceCall():
await ahook.on_start_trace_call(context=context, event=event)
case spawned.OnEndTraceCall():
await ahook.on_end_trace_call(context=context, event=event)
case spawned.OnStartCmdloop():
await ahook.on_start_cmdloop(context=context, event=event)
case spawned.OnEndCmdloop():
await ahook.on_end_cmdloop(context=context, event=event)
case spawned.OnStartPrompt():
await ahook.on_start_prompt(context=context, event=event)
case spawned.OnEndPrompt():
await ahook.on_end_prompt(context=context, event=event)
case spawned.OnWriteStdout():
await ahook.on_write_stdout(context=context, event=event)
case _:
self._logger.warning(f'Unknown event: {event!r}')
await asyncio.to_thread(queue.put, None) # type: ignore
await task


async def _monitor(context: Context, queue: QueueOut) -> None:
while (event := await asyncio.to_thread(queue.get)) is not None:
await _on_event(context, event)


async def _on_event(context: Context, event: spawned.Event) -> None:
ahook = context.hook.ahook
match event:
case spawned.OnStartTrace():
await ahook.on_start_trace(context=context, event=event)
case spawned.OnEndTrace():
await ahook.on_end_trace(context=context, event=event)
case spawned.OnStartTraceCall():
await ahook.on_start_trace_call(context=context, event=event)
case spawned.OnEndTraceCall():
await ahook.on_end_trace_call(context=context, event=event)
case spawned.OnStartCmdloop():
await ahook.on_start_cmdloop(context=context, event=event)
case spawned.OnEndCmdloop():
await ahook.on_end_cmdloop(context=context, event=event)
case spawned.OnStartPrompt():
await ahook.on_start_prompt(context=context, event=event)
case spawned.OnEndPrompt():
await ahook.on_end_prompt(context=context, event=event)
case spawned.OnWriteStdout():
await ahook.on_write_stdout(context=context, event=event)
case _:
logger = getLogger(__name__)
logger.warning(f'Unknown event: {event!r}')
5 changes: 2 additions & 3 deletions nextline/plugin/plugins/session/spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from nextline.spawned import Command, QueueIn, QueueOut, RunArg, RunResult
from nextline.utils import MultiprocessingLogging, RunningProcess, run_in_process

from .monitor import Monitor
from .monitor import relay_queue

pickling_support.install()

Expand All @@ -35,7 +35,6 @@ async def run_session(
queue_in = cast(QueueIn, mp_context.Queue())
queue_out = cast(QueueOut, mp_context.Queue())
send_command = SendCommand(queue_in)
monitor = Monitor(context, queue_out)
async with MultiprocessingLogging(mp_context=mp_context) as mp_logging:
initializer = partial(
_call_all,
Expand All @@ -49,7 +48,7 @@ async def run_session(
initializer=initializer,
)
func = partial(spawned.main, run_arg)
async with monitor:
async with relay_queue(context, queue_out):
running = await run_in_process(func, executor_factory)
yield running, send_command

Expand Down

0 comments on commit aee0a72

Please sign in to comment.