diff --git a/thunder/core/transforms.py b/thunder/core/transforms.py index a858a391a1..aeda1e0c3c 100644 --- a/thunder/core/transforms.py +++ b/thunder/core/transforms.py @@ -3157,6 +3157,8 @@ def backward_fn(saved_for_backward, cotangents): enable_saved_for_backward_recomputation: None | bool = get_compile_option( "enable_saved_for_backward_recomputation", "Enable save for backward tensors recomputation." ) + if enable_saved_for_backward_recomputation is None: + enable_saved_for_backward_recomputation = True if enable_saved_for_backward_recomputation: forward_trace, backward_trace = recompute_saved_for_backward(forward_trace, backward_trace) @@ -3195,6 +3197,9 @@ def recompute_saved_for_backward(fwd_trace: Trace, bwd_trace: Trace) -> tuple[Tr if thunder.core.proxies.ProxyTag.RECOMPUTE_IN_BACKWARD in thunder.core.proxies.unvariableify(p).tags } + if not rematerializable: + return fwd_trace, bwd_trace + producers = find_producer_symbols( fwd_trace, tuple(unvariableify(i) for i in rematerializable), diff --git a/thunder/tests/test_grad.py b/thunder/tests/test_grad.py index 4520f517b1..6e054f6907 100644 --- a/thunder/tests/test_grad.py +++ b/thunder/tests/test_grad.py @@ -1753,9 +1753,10 @@ def f(x, y): # With activation checkpointing, we are saving only the original input. # The intermediate values are recomputed during backward pass. assert len(out.grad_fn.saved_tensors) == 2 + # We detach the saved tensors (which returns a new Python tensor backed by same storage) - assert out.grad_fn.saved_tensors[0].data_ptr() == x.data_ptr() - assert out.grad_fn.saved_tensors[1].data_ptr() == y.data_ptr() + # the order seems to be non-deterministic sometimes + assert {t.data_ptr() for t in out.grad_fn.saved_tensors} == {x.data_ptr(), y.data_ptr()} g = torch.ones_like(out) out.backward(g) @@ -1768,6 +1769,49 @@ def f(x, y): torch.testing.assert_close(y.grad, y_ref.grad) +@requiresCUDA +def test_checkpoint_max_memory(): + import torch.utils.checkpoint + + class Checkpoint(torch.nn.Module): + def __init__(self, module): + super().__init__() + self.module = module + + def forward(self, *args): + return torch.utils.checkpoint.checkpoint(self.module, *args, use_reentrant=False) + + with torch.device("cuda:0"): + m = torch.nn.Sequential( + torch.nn.Linear(1024, 16), + torch.nn.ReLU(), + *[ + Checkpoint( + torch.nn.Sequential( + torch.nn.Linear(16, 2048), + torch.nn.Linear(2048, 16), + torch.nn.ReLU(), + ) + ) + for _ in range(10) + ], + torch.nn.Linear(16, 1024), + ) + inps = torch.randn(512, 1024, requires_grad=True) + + jm = thunder.jit(m, executors=()) # no rematerialization + mem_base = torch.cuda.memory_allocated() + torch.cuda.reset_accumulated_memory_stats() + res = jm(inps) + res.sum().backward() + mem_max = torch.cuda.max_memory_allocated() + # the rematerialization pass moved all(?) recomputation to the front, + # making the peak mem about 46MB. + # With checkpointing as coded in the model and recomputation where the + # values are used, we get about 12MB, so we put the barrier at 16MB + assert mem_max - mem_base < 16 * 2**20 + + def test_inconsistent_output_length_grad_transform(): from thunder.extend import OperatorExecutor from thunder.core.proxies import AnyProxy, TensorProxy