-
Notifications
You must be signed in to change notification settings - Fork 84
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
try removing ltorch.copy_
#1209
Conversation
isn't torch.Tensor.copy_ a legit method? |
We don't seem to have solid way to support the op with decent complexity in implementation and cost of compile time and performance while it doesn't seem that we have many use at the moment. |
baf3ce4
to
6785879
Compare
6785879
to
b857701
Compare
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
for more information, see https://pre-commit.ci
df851fc
to
c25a908
Compare
How's this branch going? |
still we could write an equivalent function even without BTW, would it be possible to audit Another alternative could be: diff --git a/thunder/core/functionalization.py b/thunder/core/functionalization.py
index aac6f150..a91e502a 100644
--- a/thunder/core/functionalization.py
+++ b/thunder/core/functionalization.py
@@ -315,6 +315,40 @@ def canonicalize_bsym_args(
return intermediate_trace, reverse_swap_map
+def audit_raw_copy(computation_trace: Trace, copy_bsyms: list[BoundSymbol]) -> Trace:
+ producer_map, consumer_map = producers(computation_trace), consumers(computation_trace)
+ copy_outs, copy_dsts = [], []
+ bsym_to_id: dict[BoundSymbol, int] = {bsym: i for i, bsym in enumerate(computation_trace.bound_symbols)}
+ copy_bsym_to_filter: list[BoundSymbol] = []
+
+ bsym: BoundSymbol
+ for bsym in copy_bsyms:
+ idx_of_bsym = bsym_to_id[bsym]
+ out = bsym.flat_proxy_outs[0]
+ check(out not in consumer_map, lambda: f"`prims.copy_` output {out=} is used inside a trace")
+ dst = bsym.flat_proxy_args[1]
+ if dst not in consumer_map:
+ continue
+ consumer_of_dst = tuple(filter(lambda bsym: bsym.sym.id != prims.PrimIDs.RETURN and bsym_to_id[bsym] > idx_of_bsym, consumer_map[dst]))
+ if not consumer_of_dst:
+ continue
+ check(
+ all(bsym.sym.id == prims.PrimIDs.COPY_ for bsym in consumer_of_dst),
+ lambda: f"copy destination of {dst} has consumers other than {prims.PrimIDs.RETURN} and {prims.PrimIDs.COPY_}",
+ )
+ copy_bsym_to_filter.append(bsym)
+
+ if not copy_bsym_to_filter:
+ return computation_trace
+
+ trace = from_trace(computation_trace)
+ set_of_redundant_copy_bsym = set(copy_bsym_to_filter)
+ trace.bound_symbols.extend(list(filter(lambda bsym: bsym not in set_of_redundant_copy_bsym, computation_trace.bound_symbols)))
+ trace.set_provenance(TraceProvenance("`prims.copy_` audit"))
+
+ return trace
+
+
def create_functional_bsym_from(inplace_bsym: BoundSymbol) -> BoundSymbol:
from thunder.torch import _inplace_to_out_of_place, setitem_, setitem
@@ -911,6 +945,8 @@ def functionalize_inplace_ops(
"""
if not any(is_functionalizable(bsym) for bsym in computation_trace.bound_symbols):
+ if (copy_bsyms := [bsym for bsym in computation_trace.bound_symbols if bsym.sym.id == prims.PrimIDs.COPY_]):
+ return [audit_raw_copy(computation_trace, copy_bsyms)]
return []
# Step 0:
diff --git a/thunder/tests/test_inplace_copy.py b/thunder/tests/test_inplace_copy.py
index e94a6624..0190e35f 100644
--- a/thunder/tests/test_inplace_copy.py
+++ b/thunder/tests/test_inplace_copy.py
@@ -182,9 +182,5 @@ def test_inplace_copy_dst_copy_returned_issue_1109(executor, device, dtype):
assert_close(t0, expected)
assert_close(actual_t2, expected_t2)
- # FIXME(crcrpar): Since there's no `ltorch.Tensor.copy_`, functions like `func` would not
- # be observed and executed with pytorch eager mode. Though there should be either an audit of
- # `prims.copy_` in a nvfuser region and/or what #1110 did.
- assert actual_t1.data_ptr() == actual_t2.data_ptr()
- with pytest.raises(AssertionError):
- assert_close(actual_t1, expected_t1)
+ assert actual_t1.data_ptr() != actual_t2.data_ptr()
+ assert_close(actual_t1, expected_t1) |
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
thunder/executors/nvfuserex_impl.py
Outdated
@@ -2071,8 +2071,7 @@ def copy_( | |||
) -> Any: | |||
nvcopy_from = getnv(copy_from, fd, lc_to_nv_map) | |||
nvcopy_to = getnv(copy_to, fd, lc_to_nv_map) | |||
alias_output = fd.ops.set(nvcopy_from) | |||
fd.add_output(alias_output, alias_input=nvcopy_to) | |||
fd.add_output(nvcopy_from, alias_input=nvcopy_to) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are we supposed to do this? I'm worried that this will give us wrong program.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why isn't there a copy op?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a copy op, the difference is, whether we are returning a different buffer, or the aliased source.
I think the example below would help.
The thunder program below behaves differently, depends on what's the return value.
import thunder
import torch
from thunder.core import prims
def foo(x):
x_relu = prims.abs(x)
old_x = prims.copy_(x_relu, x)
return old_x # this returns x, the aliased source from the program.
#return x_relu # this returns the x_relu, which should be in a different buffer.
jfoo = thunder.jit(foo)
x = torch.randn(2, 4, device="cuda")
o = jfoo(x)
print(thunder.last_traces(jfoo)[-1])
print(x)
print(o)
print(x.data_ptr())
print(o.data_ptr())
They are translated as nvfuser programs correspondingly.
import torch
from nvfuser import FusionDefinition, DataType
def nvfuser_fusion_id4(fd : FusionDefinition) -> None :
T0 = fd.define_tensor(shape=[2, 4], contiguity=[True, True], dtype=DataType.Float, is_cpu=False, stride_order=[1, 0])
T1 = fd.ops.neg(T0)
T2 = fd.ops.set(T1)
fd.add_output(T2, T0)
fd.add_output(T2)
# fd.add_output(T1) # doing this would cause us to return an extra buffer.
The part of copy_
is translated to
T2 = fd.ops.set(T1) # this is the extra copy we added.
fd.add_output(T2, T0) # the extra copy is used to overwrite T0.
When we return T2 as output of the fusion, since T2 is marked as alias to T0 (sharing storage), we are not returning an extra buffer here.
However, the story changes when we return T1 from the fusion. Since T1
is a new buffer in the program. You can play with this using vvv:
with FusionDefinition() as fd:
nvfuser_fusion_id4(fd)
inputs = [
torch.testing.make_tensor((2, 4), dtype=torch.float32, device='cuda:0'),
]
print(inputs)
o = fd.execute(inputs)
print(inputs)
print(o[0].data_ptr())
print(inputs[0].data_ptr())
You'll notice that without the extra fd.ops.set
in copy_
, nvfuser won't be able to faithfully represent the intended thunder behavior.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thunder/tests/test_inplace_copy.py
Outdated
# FIXME(crcrpar): Since there's no `ltorch.Tensor.copy_`, functions like `func` would not | ||
# be observed and executed with pytorch eager mode. Though there should be either an audit of | ||
# `prims.copy_` in a nvfuser region and/or what #1110 did. | ||
assert actual_t1.data_ptr() == actual_t2.data_ptr() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for putting a FIXME there. but this is the case that nvfuser is producing wrong result in func
?
one of the output will still be aliasing the input, and that's not what the program should be translated. i.e. I don't think an audit that removes redundant prims.copy_
would be able to solve that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the record, I'm all in for having a remove redundant inplace copy
, since that's a net improvement. But I'm not sure it would resolve the issue with performance regression that we are seeing right now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If a callable has multiple in-place ops, then the redundant copies would be cleaned up. We just don't have an equivalent for multiple raw prims.copy_
s
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
It seems that I've added
ltorch.copy_
in #1063 but I don't remember exactly why I did so.I even am speculating that was an unwanted change with some cost.
So in this PR, I just want to try reverting the change and see if things can be simplified.