Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

try removing ltorch.copy_ #1209

Closed
wants to merge 6 commits into from
Closed

try removing ltorch.copy_ #1209

wants to merge 6 commits into from

Conversation

crcrpar
Copy link
Collaborator

@crcrpar crcrpar commented Sep 27, 2024

It seems that I've added ltorch.copy_ in #1063 but I don't remember exactly why I did so.
I even am speculating that was an unwanted change with some cost.
So in this PR, I just want to try reverting the change and see if things can be simplified.

@t-vi
Copy link
Collaborator

t-vi commented Sep 28, 2024

isn't torch.Tensor.copy_ a legit method?

@crcrpar
Copy link
Collaborator Author

crcrpar commented Sep 28, 2024

We don't seem to have solid way to support the op with decent complexity in implementation and cost of compile time and performance while it doesn't seem that we have many use at the moment.

@crcrpar crcrpar force-pushed the crpa/remove_copy_from_ltorch branch from baf3ce4 to 6785879 Compare October 14, 2024 15:39
@crcrpar crcrpar marked this pull request as ready for review October 15, 2024 04:28
@crcrpar crcrpar force-pushed the crpa/remove_copy_from_ltorch branch from 6785879 to b857701 Compare November 4, 2024 16:44
crcrpar and others added 3 commits November 13, 2024 02:30
@jjsjann123
Copy link
Collaborator

How's this branch going?

@crcrpar
Copy link
Collaborator Author

crcrpar commented Nov 13, 2024

still we could write an equivalent function even without ltorch.copy_ using prims.copy_ but that function would not be executed in eager mode.

BTW, would it be possible to audit prims.copy_s in a region on nvfuser executor side?

Another alternative could be:

