From ee4ce4ad48779e09cadebcd98b8bdfa9128fa8b0 Mon Sep 17 00:00:00 2001 From: Simon Holesch Date: Sat, 20 Apr 2024 04:28:59 +0200 Subject: [PATCH] util: Cancel task if background task fails Cancel the foreground task as soon as the background task fails and raise the exception of the background task. With that, background task failures can be noticed immediately, not only after stopping the process. --- not_my_board/_util/_asyncio.py | 65 ++++++++++++++++++++++---- tests/test_agent.py | 33 +++++++------ tests/test_hub.py | 22 ++++----- tests/test_jsonrpc.py | 84 +++++++++++++++++----------------- tests/test_util.py | 26 +++++++++++ 5 files changed, 154 insertions(+), 76 deletions(-) create mode 100644 tests/test_util.py diff --git a/not_my_board/_util/_asyncio.py b/not_my_board/_util/_asyncio.py index d0edfdd..bae50c4 100644 --- a/not_my_board/_util/_asyncio.py +++ b/not_my_board/_util/_asyncio.py @@ -48,16 +48,63 @@ async def run_concurrently(*coros): await cancel_tasks(tasks) -@contextlib.asynccontextmanager -async def background_task(coro): - """Runs the coro until leaving the context manager. +def background_task(coro): + """Runs coro as a background task until leaving the context manager. - The coro task is canceled when leaving the context.""" - task = asyncio.create_task(coro) - try: - yield task - finally: - await cancel_tasks([task]) + If the background task fails while the context manager is active, then the + foreground task is canceled and the context manager raises the exception of + the background task. + + If the context manager exits while the background task is still running, + then the background task is canceled.""" + + return _BackgroundTask(coro) + + +class _BackgroundTask: + def __init__(self, coro): + self._coro = coro + self._bg_exception = None + + async def __aenter__(self): + self._bg_task = asyncio.create_task(self._coro) + self._bg_task.add_done_callback(self._on_bg_task_done) + self._fg_task = asyncio.current_task() + self._num_cancel_requests = self._get_num_cancel_requests() + return self._bg_task + + async def __aexit__(self, exc_type, exc, tb): + self._bg_task.remove_done_callback(self._on_bg_task_done) + if self._bg_exception: + if ( + self._uncancel() <= self._num_cancel_requests + and exc_type is asyncio.CancelledError + ): + # foreground task was only canceled by this class, raise + # real error + raise self._bg_exception from exc + else: + await cancel_tasks([self._bg_task]) + + def _on_bg_task_done(self, task): + if not task.cancelled(): + self._bg_exception = task.exception() + if self._bg_exception: + self._fg_task.cancel() + + def _get_num_cancel_requests(self): + # remove, if Python version < 3.11 is no longer supported + if hasattr(self._fg_task, "cancelling"): + return self._fg_task.cancelling() + else: + return 0 + + def _uncancel(self): + # remove, if Python version < 3.11 is no longer supported + if hasattr(self._fg_task, "uncancel"): + return self._fg_task.uncancel() + else: + return 0 async def cancel_tasks(tasks): diff --git a/tests/test_agent.py b/tests/test_agent.py index 60ae41c..a0c07df 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -410,23 +410,26 @@ async def test_reserve_twice_concurrently(agent_io): # start two reserve tasks in parallel coro_1 = agent_io.agent_api.reserve(IMPORT_DESC_1.dict()) coro_2 = agent_io.agent_api.reserve(IMPORT_DESC_1.dict()) - async with util.background_task(coro_1) as task_1: - async with util.background_task(coro_2) as task_2: - # wait until one blocks - await agent_io.hub.reserve_pending.wait() - # the other one should block until the first one finishes - assert not task_1.done() - assert not task_2.done() + with pytest.raises(RuntimeError) as execinfo: + async with util.background_task(coro_1) as task_1: + async with util.background_task(coro_2) as task_2: + # wait until one blocks + await agent_io.hub.reserve_pending.wait() + + # the other one should block until the first one finishes + assert not task_1.done() + assert not task_2.done() - # unblock reserve call - agent_io.hub.reserve_continue.set() + # unblock reserve call + agent_io.hub.reserve_continue.set() - # now one should finish successfully and the other one should fail - results = await asyncio.gather(task_1, task_2, return_exceptions=True) + await asyncio.sleep(0.5) - assert None in results - exception = results[0] or results[1] - assert "is already reserved" in str(exception) + # now one should finish successfully and the other one should fail + results = await asyncio.gather(task_1, task_2, return_exceptions=True) + + assert None in results + assert "is already reserved" in str(execinfo.value) - assert len(await agent_io.agent_api.list()) == 1 + assert len(await agent_io.agent_api.list()) == 1 diff --git a/tests/test_hub.py b/tests/test_hub.py index d3226b0..4c282d7 100644 --- a/tests/test_hub.py +++ b/tests/test_hub.py @@ -147,18 +147,18 @@ async def test_all_places_disappear_while_trying_to_reserve(hub): candidate_ids = [places["places"][0]["id"]] await agent.reserve(candidate_ids) - # try to reserve same place again - coro = agent.reserve(candidate_ids) - async with util.background_task(coro) as reserve_task: - await asyncio.sleep(0.001) - # request should be in queue now + with pytest.raises(Exception) as execinfo: + # try to reserve same place again + coro = agent.reserve(candidate_ids) + async with util.background_task(coro): + await asyncio.sleep(0.001) + # request should be in queue now - # when the exporter disappears ... - await util.cancel_tasks([exporter_task]) - # ... then the queued reservation is canceled - with pytest.raises(Exception) as execinfo: - await reserve_task - assert "All candidate places are gone" in str(execinfo.value) + # when the exporter disappears ... + await util.cancel_tasks([exporter_task]) + await asyncio.sleep(0.5) + # ... then the queued reservation is canceled + assert "All candidate places are gone" in str(execinfo.value) async def test_one_place_disappears_while_trying_to_reserve(hub): diff --git a/tests/test_jsonrpc.py b/tests/test_jsonrpc.py index 30657e3..4789052 100644 --- a/tests/test_jsonrpc.py +++ b/tests/test_jsonrpc.py @@ -231,30 +231,30 @@ async def test_send_notification(fakes): async def test_receive_error(fakes): - async with util.background_task(fakes.channel.some_func()) as call: - message = await fakes.transport.receive_from_jsonrpc() - - # send error response - await fakes.transport.send_to_jsonrpc( - id=message["id"], - error={ - "code": 123, - "message": "fake error", - "data": { - "traceback": ( - "Traceback:\n" - ' File "fake", line 207, in fake_func\n' - " raise FakeError()\n" - "FakeError: fake error\n" - ), + with pytest.raises(jsonrpc.RemoteError) as execinfo: + async with util.background_task(fakes.channel.some_func()): + message = await fakes.transport.receive_from_jsonrpc() + + # send error response + await fakes.transport.send_to_jsonrpc( + id=message["id"], + error={ + "code": 123, + "message": "fake error", + "data": { + "traceback": ( + "Traceback:\n" + ' File "fake", line 207, in fake_func\n' + " raise FakeError()\n" + "FakeError: fake error\n" + ), + }, }, - }, - ) + ) + await asyncio.sleep(0.5) - # call should raise RemoteError - with pytest.raises(jsonrpc.RemoteError) as execinfo: - await call - assert "fake error" in str(execinfo.value) + # call should raise RemoteError + assert "fake error" in str(execinfo.value) async def test_send_cancellation(fakes): @@ -314,19 +314,20 @@ async def test_prevent_hidden_function_call(fakes): async def test_fail_call_on_parse_error(fakes): - async with util.background_task(fakes.channel.some_func()) as call: - # check sent request - message = await fakes.transport.receive_from_jsonrpc() + with pytest.raises(jsonrpc.ProtocolError) as execinfo: + async with util.background_task(fakes.channel.some_func()): + # check sent request + message = await fakes.transport.receive_from_jsonrpc() - # send invalid error response - await fakes.transport.send_to_jsonrpc( - id=message["id"], error={"code": "not int"} - ) + # send invalid error response + await fakes.transport.send_to_jsonrpc( + id=message["id"], error={"code": "not int"} + ) - # call with matching ID should still fail - with pytest.raises(jsonrpc.ProtocolError) as execinfo: - await call - assert '"error.code" must be an integer' in str(execinfo.value) + await asyncio.sleep(0.5) + + # call with matching ID should still fail + assert '"error.code" must be an integer' in str(execinfo.value) @contextlib.asynccontextmanager @@ -385,14 +386,15 @@ async def test_pending_call_is_canceled_on_shutdown(): transport = FakeTransport() channel = jsonrpc.Channel(transport.send_to_test, transport.receive_from_test()) async with util.background_task(channel.communicate_forever()) as com_task: - async with util.background_task(channel.some_func()) as call: - # wait for sent message - await transport.receive_from_jsonrpc() - await util.cancel_tasks([com_task]) - - with pytest.raises(RuntimeError) as execinfo: - await call - assert "Connection closed" in str(execinfo.value) + + with pytest.raises(RuntimeError) as execinfo: + async with util.background_task(channel.some_func()): + # wait for sent message + await transport.receive_from_jsonrpc() + await util.cancel_tasks([com_task]) + await asyncio.sleep(0.5) + + assert "Connection closed" in str(execinfo.value) async def test_log_error_response_with_unknown_id(fakes): diff --git a/tests/test_util.py b/tests/test_util.py new file mode 100644 index 0000000..a087dd3 --- /dev/null +++ b/tests/test_util.py @@ -0,0 +1,26 @@ +import asyncio + +import pytest + +import not_my_board._util as util + + +async def test_background_task(): + async with util.background_task(blocking_task()) as task: + assert not task.done() + assert task.cancelled() + + +async def test_background_task_failed(): + with pytest.raises(RuntimeError) as execinfo: + async with util.background_task(failing_task()): + await asyncio.sleep(3) + assert "Dummy Error" in str(execinfo.value) + + +async def blocking_task(): + await asyncio.Event().wait() + + +async def failing_task(): + raise RuntimeError("Dummy Error")