From ad1a187ae1778e773216e2af3badf01bedee54c2 Mon Sep 17 00:00:00 2001 From: Yan Wang Date: Wed, 24 Apr 2024 16:35:39 +0200 Subject: [PATCH 1/9] Add sanity check for inplace copy (#265) --- thunder/__init__.py | 1 + thunder/core/transform_common.py | 31 +++++++++++++++++++++++++ thunder/tests/test_inplace_copy.py | 37 ++++++++++++++++++++++++++++++ 3 files changed, 69 insertions(+) diff --git a/thunder/__init__.py b/thunder/__init__.py index f83b8c2eb1..fe23a86851 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -626,6 +626,7 @@ def get_computation_and_inputs(*args, **kwargs): computation_trc = extraces[-1] cs.last_computation_transformation_stop = time.time_ns() + thunder.core.transform_common._inplace_copy_sanity_check(computation_trc) comp = computation_trc.python_callable() if backward_trc is not None: diff --git a/thunder/core/transform_common.py b/thunder/core/transform_common.py index fa74313bfe..29e8db8742 100644 --- a/thunder/core/transform_common.py +++ b/thunder/core/transform_common.py @@ -33,6 +33,37 @@ def _remove_noop_subsymbols(bsym: BoundSymbol) -> None: bsym.subsymbols = nsbsyms +def _inplace_copy_sanity_check(extrace: Trace): + """Make sure that the copy_to argument of prims.copy_ is not used as input for any of its subsequent operators, except for the Return and Del operators.""" + from thunder.core.trace import VariableInterface + inplace_copy_symbol_id = ("copy_", prims.PrimIDs.COPY_) + symbol_id_skip_list = (prims.PrimIDs.RETURN, prims.PrimIDs.DEL) + inplace_copy_to_arg: set[VariableInterface] = set() + + def check_symbol(bsym): + if bsym.sym.id in symbol_id_skip_list: + return + elif bsym.sym.is_fusion: + for subbsym in bsym.subsymbols: + check_symbol(subbsym) + else: + for input in bsym.flat_proxy_args: + vinput = variableify(input) + if vinput in inplace_copy_to_arg: + raise NotImplementedError(f"{bsym} trying to use {input} (the 'copy_to' argument of 'prims.copy_') as input, which is not supported") + if bsym.sym.id in inplace_copy_symbol_id: + copy_to_arg = bsym.flat_proxy_args[1] + vcopy_to_arg = variableify(copy_to_arg) + out = bsym.flat_proxy_outs + if out: + vcopy_to_arg_ = variableify(out[0]) + inplace_copy_to_arg.add(vcopy_to_arg_) + inplace_copy_to_arg.add(vcopy_to_arg) + + for bsym in extrace.bound_symbols: + check_symbol(bsym) + + # TODO This calls variableify(), but we could directly construct Variable objects instead, which might slightly # improve performance # Runs a Dead Code Elimination (DCE) pass diff --git a/thunder/tests/test_inplace_copy.py b/thunder/tests/test_inplace_copy.py index 33b1b668b1..f77ee05294 100644 --- a/thunder/tests/test_inplace_copy.py +++ b/thunder/tests/test_inplace_copy.py @@ -121,3 +121,40 @@ def forward(self, x): assert_close(net.state_dict()["dense1_bn.running_mean"], torch_net.state_dict()["dense1_bn.running_mean"]) assert_close(net.state_dict()["dense1_bn.running_var"], torch_net.state_dict()["dense1_bn.running_var"]) assert_close(x.grad, x1.grad) + + +@instantiate(dtypes=(thunder.float32,)) +def test_inplace_copy_sanity_check(executor, device, dtype): + def func1(x, y): + z = x * y + x = thunder.core.prims.copy_(z, x) + return x + y + + def func2(x, y): + z = x * y + thunder.core.prims.copy_(z, x) + thunder.core.prims.copy_(y, x) + return x + + def func3(x, y): + z = x * y + thunder.core.prims.copy_(z, x) + thunder.core.prims.copy_(x, y) + return y + + def func4(x, y): + z = x * y + o = thunder.core.prims.copy_(z, x) + thunder.core.prims.copy_(o, y) + return y + + + import pytest + for foo in (func1, func2, func3, func4): + traced_foo = executor.make_callable(foo) + + tdtype = ttorch.to_torch_dtype(dtype) + a = make_tensor((4, 4), device=device, dtype=tdtype) + b = make_tensor((4, 4), device=device, dtype=tdtype) + with pytest.raises(NotImplementedError, match=r"\(the 'copy_to' argument of 'prims.copy_'\) as input, which is not supported$"): + traced_foo(a, b) From 91a942119c30ce215ec81c8bf3e63871e1ca8b30 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 26 Apr 2024 14:55:51 +0000 Subject: [PATCH 2/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- thunder/core/transform_common.py | 5 ++++- thunder/tests/test_inplace_copy.py | 6 ++++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/thunder/core/transform_common.py b/thunder/core/transform_common.py index 29e8db8742..af48cd685d 100644 --- a/thunder/core/transform_common.py +++ b/thunder/core/transform_common.py @@ -36,6 +36,7 @@ def _remove_noop_subsymbols(bsym: BoundSymbol) -> None: def _inplace_copy_sanity_check(extrace: Trace): """Make sure that the copy_to argument of prims.copy_ is not used as input for any of its subsequent operators, except for the Return and Del operators.""" from thunder.core.trace import VariableInterface + inplace_copy_symbol_id = ("copy_", prims.PrimIDs.COPY_) symbol_id_skip_list = (prims.PrimIDs.RETURN, prims.PrimIDs.DEL) inplace_copy_to_arg: set[VariableInterface] = set() @@ -50,7 +51,9 @@ def check_symbol(bsym): for input in bsym.flat_proxy_args: vinput = variableify(input) if vinput in inplace_copy_to_arg: - raise NotImplementedError(f"{bsym} trying to use {input} (the 'copy_to' argument of 'prims.copy_') as input, which is not supported") + raise NotImplementedError( + f"{bsym} trying to use {input} (the 'copy_to' argument of 'prims.copy_') as input, which is not supported" + ) if bsym.sym.id in inplace_copy_symbol_id: copy_to_arg = bsym.flat_proxy_args[1] vcopy_to_arg = variableify(copy_to_arg) diff --git a/thunder/tests/test_inplace_copy.py b/thunder/tests/test_inplace_copy.py index f77ee05294..409565bf71 100644 --- a/thunder/tests/test_inplace_copy.py +++ b/thunder/tests/test_inplace_copy.py @@ -148,13 +148,15 @@ def func4(x, y): thunder.core.prims.copy_(o, y) return y - import pytest + for foo in (func1, func2, func3, func4): traced_foo = executor.make_callable(foo) tdtype = ttorch.to_torch_dtype(dtype) a = make_tensor((4, 4), device=device, dtype=tdtype) b = make_tensor((4, 4), device=device, dtype=tdtype) - with pytest.raises(NotImplementedError, match=r"\(the 'copy_to' argument of 'prims.copy_'\) as input, which is not supported$"): + with pytest.raises( + NotImplementedError, match=r"\(the 'copy_to' argument of 'prims.copy_'\) as input, which is not supported$" + ): traced_foo(a, b) From f9ea516d204d15d7b7717169b49c472cf1eecb68 Mon Sep 17 00:00:00 2001 From: Yan Wang Date: Tue, 7 May 2024 16:34:19 +0200 Subject: [PATCH 3/9] follow comments --- thunder/core/transform_common.py | 20 +++++++++++++++++++- thunder/tests/test_inplace_copy.py | 2 +- 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/thunder/core/transform_common.py b/thunder/core/transform_common.py index af48cd685d..4f347e5bf8 100644 --- a/thunder/core/transform_common.py +++ b/thunder/core/transform_common.py @@ -34,7 +34,25 @@ def _remove_noop_subsymbols(bsym: BoundSymbol) -> None: def _inplace_copy_sanity_check(extrace: Trace): - """Make sure that the copy_to argument of prims.copy_ is not used as input for any of its subsequent operators, except for the Return and Del operators.""" + """The sanity check is based on the sharp edge of nvfuser's `add_ouput(output, input)` interface, + it makes sure that the `copy_to` argument of `prims.copy_` is not used as input for any of its subsequent operators, except for the Return and Del operators + + Anti-pattern: + + .. code-block:: python + + c = prims.copy_(a, b) + d = torch.add(b, b) # or d = torch.add(c, c) + return d + + Do not use the `copy_to` variable `b` or `c` after it has been updated, use the `copy_from` variable `a` instead to reflect the dependency: + + .. code-block:: python + + c = prims.copy_(a, b) + d = torch.add(a, a) + return c + """ from thunder.core.trace import VariableInterface inplace_copy_symbol_id = ("copy_", prims.PrimIDs.COPY_) diff --git a/thunder/tests/test_inplace_copy.py b/thunder/tests/test_inplace_copy.py index 409565bf71..dc245504e5 100644 --- a/thunder/tests/test_inplace_copy.py +++ b/thunder/tests/test_inplace_copy.py @@ -1,5 +1,6 @@ from functools import partial +import pytest import torch from torch.testing import assert_close, make_tensor @@ -148,7 +149,6 @@ def func4(x, y): thunder.core.prims.copy_(o, y) return y - import pytest for foo in (func1, func2, func3, func4): traced_foo = executor.make_callable(foo) From 428e15d06c96a3bc750d93b170a6709259fed7b5 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 7 May 2024 14:44:41 +0000 Subject: [PATCH 4/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- thunder/tests/test_inplace_copy.py | 1 - 1 file changed, 1 deletion(-) diff --git a/thunder/tests/test_inplace_copy.py b/thunder/tests/test_inplace_copy.py index dc245504e5..bbf77894c9 100644 --- a/thunder/tests/test_inplace_copy.py +++ b/thunder/tests/test_inplace_copy.py @@ -149,7 +149,6 @@ def func4(x, y): thunder.core.prims.copy_(o, y) return y - for foo in (func1, func2, func3, func4): traced_foo = executor.make_callable(foo) From 8b93f2438ae4df0301bdfb11e048916bfc31b06f Mon Sep 17 00:00:00 2001 From: Yan Wang Date: Mon, 13 May 2024 12:07:13 +0200 Subject: [PATCH 5/9] Fix: only check inplace copy in each nvFusion; follow comments to use consumers and check index --- thunder/core/transform_common.py | 68 +++++++++++++----------------- thunder/tests/test_inplace_copy.py | 7 +-- 2 files changed, 34 insertions(+), 41 deletions(-) diff --git a/thunder/core/transform_common.py b/thunder/core/transform_common.py index 4f347e5bf8..c4d97cebc9 100644 --- a/thunder/core/transform_common.py +++ b/thunder/core/transform_common.py @@ -35,54 +35,46 @@ def _remove_noop_subsymbols(bsym: BoundSymbol) -> None: def _inplace_copy_sanity_check(extrace: Trace): """The sanity check is based on the sharp edge of nvfuser's `add_ouput(output, input)` interface, - it makes sure that the `copy_to` argument of `prims.copy_` is not used as input for any of its subsequent operators, except for the Return and Del operators + it makes sure that the `copy_to` argument of `prims.copy_` is not used as input for any of its subsequent operators in a nvFusion fused operator Anti-pattern: .. code-block:: python - c = prims.copy_(a, b) - d = torch.add(b, b) # or d = torch.add(c, c) - return d + [t2] = nvFusion0(x, y) + # result = prims.mul(x, y) + # a = prims.copy_(result, x) + # t2 = prims.add(a, y) or t2 = prims.add(x, y) - Do not use the `copy_to` variable `b` or `c` after it has been updated, use the `copy_from` variable `a` instead to reflect the dependency: + Do not use the `copy_to` variable `x` or `a` after it has been updated, use the `copy_from` variable `result` instead to reflect the dependency: .. code-block:: python - c = prims.copy_(a, b) - d = torch.add(a, a) - return c + [t2] = nvFusion0(x, y) + # result = prims.mul(x, y) + # a = prims.copy_(result, x) + # t2 = prims.add(result, y) """ - from thunder.core.trace import VariableInterface - - inplace_copy_symbol_id = ("copy_", prims.PrimIDs.COPY_) - symbol_id_skip_list = (prims.PrimIDs.RETURN, prims.PrimIDs.DEL) - inplace_copy_to_arg: set[VariableInterface] = set() - - def check_symbol(bsym): - if bsym.sym.id in symbol_id_skip_list: - return - elif bsym.sym.is_fusion: - for subbsym in bsym.subsymbols: - check_symbol(subbsym) - else: - for input in bsym.flat_proxy_args: - vinput = variableify(input) - if vinput in inplace_copy_to_arg: - raise NotImplementedError( - f"{bsym} trying to use {input} (the 'copy_to' argument of 'prims.copy_') as input, which is not supported" - ) - if bsym.sym.id in inplace_copy_symbol_id: - copy_to_arg = bsym.flat_proxy_args[1] - vcopy_to_arg = variableify(copy_to_arg) - out = bsym.flat_proxy_outs - if out: - vcopy_to_arg_ = variableify(out[0]) - inplace_copy_to_arg.add(vcopy_to_arg_) - inplace_copy_to_arg.add(vcopy_to_arg) - - for bsym in extrace.bound_symbols: - check_symbol(bsym) + + from thunder.core.utils import consumers + nvfuser_symbols = (bsym for bsym in extrace.bound_symbols if bsym.sym.name.startswith("nvFusion")) + for bsym in nvfuser_symbols: + consumer_dict = consumers(list(bsym.subsymbols), _map_to_numbers=True) + inplace_copy_idx = [(idx, sym) for idx, sym in enumerate(bsym.subsymbols) if sym.sym.id == prims.PrimIDs.COPY_] + for idx, subbsym in inplace_copy_idx: + copy_to_arg = subbsym.flat_args[1] + copy_to_out = subbsym.output + + def check(inp, log_str): + if inp is not None and inp in consumer_dict: + last_used_idx = consumer_dict[inp][-1] + if last_used_idx > idx: + raise NotImplementedError( + f"{bsym.subsymbols[last_used_idx]} trying to use {inp} (the {log_str} of 'prims.copy_') as input, which is not supported" + ) + + check(copy_to_arg, "'copy_to' argument") + check(copy_to_out, "output") # TODO This calls variableify(), but we could directly construct Variable objects instead, which might slightly diff --git a/thunder/tests/test_inplace_copy.py b/thunder/tests/test_inplace_copy.py index bbf77894c9..daf42a8604 100644 --- a/thunder/tests/test_inplace_copy.py +++ b/thunder/tests/test_inplace_copy.py @@ -7,7 +7,7 @@ import thunder import thunder.core.dtypes as datatypes import thunder.torch as ttorch -from thunder.tests.framework import instantiate +from thunder.tests.framework import instantiate, nvFuserExecutor @instantiate() @@ -124,7 +124,7 @@ def forward(self, x): assert_close(x.grad, x1.grad) -@instantiate(dtypes=(thunder.float32,)) +@instantiate(executors=(nvFuserExecutor,),dtypes=(thunder.float32,)) def test_inplace_copy_sanity_check(executor, device, dtype): def func1(x, y): z = x * y @@ -149,6 +149,7 @@ def func4(x, y): thunder.core.prims.copy_(o, y) return y + for foo in (func1, func2, func3, func4): traced_foo = executor.make_callable(foo) @@ -156,6 +157,6 @@ def func4(x, y): a = make_tensor((4, 4), device=device, dtype=tdtype) b = make_tensor((4, 4), device=device, dtype=tdtype) with pytest.raises( - NotImplementedError, match=r"\(the 'copy_to' argument of 'prims.copy_'\) as input, which is not supported$" + NotImplementedError, match=r"of 'prims.copy_'\) as input, which is not supported$" ): traced_foo(a, b) From ab809a0ba856979b554607539ef2f30fd3ca789d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 13 May 2024 10:09:26 +0000 Subject: [PATCH 6/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- thunder/core/transform_common.py | 1 + thunder/tests/test_inplace_copy.py | 7 ++----- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/thunder/core/transform_common.py b/thunder/core/transform_common.py index c4d97cebc9..6f8f7035b0 100644 --- a/thunder/core/transform_common.py +++ b/thunder/core/transform_common.py @@ -57,6 +57,7 @@ def _inplace_copy_sanity_check(extrace: Trace): """ from thunder.core.utils import consumers + nvfuser_symbols = (bsym for bsym in extrace.bound_symbols if bsym.sym.name.startswith("nvFusion")) for bsym in nvfuser_symbols: consumer_dict = consumers(list(bsym.subsymbols), _map_to_numbers=True) diff --git a/thunder/tests/test_inplace_copy.py b/thunder/tests/test_inplace_copy.py index daf42a8604..36a77578d9 100644 --- a/thunder/tests/test_inplace_copy.py +++ b/thunder/tests/test_inplace_copy.py @@ -124,7 +124,7 @@ def forward(self, x): assert_close(x.grad, x1.grad) -@instantiate(executors=(nvFuserExecutor,),dtypes=(thunder.float32,)) +@instantiate(executors=(nvFuserExecutor,), dtypes=(thunder.float32,)) def test_inplace_copy_sanity_check(executor, device, dtype): def func1(x, y): z = x * y @@ -149,14 +149,11 @@ def func4(x, y): thunder.core.prims.copy_(o, y) return y - for foo in (func1, func2, func3, func4): traced_foo = executor.make_callable(foo) tdtype = ttorch.to_torch_dtype(dtype) a = make_tensor((4, 4), device=device, dtype=tdtype) b = make_tensor((4, 4), device=device, dtype=tdtype) - with pytest.raises( - NotImplementedError, match=r"of 'prims.copy_'\) as input, which is not supported$" - ): + with pytest.raises(NotImplementedError, match=r"of 'prims.copy_'\) as input, which is not supported$"): traced_foo(a, b) From eedf31edfcbc4ba255bccf69ba87a1c4aabb9101 Mon Sep 17 00:00:00 2001 From: Yan Wang Date: Mon, 13 May 2024 15:00:38 +0200 Subject: [PATCH 7/9] Add compile option to disable the check --- thunder/__init__.py | 4 +++- thunder/core/transform_common.py | 3 ++- thunder/tests/test_inplace_copy.py | 4 +++- 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/thunder/__init__.py b/thunder/__init__.py index 16907002ac..49dc8b85ae 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -262,6 +262,7 @@ def jit( disable_torch_autograd: bool = False, # TODO Revisit this UX for RC1 additional_transforms: list | None = None, record_history: bool = False, + disable_inplace_copy_check: bool = False, **compile_options, # TODO RC1 Make this explicit -- dict of options ) -> Callable: """Just-in-time compile a callable (function or model). @@ -554,7 +555,8 @@ def get_computation_and_inputs(*args, **kwargs): computation_trc = extraces[-1] cs.last_computation_transformation_stop = time.time_ns() - thunder.core.transform_common._inplace_copy_sanity_check(computation_trc) + if not disable_inplace_copy_check: + thunder.core.transform_common._inplace_copy_sanity_check(computation_trc) comp = computation_trc.python_callable() if backward_trc is not None: diff --git a/thunder/core/transform_common.py b/thunder/core/transform_common.py index 6f8f7035b0..5fef39b695 100644 --- a/thunder/core/transform_common.py +++ b/thunder/core/transform_common.py @@ -71,7 +71,8 @@ def check(inp, log_str): last_used_idx = consumer_dict[inp][-1] if last_used_idx > idx: raise NotImplementedError( - f"{bsym.subsymbols[last_used_idx]} trying to use {inp} (the {log_str} of 'prims.copy_') as input, which is not supported" + f"{bsym.subsymbols[last_used_idx]} trying to use {inp} (the {log_str} of 'prims.copy_') as input, which is not safe." + f" There is a risk of accessing the wrong memory. If you are sure you don't want to use this check, it can be disabled by setting `disable_inplace_copy_check=True` in `thunder.jit`." ) check(copy_to_arg, "'copy_to' argument") diff --git a/thunder/tests/test_inplace_copy.py b/thunder/tests/test_inplace_copy.py index 36a77578d9..7cd577d536 100644 --- a/thunder/tests/test_inplace_copy.py +++ b/thunder/tests/test_inplace_copy.py @@ -155,5 +155,7 @@ def func4(x, y): tdtype = ttorch.to_torch_dtype(dtype) a = make_tensor((4, 4), device=device, dtype=tdtype) b = make_tensor((4, 4), device=device, dtype=tdtype) - with pytest.raises(NotImplementedError, match=r"of 'prims.copy_'\) as input, which is not supported$"): + with pytest.raises( + NotImplementedError, match=r"If you are sure you don't want to use this check, it can be disabled by setting `disable_inplace_copy_check=True` in `thunder.jit`.$" + ): traced_foo(a, b) From 3eec723b50bd5503eb7b08305c12a01a0a96b035 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 13 May 2024 13:19:29 +0000 Subject: [PATCH 8/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- thunder/tests/test_inplace_copy.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/thunder/tests/test_inplace_copy.py b/thunder/tests/test_inplace_copy.py index 7cd577d536..eff0631550 100644 --- a/thunder/tests/test_inplace_copy.py +++ b/thunder/tests/test_inplace_copy.py @@ -156,6 +156,7 @@ def func4(x, y): a = make_tensor((4, 4), device=device, dtype=tdtype) b = make_tensor((4, 4), device=device, dtype=tdtype) with pytest.raises( - NotImplementedError, match=r"If you are sure you don't want to use this check, it can be disabled by setting `disable_inplace_copy_check=True` in `thunder.jit`.$" + NotImplementedError, + match=r"If you are sure you don't want to use this check, it can be disabled by setting `disable_inplace_copy_check=True` in `thunder.jit`.$", ): traced_foo(a, b) From 45d1228a635ee6b98f40275101bb7d9380dd96ee Mon Sep 17 00:00:00 2001 From: Yan Wang Date: Thu, 16 May 2024 09:57:48 +0200 Subject: [PATCH 9/9] fix for comments --- thunder/__init__.py | 3 +-- thunder/core/transform_common.py | 4 ++-- thunder/tests/test_inplace_copy.py | 10 +++++----- 3 files changed, 8 insertions(+), 9 deletions(-) diff --git a/thunder/__init__.py b/thunder/__init__.py index ef950a1e84..eaff1d7a84 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -263,7 +263,6 @@ def jit( early_transforms: list | None = None, additional_transforms: list | None = None, record_history: bool = False, - disable_inplace_copy_check: bool = False, **compile_options, # TODO RC1 Make this explicit -- dict of options ) -> Callable: """Just-in-time compile a callable (function or model). @@ -573,7 +572,7 @@ def get_computation_and_inputs(*args, **kwargs): ) computation_trc = extraces[-1] - if not disable_inplace_copy_check: + if not compile_options.get("disable_inplace_copy_check", False): thunder.core.transform_common._inplace_copy_sanity_check(computation_trc) comp = computation_trc.python_callable() diff --git a/thunder/core/transform_common.py b/thunder/core/transform_common.py index 5fef39b695..34feb6a7a3 100644 --- a/thunder/core/transform_common.py +++ b/thunder/core/transform_common.py @@ -61,14 +61,14 @@ def _inplace_copy_sanity_check(extrace: Trace): nvfuser_symbols = (bsym for bsym in extrace.bound_symbols if bsym.sym.name.startswith("nvFusion")) for bsym in nvfuser_symbols: consumer_dict = consumers(list(bsym.subsymbols), _map_to_numbers=True) - inplace_copy_idx = [(idx, sym) for idx, sym in enumerate(bsym.subsymbols) if sym.sym.id == prims.PrimIDs.COPY_] + inplace_copy_idx = ((idx, sym) for idx, sym in enumerate(bsym.subsymbols) if sym.sym.id == prims.PrimIDs.COPY_) for idx, subbsym in inplace_copy_idx: copy_to_arg = subbsym.flat_args[1] copy_to_out = subbsym.output def check(inp, log_str): if inp is not None and inp in consumer_dict: - last_used_idx = consumer_dict[inp][-1] + last_used_idx = max(consumer_dict[inp]) if last_used_idx > idx: raise NotImplementedError( f"{bsym.subsymbols[last_used_idx]} trying to use {inp} (the {log_str} of 'prims.copy_') as input, which is not safe." diff --git a/thunder/tests/test_inplace_copy.py b/thunder/tests/test_inplace_copy.py index 21fec5617d..f98ba024e3 100644 --- a/thunder/tests/test_inplace_copy.py +++ b/thunder/tests/test_inplace_copy.py @@ -117,30 +117,30 @@ def forward(self, x): @instantiate(executors=(nvFuserExecutor,), dtypes=(thunder.float32,)) def test_inplace_copy_sanity_check(executor, device, dtype): - def func1(x, y): + def func0(x, y): z = x * y x = thunder.core.prims.copy_(z, x) return x + y - def func2(x, y): + def func1(x, y): z = x * y thunder.core.prims.copy_(z, x) thunder.core.prims.copy_(y, x) return x - def func3(x, y): + def func2(x, y): z = x * y thunder.core.prims.copy_(z, x) thunder.core.prims.copy_(x, y) return y - def func4(x, y): + def func3(x, y): z = x * y o = thunder.core.prims.copy_(z, x) thunder.core.prims.copy_(o, y) return y - for foo in (func1, func2, func3, func4): + for foo in (func0, func1, func2, func3): traced_foo = executor.make_callable(foo) tdtype = ttorch.to_torch_dtype(dtype)