Skip to content

Commit

Permalink
try removing ltorch.copy_
Browse files Browse the repository at this point in the history
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
  • Loading branch information
crcrpar committed Sep 27, 2024
1 parent 59467aa commit baf3ce4
Show file tree
Hide file tree
Showing 3 changed files with 1 addition and 33 deletions.
3 changes: 1 addition & 2 deletions thunder/executors/nvfuserex_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2068,8 +2068,7 @@ def copy_(
) -> Any:
nvcopy_from = getnv(copy_from, fd, lc_to_nv_map)
nvcopy_to = getnv(copy_to, fd, lc_to_nv_map)
alias_output = fd.ops.set(nvcopy_from)
fd.add_output(alias_output, alias_input=nvcopy_to)
fd.add_output(nvcopy_from, alias_input=nvcopy_to)
return nvcopy_to


Expand Down
26 changes: 0 additions & 26 deletions thunder/tests/test_inplace_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,29 +152,3 @@ def func3(x, y):
match=r"If you are sure you don't want to use this check, it can be disabled by setting `disable_inplace_copy_check=True` in `thunder.jit`.$",
):
traced_foo(a, b)


@instantiate(executors=(nvFuserExecutor,), dtypes=(thunder.float32,))
def test_inplace_copy_dst_copy_returned_issue_1109(executor, device, dtype):
def func(T0):
T1 = torch.sin(T0)
T0.copy_(T1) # destination.copy_(source)
T2 = torch.cos(T1)
T0.copy_(T2)
# T1 & T2 should be returned as separate buffer, instead of sharing
# storage with T0
return T1, T2

tdtype = ttorch.to_torch_dtype(dtype)
# This pattern is unsafe in general. Disabling sanity check to silence
# exception for testing
traced_foo = executor.make_callable(func, disable_inplace_copy_check=True)
a = make_tensor((4, 4), device=device, dtype=tdtype)
a_ref = a.clone()

o_thunder = traced_foo(a)
o_eager = func(a_ref)

assert_close(a_ref, a)
for o, o_ref in zip(o_thunder, o_eager):
assert_close(o, o_ref)
5 changes: 0 additions & 5 deletions thunder/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1917,11 +1917,6 @@ def copysign_(a, b, /):
return prims.copy_(copysign(a, b), a)


@torchsymbol(torch.Tensor.copy_, is_method=True) # , tags=(prims.OpTags.IN_PLACE,))
def copy_(a, b, /):
return prims.copy_(b, a)


# TODO Implement div
@torchsymbol(torch.div, is_method=True)
def div(
Expand Down

0 comments on commit baf3ce4

Please sign in to comment.