Skip to content
Merged
58 changes: 58 additions & 0 deletions thunder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,7 @@ def _recursive_jit_call_warning() -> None:
"backward_fn",
"backward_traces",
"return_none_instead_of_grads",
"vanilla_tensor_args",
],
)

Expand Down Expand Up @@ -353,6 +354,26 @@ def jit(
)
cs = CompileStats()

def _alias_tensor_of_args_kwargs(*args, **kwargs) -> int:
flat_args, _ = tree_flatten((args, kwargs))
data_ptr_to_tensor_group_index = {}
tensor_group_index_to_tensor_indices = defaultdict(list)
for idx, t in enumerate(flat_args):
if pytorch.is_tensor(t) and t.layout == pytorch.strided:
data_ptr = t.untyped_storage().data_ptr()
if data_ptr not in data_ptr_to_tensor_group_index:
data_ptr_to_tensor_group_index[data_ptr] = len(data_ptr_to_tensor_group_index)
tgi = data_ptr_to_tensor_group_index[data_ptr]
tensor_group_index_to_tensor_indices[tgi].append(idx)

alias_indices = []
for k, v in tensor_group_index_to_tensor_indices.items():
if len(v) > 1:
alias_indices.extend(v)
if not alias_indices:
return ""
return ",".join(f"{i}" for i in alias_indices)

@_with_cache_info_ctx
def get_computation_and_inputs(*args, **kwargs):
# set up a record of things in the current environment that impact caching / prologues
Expand Down Expand Up @@ -396,6 +417,12 @@ def get_computation_and_inputs(*args, **kwargs):
cache_info["no_grad_sync"] = no_grad_sync
return_none_instead_of_grads = is_fsdp_enabled and no_grad_sync

# NOTE(crcrpar): If a callable is free from in-place ops whose operand is args and/or their views
# alaises wouldn't matter, thus it'd be better to nullify this entry in such cases.
# It however would require the functionalized computation trace to interact with `cache_info`,
# which seems to break the consistency of cache_info, leading to a failure in cache_info check.
cache_info["alias_tensor_indices"] = _alias_tensor_of_args_kwargs(*args, **kwargs)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While you would not want to change the cache info, you can grab the prologue check for it and remove it. It would be better to actually not collect the information in this case, but I guess that's how it is for now.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we were to support the combo of in-place ops and alias tensor args, I do find it better to have separate traces whether or not the args include alias tensors


# TODO RC1 Add module and function checks to prologue (make it a compile option)

# Checks cache
Expand All @@ -413,6 +440,7 @@ def get_computation_and_inputs(*args, **kwargs):
backward_fn,
backward_traces,
_return_none_instead_of_grads,
_vanilla_args,
) = cache_entry
try:
cs.last_prologue_execution_start = time.perf_counter_ns()
Expand Down Expand Up @@ -502,13 +530,31 @@ def get_computation_and_inputs(*args, **kwargs):
prologue_traces = [prologue_trc]
computation_traces = [computation_trc]
orig_to_view_swap_map = check_inplace_to_views(computation_trc)
vanilla_tensor_args: set[int] | None = None
if not compile_options.get("skip_inplace_functionalization", False):
orig_len = len(computation_traces)
computation_traces.extend(
functionalize_inplace_ops(
computation_trace=computation_trc, orig_to_view_swap_map=orig_to_view_swap_map
)
)
computation_trc = computation_traces[-1]
if len(computation_traces) > orig_len:
from thunder.core.pytree import tree_flatten
from thunder.core.utils import ProxyDict

flat_args, _ = tree_flatten((computation_trc.args, computation_trc.kwargs))
arg_to_idx = ProxyDict()
for i, a in enumerate(flat_args):
if not isinstance(a, TensorProxy):
continue
arg_to_idx[a] = i

vanilla_tensor_args: set[int] = {
arg_to_idx[bsym.flat_proxy_args[1]]
for bsym in filter(lambda b: b.sym.id == prims.PrimIDs.COPY_, computation_trc.bound_symbols)
if bsym.flat_proxy_args[1] in arg_to_idx
}

