From e543eff09189ab5cc5a1c35ca42aea438f048038 Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Sun, 22 Dec 2024 21:22:28 +0100 Subject: [PATCH] split ThunderFunction to deallocate grad_outs while computing backward --- thunder/__init__.py | 9 +++-- thunder/executors/torch_autograd.py | 52 +++++++++++++++++++++++++---- thunder/tests/test_core.py | 4 +-- thunder/tests/test_dynamo.py | 4 +-- thunder/tests/test_grad.py | 6 ++-- thunder/tests/test_jit_general.py | 2 +- 6 files changed, 61 insertions(+), 16 deletions(-) diff --git a/thunder/__init__.py b/thunder/__init__.py index eceacc4cad..8f0fcb0a5d 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -72,7 +72,7 @@ ) from thunder.core.interpreter import print_interpreter_log, print_to_log from thunder.core.jit_ext import thunder_general_jit -from thunder.executors.torch_autograd import split_forward_backward, ThunderFunction +from thunder.executors.torch_autograd import split_forward_backward, ThunderFunction1, ThunderFunction2 # NOTE This import is intentionally pytorch so that it thunder.torch doesn't import this import torch as pytorch @@ -751,14 +751,19 @@ def maybe_connect_to_autograd(cache_entry, result): # resulting tensors to PyTorch's Autograd graph using the # ThunderFunction (which is a torch.autograd.Function subclass) data_for_autograd, (saved_tensors, saved_other) = result - ThunderFunction.apply( + side_channel = {} + dummy_res = ThunderFunction1.apply( cache_entry.return_none_instead_of_grads, cache_entry.backward_fn, + side_channel, saved_tensors, saved_other, data_for_autograd["flat_output"], *data_for_autograd["flat_args"], ) + # we need to pass the inputs to avoid "leave has moved inside the graph" + # if the function returns an argument as is + ThunderFunction2.apply(dummy_res, side_channel, *data_for_autograd["flat_args"]) result = data_for_autograd["output"] return result diff --git a/thunder/executors/torch_autograd.py b/thunder/executors/torch_autograd.py index ce9497125b..5ac1cab8df 100644 --- a/thunder/executors/torch_autograd.py +++ b/thunder/executors/torch_autograd.py @@ -60,10 +60,24 @@ def rename_bwd_trace_outputs(bwd_trace: TraceCtx, fwd_trace: TraceCtx) -> TraceC return renamed_bwd_trace -class ThunderFunction(torch.autograd.Function): +# We split the autograd.Function into two parts because this allows +# the args to the ThunderFunction2.backward to go out of scope +# and the tensors (the grad_outs matching the flattened output) to be +# deallocated when they have been processed by the compiled backward function. +# For the correspondence between the functions hidden from autograd, we use +# a side channel (an empt dict) passed as an argument. To link the two +# functions in autograd, we use a dummy tensor on the meta device. +class ThunderFunction1(torch.autograd.Function): @staticmethod def forward( - ctx, return_none_instead_of_grads, compiled_backward, saved_tensors, saved_other, flat_output, *flat_args + ctx, + return_none_instead_of_grads, + compiled_backward, + side_channel, + saved_tensors, + saved_other, + flat_output, + *flat_args, ): # Here we just propagate the tensors through the autograd graph ctx.return_none_instead_of_grads = return_none_instead_of_grads @@ -85,17 +99,22 @@ def detach_if_tensor(t): return t saved_tensors = tuple(map(detach_if_tensor, saved_tensors)) + assert not side_channel + ctx.side_channel = side_channel + ctx.side_channel["fw"] = flat_output # We must save tensors using ctx.save_for_backward ctx.save_for_backward(*saved_tensors) - return flat_output + return torch.randn(1, device="meta", requires_grad=True) # NOTE: If `torch.autograd.function.once_differentiable` is to be removed, # one must take care of correctly removing the `detach_if_tensor` above. # For more context, see NOTE [Saved view of output of torch.autograd.Function leaks] above. @staticmethod @torch.autograd.function.once_differentiable - def backward(ctx, *args): + def backward(ctx, _): + args = ctx.side_channel.pop("bw") + assert not ctx.side_channel # ctx.saved_tensors is a tuple of tensors saved in forward. Our compiled # backward is a really long function that takes all the tensors saved in # forward and gradually uses them to compute the gradients of the @@ -114,16 +133,33 @@ def backward(ctx, *args): ctx.maybe_clear_saved_tensors() # Delete the reference to all saved tensors in the context grads = ctx.compiled_backward([saved_tensors_list, ctx.saved_other], args) + assert not args # Inside the compiled backward we must clear the saved_tensors_list assert not saved_tensors_list, "saved_tensors_list must be empty after calling compiled_backward" # TODO(crcrpar): Remove if-else once `dist_prims.stash_grad_for_fsdp` starts to return `None` # NOTE(crcrpar): In fsdp no-sync, unsharded gradients are attached and accumulated to their parameters as the attr of `_thunder_fsdp_unsharded_grad` in order to avoid shape mismatch of a param and its grad. When exiting the no_sync context, the accumulated, unsharded gradients are reduce-scattered into the attr of `grad` and `_thunder_fsdp_unsharded_grad` is removed. if not ctx.return_none_instead_of_grads: - return (None, None, None, None, None, *grads) + return (None, None, None, None, None, None, *grads) else: n_grads = len(grads) del grads - return (None, None, None, None, None, *([None] * n_grads)) + return (None, None, None, None, None, None, *([None] * n_grads)) + + +class ThunderFunction2(torch.autograd.Function): + @staticmethod + def forward(ctx, dummy, side_channel, *args): + ctx.side_channel = side_channel + ctx.num_args = len(args) + res = ctx.side_channel.pop("fw") + assert not ctx.side_channel + return res + + @staticmethod + def backward(ctx, *args): + assert not ctx.side_channel + ctx.side_channel["bw"] = list(args) + return torch.randn(1, device="meta"), None, *([None] * ctx.num_args) def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stats, /, *flat_args): @@ -345,4 +381,8 @@ def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stat # We only want the forward function to be called with `te.fp8_autocast` manager. bw_extrace._include_te_fp8_autocast = False + if len(bw_extrace.bound_symbols) == 1: + # only return, no unpacking, so no gradient is calculated + bw_extrace = None + return fw_extrace, bw_extrace diff --git a/thunder/tests/test_core.py b/thunder/tests/test_core.py index bf6e9bd7d3..0905389008 100644 --- a/thunder/tests/test_core.py +++ b/thunder/tests/test_core.py @@ -2510,8 +2510,8 @@ def foo2(x): return x + 1 x = torch.randn(3, 3, requires_grad=True) - thunder.jit(foo2)(x).sum().backward() - assert x.grad is None + res = thunder.jit(foo2)(x) + assert not res.requires_grad # Test `no_grad` ctx correctly disable gradient computation def foo3(x): diff --git a/thunder/tests/test_dynamo.py b/thunder/tests/test_dynamo.py index 4bcd2333f4..b9dd10c79f 100644 --- a/thunder/tests/test_dynamo.py +++ b/thunder/tests/test_dynamo.py @@ -65,7 +65,7 @@ def func(x): # out should have grad_fn and its name should be ThunderFunctionBackward assert out.grad_fn is not None - assert out.grad_fn.name() == "ThunderFunctionBackward" + assert out.grad_fn.name() == "ThunderFunction2Backward" # We record the GraphModules that was compiled by ThunderCompiler backend = compiled._backend @@ -317,7 +317,7 @@ def func(x): # out should have grad_fn and its name should be ThunderFunctionBackward assert out.grad_fn is not None - assert out.grad_fn.name() == "ThunderFunctionBackward" + assert out.grad_fn.name() == "ThunderFunction2Backward" backend = cfunc._backend # We record the GraphModules that was compiled by ThunderCompiler diff --git a/thunder/tests/test_grad.py b/thunder/tests/test_grad.py index 9b68b37e7f..73e3453e9a 100644 --- a/thunder/tests/test_grad.py +++ b/thunder/tests/test_grad.py @@ -1752,10 +1752,10 @@ def f(x, y): # With activation checkpointing, we are saving only the original input. # The intermediate values are recomputed during backward pass. - assert len(out.grad_fn.saved_tensors) == 2 + assert len(out.grad_fn.next_functions[0][0].saved_tensors) == 2 # We detach the saved tensors (which returns a new Python tensor backed by same storage) - assert out.grad_fn.saved_tensors[0].data_ptr() == x.data_ptr() - assert out.grad_fn.saved_tensors[1].data_ptr() == y.data_ptr() + assert out.grad_fn.next_functions[0][0].saved_tensors[0].data_ptr() == x.data_ptr() + assert out.grad_fn.next_functions[0][0].saved_tensors[1].data_ptr() == y.data_ptr() g = torch.ones_like(out) out.backward(g) diff --git a/thunder/tests/test_jit_general.py b/thunder/tests/test_jit_general.py index 0d578cbea0..e1be6d86b8 100644 --- a/thunder/tests/test_jit_general.py +++ b/thunder/tests/test_jit_general.py @@ -1194,7 +1194,7 @@ def forward(self, x): x = torch.randn((2, 2), dtype=torch.float64, requires_grad=True) model = Model().to(dtype=torch.float64) jitted = thunder.jit(model) - gradcheck(jitted, (x,)) + gradcheck(jitted, (x,), check_batched_grad=False) jitted.zero_grad() x = torch.randn((2, 2), dtype=torch.float64, requires_grad=True)