Skip to content

Commit 953cf67

Browse files
committed
Fix task cancellation propagation to subtasks when using sync middleware
1 parent 3f0147d commit 953cf67

File tree

2 files changed

+199
-17
lines changed

2 files changed

+199
-17
lines changed

asgiref/sync.py

Lines changed: 45 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,10 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R:
203203
# `main_wrap`.
204204
context = [contextvars.copy_context()]
205205

206+
# Get task context so that parent task knows which task to propagate
207+
# an asyncio.CancelledError to.
208+
task_context = getattr(SyncToAsync.threadlocal, "task_context", None)
209+
206210
loop = None
207211
# Use call_soon_threadsafe to schedule a synchronous callback on the
208212
# main event loop's thread if it's there, otherwise make a new loop
@@ -211,6 +215,7 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R:
211215
awaitable = self.main_wrap(
212216
call_result,
213217
sys.exc_info(),
218+
task_context,
214219
context,
215220
*args,
216221
**kwargs,
@@ -295,6 +300,7 @@ async def main_wrap(
295300
self,
296301
call_result: "Future[_R]",
297302
exc_info: "OptExcInfo",
303+
task_context: "Optional[List[asyncio.Task[Any]]]",
298304
context: List[contextvars.Context],
299305
*args: _P.args,
300306
**kwargs: _P.kwargs,
@@ -309,6 +315,10 @@ async def main_wrap(
309315
if context is not None:
310316
_restore_context(context[0])
311317

318+
current_task = asyncio.current_task()
319+
if current_task is not None and task_context is not None:
320+
task_context.append(current_task)
321+
312322
try:
313323
# If we have an exception, run the function inside the except block
314324
# after raising it so exc_info is correctly populated.
@@ -324,6 +334,8 @@ async def main_wrap(
324334
else:
325335
call_result.set_result(result)
326336
finally:
337+
if current_task is not None and task_context is not None:
338+
task_context.remove(current_task)
327339
context[0] = contextvars.copy_context()
328340

329341

@@ -437,20 +449,38 @@ async def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R:
437449
context = contextvars.copy_context()
438450
child = functools.partial(self.func, *args, **kwargs)
439451
func = context.run
440-
452+
task_context: List[asyncio.Task[Any]] = []
453+
454+
# Run the code in the right thread
455+
exec_coro = loop.run_in_executor(
456+
executor,
457+
functools.partial(
458+
self.thread_handler,
459+
loop,
460+
sys.exc_info(),
461+
task_context,
462+
func,
463+
child,
464+
),
465+
)
466+
ret: _R
441467
try:
442-
# Run the code in the right thread
443-
ret: _R = await loop.run_in_executor(
444-
executor,
445-
functools.partial(
446-
self.thread_handler,
447-
loop,
448-
sys.exc_info(),
449-
func,
450-
child,
451-
),
452-
)
453-
468+
ret = await asyncio.shield(exec_coro)
469+
except asyncio.CancelledError:
470+
cancel_parent = True
471+
try:
472+
task = task_context[0]
473+
task.cancel()
474+
try:
475+
await task
476+
cancel_parent = False
477+
except asyncio.CancelledError:
478+
pass
479+
except IndexError:
480+
pass
481+
if cancel_parent:
482+
exec_coro.cancel()
483+
ret = await exec_coro
454484
finally:
455485
_restore_context(context)
456486
self.deadlock_context.set(False)
@@ -466,7 +496,7 @@ def __get__(
466496
func = functools.partial(self.__call__, parent)
467497
return functools.update_wrapper(func, self.func)
468498

469-
def thread_handler(self, loop, exc_info, func, *args, **kwargs):
499+
def thread_handler(self, loop, exc_info, task_context, func, *args, **kwargs):
470500
"""
471501
Wraps the sync application with exception handling.
472502
"""
@@ -476,6 +506,7 @@ def thread_handler(self, loop, exc_info, func, *args, **kwargs):
476506
# Set the threadlocal for AsyncToSync
477507
self.threadlocal.main_event_loop = loop
478508
self.threadlocal.main_event_loop_pid = os.getpid()
509+
self.threadlocal.task_context = task_context
479510

480511
# Run the function
481512
# If we have an exception, run the function inside the except block

tests/test_sync.py

Lines changed: 154 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -852,13 +852,10 @@ def sync_task():
852852

853853

854854
@pytest.mark.asyncio
855-
@pytest.mark.skip(reason="deadlocks")
856855
async def test_inner_shield_sync_middleware():
857856
"""
858857
Tests that asyncio.shield is capable of preventing http.disconnect from
859858
cancelling a django request task when using sync middleware.
860-
861-
Currently this tests is skipped as it causes a deadlock.
862859
"""
863860

864861
# Hypothetical Django scenario - middleware function is sync
@@ -968,3 +965,157 @@ async def async_task():
968965
assert task_complete
969966

970967
assert task_executed
968+
969+
970+
@pytest.mark.asyncio
971+
async def test_inner_shield_sync_and_async_middleware():
972+
"""
973+
Tests that asyncio.shield is capable of preventing http.disconnect from
974+
cancelling a django request task when using sync and middleware chained
975+
together.
976+
"""
977+
978+
# Hypothetical Django scenario - middleware function is sync
979+
def sync_middleware_1():
980+
async_to_sync(async_middleware_2)()
981+
982+
# Hypothetical Django scenario - middleware function is async
983+
async def async_middleware_2():
984+
await sync_to_async(sync_middleware_3)()
985+
986+
# Hypothetical Django scenario - middleware function is sync
987+
def sync_middleware_3():
988+
async_to_sync(async_middleware_4)()
989+
990+
# Hypothetical Django scenario - middleware function is async
991+
async def async_middleware_4():
992+
await sync_to_async(sync_middleware_5)()
993+
994+
# Hypothetical Django scenario - middleware function is sync
995+
def sync_middleware_5():
996+
async_to_sync(async_view)()
997+
998+
task_complete = False
999+
task_cancel_caught = False
1000+
1001+
# Future that completes when subtask cancellation attempt is caught
1002+
task_blocker = asyncio.Future()
1003+
1004+
async def async_view():
1005+
"""Async view with a task that is shielded from cancellation."""
1006+
nonlocal task_complete, task_cancel_caught, task_blocker
1007+
task = asyncio.create_task(async_task())
1008+
try:
1009+
await asyncio.shield(task)
1010+
except asyncio.CancelledError:
1011+
task_cancel_caught = True
1012+
task_blocker.set_result(True)
1013+
await task
1014+
task_complete = True
1015+
1016+
task_executed = False
1017+
1018+
# Future that completes after subtask is created
1019+
task_started_future = asyncio.Future()
1020+
1021+
async def async_task():
1022+
"""Async subtask that should not be canceled when parent is canceled."""
1023+
nonlocal task_started_future, task_executed, task_blocker
1024+
task_started_future.set_result(True)
1025+
await task_blocker
1026+
task_executed = True
1027+
1028+
task_cancel_propagated = False
1029+
1030+
async with ThreadSensitiveContext():
1031+
task = asyncio.create_task(sync_to_async(sync_middleware_1)())
1032+
await task_started_future
1033+
task.cancel()
1034+
try:
1035+
await task
1036+
except asyncio.CancelledError:
1037+
task_cancel_propagated = True
1038+
assert not task_cancel_propagated
1039+
assert task_cancel_caught
1040+
assert task_complete
1041+
1042+
assert task_executed
1043+
1044+
1045+
@pytest.mark.asyncio
1046+
async def test_inner_shield_sync_and_async_middleware_sync_task():
1047+
"""
1048+
Tests that asyncio.shield is capable of preventing http.disconnect from
1049+
cancelling a django request task when using sync and middleware chained
1050+
together with an async view calling a sync calling an async task through
1051+
a sync parent.
1052+
"""
1053+
1054+
# Hypothetical Django scenario - middleware function is sync
1055+
def sync_middleware_1():
1056+
async_to_sync(async_middleware_2)()
1057+
1058+
# Hypothetical Django scenario - middleware function is async
1059+
async def async_middleware_2():
1060+
await sync_to_async(sync_middleware_3)()
1061+
1062+
# Hypothetical Django scenario - middleware function is sync
1063+
def sync_middleware_3():
1064+
async_to_sync(async_middleware_4)()
1065+
1066+
# Hypothetical Django scenario - middleware function is async
1067+
async def async_middleware_4():
1068+
await sync_to_async(sync_middleware_5)()
1069+
1070+
# Hypothetical Django scenario - middleware function is sync
1071+
def sync_middleware_5():
1072+
async_to_sync(async_view)()
1073+
1074+
task_complete = False
1075+
task_cancel_caught = False
1076+
1077+
# Future that completes when subtask cancellation attempt is caught
1078+
task_blocker = asyncio.Future()
1079+
1080+
async def async_view():
1081+
"""Async view with a task that is shielded from cancellation."""
1082+
nonlocal task_complete, task_cancel_caught, task_blocker
1083+
task = asyncio.create_task(sync_to_async(sync_parent)())
1084+
try:
1085+
await asyncio.shield(task)
1086+
except asyncio.CancelledError:
1087+
task_cancel_caught = True
1088+
task_blocker.set_result(True)
1089+
await task
1090+
task_complete = True
1091+
1092+
task_executed = False
1093+
1094+
# Future that completes after subtask is created
1095+
task_started_future = asyncio.Future()
1096+
1097+
def sync_parent():
1098+
async_to_sync(async_task)()
1099+
1100+
async def async_task():
1101+
"""Async subtask that should not be canceled when parent is canceled."""
1102+
nonlocal task_started_future, task_executed, task_blocker
1103+
task_started_future.set_result(True)
1104+
await task_blocker
1105+
task_executed = True
1106+
1107+
task_cancel_propagated = False
1108+
1109+
async with ThreadSensitiveContext():
1110+
task = asyncio.create_task(sync_to_async(sync_middleware_1)())
1111+
await task_started_future
1112+
task.cancel()
1113+
try:
1114+
await task
1115+
except asyncio.CancelledError:
1116+
task_cancel_propagated = True
1117+
assert not task_cancel_propagated
1118+
assert task_cancel_caught
1119+
assert task_complete
1120+
1121+
assert task_executed

0 commit comments

Comments
 (0)