Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sanity check for inplace copy #285

Merged
merged 13 commits into from
May 16, 2024
Merged
1 change: 1 addition & 0 deletions thunder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,7 @@ def get_computation_and_inputs(*args, **kwargs):
computation_trc = extraces[-1]
cs.last_computation_transformation_stop = time.time_ns()

thunder.core.transform_common._inplace_copy_sanity_check(computation_trc)
kiya00 marked this conversation as resolved.
Show resolved Hide resolved
comp = computation_trc.python_callable()

if backward_trc is not None:
Expand Down
52 changes: 52 additions & 0 deletions thunder/core/transform_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,58 @@ def _remove_noop_subsymbols(bsym: BoundSymbol) -> None:
bsym.subsymbols = nsbsyms


def _inplace_copy_sanity_check(extrace: Trace):
"""The sanity check is based on the sharp edge of nvfuser's `add_ouput(output, input)` interface,
it makes sure that the `copy_to` argument of `prims.copy_` is not used as input for any of its subsequent operators, except for the Return and Del operators

Anti-pattern:

.. code-block:: python

c = prims.copy_(a, b)
d = torch.add(b, b) # or d = torch.add(c, c)
return d

Do not use the `copy_to` variable `b` or `c` after it has been updated, use the `copy_from` variable `a` instead to reflect the dependency:

.. code-block:: python

c = prims.copy_(a, b)
d = torch.add(a, a)
return c
"""
from thunder.core.trace import VariableInterface

inplace_copy_symbol_id = ("copy_", prims.PrimIDs.COPY_)
symbol_id_skip_list = (prims.PrimIDs.RETURN, prims.PrimIDs.DEL)
inplace_copy_to_arg: set[VariableInterface] = set()
kiya00 marked this conversation as resolved.
Show resolved Hide resolved

def check_symbol(bsym):
if bsym.sym.id in symbol_id_skip_list:
return
elif bsym.sym.is_fusion:
kiya00 marked this conversation as resolved.
Show resolved Hide resolved
for subbsym in bsym.subsymbols:
check_symbol(subbsym)
else:
for input in bsym.flat_proxy_args:
vinput = variableify(input)
if vinput in inplace_copy_to_arg:
raise NotImplementedError(
f"{bsym} trying to use {input} (the 'copy_to' argument of 'prims.copy_') as input, which is not supported"
kiya00 marked this conversation as resolved.
Show resolved Hide resolved
)
if bsym.sym.id in inplace_copy_symbol_id:
copy_to_arg = bsym.flat_proxy_args[1]
vcopy_to_arg = variableify(copy_to_arg)
out = bsym.flat_proxy_outs
if out:
vcopy_to_arg_ = variableify(out[0])
inplace_copy_to_arg.add(vcopy_to_arg_)
inplace_copy_to_arg.add(vcopy_to_arg)
kiya00 marked this conversation as resolved.
Show resolved Hide resolved

for bsym in extrace.bound_symbols:
check_symbol(bsym)


# TODO This calls variableify(), but we could directly construct Variable objects instead, which might slightly
# improve performance
# Runs a Dead Code Elimination (DCE) pass
Expand Down
38 changes: 38 additions & 0 deletions thunder/tests/test_inplace_copy.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from functools import partial

import pytest
import torch
from torch.testing import assert_close, make_tensor

Expand Down Expand Up @@ -121,3 +122,40 @@ def forward(self, x):
assert_close(net.state_dict()["dense1_bn.running_mean"], torch_net.state_dict()["dense1_bn.running_mean"])
assert_close(net.state_dict()["dense1_bn.running_var"], torch_net.state_dict()["dense1_bn.running_var"])
assert_close(x.grad, x1.grad)


@instantiate(dtypes=(thunder.float32,))
def test_inplace_copy_sanity_check(executor, device, dtype):
t-vi marked this conversation as resolved.
Show resolved Hide resolved
def func1(x, y):
z = x * y
x = thunder.core.prims.copy_(z, x)
return x + y

def func2(x, y):
z = x * y
thunder.core.prims.copy_(z, x)
thunder.core.prims.copy_(y, x)
return x

def func3(x, y):
z = x * y
thunder.core.prims.copy_(z, x)
thunder.core.prims.copy_(x, y)
return y

def func4(x, y):
z = x * y
o = thunder.core.prims.copy_(z, x)
thunder.core.prims.copy_(o, y)
return y

for foo in (func1, func2, func3, func4):
t-vi marked this conversation as resolved.
Show resolved Hide resolved
traced_foo = executor.make_callable(foo)

tdtype = ttorch.to_torch_dtype(dtype)
a = make_tensor((4, 4), device=device, dtype=tdtype)
b = make_tensor((4, 4), device=device, dtype=tdtype)
with pytest.raises(
NotImplementedError, match=r"\(the 'copy_to' argument of 'prims.copy_'\) as input, which is not supported$"
):
traced_foo(a, b)
Loading