Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

try removing ltorch.copy_ #1209

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions thunder/executors/nvfuserex_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2071,8 +2071,7 @@ def copy_(
) -> Any:
nvcopy_from = getnv(copy_from, fd, lc_to_nv_map)
nvcopy_to = getnv(copy_to, fd, lc_to_nv_map)
alias_output = fd.ops.set(nvcopy_from)
fd.add_output(alias_output, alias_input=nvcopy_to)
fd.add_output(nvcopy_from, alias_input=nvcopy_to)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we supposed to do this? I'm worried that this will give us wrong program.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR is composed with the revert of #1209 and the test update.
I thought fd.ops.set is something ideally we want to avoid as per #1173

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry about not being more clear in #1173

We still wanted to have the set for copy_. I think the root-cause with issue #1173 is that we are returning nvcopy_from in the fusion region.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why isn't there a copy op?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a copy op, the difference is, whether we are returning a different buffer, or the aliased source.

I think the example below would help.

The thunder program below behaves differently, depends on what's the return value.

import thunder
import torch
from thunder.core import prims

def foo(x):
    x_relu = prims.abs(x)
    old_x = prims.copy_(x_relu, x)
    return old_x  # this returns x, the aliased source from the program.
    #return x_relu  # this returns the x_relu, which should be in a different buffer.

jfoo = thunder.jit(foo)

x = torch.randn(2, 4, device="cuda")
o = jfoo(x)

print(thunder.last_traces(jfoo)[-1])
print(x)
print(o)
print(x.data_ptr())
print(o.data_ptr())

They are translated as nvfuser programs correspondingly.

import torch
from nvfuser import FusionDefinition, DataType

def nvfuser_fusion_id4(fd : FusionDefinition) -> None :
    T0 = fd.define_tensor(shape=[2, 4], contiguity=[True, True], dtype=DataType.Float, is_cpu=False, stride_order=[1, 0])
    T1 = fd.ops.neg(T0)
    T2 = fd.ops.set(T1)
    fd.add_output(T2, T0)
    fd.add_output(T2)
    # fd.add_output(T1)  # doing this would cause us to return an extra buffer.

The part of copy_ is translated to

T2 = fd.ops.set(T1)  # this is the extra copy we added.
fd.add_output(T2, T0)  # the extra copy is used to overwrite T0.

When we return T2 as output of the fusion, since T2 is marked as alias to T0 (sharing storage), we are not returning an extra buffer here.

However, the story changes when we return T1 from the fusion. Since T1 is a new buffer in the program. You can play with this using vvv:

with FusionDefinition() as fd:
    nvfuser_fusion_id4(fd)

inputs = [
    torch.testing.make_tensor((2, 4), dtype=torch.float32, device='cuda:0'),
]

print(inputs)
o = fd.execute(inputs)
print(inputs)

print(o[0].data_ptr())
print(inputs[0].data_ptr())

You'll notice that without the extra fd.ops.set in copy_, nvfuser won't be able to faithfully represent the intended thunder behavior.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay, then I don't think this PR deserves a merge.

f0eac73: compile 54s, 11.07ms
82dc7a7: compile 97s, 14.28ms

return nvcopy_to


Expand Down
32 changes: 21 additions & 11 deletions thunder/tests/test_inplace_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from torch.testing import assert_close, make_tensor

import thunder
from thunder.core import prims
import thunder.core.dtypes as datatypes
import thunder.torch as ttorch
from thunder.tests.framework import instantiate, nvFuserExecutor
Expand Down Expand Up @@ -158,9 +159,9 @@ def func3(x, y):
def test_inplace_copy_dst_copy_returned_issue_1109(executor, device, dtype):
def func(T0):
T1 = torch.sin(T0)
T0.copy_(T1) # destination.copy_(source)
prims.copy_(T1, T0)
crcrpar marked this conversation as resolved.
Show resolved Hide resolved
T2 = torch.cos(T1)
T0.copy_(T2)
prims.copy_(T2, T0)
# T1 & T2 should be returned as separate buffer, instead of sharing
# storage with T0
return T1, T2
Expand All @@ -169,12 +170,21 @@ def func(T0):
# This pattern is unsafe in general. Disabling sanity check to silence
# exception for testing
traced_foo = executor.make_callable(func, disable_inplace_copy_check=True)
a = make_tensor((4, 4), device=device, dtype=tdtype)
a_ref = a.clone()

o_thunder = traced_foo(a)
o_eager = func(a_ref)

assert_close(a_ref, a)
for o, o_ref in zip(o_thunder, o_eager):
assert_close(o, o_ref)
t0 = make_tensor((4, 4), device=device, dtype=tdtype)
t0_ref = t0.clone()

actual_t1, actual_t2 = traced_foo(t0)

expected = t0_ref.sin().cos()
expected_t1 = t0_ref.sin()
expected_t2 = expected_t1.cos()
expected = expected_t2

assert_close(t0, expected)
assert_close(actual_t2, expected_t2)
# FIXME(crcrpar): Since there's no `ltorch.Tensor.copy_`, functions like `func` would not
# be observed and executed with pytorch eager mode. Though there should be either an audit of
# `prims.copy_` in a nvfuser region and/or what #1110 did.
assert actual_t1.data_ptr() == actual_t2.data_ptr()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for putting a FIXME there. but this is the case that nvfuser is producing wrong result in func?

one of the output will still be aliasing the input, and that's not what the program should be translated. i.e. I don't think an audit that removes redundant prims.copy_ would be able to solve that.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the record, I'm all in for having a remove redundant inplace copy, since that's a net improvement. But I'm not sure it would resolve the issue with performance regression that we are seeing right now.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If a callable has multiple in-place ops, then the redundant copies would be cleaned up. We just don't have an equivalent for multiple raw prims.copy_s

with pytest.raises(AssertionError):
assert_close(actual_t1, expected_t1)
5 changes: 0 additions & 5 deletions thunder/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1945,11 +1945,6 @@ def copysign_(a, b, /):
return prims.copy_(copysign(a, b), a)


@torchsymbol(torch.Tensor.copy_, is_method=True) # , tags=(prims.OpTags.IN_PLACE,))
def copy_(a, b, /):
return prims.copy_(b, a)


# TODO Implement div
@torchsymbol(torch.div, is_method=True)
def div(
Expand Down
Loading