@@ -343,6 +343,7 @@ def __init__(self, deadline: float = math.inf, shield: bool = False):
343
343
self ._deadline = deadline
344
344
self ._shield = shield
345
345
self ._parent_scope : CancelScope | None = None
346
+ self ._child_scopes : set [CancelScope ] = set ()
346
347
self ._cancel_called = False
347
348
self ._cancelled_caught = False
348
349
self ._active = False
@@ -369,6 +370,9 @@ def __enter__(self) -> CancelScope:
369
370
else :
370
371
self ._parent_scope = task_state .cancel_scope
371
372
task_state .cancel_scope = self
373
+ if self ._parent_scope is not None :
374
+ self ._parent_scope ._child_scopes .add (self )
375
+ self ._parent_scope ._tasks .remove (host_task )
372
376
373
377
self ._timeout ()
374
378
self ._active = True
@@ -377,7 +381,7 @@ def __enter__(self) -> CancelScope:
377
381
378
382
# Start cancelling the host task if the scope was cancelled before entering
379
383
if self ._cancel_called :
380
- self ._deliver_cancellation ()
384
+ self ._deliver_cancellation (self )
381
385
382
386
return self
383
387
@@ -409,13 +413,15 @@ def __exit__(
409
413
self ._timeout_handle = None
410
414
411
415
self ._tasks .remove (self ._host_task )
416
+ if self ._parent_scope is not None :
417
+ self ._parent_scope ._child_scopes .remove (self )
418
+ self ._parent_scope ._tasks .add (self ._host_task )
412
419
413
420
host_task_state .cancel_scope = self ._parent_scope
414
421
415
422
# Restart the cancellation effort in the farthest directly cancelled parent
416
423
# scope if this one was shielded
417
- if self ._shield :
418
- self ._deliver_cancellation_to_parent ()
424
+ self ._restart_cancellation_in_parent ()
419
425
420
426
if self ._cancel_called and exc_val is not None :
421
427
for exc in iterate_exceptions (exc_val ):
@@ -451,65 +457,67 @@ def _timeout(self) -> None:
451
457
else :
452
458
self ._timeout_handle = loop .call_at (self ._deadline , self ._timeout )
453
459
454
- def _deliver_cancellation (self ) -> None :
460
+ def _deliver_cancellation (self , origin : CancelScope ) -> bool :
455
461
"""
456
462
Deliver cancellation to directly contained tasks and nested cancel scopes.
457
463
458
464
Schedule another run at the end if we still have tasks eligible for
459
465
cancellation.
466
+
467
+ :param origin: the cancel scope that originated the cancellation
468
+ :return: ``True`` if the delivery needs to be retried on the next cycle
469
+
460
470
"""
461
471
should_retry = False
462
472
current = current_task ()
463
473
for task in self ._tasks :
464
474
if task ._must_cancel : # type: ignore[attr-defined]
465
475
continue
466
476
467
- # The task is eligible for cancellation if it has started and is not in a
468
- # cancel scope shielded from this one
469
- cancel_scope = _task_states [task ].cancel_scope
470
- while cancel_scope is not self :
471
- if cancel_scope is None or cancel_scope ._shield :
472
- break
473
- else :
474
- cancel_scope = cancel_scope ._parent_scope
475
- else :
476
- should_retry = True
477
- if task is not current and (
478
- task is self ._host_task or _task_started (task )
479
- ):
480
- waiter = task ._fut_waiter # type: ignore[attr-defined]
481
- if not isinstance (waiter , asyncio .Future ) or not waiter .done ():
482
- self ._cancel_calls += 1
483
- if sys .version_info >= (3 , 9 ):
484
- task .cancel (f"Cancelled by cancel scope { id (self ):x} " )
485
- else :
486
- task .cancel ()
477
+ # The task is eligible for cancellation if it has started
478
+ should_retry = True
479
+ if task is not current and (task is self ._host_task or _task_started (task )):
480
+ waiter = task ._fut_waiter # type: ignore[attr-defined]
481
+ if not isinstance (waiter , asyncio .Future ) or not waiter .done ():
482
+ self ._cancel_calls += 1
483
+ if sys .version_info >= (3 , 9 ):
484
+ task .cancel (f"Cancelled by cancel scope { id (origin ):x} " )
485
+ else :
486
+ task .cancel ()
487
+
488
+ # Deliver cancellation to child scopes that aren't shielded or running their own
489
+ # cancellation callbacks
490
+ for scope in self ._child_scopes :
491
+ if not scope ._shield and not scope .cancel_called :
492
+ should_retry = scope ._deliver_cancellation (origin ) or should_retry
487
493
488
494
# Schedule another callback if there are still tasks left
489
- if should_retry :
490
- self ._cancel_handle = get_running_loop ().call_soon (
491
- self ._deliver_cancellation
492
- )
493
- else :
494
- self ._cancel_handle = None
495
+ if origin is self :
496
+ if should_retry :
497
+ self ._cancel_handle = get_running_loop ().call_soon (
498
+ self ._deliver_cancellation , origin
499
+ )
500
+ else :
501
+ self ._cancel_handle = None
502
+
503
+ return should_retry
495
504
496
- def _deliver_cancellation_to_parent (self ) -> None :
497
- """Start cancellation effort in the farthest directly cancelled parent scope"""
505
+ def _restart_cancellation_in_parent (self ) -> None :
506
+ """Start cancellation effort in the closest directly cancelled parent scope"""
498
507
scope = self ._parent_scope
499
- scope_to_cancel : CancelScope | None = None
500
508
while scope is not None :
501
- if scope ._cancel_called and scope ._cancel_handle is None :
502
- scope_to_cancel = scope
509
+ if scope ._cancel_called :
510
+ if scope ._cancel_handle is None :
511
+ scope ._deliver_cancellation (scope )
512
+
513
+ break
503
514
504
515
# No point in looking beyond any shielded scope
505
516
if scope ._shield :
506
517
break
507
518
508
519
scope = scope ._parent_scope
509
520
510
- if scope_to_cancel is not None :
511
- scope_to_cancel ._deliver_cancellation ()
512
-
513
521
def _parent_cancelled (self ) -> bool :
514
522
# Check whether any parent has been cancelled
515
523
cancel_scope = self ._parent_scope
@@ -529,7 +537,7 @@ def cancel(self) -> None:
529
537
530
538
self ._cancel_called = True
531
539
if self ._host_task is not None :
532
- self ._deliver_cancellation ()
540
+ self ._deliver_cancellation (self )
533
541
534
542
@property
535
543
def deadline (self ) -> float :
@@ -562,7 +570,7 @@ def shield(self, value: bool) -> None:
562
570
if self ._shield != value :
563
571
self ._shield = value
564
572
if not value :
565
- self ._deliver_cancellation_to_parent ()
573
+ self ._restart_cancellation_in_parent ()
566
574
567
575
568
576
#
@@ -623,6 +631,7 @@ def __init__(self) -> None:
623
631
self .cancel_scope : CancelScope = CancelScope ()
624
632
self ._active = False
625
633
self ._exceptions : list [BaseException ] = []
634
+ self ._tasks : set [asyncio .Task ] = set ()
626
635
627
636
async def __aenter__ (self ) -> TaskGroup :
628
637
self .cancel_scope .__enter__ ()
@@ -642,9 +651,9 @@ async def __aexit__(
642
651
self ._exceptions .append (exc_val )
643
652
644
653
cancelled_exc_while_waiting_tasks : CancelledError | None = None
645
- while self .cancel_scope . _tasks :
654
+ while self ._tasks :
646
655
try :
647
- await asyncio .wait (self .cancel_scope . _tasks )
656
+ await asyncio .wait (self ._tasks )
648
657
except CancelledError as exc :
649
658
# This task was cancelled natively; reraise the CancelledError later
650
659
# unless this task was already interrupted by another exception
@@ -676,8 +685,11 @@ def _spawn(
676
685
task_status_future : asyncio .Future | None = None ,
677
686
) -> asyncio .Task :
678
687
def task_done (_task : asyncio .Task ) -> None :
679
- assert _task in self .cancel_scope ._tasks
680
- self .cancel_scope ._tasks .remove (_task )
688
+ task_state = _task_states [_task ]
689
+ assert task_state .cancel_scope is not None
690
+ assert _task in task_state .cancel_scope ._tasks
691
+ task_state .cancel_scope ._tasks .remove (_task )
692
+ self ._tasks .remove (task )
681
693
del _task_states [_task ]
682
694
683
695
try :
@@ -693,7 +705,8 @@ def task_done(_task: asyncio.Task) -> None:
693
705
if not isinstance (exc , CancelledError ):
694
706
self ._exceptions .append (exc )
695
707
696
- self .cancel_scope .cancel ()
708
+ if not self .cancel_scope ._parent_cancelled ():
709
+ self .cancel_scope .cancel ()
697
710
else :
698
711
task_status_future .set_exception (exc )
699
712
elif task_status_future is not None and not task_status_future .done ():
@@ -732,6 +745,7 @@ def task_done(_task: asyncio.Task) -> None:
732
745
parent_id = parent_id , cancel_scope = self .cancel_scope
733
746
)
734
747
self .cancel_scope ._tasks .add (task )
748
+ self ._tasks .add (task )
735
749
return task
736
750
737
751
def start_soon (
0 commit comments