diff --git a/thunder/__init__.py b/thunder/__init__.py index 42aa3260e5..0b1d7b82aa 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -251,6 +251,7 @@ def _recursive_jit_call_warning() -> None: "backward_fn", "backward_traces", "return_none_instead_of_grads", + "vanilla_tensor_args", ], ) @@ -353,6 +354,26 @@ def jit( ) cs = CompileStats() + def _alias_tensor_of_args_kwargs(*args, **kwargs) -> int: + flat_args, _ = tree_flatten((args, kwargs)) + data_ptr_to_tensor_group_index = {} + tensor_group_index_to_tensor_indices = defaultdict(list) + for idx, t in enumerate(flat_args): + if pytorch.is_tensor(t) and t.layout == pytorch.strided: + data_ptr = t.untyped_storage().data_ptr() + if data_ptr not in data_ptr_to_tensor_group_index: + data_ptr_to_tensor_group_index[data_ptr] = len(data_ptr_to_tensor_group_index) + tgi = data_ptr_to_tensor_group_index[data_ptr] + tensor_group_index_to_tensor_indices[tgi].append(idx) + + alias_indices = [] + for k, v in tensor_group_index_to_tensor_indices.items(): + if len(v) > 1: + alias_indices.extend(v) + if not alias_indices: + return "" + return ",".join(f"{i}" for i in alias_indices) + @_with_cache_info_ctx def get_computation_and_inputs(*args, **kwargs): # set up a record of things in the current environment that impact caching / prologues @@ -396,6 +417,12 @@ def get_computation_and_inputs(*args, **kwargs): cache_info["no_grad_sync"] = no_grad_sync return_none_instead_of_grads = is_fsdp_enabled and no_grad_sync + # NOTE(crcrpar): If a callable is free from in-place ops whose operand is args and/or their views + # alaises wouldn't matter, thus it'd be better to nullify this entry in such cases. + # It however would require the functionalized computation trace to interact with `cache_info`, + # which seems to break the consistency of cache_info, leading to a failure in cache_info check. + cache_info["alias_tensor_indices"] = _alias_tensor_of_args_kwargs(*args, **kwargs) + # TODO RC1 Add module and function checks to prologue (make it a compile option) # Checks cache @@ -413,6 +440,7 @@ def get_computation_and_inputs(*args, **kwargs): backward_fn, backward_traces, _return_none_instead_of_grads, + _vanilla_args, ) = cache_entry try: cs.last_prologue_execution_start = time.perf_counter_ns() @@ -502,13 +530,31 @@ def get_computation_and_inputs(*args, **kwargs): prologue_traces = [prologue_trc] computation_traces = [computation_trc] orig_to_view_swap_map = check_inplace_to_views(computation_trc) + vanilla_tensor_args: set[int] | None = None if not compile_options.get("skip_inplace_functionalization", False): + orig_len = len(computation_traces) computation_traces.extend( functionalize_inplace_ops( computation_trace=computation_trc, orig_to_view_swap_map=orig_to_view_swap_map ) ) computation_trc = computation_traces[-1] + if len(computation_traces) > orig_len: + from thunder.core.pytree import tree_flatten + from thunder.core.utils import ProxyDict + + flat_args, _ = tree_flatten((computation_trc.args, computation_trc.kwargs)) + arg_to_idx = ProxyDict() + for i, a in enumerate(flat_args): + if not isinstance(a, TensorProxy): + continue + arg_to_idx[a] = i + + vanilla_tensor_args: set[int] = { + arg_to_idx[bsym.flat_proxy_args[1]] + for bsym in filter(lambda b: b.sym.id == prims.PrimIDs.COPY_, computation_trc.bound_symbols) + if bsym.flat_proxy_args[1] in arg_to_idx + } if epilogue_trc is not None: epilogue_traces = [epilogue_trc] @@ -671,6 +717,7 @@ def get_computation_and_inputs(*args, **kwargs): backward_fn, backward_traces, return_none_instead_of_grads, + vanilla_tensor_args, ) if cd.cache_option is not CACHE_OPTIONS.NO_CACHING: cs.interpreter_cache.append(cache_entry) @@ -696,6 +743,17 @@ def fn_(*args, **kwargs) -> Any: cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs) cs.last_trace_host_execution_start = time.perf_counter_ns() + if cache_entry.vanilla_tensor_args: + + if alias_tensor_indices_str := _alias_tensor_of_args_kwargs(*inps): + alias_tensor_indices = {int(i) for i in alias_tensor_indices_str.split(",")} + vanilla_tensor_args = cache_entry.vanilla_tensor_args + check( + not vanilla_tensor_args & alias_tensor_indices, + lambda: f"It seems that {vanilla_tensor_args} are {alias_tensor_indices=} share their storage and some of them are modified in-place", + NotImplementedError, + ) + result = cache_entry.computation_fn(*inps) if cache_entry.backward_fn: diff --git a/thunder/core/proxies.py b/thunder/core/proxies.py index d83abf11dc..343527eb46 100644 --- a/thunder/core/proxies.py +++ b/thunder/core/proxies.py @@ -1367,7 +1367,13 @@ def __init__( self._distparallel_type, self._thunder_fsdp_padding_size, ) = _infer_tensor_properties( - like, shape, device, dtype, requires_grad, distparallel_type, thunder_fsdp_padding_size + like, + shape, + device, + dtype, + requires_grad, + distparallel_type, + thunder_fsdp_padding_size, ) # NOTE The following properties DO NOT depend on the language context or record diff --git a/thunder/core/transform_common.py b/thunder/core/transform_common.py index a43e9b3229..38dff4a343 100644 --- a/thunder/core/transform_common.py +++ b/thunder/core/transform_common.py @@ -2,6 +2,7 @@ import time from typing import TYPE_CHECKING from abc import ABC, abstractmethod +from collections import defaultdict from collections.abc import Sequence from collections import defaultdict from itertools import filterfalse diff --git a/thunder/tests/test_inplace_functionalization.py b/thunder/tests/test_inplace_functionalization.py index 21f47309db..915f143747 100644 --- a/thunder/tests/test_inplace_functionalization.py +++ b/thunder/tests/test_inplace_functionalization.py @@ -564,3 +564,49 @@ def f(x, y, idx, src): assert a.allclose(a_) assert b.allclose(b_) assert o.allclose(o_) + + +def test_error_out_func_with_alias_args(): + + @thunder.jit + def f_with_inplace(a, b): + return a.exp_() + b.tanh_() + + a = torch.ones((1, 1)) + b = torch.zeros((1, 1)) + + msg = "share their storage and some of them are modified in-place" + with pytest.raises(NotImplementedError) as excinfo: + f_with_inplace(a, a) + assert msg in str(excinfo.value) + assert (thunder.cache_hits(f_with_inplace), thunder.cache_misses(f_with_inplace)) == (0, 1) + + with pytest.raises(NotImplementedError) as excinfo: + f_with_inplace(b, b) + assert msg in str(excinfo.value) + assert (thunder.cache_hits(f_with_inplace), thunder.cache_misses(f_with_inplace)) == (1, 1) + + # Make sure the cache changes accordingly + f_with_inplace(a, b) + assert (thunder.cache_hits(f_with_inplace), thunder.cache_misses(f_with_inplace)) == (1, 2) + + f_with_inplace(b, a) + assert (thunder.cache_hits(f_with_inplace), thunder.cache_misses(f_with_inplace)) == (2, 2) + + with pytest.raises(NotImplementedError) as excinfo: + f_with_inplace(b, b) + assert msg in str(excinfo.value) + assert (thunder.cache_hits(f_with_inplace), thunder.cache_misses(f_with_inplace)) == (3, 2) + + @thunder.jit + def f(a, b): + return a.exp() + b.tanh() + + f(a, a) + assert (thunder.cache_hits(f), thunder.cache_misses(f)) == (0, 1) + f(a, b) + assert (thunder.cache_hits(f), thunder.cache_misses(f)) == (0, 2) + f(b, a) + assert (thunder.cache_hits(f), thunder.cache_misses(f)) == (1, 2) + f(b, b) + assert (thunder.cache_hits(f), thunder.cache_misses(f)) == (2, 2)