Skip to content

Commit

Permalink
Add sanity check for inplace copy (#265)
Browse files Browse the repository at this point in the history
  • Loading branch information
kiya00 committed Apr 26, 2024
1 parent e0ab648 commit ad1a187
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 0 deletions.
1 change: 1 addition & 0 deletions thunder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,6 +626,7 @@ def get_computation_and_inputs(*args, **kwargs):
computation_trc = extraces[-1]
cs.last_computation_transformation_stop = time.time_ns()

thunder.core.transform_common._inplace_copy_sanity_check(computation_trc)
comp = computation_trc.python_callable()

if backward_trc is not None:
Expand Down
31 changes: 31 additions & 0 deletions thunder/core/transform_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,37 @@ def _remove_noop_subsymbols(bsym: BoundSymbol) -> None:
bsym.subsymbols = nsbsyms


def _inplace_copy_sanity_check(extrace: Trace):
"""Make sure that the copy_to argument of prims.copy_ is not used as input for any of its subsequent operators, except for the Return and Del operators."""
from thunder.core.trace import VariableInterface
inplace_copy_symbol_id = ("copy_", prims.PrimIDs.COPY_)
symbol_id_skip_list = (prims.PrimIDs.RETURN, prims.PrimIDs.DEL)
inplace_copy_to_arg: set[VariableInterface] = set()

def check_symbol(bsym):
if bsym.sym.id in symbol_id_skip_list:
return
elif bsym.sym.is_fusion:
for subbsym in bsym.subsymbols:
check_symbol(subbsym)
else:
for input in bsym.flat_proxy_args:
vinput = variableify(input)
if vinput in inplace_copy_to_arg:
raise NotImplementedError(f"{bsym} trying to use {input} (the 'copy_to' argument of 'prims.copy_') as input, which is not supported")
if bsym.sym.id in inplace_copy_symbol_id:
copy_to_arg = bsym.flat_proxy_args[1]
vcopy_to_arg = variableify(copy_to_arg)
out = bsym.flat_proxy_outs
if out:
vcopy_to_arg_ = variableify(out[0])
inplace_copy_to_arg.add(vcopy_to_arg_)
inplace_copy_to_arg.add(vcopy_to_arg)

for bsym in extrace.bound_symbols:
check_symbol(bsym)


# TODO This calls variableify(), but we could directly construct Variable objects instead, which might slightly
# improve performance
# Runs a Dead Code Elimination (DCE) pass
Expand Down
37 changes: 37 additions & 0 deletions thunder/tests/test_inplace_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,3 +121,40 @@ def forward(self, x):
assert_close(net.state_dict()["dense1_bn.running_mean"], torch_net.state_dict()["dense1_bn.running_mean"])
assert_close(net.state_dict()["dense1_bn.running_var"], torch_net.state_dict()["dense1_bn.running_var"])
assert_close(x.grad, x1.grad)


@instantiate(dtypes=(thunder.float32,))
def test_inplace_copy_sanity_check(executor, device, dtype):
def func1(x, y):
z = x * y
x = thunder.core.prims.copy_(z, x)
return x + y

def func2(x, y):
z = x * y
thunder.core.prims.copy_(z, x)
thunder.core.prims.copy_(y, x)
return x

def func3(x, y):
z = x * y
thunder.core.prims.copy_(z, x)
thunder.core.prims.copy_(x, y)
return y

def func4(x, y):
z = x * y
o = thunder.core.prims.copy_(z, x)
thunder.core.prims.copy_(o, y)
return y


import pytest
for foo in (func1, func2, func3, func4):
traced_foo = executor.make_callable(foo)

tdtype = ttorch.to_torch_dtype(dtype)
a = make_tensor((4, 4), device=device, dtype=tdtype)
b = make_tensor((4, 4), device=device, dtype=tdtype)
with pytest.raises(NotImplementedError, match=r"\(the 'copy_to' argument of 'prims.copy_'\) as input, which is not supported$"):
traced_foo(a, b)

0 comments on commit ad1a187

Please sign in to comment.