if epilogue_trc is not None:
epilogue_traces = [epilogue_trc]
Expand Down Expand Up @@ -671,6 +717,7 @@ def get_computation_and_inputs(*args, **kwargs):
backward_fn,
backward_traces,
return_none_instead_of_grads,
vanilla_tensor_args,
)
if cd.cache_option is not CACHE_OPTIONS.NO_CACHING:
cs.interpreter_cache.append(cache_entry)
Expand All @@ -696,6 +743,17 @@ def fn_(*args, **kwargs) -> Any:
cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs)
cs.last_trace_host_execution_start = time.perf_counter_ns()

if cache_entry.vanilla_tensor_args:

if alias_tensor_indices_str := _alias_tensor_of_args_kwargs(*inps):
alias_tensor_indices = {int(i) for i in alias_tensor_indices_str.split(",")}
vanilla_tensor_args = cache_entry.vanilla_tensor_args
check(
not vanilla_tensor_args & alias_tensor_indices,
lambda: f"It seems that {vanilla_tensor_args} are {alias_tensor_indices=} share their storage and some of them are modified in-place",
NotImplementedError,
)

result = cache_entry.computation_fn(*inps)

if cache_entry.backward_fn:
Expand Down
8 changes: 7 additions & 1 deletion thunder/core/proxies.py
Original file line number Diff line number Diff line change
Expand Up @@ -1367,7 +1367,13 @@ def __init__(
self._distparallel_type,
self._thunder_fsdp_padding_size,
) = _infer_tensor_properties(
like, shape, device, dtype, requires_grad, distparallel_type, thunder_fsdp_padding_size
like,
shape,
device,
dtype,
requires_grad,
distparallel_type,
thunder_fsdp_padding_size,
)

# NOTE The following properties DO NOT depend on the language context or record
Expand Down
1 change: 1 addition & 0 deletions thunder/core/transform_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import time
from typing import TYPE_CHECKING
from abc import ABC, abstractmethod
from collections import defaultdict
from collections.abc import Sequence
from collections import defaultdict
from itertools import filterfalse
Expand Down
46 changes: 46 additions & 0 deletions thunder/tests/test_inplace_functionalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,3 +564,49 @@ def f(x, y, idx, src):
assert a.allclose(a_)
assert b.allclose(b_)
assert o.allclose(o_)


def test_error_out_func_with_alias_args():

@thunder.jit
def f_with_inplace(a, b):
return a.exp_() + b.tanh_()

a = torch.ones((1, 1))
b = torch.zeros((1, 1))

msg = "share their storage and some of them are modified in-place"
with pytest.raises(NotImplementedError) as excinfo:
f_with_inplace(a, a)
assert msg in str(excinfo.value)
assert (thunder.cache_hits(f_with_inplace), thunder.cache_misses(f_with_inplace)) == (0, 1)

with pytest.raises(NotImplementedError) as excinfo:
f_with_inplace(b, b)
assert msg in str(excinfo.value)
assert (thunder.cache_hits(f_with_inplace), thunder.cache_misses(f_with_inplace)) == (1, 1)

# Make sure the cache changes accordingly
f_with_inplace(a, b)
assert (thunder.cache_hits(f_with_inplace), thunder.cache_misses(f_with_inplace)) == (1, 2)

f_with_inplace(b, a)
assert (thunder.cache_hits(f_with_inplace), thunder.cache_misses(f_with_inplace)) == (2, 2)

with pytest.raises(NotImplementedError) as excinfo:
f_with_inplace(b, b)
assert msg in str(excinfo.value)
assert (thunder.cache_hits(f_with_inplace), thunder.cache_misses(f_with_inplace)) == (3, 2)

@thunder.jit
def f(a, b):
return a.exp() + b.tanh()

f(a, a)
assert (thunder.cache_hits(f), thunder.cache_misses(f)) == (0, 1)
f(a, b)
assert (thunder.cache_hits(f), thunder.cache_misses(f)) == (0, 2)
f(b, a)
assert (thunder.cache_hits(f), thunder.cache_misses(f)) == (1, 2)
f(b, b)
assert (thunder.cache_hits(f), thunder.cache_misses(f)) == (2, 2)