Skip to content

Commit

Permalink
util: Cancel task if background task fails
Browse files Browse the repository at this point in the history
Cancel the foreground task as soon as the background task fails and
raise the exception of the background task. With that, background task
failures can be noticed immediately, not only after stopping the
process.
  • Loading branch information
holesch committed Apr 20, 2024
1 parent fc37583 commit ee4ce4a
Show file tree
Hide file tree
Showing 5 changed files with 154 additions and 76 deletions.
65 changes: 56 additions & 9 deletions not_my_board/_util/_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,16 +48,63 @@ async def run_concurrently(*coros):
await cancel_tasks(tasks)


@contextlib.asynccontextmanager
async def background_task(coro):
"""Runs the coro until leaving the context manager.
def background_task(coro):
"""Runs coro as a background task until leaving the context manager.
The coro task is canceled when leaving the context."""
task = asyncio.create_task(coro)
try:
yield task
finally:
await cancel_tasks([task])
If the background task fails while the context manager is active, then the
foreground task is canceled and the context manager raises the exception of
the background task.
If the context manager exits while the background task is still running,
then the background task is canceled."""

return _BackgroundTask(coro)


class _BackgroundTask:
def __init__(self, coro):
self._coro = coro
self._bg_exception = None

async def __aenter__(self):
self._bg_task = asyncio.create_task(self._coro)
self._bg_task.add_done_callback(self._on_bg_task_done)
self._fg_task = asyncio.current_task()
self._num_cancel_requests = self._get_num_cancel_requests()
return self._bg_task

async def __aexit__(self, exc_type, exc, tb):
self._bg_task.remove_done_callback(self._on_bg_task_done)
if self._bg_exception:
if (
self._uncancel() <= self._num_cancel_requests
and exc_type is asyncio.CancelledError
):
# foreground task was only canceled by this class, raise
# real error
raise self._bg_exception from exc
else:
await cancel_tasks([self._bg_task])

def _on_bg_task_done(self, task):
if not task.cancelled():
self._bg_exception = task.exception()
if self._bg_exception:
self._fg_task.cancel()

def _get_num_cancel_requests(self):
# remove, if Python version < 3.11 is no longer supported
if hasattr(self._fg_task, "cancelling"):
return self._fg_task.cancelling()
else:
return 0

def _uncancel(self):
# remove, if Python version < 3.11 is no longer supported
if hasattr(self._fg_task, "uncancel"):
return self._fg_task.uncancel()
else:
return 0


async def cancel_tasks(tasks):
Expand Down
33 changes: 18 additions & 15 deletions tests/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,23 +410,26 @@ async def test_reserve_twice_concurrently(agent_io):
# start two reserve tasks in parallel
coro_1 = agent_io.agent_api.reserve(IMPORT_DESC_1.dict())
coro_2 = agent_io.agent_api.reserve(IMPORT_DESC_1.dict())
async with util.background_task(coro_1) as task_1:
async with util.background_task(coro_2) as task_2:
# wait until one blocks
await agent_io.hub.reserve_pending.wait()

# the other one should block until the first one finishes
assert not task_1.done()
assert not task_2.done()
with pytest.raises(RuntimeError) as execinfo:
async with util.background_task(coro_1) as task_1:
async with util.background_task(coro_2) as task_2:
# wait until one blocks
await agent_io.hub.reserve_pending.wait()

# the other one should block until the first one finishes
assert not task_1.done()
assert not task_2.done()

# unblock reserve call
agent_io.hub.reserve_continue.set()
# unblock reserve call
agent_io.hub.reserve_continue.set()

# now one should finish successfully and the other one should fail
results = await asyncio.gather(task_1, task_2, return_exceptions=True)
await asyncio.sleep(0.5)

assert None in results
exception = results[0] or results[1]
assert "is already reserved" in str(exception)
# now one should finish successfully and the other one should fail
results = await asyncio.gather(task_1, task_2, return_exceptions=True)

assert None in results
assert "is already reserved" in str(execinfo.value)

assert len(await agent_io.agent_api.list()) == 1
assert len(await agent_io.agent_api.list()) == 1
22 changes: 11 additions & 11 deletions tests/test_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,18 +147,18 @@ async def test_all_places_disappear_while_trying_to_reserve(hub):
candidate_ids = [places["places"][0]["id"]]
await agent.reserve(candidate_ids)

# try to reserve same place again
coro = agent.reserve(candidate_ids)
async with util.background_task(coro) as reserve_task:
await asyncio.sleep(0.001)
# request should be in queue now
with pytest.raises(Exception) as execinfo:
# try to reserve same place again
coro = agent.reserve(candidate_ids)
async with util.background_task(coro):
await asyncio.sleep(0.001)
# request should be in queue now

