Skip to content

Commit

Permalink
bring back test case
Browse files Browse the repository at this point in the history
  • Loading branch information
crcrpar committed Nov 12, 2024
1 parent f574939 commit b952e2a
Showing 1 changed file with 27 additions and 0 deletions.
27 changes: 27 additions & 0 deletions thunder/tests/test_inplace_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from torch.testing import assert_close, make_tensor

import thunder
from thunder.core import prims
import thunder.core.dtypes as datatypes
import thunder.torch as ttorch
from thunder.tests.framework import instantiate, nvFuserExecutor
Expand Down Expand Up @@ -152,3 +153,29 @@ def func3(x, y):
match=r"If you are sure you don't want to use this check, it can be disabled by setting `disable_inplace_copy_check=True` in `thunder.jit`.$",
):
traced_foo(a, b)


@instantiate(executors=(nvFuserExecutor,), dtypes=(thunder.float32,))
def test_inplace_copy_dst_copy_returned_issue_1109(executor, device, dtype):
def func(T0):
T1 = torch.sin(T0)
prims.copy_(T1, T0)
T2 = torch.cos(T1)
prims.copy_(T2, T0)
# T1 & T2 should be returned as separate buffer, instead of sharing
# storage with T0
return T1, T2

tdtype = ttorch.to_torch_dtype(dtype)
# This pattern is unsafe in general. Disabling sanity check to silence
# exception for testing
traced_foo = executor.make_callable(func, disable_inplace_copy_check=True)
a = make_tensor((4, 4), device=device, dtype=tdtype)
a_ref = a.clone()

o_thunder = traced_foo(a)
o_eager = func(a_ref)

assert_close(a_ref, a)
for o, o_ref in zip(o_thunder, o_eager):
assert_close(o, o_ref)

0 comments on commit b952e2a

Please sign in to comment.