diff --git a/torchtitan/distributed/deepep/deepep.py b/torchtitan/distributed/deepep/deepep.py index ce44fc232e..d6c82916c4 100644 --- a/torchtitan/distributed/deepep/deepep.py +++ b/torchtitan/distributed/deepep/deepep.py @@ -146,7 +146,7 @@ def _dispatch_backward( if grad_recv_x is None: return None, None, None, None, None, None, None - handle = _handle_cache.get(ctx.cache_id_int) + handle = ctx.saved_handle assert handle is not None, f"Handle not found for cache_id={ctx.cache_id_int}" previous_event = _create_event_if_async(True) @@ -161,7 +161,6 @@ def _dispatch_backward( ) _sync_stream_if_async(True, after_event) - _handle_cache.pop(ctx.cache_id_int, None) grad_x = grad_x.to(ctx.input_dtype) grad_topk_weights = ( @@ -176,6 +175,7 @@ def _dispatch_setup_context(ctx, inputs, output): recv_x, recv_indices, recv_scores, num_recv, cache_id = output ctx.cache_id_int = cache_id.item() ctx.input_dtype = x.dtype + ctx.saved_handle = _handle_cache.get(ctx.cache_id_int) def _combine_backward(ctx, grad_combined): @@ -207,7 +207,8 @@ def _combine_backward(ctx, grad_combined): def _combine_setup_context(ctx, inputs, output): x, cache_id = inputs ctx.cache_id_int = cache_id.item() - ctx.saved_handle = _handle_cache.get(ctx.cache_id_int) + # Pop from cache - safe because dispatch_setup_context already saved handle for dispatch_backward + ctx.saved_handle = _handle_cache.pop(ctx.cache_id_int, None) torch.library.register_autograd(