Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

grad rule for copy_with_setitem #1322

Merged
merged 4 commits into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions thunder/core/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1414,6 +1414,26 @@ def _maximum_grad(a: TensorProxy, b: TensorProxy, /):
# This operation creates no grad associations
register_grad(pids.SHAPE, prims.shape)


def _copy_with_setitem_grad(a: TensorProxy, index, value: Number | TensorProxy):
fwd = prims.copy_with_setitem(a, index, value)
g = get_grad(fwd)

a_grad = prims.copy_with_setitem(g, index, 0)
put_grad(a, a_grad)

if isinstance(value, TensorProxy):
value_grad = g[index]
expanded_dims = value_grad.ndim - value.ndim
if expanded_dims > 0:
value_grad = prims.sum(value_grad, tuple(range(expanded_dims)))
put_grad(value, value_grad)

return fwd


register_grad(pids.COPY_WITH_SETITEM, _copy_with_setitem_grad)

#
# Phantom grad transform helpers
#
Expand Down
62 changes: 52 additions & 10 deletions thunder/tests/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,18 +239,60 @@ def foo():
tfoo()


def test_setitem():
def fn(a):
a[:3] = 2
@pytest.mark.parametrize("requires_grad", (True, False))
def test_setitem(requires_grad):
t-vi marked this conversation as resolved.
Show resolved Hide resolved

def _test_forward_and_backward(fn, a, value):
a_ref = a.detach().clone()
a_ref.requires_grad_(a.requires_grad)

if isinstance(value, torch.Tensor):
value_ref = value.detach().clone()
value_ref.requires_grad_(value.requires_grad)
else:
value_ref = value

out_ref = fn(a_ref, value_ref)
jf = thunder.jit(fn)
out = jf(a, value)
assert_close(a, a_ref)
assert_close(out, out_ref)

if requires_grad:
g = torch.randn_like(out)
inputs = (a, value) if isinstance(value, torch.Tensor) else (a,)
actual_grad = torch.autograd.grad(out, inputs, g)

inputs_ref = (a_ref, value_ref) if isinstance(value, torch.Tensor) else (a_ref,)
expected_grad = torch.autograd.grad(out_ref, inputs_ref, g)
assert_close(actual_grad, expected_grad)

def clone_if_requires_grad(a):
if requires_grad:
# Withou the clone
# PyTorch eager errors with
# `RuntimeError: a view of a leaf Variable that requires grad is being used in an in-place operation.`
# and thunder has silent correctness issue - https://github.com/Lightning-AI/lightning-thunder/issues/1284
return a.clone()
return a

def fn(a, value):
a = clone_if_requires_grad(a)
a[:3] = value
return a * 2

a_ref = torch.ones(5)
out_ref = fn(a_ref)
a = torch.ones(5)
jf = thunder.jit(fn)
out = jf(a)
assert_close(a, a_ref)
assert_close(out, out_ref)
# set value: scalar
_test_forward_and_backward(fn, torch.randn(5, requires_grad=requires_grad), 2.0)

# set value: tensor which needs to be broadcasted
_test_forward_and_backward(
fn, torch.randn(5, requires_grad=requires_grad), torch.tensor(2.0, requires_grad=requires_grad)
)

# set value: tensor of same rank
_test_forward_and_backward(
fn, torch.randn(5, requires_grad=requires_grad), torch.tensor([1.0, 2.0, 3.0], requires_grad=requires_grad)
)


# TODO: Add random operator support to OpInfo
Expand Down
Loading