diff --git a/thunder/core/functionalization.py b/thunder/core/functionalization.py
index aac6f150..a91e502a 100644
--- a/thunder/core/functionalization.py
+++ b/thunder/core/functionalization.py
@@ -315,6 +315,40 @@ def canonicalize_bsym_args(
     return intermediate_trace, reverse_swap_map


+def audit_raw_copy(computation_trace: Trace, copy_bsyms: list[BoundSymbol]) -> Trace:
+    producer_map, consumer_map = producers(computation_trace), consumers(computation_trace)
+    copy_outs, copy_dsts = [], []
+    bsym_to_id: dict[BoundSymbol, int] = {bsym: i for i, bsym in enumerate(computation_trace.bound_symbols)}
+    copy_bsym_to_filter: list[BoundSymbol] = []
+
+    bsym: BoundSymbol
+    for bsym in copy_bsyms:
+        idx_of_bsym = bsym_to_id[bsym]
+        out = bsym.flat_proxy_outs[0]
+        check(out not in consumer_map, lambda: f"`prims.copy_` output {out=} is used inside a trace")
+        dst = bsym.flat_proxy_args[1]
+        if dst not in consumer_map:
+            continue
+        consumer_of_dst = tuple(filter(lambda bsym: bsym.sym.id != prims.PrimIDs.RETURN and bsym_to_id[bsym] > idx_of_bsym, consumer_map[dst]))
+        if not consumer_of_dst:
+            continue
+        check(
+            all(bsym.sym.id == prims.PrimIDs.COPY_ for bsym in consumer_of_dst),
+            lambda: f"copy destination of {dst} has consumers other than {prims.PrimIDs.RETURN} and {prims.PrimIDs.COPY_}",
+        )
+        copy_bsym_to_filter.append(bsym)
+
+    if not copy_bsym_to_filter:
+        return computation_trace
+
+    trace = from_trace(computation_trace)
+    set_of_redundant_copy_bsym = set(copy_bsym_to_filter)
+    trace.bound_symbols.extend(list(filter(lambda bsym: bsym not in set_of_redundant_copy_bsym, computation_trace.bound_symbols)))
+    trace.set_provenance(TraceProvenance("`prims.copy_` audit"))
+
+    return trace
+
+
 def create_functional_bsym_from(inplace_bsym: BoundSymbol) -> BoundSymbol:
     from thunder.torch import _inplace_to_out_of_place, setitem_, setitem

@@ -911,6 +945,8 @@ def functionalize_inplace_ops(
     """

     if not any(is_functionalizable(bsym) for bsym in computation_trace.bound_symbols):
+        if (copy_bsyms := [bsym for bsym in computation_trace.bound_symbols if bsym.sym.id == prims.PrimIDs.COPY_]):
+            return [audit_raw_copy(computation_trace, copy_bsyms)]
         return []

     # Step 0:
diff --git a/thunder/tests/test_inplace_copy.py b/thunder/tests/test_inplace_copy.py
index e94a6624..0190e35f 100644
--- a/thunder/tests/test_inplace_copy.py
+++ b/thunder/tests/test_inplace_copy.py
@@ -182,9 +182,5 @@ def test_inplace_copy_dst_copy_returned_issue_1109(executor, device, dtype):

     assert_close(t0, expected)
     assert_close(actual_t2, expected_t2)
-    # FIXME(crcrpar): Since there's no `ltorch.Tensor.copy_`, functions like `func` would not
-    # be observed and executed with pytorch eager mode. Though there should be either an audit of
-    # `prims.copy_` in a nvfuser region and/or what #1110 did.
-    assert actual_t1.data_ptr() == actual_t2.data_ptr()
-    with pytest.raises(AssertionError):
-        assert_close(actual_t1, expected_t1)
+    assert actual_t1.data_ptr() != actual_t2.data_ptr()
+    assert_close(actual_t1, expected_t1)

Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
@@ -2071,8 +2071,7 @@ def copy_(
) -> Any:
nvcopy_from = getnv(copy_from, fd, lc_to_nv_map)
nvcopy_to = getnv(copy_to, fd, lc_to_nv_map)
alias_output = fd.ops.set(nvcopy_from)
fd.add_output(alias_output, alias_input=nvcopy_to)
fd.add_output(nvcopy_from, alias_input=nvcopy_to)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we supposed to do this? I'm worried that this will give us wrong program.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR is composed with the revert of #1209 and the test update.
I thought fd.ops.set is something ideally we want to avoid as per #1173

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry about not being more clear in #1173

We still wanted to have the set for copy_. I think the root-cause with issue #1173 is that we are returning nvcopy_from in the fusion region.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why isn't there a copy op?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a copy op, the difference is, whether we are returning a different buffer, or the aliased source.

I think the example below would help.

The thunder program below behaves differently, depends on what's the return value.

import thunder
import torch
from thunder.core import prims

def foo(x):
    x_relu = prims.abs(x)
    old_x = prims.copy_(x_relu, x)
    return old_x  # this returns x, the aliased source from the program.
    #return x_relu  # this returns the x_relu, which should be in a different buffer.

jfoo = thunder.jit(foo)

x = torch.randn(2, 4, device="cuda")
o = jfoo(x)

print(thunder.last_traces(jfoo)[-1])
print(x)
print(o)
print(x.data_ptr())
print(o.data_ptr())

They are translated as nvfuser programs correspondingly.

import torch
from nvfuser import FusionDefinition, DataType

def nvfuser_fusion_id4(fd : FusionDefinition) -> None :
    T0 = fd.define_tensor(shape=[2, 4], contiguity=[True, True], dtype=DataType.Float, is_cpu=False, stride_order=[1, 0])
    T1 = fd.ops.neg(T0)
    T2 = fd.ops.set(T1)
    fd.add_output(T2, T0)
    fd.add_output(T2)
    # fd.add_output(T1)  # doing this would cause us to return an extra buffer.

The part of copy_ is translated to

T2 = fd.ops.set(T1)  # this is the extra copy we added.
fd.add_output(T2, T0)  # the extra copy is used to overwrite T0.

When we return T2 as output of the fusion, since T2 is marked as alias to T0 (sharing storage), we are not returning an extra buffer here.

However, the story changes when we return T1 from the fusion. Since T1 is a new buffer in the program. You can play with this using vvv:

with FusionDefinition() as fd:
    nvfuser_fusion_id4(fd)

inputs = [
    torch.testing.make_tensor((2, 4), dtype=torch.float32, device='cuda:0'),
]

print(inputs)
o = fd.execute(inputs)
print(inputs)

print(o[0].data_ptr())
print(inputs[0].data_ptr())

You'll notice that without the extra fd.ops.set in copy_, nvfuser won't be able to faithfully represent the intended thunder behavior.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay, then I don't think this PR deserves a merge.

f0eac73: compile 54s, 11.07ms
82dc7a7: compile 97s, 14.28ms

# FIXME(crcrpar): Since there's no `ltorch.Tensor.copy_`, functions like `func` would not
# be observed and executed with pytorch eager mode. Though there should be either an audit of
# `prims.copy_` in a nvfuser region and/or what #1110 did.
assert actual_t1.data_ptr() == actual_t2.data_ptr()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for putting a FIXME there. but this is the case that nvfuser is producing wrong result in func?

one of the output will still be aliasing the input, and that's not what the program should be translated. i.e. I don't think an audit that removes redundant prims.copy_ would be able to solve that.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the record, I'm all in for having a remove redundant inplace copy, since that's a net improvement. But I'm not sure it would resolve the issue with performance regression that we are seeing right now.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If a callable has multiple in-place ops, then the redundant copies would be cleaned up. We just don't have an equivalent for multiple raw prims.copy_s

crcrpar and others added 2 commits November 19, 2024 12:27
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
@crcrpar crcrpar marked this pull request as draft November 19, 2024 11:30
@crcrpar crcrpar closed this Nov 20, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants