diff --git a/synapse/tests/test_tools_storm.py b/synapse/tests/test_tools_storm.py index 3dd3b9722e0..a93713a19be 100644 --- a/synapse/tests/test_tools_storm.py +++ b/synapse/tests/test_tools_storm.py @@ -1,5 +1,9 @@ import os +import sys +import signal import asyncio +import multiprocessing + import synapse.tests.utils as s_test from prompt_toolkit.document import Document @@ -7,10 +11,49 @@ import synapse.exc as s_exc import synapse.common as s_common +import synapse.telepath as s_telepath + +import synapse.lib.coro as s_coro import synapse.lib.output as s_output import synapse.lib.msgpack as s_msgpack import synapse.tools.storm as s_t_storm +def run_cli_till_print(url, evt1): + ''' + Run the stormCLI until we get a print mesg then set the event. + + This is a Process target. + ''' + async def main(): + outp = s_output.OutPutStr() # Capture output instead of sending it to stdout + async with await s_telepath.openurl(url) as proxy: + async with await s_t_storm.StormCli.anit(proxy, outp=outp) as scli: + cmdqueue = asyncio.Queue() + await cmdqueue.put('while (true) { $lib.print(go) $lib.time.sleep(1) }') + await cmdqueue.put('!quit') + + async def fake_prompt(): + return await cmdqueue.get() + + scli.prompt = fake_prompt + + d = {'evt1': False} + async def onmesg(event): + if d.get('evt1'): + return + mesg = event[1].get('mesg') + if mesg[0] != 'print': + return + evt1.set() + d['evt1'] = True + + with scli.onWith('storm:mesg', onmesg): + await scli.addSignalHandlers() + await scli.runCmdLoop() + + asyncio.run(main()) + sys.exit(137) + class StormCliTest(s_test.SynTest): async def test_tools_storm(self): @@ -410,3 +453,23 @@ async def fake_prompt(): outp.expect('') outp.expect('o/') self.true(scli.isfini) + + async def test_storm_cmdloop_sigint(self): + ''' + Test interrupting a long-running query in the command loop with a process target and SIGINT. + ''' + + async with self.getTestCore() as core: + url = core.getLocalUrl() + + ctx = multiprocessing.get_context('spawn') + + evt1 = ctx.Event() + + proc = ctx.Process(target=run_cli_till_print, args=(url, evt1,)) + proc.start() + + self.true(await s_coro.executor(evt1.wait, timeout=30)) + os.kill(proc.pid, signal.SIGINT) + proc.join(timeout=30) + self.eq(proc.exitcode, 137)