# when the exporter disappears ...
await util.cancel_tasks([exporter_task])
# ... then the queued reservation is canceled
with pytest.raises(Exception) as execinfo:
await reserve_task
assert "All candidate places are gone" in str(execinfo.value)
# when the exporter disappears ...
await util.cancel_tasks([exporter_task])
await asyncio.sleep(0.5)
# ... then the queued reservation is canceled
assert "All candidate places are gone" in str(execinfo.value)


async def test_one_place_disappears_while_trying_to_reserve(hub):
Expand Down
84 changes: 43 additions & 41 deletions tests/test_jsonrpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,30 +231,30 @@ async def test_send_notification(fakes):


async def test_receive_error(fakes):
async with util.background_task(fakes.channel.some_func()) as call:
message = await fakes.transport.receive_from_jsonrpc()

# send error response
await fakes.transport.send_to_jsonrpc(
id=message["id"],
error={
"code": 123,
"message": "fake error",
"data": {
"traceback": (
"Traceback:\n"
' File "fake", line 207, in fake_func\n'
" raise FakeError()\n"
"FakeError: fake error\n"
),
with pytest.raises(jsonrpc.RemoteError) as execinfo:
async with util.background_task(fakes.channel.some_func()):
message = await fakes.transport.receive_from_jsonrpc()

# send error response
await fakes.transport.send_to_jsonrpc(
id=message["id"],
error={
"code": 123,
"message": "fake error",
"data": {
"traceback": (
"Traceback:\n"
' File "fake", line 207, in fake_func\n'
" raise FakeError()\n"
"FakeError: fake error\n"
),
},
},
},
)
)
await asyncio.sleep(0.5)

# call should raise RemoteError
with pytest.raises(jsonrpc.RemoteError) as execinfo:
await call
assert "fake error" in str(execinfo.value)
# call should raise RemoteError
assert "fake error" in str(execinfo.value)


async def test_send_cancellation(fakes):
Expand Down Expand Up @@ -314,19 +314,20 @@ async def test_prevent_hidden_function_call(fakes):


async def test_fail_call_on_parse_error(fakes):
async with util.background_task(fakes.channel.some_func()) as call:
# check sent request
message = await fakes.transport.receive_from_jsonrpc()
with pytest.raises(jsonrpc.ProtocolError) as execinfo:
async with util.background_task(fakes.channel.some_func()):
# check sent request
message = await fakes.transport.receive_from_jsonrpc()

# send invalid error response
await fakes.transport.send_to_jsonrpc(
id=message["id"], error={"code": "not int"}
)
# send invalid error response
await fakes.transport.send_to_jsonrpc(
id=message["id"], error={"code": "not int"}
)

# call with matching ID should still fail
with pytest.raises(jsonrpc.ProtocolError) as execinfo:
await call
assert '"error.code" must be an integer' in str(execinfo.value)
await asyncio.sleep(0.5)

# call with matching ID should still fail
assert '"error.code" must be an integer' in str(execinfo.value)


@contextlib.asynccontextmanager
Expand Down Expand Up @@ -385,14 +386,15 @@ async def test_pending_call_is_canceled_on_shutdown():
transport = FakeTransport()
channel = jsonrpc.Channel(transport.send_to_test, transport.receive_from_test())
async with util.background_task(channel.communicate_forever()) as com_task:
async with util.background_task(channel.some_func()) as call:
# wait for sent message
await transport.receive_from_jsonrpc()
await util.cancel_tasks([com_task])

with pytest.raises(RuntimeError) as execinfo:
await call
assert "Connection closed" in str(execinfo.value)

with pytest.raises(RuntimeError) as execinfo:
async with util.background_task(channel.some_func()):
# wait for sent message
await transport.receive_from_jsonrpc()
await util.cancel_tasks([com_task])
await asyncio.sleep(0.5)

assert "Connection closed" in str(execinfo.value)


async def test_log_error_response_with_unknown_id(fakes):
Expand Down
26 changes: 26 additions & 0 deletions tests/test_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import asyncio

import pytest

import not_my_board._util as util


async def test_background_task():
async with util.background_task(blocking_task()) as task:
assert not task.done()
assert task.cancelled()


async def test_background_task_failed():
with pytest.raises(RuntimeError) as execinfo:
async with util.background_task(failing_task()):
await asyncio.sleep(3)
assert "Dummy Error" in str(execinfo.value)


async def blocking_task():
await asyncio.Event().wait()


async def failing_task():
raise RuntimeError("Dummy Error")

0 comments on commit ee4ce4a

Please sign in to comment.