Skip to content

Commit

Permalink
split ThunderFunction to deallocate grad_outs while computing backward
Browse files Browse the repository at this point in the history
  • Loading branch information
t-vi committed Dec 22, 2024
1 parent c392c35 commit e543eff
Show file tree
Hide file tree
Showing 6 changed files with 61 additions and 16 deletions.
9 changes: 7 additions & 2 deletions thunder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@
)
from thunder.core.interpreter import print_interpreter_log, print_to_log
from thunder.core.jit_ext import thunder_general_jit
from thunder.executors.torch_autograd import split_forward_backward, ThunderFunction
from thunder.executors.torch_autograd import split_forward_backward, ThunderFunction1, ThunderFunction2

# NOTE This import is intentionally pytorch so that it thunder.torch doesn't import this
import torch as pytorch
Expand Down Expand Up @@ -751,14 +751,19 @@ def maybe_connect_to_autograd(cache_entry, result):
# resulting tensors to PyTorch's Autograd graph using the
# ThunderFunction (which is a torch.autograd.Function subclass)
data_for_autograd, (saved_tensors, saved_other) = result
ThunderFunction.apply(
side_channel = {}
dummy_res = ThunderFunction1.apply(
cache_entry.return_none_instead_of_grads,
cache_entry.backward_fn,
side_channel,
saved_tensors,
saved_other,
data_for_autograd["flat_output"],
*data_for_autograd["flat_args"],
)
# we need to pass the inputs to avoid "leave has moved inside the graph"
# if the function returns an argument as is
ThunderFunction2.apply(dummy_res, side_channel, *data_for_autograd["flat_args"])
result = data_for_autograd["output"]

return result
Expand Down
52 changes: 46 additions & 6 deletions thunder/executors/torch_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,24 @@ def rename_bwd_trace_outputs(bwd_trace: TraceCtx, fwd_trace: TraceCtx) -> TraceC
return renamed_bwd_trace


class ThunderFunction(torch.autograd.Function):
# We split the autograd.Function into two parts because this allows
# the args to the ThunderFunction2.backward to go out of scope
# and the tensors (the grad_outs matching the flattened output) to be
# deallocated when they have been processed by the compiled backward function.
# For the correspondence between the functions hidden from autograd, we use
# a side channel (an empt dict) passed as an argument. To link the two
# functions in autograd, we use a dummy tensor on the meta device.
class ThunderFunction1(torch.autograd.Function):
@staticmethod
def forward(
ctx, return_none_instead_of_grads, compiled_backward, saved_tensors, saved_other, flat_output, *flat_args
ctx,
return_none_instead_of_grads,
compiled_backward,
side_channel,
saved_tensors,
saved_other,
flat_output,
*flat_args,
):
# Here we just propagate the tensors through the autograd graph
ctx.return_none_instead_of_grads = return_none_instead_of_grads
Expand All @@ -85,17 +99,22 @@ def detach_if_tensor(t):
return t

saved_tensors = tuple(map(detach_if_tensor, saved_tensors))
assert not side_channel
ctx.side_channel = side_channel
ctx.side_channel["fw"] = flat_output

# We must save tensors using ctx.save_for_backward
ctx.save_for_backward(*saved_tensors)
return flat_output
return torch.randn(1, device="meta", requires_grad=True)

# NOTE: If `torch.autograd.function.once_differentiable` is to be removed,
# one must take care of correctly removing the `detach_if_tensor` above.
# For more context, see NOTE [Saved view of output of torch.autograd.Function leaks] above.
@staticmethod
@torch.autograd.function.once_differentiable
def backward(ctx, *args):
def backward(ctx, _):
args = ctx.side_channel.pop("bw")
assert not ctx.side_channel
# ctx.saved_tensors is a tuple of tensors saved in forward. Our compiled
# backward is a really long function that takes all the tensors saved in
# forward and gradually uses them to compute the gradients of the
Expand All @@ -114,16 +133,33 @@ def backward(ctx, *args):
ctx.maybe_clear_saved_tensors() # Delete the reference to all saved tensors in the context
grads = ctx.compiled_backward([saved_tensors_list, ctx.saved_other], args)

assert not args
# Inside the compiled backward we must clear the saved_tensors_list
assert not saved_tensors_list, "saved_tensors_list must be empty after calling compiled_backward"
# TODO(crcrpar): Remove if-else once `dist_prims.stash_grad_for_fsdp` starts to return `None`
# NOTE(crcrpar): In fsdp no-sync, unsharded gradients are attached and accumulated to their parameters as the attr of `_thunder_fsdp_unsharded_grad` in order to avoid shape mismatch of a param and its grad. When exiting the no_sync context, the accumulated, unsharded gradients are reduce-scattered into the attr of `grad` and `_thunder_fsdp_unsharded_grad` is removed.
if not ctx.return_none_instead_of_grads:
return (None, None, None, None, None, *grads)
return (None, None, None, None, None, None, *grads)
else:
n_grads = len(grads)
del grads
return (None, None, None, None, None, *([None] * n_grads))
return (None, None, None, None, None, None, *([None] * n_grads))


class ThunderFunction2(torch.autograd.Function):
@staticmethod
def forward(ctx, dummy, side_channel, *args):
ctx.side_channel = side_channel
ctx.num_args = len(args)
res = ctx.side_channel.pop("fw")
assert not ctx.side_channel
return res

@staticmethod
def backward(ctx, *args):
assert not ctx.side_channel
ctx.side_channel["bw"] = list(args)
return torch.randn(1, device="meta"), None, *([None] * ctx.num_args)


def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stats, /, *flat_args):
Expand Down Expand Up @@ -345,4 +381,8 @@ def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stat
# We only want the forward function to be called with `te.fp8_autocast` manager.
bw_extrace._include_te_fp8_autocast = False

