Skip to content

Commit 238a340

Browse files
committed
Fixed cancellation propagation when task group host is in a shielded scope
Fixes #642.
1 parent 84c1bb0 commit 238a340

File tree

3 files changed

+85
-45
lines changed

3 files changed

+85
-45
lines changed

docs/versionhistory.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@ This library adheres to `Semantic Versioning 2.0 <http://semver.org/>`_.
1313
from Egor Blagov)
1414
- Fixed ``loop_factory`` and ``use_uvloop`` options not being used on the asyncio
1515
backend (`#643 <https://github.com/agronholm/anyio/issues/643>`_)
16+
- Fixed cancellation propagating on asyncio from a task group to child tasks if the task
17+
hosting the task group is in a shielded cancel scope
18+
(`#642 <https://github.com/agronholm/anyio/issues/642>`_)
1619

1720
**4.1.0**
1821

src/anyio/_backends/_asyncio.py

Lines changed: 59 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,7 @@ def __init__(self, deadline: float = math.inf, shield: bool = False):
343343
self._deadline = deadline
344344
self._shield = shield
345345
self._parent_scope: CancelScope | None = None
346+
self._child_scopes: set[CancelScope] = set()
346347
self._cancel_called = False
347348
self._cancelled_caught = False
348349
self._active = False
@@ -369,6 +370,9 @@ def __enter__(self) -> CancelScope:
369370
else:
370371
self._parent_scope = task_state.cancel_scope
371372
task_state.cancel_scope = self
373+
if self._parent_scope is not None:
374+
self._parent_scope._child_scopes.add(self)
375+
self._parent_scope._tasks.remove(host_task)
372376

373377
self._timeout()
374378
self._active = True
@@ -377,7 +381,7 @@ def __enter__(self) -> CancelScope:
377381

378382
# Start cancelling the host task if the scope was cancelled before entering
379383
if self._cancel_called:
380-
self._deliver_cancellation()
384+
self._deliver_cancellation(self)
381385

382386
return self
383387

