From 6dbb332ee34450b5556df6754a7943c882b7fa56 Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Thu, 17 Oct 2024 11:58:13 +0200 Subject: [PATCH 1/2] grad rule for copy_with_setitem --- thunder/core/transforms.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/thunder/core/transforms.py b/thunder/core/transforms.py index 4e81bf3261..a21f422a8a 100644 --- a/thunder/core/transforms.py +++ b/thunder/core/transforms.py @@ -1414,6 +1414,26 @@ def _maximum_grad(a: TensorProxy, b: TensorProxy, /): # This operation creates no grad associations register_grad(pids.SHAPE, prims.shape) + +def _copy_with_setitem_grad(a: TensorProxy, index, value: Number | TensorProxy): + fwd = prims.copy_with_setitem(a, index, value) + g = get_grad(fwd) + + a_grad = prims.copy_with_setitem(g, index, 0) + put_grad(a, a_grad) + + if isinstance(value, TensorProxy): + value_grad = g[index] + expanded_dims = value_grad.ndim - value.ndim + if expanded_dims > 0: + value_grad = prims.sum(value_grad, tuple(range(expanded_dims))) + put_grad(value, value_grad) + + return fwd + + +register_grad(pids.COPY_WITH_SETITEM, _copy_with_setitem_grad) + # # Phantom grad transform helpers # From fd8e2577cd2ed2f6e5813edf64870c9686e6d580 Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Mon, 28 Oct 2024 22:07:46 +0100 Subject: [PATCH 2/2] add test --- thunder/core/transforms.py | 2 +- thunder/tests/test_ops.py | 62 ++++++++++++++++++++++++++++++++------ 2 files changed, 53 insertions(+), 11 deletions(-) diff --git a/thunder/core/transforms.py b/thunder/core/transforms.py index a21f422a8a..50ee60971d 100644 --- a/thunder/core/transforms.py +++ b/thunder/core/transforms.py @@ -1427,7 +1427,7 @@ def _copy_with_setitem_grad(a: TensorProxy, index, value: Number | TensorProxy): expanded_dims = value_grad.ndim - value.ndim if expanded_dims > 0: value_grad = prims.sum(value_grad, tuple(range(expanded_dims))) - put_grad(value, value_grad) + put_grad(value, value_grad) return fwd diff --git a/thunder/tests/test_ops.py b/thunder/tests/test_ops.py index a92f9650a5..a588e94f3b 100644 --- a/thunder/tests/test_ops.py +++ b/thunder/tests/test_ops.py @@ -239,18 +239,60 @@ def foo(): tfoo() -def test_setitem(): - def fn(a): - a[:3] = 2 +@pytest.mark.parametrize("requires_grad", (True, False)) +def test_setitem(requires_grad): + + def _test_forward_and_backward(fn, a, value): + a_ref = a.detach().clone() + a_ref.requires_grad_(a.requires_grad) + + if isinstance(value, torch.Tensor): + value_ref = value.detach().clone() + value_ref.requires_grad_(value.requires_grad) + else: + value_ref = value + + out_ref = fn(a_ref, value_ref) + jf = thunder.jit(fn) + out = jf(a, value) + assert_close(a, a_ref) + assert_close(out, out_ref) + + if requires_grad: + g = torch.randn_like(out) + inputs = (a, value) if isinstance(value, torch.Tensor) else (a,) + actual_grad = torch.autograd.grad(out, inputs, g) + + inputs_ref = (a_ref, value_ref) if isinstance(value, torch.Tensor) else (a_ref,) + expected_grad = torch.autograd.grad(out_ref, inputs_ref, g) + assert_close(actual_grad, expected_grad) + + def clone_if_requires_grad(a): + if requires_grad: + # Withou the clone + # PyTorch eager errors with + # `RuntimeError: a view of a leaf Variable that requires grad is being used in an in-place operation.` + # and thunder has silent correctness issue - https://github.com/Lightning-AI/lightning-thunder/issues/1284 + return a.clone() + return a + + def fn(a, value): + a = clone_if_requires_grad(a) + a[:3] = value return a * 2 - a_ref = torch.ones(5) - out_ref = fn(a_ref) - a = torch.ones(5) - jf = thunder.jit(fn) - out = jf(a) - assert_close(a, a_ref) - assert_close(out, out_ref) + # set value: scalar + _test_forward_and_backward(fn, torch.randn(5, requires_grad=requires_grad), 2.0) + + # set value: tensor which needs to be broadcasted + _test_forward_and_backward( + fn, torch.randn(5, requires_grad=requires_grad), torch.tensor(2.0, requires_grad=requires_grad) + ) + + # set value: tensor of same rank + _test_forward_and_backward( + fn, torch.randn(5, requires_grad=requires_grad), torch.tensor([1.0, 2.0, 3.0], requires_grad=requires_grad) + ) # TODO: Add random operator support to OpInfo