Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add prims.copy_to_out_ #1194

Closed
wants to merge 8 commits into from
Prev Previous commit
Next Next commit
Allow computed to come from/go outside of the fused region
shino16 committed Sep 25, 2024
commit 77657728dd9021d9933d2a700ed98d8610a5f437
6 changes: 3 additions & 3 deletions thunder/__init__.py
Original file line number Diff line number Diff line change
@@ -652,9 +652,6 @@ def get_computation_and_inputs(*args, **kwargs):

if backward_trc is None:
computation_trc = thunder.executors.passes.del_last_used(computation_trc)

if not compile_options.get("disable_inplace_copy_check", False):
thunder.core.transform_common._inplace_copy_sanity_check(computation_trc)
computation_traces.append(computation_trc)

for transform in transforms:
@@ -681,6 +678,9 @@ def get_computation_and_inputs(*args, **kwargs):
computation_trc = unwrap_return_value(computation_trc)
computation_traces.append(computation_trc)

if not compile_options.get("disable_inplace_copy_check", False):
thunder.core.transform_common._inplace_copy_sanity_check(computation_trc)

computation_trc = transform_to_torch_types(computation_trc)
comp = computation_trc.python_callable()

34 changes: 19 additions & 15 deletions thunder/core/transform_common.py
Original file line number Diff line number Diff line change
@@ -62,7 +62,7 @@ def _remove_noop_subsymbols(bsym: BoundSymbol) -> None:

def _inplace_copy_sanity_check(extrace: Trace):
"""The sanity check is based on the sharp edge of nvfuser's `add_ouput(output, input)` interface.
It makes sure that the `copy_to` argument of `prims.copy_` and the both arguments of `prims.copy_to_out_` are not used as input for any of its subsequent operators in a nvFusion fused operator. It also checks that the `computed` argument of `prims.copy_to_out_` is used only within the nvFusion region.
It makes sure that the `copy_to` argument of `prims.copy_` and the both arguments of `prims.copy_to_out_` are not used as input for any of its subsequent operators in a nvFusion fused operator. It also checks that the `computed` argument of `prims.copy_to_out_` is created within the trace and not used even beyond the nvFusion region.

Anti-patterns:

@@ -93,15 +93,13 @@ def _inplace_copy_sanity_check(extrace: Trace):
# a = prims.copy_(result, x)
# t2 = prims.add(result, y)
"""
# The checks on prims.copy_to_output_ are rather conservative and technical, but they will not be exposed
# as long as users use prims.copy_to_output_ in a standard way, to copy outputs of intermediate arithmetic ops.

from thunder.core.utils import consumers

