diff --git a/thunder/__init__.py b/thunder/__init__.py index 9b85ad47ac..eaff1d7a84 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -572,6 +572,8 @@ def get_computation_and_inputs(*args, **kwargs): ) computation_trc = extraces[-1] + if not compile_options.get("disable_inplace_copy_check", False): + thunder.core.transform_common._inplace_copy_sanity_check(computation_trc) comp = computation_trc.python_callable() if backward_trc is not None: diff --git a/thunder/core/transform_common.py b/thunder/core/transform_common.py index fa74313bfe..34feb6a7a3 100644 --- a/thunder/core/transform_common.py +++ b/thunder/core/transform_common.py @@ -33,6 +33,52 @@ def _remove_noop_subsymbols(bsym: BoundSymbol) -> None: bsym.subsymbols = nsbsyms +def _inplace_copy_sanity_check(extrace: Trace): + """The sanity check is based on the sharp edge of nvfuser's `add_ouput(output, input)` interface, + it makes sure that the `copy_to` argument of `prims.copy_` is not used as input for any of its subsequent operators in a nvFusion fused operator + + Anti-pattern: + + .. code-block:: python + + [t2] = nvFusion0(x, y) + # result = prims.mul(x, y) + # a = prims.copy_(result, x) + # t2 = prims.add(a, y) or t2 = prims.add(x, y) + + Do not use the `copy_to` variable `x` or `a` after it has been updated, use the `copy_from` variable `result` instead to reflect the dependency: + + .. code-block:: python + + [t2] = nvFusion0(x, y) + # result = prims.mul(x, y) + # a = prims.copy_(result, x) + # t2 = prims.add(result, y) + """ + + from thunder.core.utils import consumers + + nvfuser_symbols = (bsym for bsym in extrace.bound_symbols if bsym.sym.name.startswith("nvFusion")) + for bsym in nvfuser_symbols: + consumer_dict = consumers(list(bsym.subsymbols), _map_to_numbers=True) + inplace_copy_idx = ((idx, sym) for idx, sym in enumerate(bsym.subsymbols) if sym.sym.id == prims.PrimIDs.COPY_) + for idx, subbsym in inplace_copy_idx: + copy_to_arg = subbsym.flat_args[1] + copy_to_out = subbsym.output + + def check(inp, log_str): + if inp is not None and inp in consumer_dict: + last_used_idx = max(consumer_dict[inp]) + if last_used_idx > idx: + raise NotImplementedError( + f"{bsym.subsymbols[last_used_idx]} trying to use {inp} (the {log_str} of 'prims.copy_') as input, which is not safe." + f" There is a risk of accessing the wrong memory. If you are sure you don't want to use this check, it can be disabled by setting `disable_inplace_copy_check=True` in `thunder.jit`." + ) + + check(copy_to_arg, "'copy_to' argument") + check(copy_to_out, "output") + + # TODO This calls variableify(), but we could directly construct Variable objects instead, which might slightly # improve performance # Runs a Dead Code Elimination (DCE) pass diff --git a/thunder/tests/test_inplace_copy.py b/thunder/tests/test_inplace_copy.py index 81abbe1e05..f98ba024e3 100644 --- a/thunder/tests/test_inplace_copy.py +++ b/thunder/tests/test_inplace_copy.py @@ -1,12 +1,13 @@ from functools import partial +import pytest import torch from torch.testing import assert_close, make_tensor import thunder import thunder.core.dtypes as datatypes import thunder.torch as ttorch -from thunder.tests.framework import instantiate +from thunder.tests.framework import instantiate, nvFuserExecutor @instantiate() @@ -112,3 +113,41 @@ def forward(self, x): assert_close(net.state_dict()["dense1_bn.running_mean"], torch_net.state_dict()["dense1_bn.running_mean"]) assert_close(net.state_dict()["dense1_bn.running_var"], torch_net.state_dict()["dense1_bn.running_var"]) assert_close(x.grad, x1.grad) + + +@instantiate(executors=(nvFuserExecutor,), dtypes=(thunder.float32,)) +def test_inplace_copy_sanity_check(executor, device, dtype): + def func0(x, y): + z = x * y + x = thunder.core.prims.copy_(z, x) + return x + y + + def func1(x, y): + z = x * y + thunder.core.prims.copy_(z, x) + thunder.core.prims.copy_(y, x) + return x + + def func2(x, y): + z = x * y + thunder.core.prims.copy_(z, x) + thunder.core.prims.copy_(x, y) + return y + + def func3(x, y): + z = x * y + o = thunder.core.prims.copy_(z, x) + thunder.core.prims.copy_(o, y) + return y + + for foo in (func0, func1, func2, func3): + traced_foo = executor.make_callable(foo) + + tdtype = ttorch.to_torch_dtype(dtype) + a = make_tensor((4, 4), device=device, dtype=tdtype) + b = make_tensor((4, 4), device=device, dtype=tdtype) + with pytest.raises( + NotImplementedError, + match=r"If you are sure you don't want to use this check, it can be disabled by setting `disable_inplace_copy_check=True` in `thunder.jit`.$", + ): + traced_foo(a, b)