From ad1a187ae1778e773216e2af3badf01bedee54c2 Mon Sep 17 00:00:00 2001 From: Yan Wang Date: Wed, 24 Apr 2024 16:35:39 +0200 Subject: [PATCH] Add sanity check for inplace copy (#265) --- thunder/__init__.py | 1 + thunder/core/transform_common.py | 31 +++++++++++++++++++++++++ thunder/tests/test_inplace_copy.py | 37 ++++++++++++++++++++++++++++++ 3 files changed, 69 insertions(+) diff --git a/thunder/__init__.py b/thunder/__init__.py index f83b8c2eb1..fe23a86851 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -626,6 +626,7 @@ def get_computation_and_inputs(*args, **kwargs): computation_trc = extraces[-1] cs.last_computation_transformation_stop = time.time_ns() + thunder.core.transform_common._inplace_copy_sanity_check(computation_trc) comp = computation_trc.python_callable() if backward_trc is not None: diff --git a/thunder/core/transform_common.py b/thunder/core/transform_common.py index fa74313bfe..29e8db8742 100644 --- a/thunder/core/transform_common.py +++ b/thunder/core/transform_common.py @@ -33,6 +33,37 @@ def _remove_noop_subsymbols(bsym: BoundSymbol) -> None: bsym.subsymbols = nsbsyms +def _inplace_copy_sanity_check(extrace: Trace): + """Make sure that the copy_to argument of prims.copy_ is not used as input for any of its subsequent operators, except for the Return and Del operators.""" + from thunder.core.trace import VariableInterface + inplace_copy_symbol_id = ("copy_", prims.PrimIDs.COPY_) + symbol_id_skip_list = (prims.PrimIDs.RETURN, prims.PrimIDs.DEL) + inplace_copy_to_arg: set[VariableInterface] = set() + + def check_symbol(bsym): + if bsym.sym.id in symbol_id_skip_list: + return + elif bsym.sym.is_fusion: + for subbsym in bsym.subsymbols: + check_symbol(subbsym) + else: + for input in bsym.flat_proxy_args: + vinput = variableify(input) + if vinput in inplace_copy_to_arg: + raise NotImplementedError(f"{bsym} trying to use {input} (the 'copy_to' argument of 'prims.copy_') as input, which is not supported") + if bsym.sym.id in inplace_copy_symbol_id: + copy_to_arg = bsym.flat_proxy_args[1] + vcopy_to_arg = variableify(copy_to_arg) + out = bsym.flat_proxy_outs + if out: + vcopy_to_arg_ = variableify(out[0]) + inplace_copy_to_arg.add(vcopy_to_arg_) + inplace_copy_to_arg.add(vcopy_to_arg) + + for bsym in extrace.bound_symbols: + check_symbol(bsym) + + # TODO This calls variableify(), but we could directly construct Variable objects instead, which might slightly # improve performance # Runs a Dead Code Elimination (DCE) pass diff --git a/thunder/tests/test_inplace_copy.py b/thunder/tests/test_inplace_copy.py index 33b1b668b1..f77ee05294 100644 --- a/thunder/tests/test_inplace_copy.py +++ b/thunder/tests/test_inplace_copy.py @@ -121,3 +121,40 @@ def forward(self, x): assert_close(net.state_dict()["dense1_bn.running_mean"], torch_net.state_dict()["dense1_bn.running_mean"]) assert_close(net.state_dict()["dense1_bn.running_var"], torch_net.state_dict()["dense1_bn.running_var"]) assert_close(x.grad, x1.grad) + + +@instantiate(dtypes=(thunder.float32,)) +def test_inplace_copy_sanity_check(executor, device, dtype): + def func1(x, y): + z = x * y + x = thunder.core.prims.copy_(z, x) + return x + y + + def func2(x, y): + z = x * y + thunder.core.prims.copy_(z, x) + thunder.core.prims.copy_(y, x) + return x + + def func3(x, y): + z = x * y + thunder.core.prims.copy_(z, x) + thunder.core.prims.copy_(x, y) + return y + + def func4(x, y): + z = x * y + o = thunder.core.prims.copy_(z, x) + thunder.core.prims.copy_(o, y) + return y + + + import pytest + for foo in (func1, func2, func3, func4): + traced_foo = executor.make_callable(foo) + + tdtype = ttorch.to_torch_dtype(dtype) + a = make_tensor((4, 4), device=device, dtype=tdtype) + b = make_tensor((4, 4), device=device, dtype=tdtype) + with pytest.raises(NotImplementedError, match=r"\(the 'copy_to' argument of 'prims.copy_'\) as input, which is not supported$"): + traced_foo(a, b)