Skip to content

Disallow in-place to view tensors #630

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions thunder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
AdditionalTransform,
PostOptimizationTransform,
functionalize_inplace_ops,
check_inplace_to_views,
)
from thunder.common import (
CompileData,
Expand Down Expand Up @@ -509,6 +510,7 @@ def get_computation_and_inputs(*args, **kwargs):

prologue_traces = [prologue_trc]
computation_traces = [computation_trc]
check_inplace_to_views(computation_trc)
if not compile_options.get("skip_inplace_functionalization", False):
computation_traces.extend(functionalize_inplace_ops(computation_trace=computation_trc))
computation_trc = computation_traces[-1]
Expand Down
23 changes: 23 additions & 0 deletions thunder/core/transform_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,29 @@ def transform_trace(self, computation_trace: Trace, **kwargs):
pass


def check_inplace_to_views(computation_trace: Trace) -> None:
"""Error out if ``computation_trace`` has any in-place op of `torch.reshape`'s output."""
from thunder.core import utils
import thunder.torch as ltorch

producer_bsyms = producers(computation_trace)

bsym: BoundSymbol
for bsym in filter(lambda b: has_tags(b, {prims.OpTags.IN_PLACE}), computation_trace.bound_symbols):
for in_tensor in filter(lambda p: isinstance(p, TensorProxy), bsym.flat_proxy_args):
prod_bsym: BoundSymbol = producer_bsyms[in_tensor]
utils.check(
not has_tags(prod_bsym, {prims.OpTags.SHAPE_OP}),
lambda: f"in-place op to view tensors is not allowed but `{bsym.sym.id}` takes `{prod_bsym.sym.id}` output `{in_tensor}`",
NotImplementedError,
)
utils.check(
prod_bsym.sym != ltorch.contiguous,
lambda: f"in-place op to `torch.Tensor.contiguous` output is not allowed but `{bsym.sym.id}` takes `{prod_bsym.sym.id}` output `{in_tensor}`",
NotImplementedError,
)


def functionalize_inplace_ops(computation_trace: Trace) -> list[Trace]:
"""Functionalize in-place ops in ``computation_trace``.

Expand Down
27 changes: 26 additions & 1 deletion thunder/tests/test_inplace_functionalization.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from __future__ import annotations
from collections.abc import Callable
from dataclasses import dataclass
from functools import partial
from collections.abc import Callable

import pytest
import torch.testing

from thunder.core import dtypes
Expand Down Expand Up @@ -122,3 +123,27 @@ def test_functionalization(op: OpInfo, device: str, dtype: dtypes.dtype, executo
fw_extrace.bound_symbols,
)
)


def test_invalid_cases():
import thunder

a = torch.randn((2, 2))

def f_with_reshape(a: torch.Tensor) -> torch.Tensor:
b = torch.reshape(a, (-1,))
b.exp_()
return b

jitted = thunder.jit(f_with_reshape)
with pytest.raises(NotImplementedError, match="in-place op to view tensors is not allowed but"):
jitted(a)

def f_with_contiguous(a: torch.Tensor) -> torch.Tensor:
b = a.contiguous()
b.exp_()
return b

jitted = thunder.jit(f_with_contiguous)
with pytest.raises(NotImplementedError, match="in-place op to `torch.Tensor.contiguous`"):
jitted(a)
Loading