if len(bw_extrace.bound_symbols) == 1:
# only return, no unpacking, so no gradient is calculated
bw_extrace = None

return fw_extrace, bw_extrace
4 changes: 2 additions & 2 deletions thunder/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2510,8 +2510,8 @@ def foo2(x):
return x + 1

x = torch.randn(3, 3, requires_grad=True)
thunder.jit(foo2)(x).sum().backward()
assert x.grad is None
res = thunder.jit(foo2)(x)
assert not res.requires_grad

# Test `no_grad` ctx correctly disable gradient computation
def foo3(x):
Expand Down
4 changes: 2 additions & 2 deletions thunder/tests/test_dynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def func(x):

# out should have grad_fn and its name should be ThunderFunctionBackward
assert out.grad_fn is not None
assert out.grad_fn.name() == "ThunderFunctionBackward"
assert out.grad_fn.name() == "ThunderFunction2Backward"

# We record the GraphModules that was compiled by ThunderCompiler
backend = compiled._backend
Expand Down Expand Up @@ -317,7 +317,7 @@ def func(x):

# out should have grad_fn and its name should be ThunderFunctionBackward
assert out.grad_fn is not None
assert out.grad_fn.name() == "ThunderFunctionBackward"
assert out.grad_fn.name() == "ThunderFunction2Backward"

backend = cfunc._backend
# We record the GraphModules that was compiled by ThunderCompiler
Expand Down
6 changes: 3 additions & 3 deletions thunder/tests/test_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -1752,10 +1752,10 @@ def f(x, y):

# With activation checkpointing, we are saving only the original input.
# The intermediate values are recomputed during backward pass.
assert len(out.grad_fn.saved_tensors) == 2
assert len(out.grad_fn.next_functions[0][0].saved_tensors) == 2
# We detach the saved tensors (which returns a new Python tensor backed by same storage)
assert out.grad_fn.saved_tensors[0].data_ptr() == x.data_ptr()
assert out.grad_fn.saved_tensors[1].data_ptr() == y.data_ptr()
assert out.grad_fn.next_functions[0][0].saved_tensors[0].data_ptr() == x.data_ptr()
assert out.grad_fn.next_functions[0][0].saved_tensors[1].data_ptr() == y.data_ptr()

g = torch.ones_like(out)
out.backward(g)
Expand Down
2 changes: 1 addition & 1 deletion thunder/tests/test_jit_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -1194,7 +1194,7 @@ def forward(self, x):
x = torch.randn((2, 2), dtype=torch.float64, requires_grad=True)
model = Model().to(dtype=torch.float64)
jitted = thunder.jit(model)
gradcheck(jitted, (x,))
gradcheck(jitted, (x,), check_batched_grad=False)

jitted.zero_grad()
x = torch.randn((2, 2), dtype=torch.float64, requires_grad=True)
Expand Down

0 comments on commit e543eff

Please sign in to comment.