trace_args = {p.name for p in tree_iter((extrace.args, extrace.kwargs))}
trace_consumers = consumers(extrace)
nvfuser_symbols = (bsym for bsym in extrace.bound_symbols if bsym.sym.name.startswith("nvFusion"))
for bsym in nvfuser_symbols:
consumer_dict = consumers(list(bsym.subsymbols), _map_to_numbers=True)
region_args = {p.name for p in bsym.flat_proxy_args}
region_consumers = consumers(list(bsym.subsymbols), _map_to_numbers=True)
region_outputs = {p.name for p in bsym.flat_proxy_outs}
inplace_copy_idx = (
(idx, sym)
@@ -112,8 +110,8 @@ def _inplace_copy_sanity_check(extrace: Trace):
instruction = "If you are sure you don't want to use this check, it can be disabled by setting `disable_inplace_copy_check=True` in `thunder.jit`."

def check(inp, description):
if inp is not None and inp in consumer_dict:
last_used_idx = max(consumer_dict[inp])
if inp is not None and inp in region_consumers:
last_used_idx = max(region_consumers[inp])
if last_used_idx > idx:
raise NotImplementedError(
f"{bsym.subsymbols[last_used_idx]} trying to use {inp} (the {description}) as input, which is not safe. There is a risk of accessing the wrong memory. "
@@ -133,17 +131,23 @@ def check(inp, description):
check(out, "'out' argument of 'prims.copy_to_out_")
check(output, "output of 'prims.copy_to_out_")

if computed.name in region_args:
if computed.name in trace_args:
raise NotImplementedError(
f"{computed} (the 'computed' argument of 'prims.copy_to_out_') is defined outside of the nvFuser region. "
f"{computed} (the 'computed' argument of 'prims.copy_to_out_') is created outside of the execution trace. "
f"Copies onto {out} or {output} in the region may propagate to {computed}. " + instruction
)

if computed.name in region_outputs:
raise NotImplementedError(
f"{computed} (the 'computed' argument of 'prims.copy_to_out_') is used outside of the nvFuser region. "
f"Copies onto {out} or {output} in the region may propagate to {computed}. " + instruction
)
if computed in trace_consumers:
for consumer in trace_consumers[computed]:
# If the bsym creating the `computed` tensor cannot be fused, it will be put outside
# of the fused region, and the created `computed` tensor will be passed to the fused
# region. The tensor should be deleted immediately after the region
# These are the only consumers of the `computed` tensor in a normal setting
if consumer.sym.name != bsym.sym.name and consumer.sym != prims.python_del:
raise NotImplementedError(
f"{consumer} trying to use {computed} (the 'computed' argument of 'prims.copy_to_out_') as input, which is not safe. There is a risk of accessing the wrong memory. "
+ instruction
)


# TODO This calls variableify(), but we could directly construct Variable objects instead, which might slightly
54 changes: 26 additions & 28 deletions thunder/tests/test_inplace_copy.py
Original file line number Diff line number Diff line change
@@ -177,56 +177,54 @@ def test_copy_to_out_sanity_check_on_computed(executor, device, dtype):
a = make_tensor((4, 4), device=device, dtype=tdtype)
b = make_tensor((4, 4), device=device, dtype=tdtype)
a_ref = a.detach().clone()
b_ref = b.detach().clone()
idx = torch.arange(4).to(device=device)
src = make_tensor((4, 4), device=device, dtype=torch.float32)

def torch_good(x, y):
z = x * y
def torch_good(x, idx, src):
z = torch.index_copy(x, 0, idx, src)
o = x.copy_(z)
return o

def good(x, y):
z = x * y
def good(x, idx, src):
z = torch.index_copy(x, 0, idx, src)
# `computed` comes from outside of the fused region, which is inevitable
o = thunder.core.prims.copy_to_out_(z, out=x)
return o

def bad1(x, y):
z = x * y
z = x + x
o = thunder.core.prims.copy_to_out_(z, out=x)
return o, z
return o, z + z # `computed` consumed after `copy_to_out_`

def bad2(x, y):
z = x * y
z = x + x
o = thunder.core.prims.copy_to_out_(z, out=x)
return o + z
return o, z # `computed` consumed outside of the fused region

def bad3(x, y):
o = thunder.core.prims.copy_to_out_(y, out=x)
o = thunder.core.prims.copy_to_out_(y, out=x) # `computed` is a function argument
return o

def bad4(x, y):
z = x * y
o = thunder.core.prims.copy_to_out_(z, out=x)
return o, torch.concat((z, z)) # not fused

def bad5(x, y):
x2 = torch.concat((x, x))
y2 = torch.concat((y, y)) # not fused
o = thunder.core.prims.copy_to_out_(y2, out=x2)
return o
# This is not checked
# def bad4(x, y):
# y2 = y.view(-1).view(x.shape)
# o = thunder.core.prims.copy_to_out_(y2, out=x) # `computed` aliases a function argument
# return o

traced_good = executor.make_callable(good)
out = traced_good(a, b)
out_ref = torch_good(a_ref, b_ref)
assert_close([a, b, out], [a_ref, b_ref, out_ref])

for foo in [bad1, bad2, bad3, bad4, bad5]:
print(foo.__name__)
traced_foo = executor.make_callable(foo)
out = traced_good(a, idx, src)
out_ref = torch_good(a_ref, idx, src)
assert_close(a, a_ref)
assert_close(out, out_ref)
assert (a.data_ptr() == out.data_ptr()) == (a_ref.data_ptr() == out_ref.data_ptr())

for bad in [bad1, bad2, bad3]:
traced_bad = executor.make_callable(bad)
with pytest.raises(
NotImplementedError,
match=r"If you are sure you don't want to use this check, it can be disabled by setting `disable_inplace_copy_check=True` in `thunder.jit`.$",
):
traced_foo(a, b)
traced_bad(a, b)


@instantiate(executors=(nvFuserExecutor,), dtypes=(thunder.float32,))