From c5f8bf73f21f20be8e748bbf0801d62e69f4bd92 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Fri, 1 Nov 2024 17:13:09 +0200 Subject: [PATCH] Fix thunder.torch.checkpoint to support multiple arguments (#1391) --- thunder/tests/test_grad.py | 18 +++++++++++------- thunder/torch/__init__.py | 4 ++-- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/thunder/tests/test_grad.py b/thunder/tests/test_grad.py index a7b0898c7f..25f42ccb87 100644 --- a/thunder/tests/test_grad.py +++ b/thunder/tests/test_grad.py @@ -1704,8 +1704,8 @@ def test_torch_checkpoint(): import torch.utils.checkpoint import torch._higher_order_ops.wrap - def fn_to_checkpoint(x): - return x.sin().cos().exp() + def fn_to_checkpoint(x, y): + return x.sin().cos().exp().mul(y) checkpoint_fns = ( thunder.torch.checkpoint, @@ -1715,26 +1715,30 @@ def fn_to_checkpoint(x): for checkpoint_fn in checkpoint_fns: - def f(x): - return checkpoint_fn(fn_to_checkpoint, x) + def f(x, y): + return checkpoint_fn(fn_to_checkpoint, x, y) x = make_tensor((2, 2), device="cpu", dtype=torch.float32, requires_grad=True) + y = make_tensor((2, 2), device="cpu", dtype=torch.float32, requires_grad=True) jf = thunder.jit(f) - out = jf(x) + out = jf(x, y) # With activation checkpointing, we are saving only the original input. # The intermediate values are recomputed during backward pass. - assert len(out.grad_fn.saved_tensors) == 1 + assert len(out.grad_fn.saved_tensors) == 2 # We detach the saved tensors (which returns a new Python tensor backed by same storage) assert out.grad_fn.saved_tensors[0].data_ptr() == x.data_ptr() + assert out.grad_fn.saved_tensors[1].data_ptr() == y.data_ptr() g = torch.ones_like(out) out.backward(g) x_ref = x.detach().requires_grad_() - out_ref = fn_to_checkpoint(x_ref) + y_ref = y.detach().requires_grad_() + out_ref = fn_to_checkpoint(x_ref, y_ref) out_ref.backward(g) torch.testing.assert_close(x.grad, x_ref.grad) + torch.testing.assert_close(y.grad, y_ref.grad) def test_inconsistent_output_length_grad_transform(): diff --git a/thunder/torch/__init__.py b/thunder/torch/__init__.py index 4056dee2c9..b89508e080 100644 --- a/thunder/torch/__init__.py +++ b/thunder/torch/__init__.py @@ -5301,8 +5301,8 @@ def _backward_checkpoint( ) -> tuple[None | TensorLike, ...]: from thunder.core.transforms import vjp - result = vjp(function)(args, grad_outputs, **kwargs) - return result + _, grads = vjp(function)(args, grad_outputs, **kwargs) + return grads #