@@ -409,13 +413,15 @@ def __exit__(
409413
self._timeout_handle = None
410414

411415
self._tasks.remove(self._host_task)
416+
if self._parent_scope is not None:
417+
self._parent_scope._child_scopes.remove(self)
418+
self._parent_scope._tasks.add(self._host_task)
412419

413420
host_task_state.cancel_scope = self._parent_scope
414421

415422
# Restart the cancellation effort in the farthest directly cancelled parent
416423
# scope if this one was shielded
417-
if self._shield:
418-
self._deliver_cancellation_to_parent()
424+
self._restart_cancellation_in_parent()
419425

420426
if self._cancel_called and exc_val is not None:
421427
for exc in iterate_exceptions(exc_val):
@@ -451,65 +457,67 @@ def _timeout(self) -> None:
451457
else:
452458
self._timeout_handle = loop.call_at(self._deadline, self._timeout)
453459

454-
def _deliver_cancellation(self) -> None:
460+
def _deliver_cancellation(self, origin: CancelScope) -> bool:
455461
"""
456462
Deliver cancellation to directly contained tasks and nested cancel scopes.
457463
458464
Schedule another run at the end if we still have tasks eligible for
459465
cancellation.
466+
467+
:param origin: the cancel scope that originated the cancellation
468+
:return: ``True`` if the delivery needs to be retried on the next cycle
469+
460470
"""
461471
should_retry = False
462472
current = current_task()
463473
for task in self._tasks:
464474
if task._must_cancel: # type: ignore[attr-defined]
465475
continue
466476

467-
# The task is eligible for cancellation if it has started and is not in a
468-
# cancel scope shielded from this one
469-
cancel_scope = _task_states[task].cancel_scope
470-
while cancel_scope is not self:
471-
if cancel_scope is None or cancel_scope._shield:
472-
break
473-
else:
474-
cancel_scope = cancel_scope._parent_scope
475-
else:
476-
should_retry = True
477-
if task is not current and (
478-
task is self._host_task or _task_started(task)
479-
):
480-
waiter = task._fut_waiter # type: ignore[attr-defined]
481-
if not isinstance(waiter, asyncio.Future) or not waiter.done():
482-
self._cancel_calls += 1
483-
if sys.version_info >= (3, 9):
484-
task.cancel(f"Cancelled by cancel scope {id(self):x}")
485-
else:
486-
task.cancel()
477+
# The task is eligible for cancellation if it has started
478+
should_retry = True
479+
if task is not current and (task is self._host_task or _task_started(task)):
480+
waiter = task._fut_waiter # type: ignore[attr-defined]
481+
if not isinstance(waiter, asyncio.Future) or not waiter.done():
482+
self._cancel_calls += 1
483+
if sys.version_info >= (3, 9):
484+
task.cancel(f"Cancelled by cancel scope {id(origin):x}")
485+
else:
486+
task.cancel()
487+
488+
# Deliver cancellation to child scopes that aren't shielded or running their own
489+
# cancellation callbacks
490+
for scope in self._child_scopes:
491+
if not scope._shield and not scope.cancel_called:
492+
should_retry = scope._deliver_cancellation(origin) or should_retry
487493

488494
# Schedule another callback if there are still tasks left
489-
if should_retry:
490-
self._cancel_handle = get_running_loop().call_soon(
491-
self._deliver_cancellation
492-
)
493-
else:
494-
self._cancel_handle = None
495+
if origin is self:
496+
if should_retry:
497+
self._cancel_handle = get_running_loop().call_soon(
498+
self._deliver_cancellation, origin
499+
)
500+
else:
501+
self._cancel_handle = None
502+
503+
return should_retry
495504

496-
def _deliver_cancellation_to_parent(self) -> None:
497-
"""Start cancellation effort in the farthest directly cancelled parent scope"""
505+
def _restart_cancellation_in_parent(self) -> None:
506+
"""Start cancellation effort in the closest directly cancelled parent scope"""
498507
scope = self._parent_scope
499-
scope_to_cancel: CancelScope | None = None
500508
while scope is not None:
501-
if scope._cancel_called and scope._cancel_handle is None:
502-
scope_to_cancel = scope
509+
if scope._cancel_called:
510+
if scope._cancel_handle is None:
511+
scope._deliver_cancellation(scope)
512+
513+
break
503514

504515
# No point in looking beyond any shielded scope
505516
if scope._shield:
506517
break
507518

508519
scope = scope._parent_scope
509520

510-
if scope_to_cancel is not None:
511-
scope_to_cancel._deliver_cancellation()
512-
513521
def _parent_cancelled(self) -> bool:
514522
# Check whether any parent has been cancelled
515523
cancel_scope = self._parent_scope
@@ -529,7 +537,7 @@ def cancel(self) -> None:
529537

530538
self._cancel_called = True
531539
if self._host_task is not None:
532-
self._deliver_cancellation()
540+
self._deliver_cancellation(self)
533541

534542
@property
535543
def deadline(self) -> float:
@@ -562,7 +570,7 @@ def shield(self, value: bool) -> None:
562570
if self._shield != value:
563571
self._shield = value
564572
if not value:
565-
self._deliver_cancellation_to_parent()
573+
self._restart_cancellation_in_parent()
566574

567575

568576
#
@@ -623,6 +631,7 @@ def __init__(self) -> None:
623631
self.cancel_scope: CancelScope = CancelScope()
624632
self._active = False
625633
self._exceptions: list[BaseException] = []
634+
self._tasks: set[asyncio.Task] = set()
626635

627636
async def __aenter__(self) -> TaskGroup:
628637
self.cancel_scope.__enter__()
@@ -642,9 +651,9 @@ async def __aexit__(
642651
self._exceptions.append(exc_val)
643652

644653
cancelled_exc_while_waiting_tasks: CancelledError | None = None
645-
while self.cancel_scope._tasks:
654+
while self._tasks:
646655
try:
647-
await asyncio.wait(self.cancel_scope._tasks)
656+
await asyncio.wait(self._tasks)
648657
except CancelledError as exc:
649658
# This task was cancelled natively; reraise the CancelledError later
650659
# unless this task was already interrupted by another exception
@@ -676,8 +685,11 @@ def _spawn(
676685
task_status_future: asyncio.Future | None = None,
677686
) -> asyncio.Task:
678687
def task_done(_task: asyncio.Task) -> None:
679-
assert _task in self.cancel_scope._tasks
680-
self.cancel_scope._tasks.remove(_task)
688+
task_state = _task_states[_task]
689+
assert task_state.cancel_scope is not None
690+
assert _task in task_state.cancel_scope._tasks
691+
task_state.cancel_scope._tasks.remove(_task)
692+
self._tasks.remove(task)
681693
del _task_states[_task]
682694

683695
try:
@@ -693,7 +705,8 @@ def task_done(_task: asyncio.Task) -> None:
693705
if not isinstance(exc, CancelledError):
694706
self._exceptions.append(exc)
695707

696-
self.cancel_scope.cancel()
708+
if not self.cancel_scope._parent_cancelled():
709+
self.cancel_scope.cancel()
697710
else:
698711
task_status_future.set_exception(exc)
699712
elif task_status_future is not None and not task_status_future.done():
@@ -732,6 +745,7 @@ def task_done(_task: asyncio.Task) -> None:
732745
parent_id=parent_id, cancel_scope=self.cancel_scope
733746
)
734747
self.cancel_scope._tasks.add(task)
748+
self._tasks.add(task)
735749
return task
736750

737751
def start_soon(

tests/test_taskgroups.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1293,6 +1293,29 @@ def handler(excgrp: BaseExceptionGroup) -> None:
12931293
await anyio.sleep_forever()
12941294

12951295

1296+
async def test_cancel_child_task_when_host_is_shielded() -> None:
1297+
# Regression test for #642
1298+
# Tests that cancellation propagates to a child task even if the host task is within
1299+
# a shielded cancel scope.
1300+
cancelled = anyio.Event()
1301+
1302+
async def wait_cancel() -> None:
1303+
try:
1304+
await anyio.sleep_forever()
1305+
except anyio.get_cancelled_exc_class():
1306+
cancelled.set()
1307+
raise
1308+
1309+
with CancelScope() as parent_scope:
1310+
async with anyio.create_task_group() as task_group:
1311+
task_group.start_soon(wait_cancel)
1312+
await wait_all_tasks_blocked()
1313+
1314+
with CancelScope(shield=True), fail_after(1):
1315+
parent_scope.cancel()
1316+
await cancelled.wait()
1317+
1318+
12961319
class TestTaskStatusTyping:
12971320
"""
12981321
These tests do not do anything at run time, but since the test suite is also checked

0 commit comments

Comments
 (0)