diff --git a/changes/5433695fd5f2187f2cd9f7a4ba0ae4d9.yaml b/changes/5433695fd5f2187f2cd9f7a4ba0ae4d9.yaml new file mode 100644 index 00000000000..08ef0211ca3 --- /dev/null +++ b/changes/5433695fd5f2187f2cd9f7a4ba0ae4d9.yaml @@ -0,0 +1,5 @@ +--- +desc: Fixed SIGINT handling in the ``synapse.tools.storm`` CLI tool. +prs: [] +type: bug +... diff --git a/pyproject.toml b/pyproject.toml index d05e5c234de..bbb5acdabc4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,7 +35,7 @@ dependencies = [ 'aiohttp-socks>=0.9.0,<0.10.0', 'aioimaplib>=1.1.0,<1.2.0', 'aiosmtplib>=3.0.0,<3.1.0', - 'prompt-toolkit>=3.0.4,<3.1.0', + 'prompt_toolkit>=3.0.29,<3.1.0', 'lark==1.2.2', 'Pygments>=2.7.4,<2.18.0', 'packaging>=20.0,<25.0', diff --git a/requirements.txt b/requirements.txt index 45b4f71eba4..bd52fb24e1f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,7 +9,7 @@ aiohttp>=3.10.0,<4.0 aiohttp-socks>=0.9.0,<0.10.0 aioimaplib>=1.1.0,<1.2.0 aiosmtplib>=3.0.0,<3.1.0 -prompt-toolkit>=3.0.4,<3.1.0 +prompt_toolkit>=3.0.29,<3.1.0 lark==1.2.2 Pygments>=2.7.4,<2.18.0 fastjsonschema>=2.18.0,<2.20.0 diff --git a/synapse/lib/cli.py b/synapse/lib/cli.py index 4feb1d217ff..ed5ed77d93c 100644 --- a/synapse/lib/cli.py +++ b/synapse/lib/cli.py @@ -281,18 +281,26 @@ async def _onItemFini(self): await self.fini() - async def addSignalHandlers(self): + async def addSignalHandlers(self): # pragma: no cover ''' Register SIGINT signal handler with the ioloop to cancel the currently running cmdloop task. + Removes the handler when the cli is fini'd. ''' - def sigint(): - self.printf('') if self.cmdtask is not None: self.cmdtask.cancel() self.loop.add_signal_handler(signal.SIGINT, sigint) + def onfini(): + # N.B. This is reaches into some loop / handle internals but + # prevents us from removing a handler that overwrote our own. + hndl = self.loop._signal_handlers.get(signal.SIGINT, None) # type: asyncio.Handle + if hndl is not None and hndl._callback is sigint: + self.loop.remove_signal_handler(signal.SIGINT) + + self.onfini(onfini) + def get(self, name, defval=None): return self.locs.get(name, defval) @@ -324,8 +332,12 @@ async def prompt(self, text=None): if text is None: text = self.cmdprompt - with patch_stdout(): - retn = await self.sess.prompt_async(text, vi_mode=self.vi_mode, enable_open_in_editor=True) + with patch_stdout(): # pragma: no cover + retn = await self.sess.prompt_async(text, + vi_mode=self.vi_mode, + enable_open_in_editor=True, + handle_sigint=False # We handle sigint in the loop + ) return retn def printf(self, mesg, addnl=True, color=None): @@ -390,7 +402,7 @@ async def runCmdLoop(self): self.cmdtask = self.schedCoro(coro) await self.cmdtask - except KeyboardInterrupt: + except (KeyboardInterrupt, asyncio.CancelledError): if self.isfini: return @@ -408,11 +420,8 @@ async def runCmdLoop(self): if self.cmdtask is not None: self.cmdtask.cancel() try: - self.cmdtask.result() - except asyncio.CancelledError: - # Wait a beat to let any remaining nodes to print out before we print the prompt - await asyncio.sleep(1) - except Exception: + await asyncio.wait_for(self.cmdtask, timeout=0.1) + except (asyncio.CancelledError, asyncio.TimeoutError): pass async def runCmdLine(self, line): diff --git a/synapse/tests/test_tools_storm.py b/synapse/tests/test_tools_storm.py index 92464d94ae6..a93713a19be 100644 --- a/synapse/tests/test_tools_storm.py +++ b/synapse/tests/test_tools_storm.py @@ -1,4 +1,9 @@ import os +import sys +import signal +import asyncio +import multiprocessing + import synapse.tests.utils as s_test from prompt_toolkit.document import Document @@ -6,10 +11,49 @@ import synapse.exc as s_exc import synapse.common as s_common +import synapse.telepath as s_telepath + +import synapse.lib.coro as s_coro import synapse.lib.output as s_output import synapse.lib.msgpack as s_msgpack import synapse.tools.storm as s_t_storm +def run_cli_till_print(url, evt1): + ''' + Run the stormCLI until we get a print mesg then set the event. + + This is a Process target. + ''' + async def main(): + outp = s_output.OutPutStr() # Capture output instead of sending it to stdout + async with await s_telepath.openurl(url) as proxy: + async with await s_t_storm.StormCli.anit(proxy, outp=outp) as scli: + cmdqueue = asyncio.Queue() + await cmdqueue.put('while (true) { $lib.print(go) $lib.time.sleep(1) }') + await cmdqueue.put('!quit') + + async def fake_prompt(): + return await cmdqueue.get() + + scli.prompt = fake_prompt + + d = {'evt1': False} + async def onmesg(event): + if d.get('evt1'): + return + mesg = event[1].get('mesg') + if mesg[0] != 'print': + return + evt1.set() + d['evt1'] = True + + with scli.onWith('storm:mesg', onmesg): + await scli.addSignalHandlers() + await scli.runCmdLoop() + + asyncio.run(main()) + sys.exit(137) + class StormCliTest(s_test.SynTest): async def test_tools_storm(self): @@ -378,3 +422,54 @@ async def get_completions(text): ), vals ) + + async def test_storm_cmdloop_interrupt(self): + ''' + Test interrupting a long-running query in the command loop + ''' + async with self.getTestCore() as core: + + async with core.getLocalProxy() as proxy: + + outp = s_test.TstOutPut() + async with await s_t_storm.StormCli.anit(proxy, outp=outp) as scli: + + cmdqueue = asyncio.Queue() + await cmdqueue.put('while (true) { $lib.time.sleep(1) }') + await cmdqueue.put('!quit') + + async def fake_prompt(): + return await cmdqueue.get() + scli.prompt = fake_prompt + + cmdloop_task = asyncio.create_task(scli.runCmdLoop()) + await asyncio.sleep(0.1) + + if scli.cmdtask is not None: + scli.cmdtask.cancel() + + await cmdloop_task + + outp.expect('') + outp.expect('o/') + self.true(scli.isfini) + + async def test_storm_cmdloop_sigint(self): + ''' + Test interrupting a long-running query in the command loop with a process target and SIGINT. + ''' + + async with self.getTestCore() as core: + url = core.getLocalUrl() + + ctx = multiprocessing.get_context('spawn') + + evt1 = ctx.Event() + + proc = ctx.Process(target=run_cli_till_print, args=(url, evt1,)) + proc.start() + + self.true(await s_coro.executor(evt1.wait, timeout=30)) + os.kill(proc.pid, signal.SIGINT) + proc.join(timeout=30) + self.eq(proc.exitcode, 137)