Skip to content

Commit

Permalink
reverting special handling of copy_ in nvfuser executor (#806)
Browse files Browse the repository at this point in the history
  • Loading branch information
jjsjann123 authored Jul 25, 2024
1 parent 1ff7e2d commit c734e99
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 16 deletions.
16 changes: 0 additions & 16 deletions thunder/executors/nvfuserex_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -786,22 +786,6 @@ def _can_fuse_node(n: Node):
# (Used to name fusions like nvFusion0, nvFusion1, ...)
fusion_counter: int = 0
for bsyms in bound_symbol_groups:
# Related to in-place ops.
# prims.copy_ is a no-op for NVFuser in a sense that it does not
# generate nor runs any kernels.
# For that reason we should avoid fusing prims.copy_ unless
# it comes after other non-copy symbols in a fusion.
# See the following relevant issues:
# https://github.com/Lightning-AI/lightning-thunder/issues/789
# https://github.com/Lightning-AI/lightning-thunder/issues/791
# NOTE: filter all first "dangling" no-op copies
while len(bsyms) > 0 and bsyms[0].sym.id is prims.PrimIDs.COPY_:
fused_bsyms.append(bsyms[0])
bsyms = bsyms[1:]

if len(bsyms) == 0:
continue

# TODO The following allows generating single node fusions, which
# may be suboptimal for real-world performance.
# Provide a mechanism to switch between "test" and "perf" modes
Expand Down
27 changes: 27 additions & 0 deletions thunder/tests/test_inplace_functionalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,3 +484,30 @@ def functional_f(a, b, c):
for t in filter(lambda t: t._provenance is not None and "Functionalize in-place ops" in t._provenance.pss, traces):
for bsym in filter(lambda b: b.subsymbols, t.bound_symbols):
assert bsym.rhs != bsym.subsymbols[0].rhs, bsym


@instantiate(
dtypes=NOTHING,
)
def test_inplace_copy_on_fusion_inputs_issue_791(executor, device, _):

def f(x, y, idx, src):
x.index_copy_(0, idx, src)
z = x + 1
y.index_copy_(0, idx, src)
return z

a = make_tensor((2, 2), device=device, dtype=torch.float32)
b = make_tensor((2, 2), device=device, dtype=torch.float32)
a_, b_ = a.clone().detach(), b.clone().detach()
idx = torch.arange(2).to(device=device)
src = make_tensor((2, 2), device=device, dtype=torch.float32)

o_ = f(a_, b_, idx, src)

jitted = executor.make_callable(f)
o = jitted(a, b, idx, src)

assert a.allclose(a_)
assert b.allclose(b_)
assert o.allclose(o_)

0 comments on commit c734e99

Please sign in to comment.