Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sanity check for inplace copy #285

Merged
merged 13 commits into from
May 16, 2024
Merged
3 changes: 3 additions & 0 deletions thunder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,7 @@ def jit(
early_transforms: list | None = None,
additional_transforms: list | None = None,
record_history: bool = False,
disable_inplace_copy_check: bool = False,
t-vi marked this conversation as resolved.
Show resolved Hide resolved
**compile_options, # TODO RC1 Make this explicit -- dict of options
) -> Callable:
"""Just-in-time compile a callable (function or model).
Expand Down Expand Up @@ -572,6 +573,8 @@ def get_computation_and_inputs(*args, **kwargs):
)
computation_trc = extraces[-1]

if not disable_inplace_copy_check:
thunder.core.transform_common._inplace_copy_sanity_check(computation_trc)
comp = computation_trc.python_callable()

if backward_trc is not None:
Expand Down
46 changes: 46 additions & 0 deletions thunder/core/transform_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,52 @@ def _remove_noop_subsymbols(bsym: BoundSymbol) -> None:
bsym.subsymbols = nsbsyms


def _inplace_copy_sanity_check(extrace: Trace):
"""The sanity check is based on the sharp edge of nvfuser's `add_ouput(output, input)` interface,
it makes sure that the `copy_to` argument of `prims.copy_` is not used as input for any of its subsequent operators in a nvFusion fused operator

Anti-pattern:

.. code-block:: python

[t2] = nvFusion0(x, y)
# result = prims.mul(x, y)
# a = prims.copy_(result, x)
# t2 = prims.add(a, y) or t2 = prims.add(x, y)

Do not use the `copy_to` variable `x` or `a` after it has been updated, use the `copy_from` variable `result` instead to reflect the dependency:

.. code-block:: python

[t2] = nvFusion0(x, y)
# result = prims.mul(x, y)
# a = prims.copy_(result, x)
# t2 = prims.add(result, y)
"""

from thunder.core.utils import consumers

nvfuser_symbols = (bsym for bsym in extrace.bound_symbols if bsym.sym.name.startswith("nvFusion"))
for bsym in nvfuser_symbols:
consumer_dict = consumers(list(bsym.subsymbols), _map_to_numbers=True)
inplace_copy_idx = [(idx, sym) for idx, sym in enumerate(bsym.subsymbols) if sym.sym.id == prims.PrimIDs.COPY_]
t-vi marked this conversation as resolved.
Show resolved Hide resolved
for idx, subbsym in inplace_copy_idx:
copy_to_arg = subbsym.flat_args[1]
copy_to_out = subbsym.output

def check(inp, log_str):
if inp is not None and inp in consumer_dict:
last_used_idx = consumer_dict[inp][-1]
t-vi marked this conversation as resolved.
Show resolved Hide resolved
if last_used_idx > idx:
raise NotImplementedError(
f"{bsym.subsymbols[last_used_idx]} trying to use {inp} (the {log_str} of 'prims.copy_') as input, which is not safe."
f" There is a risk of accessing the wrong memory. If you are sure you don't want to use this check, it can be disabled by setting `disable_inplace_copy_check=True` in `thunder.jit`."
)

check(copy_to_arg, "'copy_to' argument")
check(copy_to_out, "output")


# TODO This calls variableify(), but we could directly construct Variable objects instead, which might slightly
# improve performance
# Runs a Dead Code Elimination (DCE) pass
Expand Down
41 changes: 40 additions & 1 deletion thunder/tests/test_inplace_copy.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from functools import partial

import pytest
import torch
from torch.testing import assert_close, make_tensor

import thunder
import thunder.core.dtypes as datatypes
import thunder.torch as ttorch
from thunder.tests.framework import instantiate
from thunder.tests.framework import instantiate, nvFuserExecutor


@instantiate()
Expand Down Expand Up @@ -112,3 +113,41 @@ def forward(self, x):
assert_close(net.state_dict()["dense1_bn.running_mean"], torch_net.state_dict()["dense1_bn.running_mean"])
assert_close(net.state_dict()["dense1_bn.running_var"], torch_net.state_dict()["dense1_bn.running_var"])
assert_close(x.grad, x1.grad)


@instantiate(executors=(nvFuserExecutor,), dtypes=(thunder.float32,))
def test_inplace_copy_sanity_check(executor, device, dtype):
t-vi marked this conversation as resolved.
Show resolved Hide resolved
def func1(x, y):
z = x * y
x = thunder.core.prims.copy_(z, x)
return x + y

def func2(x, y):
z = x * y
thunder.core.prims.copy_(z, x)
thunder.core.prims.copy_(y, x)
return x

def func3(x, y):
z = x * y
thunder.core.prims.copy_(z, x)
thunder.core.prims.copy_(x, y)
return y

def func4(x, y):
z = x * y
o = thunder.core.prims.copy_(z, x)
thunder.core.prims.copy_(o, y)
return y

for foo in (func1, func2, func3, func4):
t-vi marked this conversation as resolved.
Show resolved Hide resolved
traced_foo = executor.make_callable(foo)

tdtype = ttorch.to_torch_dtype(dtype)
a = make_tensor((4, 4), device=device, dtype=tdtype)
b = make_tensor((4, 4), device=device, dtype=tdtype)
with pytest.raises(
NotImplementedError,
match=r"If you are sure you don't want to use this check, it can be disabled by setting `disable_inplace_copy_check=True` in `thunder.jit`.$",
):
traced_foo(a, b)
Loading