diff --git a/thunder/common.py b/thunder/common.py index 92ebeac8e5..d44bc73453 100644 --- a/thunder/common.py +++ b/thunder/common.py @@ -293,7 +293,7 @@ def translate(x: Any, *, name: str | None = None) -> Any: if isinstance(x, Proxy): # register proxy name used by NumberProxies in TensorProxy.shape if isinstance(x, TensorProxy): - for s_p in filter(lambda s: isinstance(s, Proxy), x.shape): + for s_p in filter(lambda s: isinstance(s, Proxy), x._shape): # TODO need to avoid name conflict here, since s_p.name # could have conflicted with something defined earlier in # the trace. diff --git a/thunder/core/proxies.py b/thunder/core/proxies.py index ae7744dda3..181b325230 100644 --- a/thunder/core/proxies.py +++ b/thunder/core/proxies.py @@ -16,7 +16,12 @@ from thunder.core.compile_data import using_symbolic_values, using_jit from thunder.core.interpreter import is_jitting, ProvenanceRecord, PseudoInst -from thunder.core.trace import VariableInterface, get_tracectx, TraceCtx +from thunder.core.trace import ( + VariableInterface, + get_tracectx, + is_tracing, + TraceCtx, +) from thunder.core.baseutils import ( ProxyInterface, NumberProxyInterface, @@ -1242,8 +1247,7 @@ def _infer_tensor_properties( thunder_fsdp_padding_size if thunder_fsdp_padding_size is not None else _thunder_fsdp_padding_size ) - # dynamic shape not yet enabled, otherwise, the bake in should be guarded with if not using_symbolic_values(): - # dynamic shape support is currently block by #471 https://github.com/Lightning-AI/lightning-thunder/issues/471 + baseutils.check(_shape is not None, lambda: f"_shape cannot be None when creating TensorProxy") if not using_symbolic_values(): _shape = tuple(pyval(x) for x in _shape) # Computes derived properties @@ -1251,7 +1255,7 @@ def _infer_tensor_properties( else: # deferred computation of numel # TODO: similar to how `shape` is handled, this should be CSE or lifted for efficiency - _numel = lambda tp: reduce(operator.mul, tp.shape, 1) + _numel = lambda *args: reduce(operator.mul, _shape, 1) # TODO Alias rank to ndim? _ndim = len(_shape) @@ -1459,7 +1463,12 @@ def __init__( # outside of a trace or language context @property def shape(self): - return self._shape + if not using_symbolic_values() or not is_tracing(): + return self._shape + else: + from thunder.core.prims import shape + + return shape(self) @property def ndim(self): @@ -1548,10 +1557,10 @@ def replace(self, **changes): ) def __repr__(self): - return f'<{type(self).__name__}(name="{self.name}", dtype={self.dtype}, shape={self.shape})>' + return f'<{type(self).__name__}(name="{self.name}", dtype={self.dtype}, shape={self._shape})>' def type_string(self): - return f"{self.device.device_str()} {self.dtype.shortname()}{list(self.shape)}" + return f"{self.device.device_str()} {self.dtype.shortname()}{list(self._shape)}" # NOTE __getattr__ is overridden to support language-specific methods def __getattr__(self, attr: str, /): diff --git a/thunder/core/trace_interpreter.py b/thunder/core/trace_interpreter.py index a7d7e598e2..ed0e110191 100644 --- a/thunder/core/trace_interpreter.py +++ b/thunder/core/trace_interpreter.py @@ -122,9 +122,9 @@ def add_to_swap_map(old, new): # (FSDP, tensor parallel) transforms do "break" shape metadata new_trace.names.remove(old.name) # taken by the .replace proxy if isinstance(new, VJPDual): - old = old.replace(shape=new.primal.shape) + old = old.replace(shape=new.primal._shape) else: - old = old.replace(shape=new.shape) + old = old.replace(shape=new._shape) if isinstance(new, VJPDual): swap_map[variableify(new.primal)] = old diff --git a/thunder/core/transforms.py b/thunder/core/transforms.py index 71d6872a66..0979551a5e 100644 --- a/thunder/core/transforms.py +++ b/thunder/core/transforms.py @@ -2811,7 +2811,7 @@ def _vjp(primals, cotangents, **kwargs): # If the argument is a CPU scalar tensor, its gradient needs to be summed into a scalar tensor. vjp_result = tuple( ( - sum_to(grad, arg.shape) + sum_to(grad, arg._shape) if (grad is not None and isinstance(arg, TensorProxy) and arg.device.type == "cpu") else grad ) diff --git a/thunder/core/utils.py b/thunder/core/utils.py index edce139cc3..bb91172b46 100644 --- a/thunder/core/utils.py +++ b/thunder/core/utils.py @@ -1000,8 +1000,9 @@ def __repr__(self) -> str: return str(self._dict) -# NOTE That this pass does not assume that the bound symbols are in a reasonable order, -# but it does assume that each proxy is uniquely constructed once +# NOTE That this pass does not assume that the bound symbols are in a reasonable order. +# For bound symbols with multiple producers, this pass returns the first producer of +# in order of the presented bound symbols # Returns a proxy -> producer mapping # If _map_to_numbers is True then producers are represented by their position in the trace (their "line number") def producers(trace_or_bsyms: TraceCtx | list[BoundSymbolInterface], *, _map_to_numbers: bool = False) -> ProxyDict: @@ -1021,6 +1022,10 @@ def producers(trace_or_bsyms: TraceCtx | list[BoundSymbolInterface], *, _map_to_ continue for out in bsym.flat_proxy_outs: + # if a producer has already been traversed, skip + if producers.get(out, None) != None: + continue + vout = variableify(out) # Checks if the proxy was also an input (in which case this is not its producers) diff --git a/thunder/executors/pythonex.py b/thunder/executors/pythonex.py index 942e9c9b0b..de2bae4d0b 100644 --- a/thunder/executors/pythonex.py +++ b/thunder/executors/pythonex.py @@ -344,6 +344,7 @@ def _elementwise_binary_checker(a: NumberLike | TensorProxy, b: NumberLike | Ten pythonex_pow = ex.register_operator("pow", like=prims.pow, module=operator) sub = ex.register_operator("sub", like=prims.sub, module=operator) div = ex.register_operator("div", like=prims.div, fn=_div_prim_impl) +shape = ex.register_operator("shape", like=prims.shape, fn=lambda x: x.shape) # TODO: Restore truediv once we find it... # truediv = ex.register_operator("truediv", like=prims.truediv, module=operator) @@ -367,6 +368,7 @@ def _elementwise_binary_checker(a: NumberLike | TensorProxy, b: NumberLike | Ten ex.register_implementation(prims.pow, pythonex_pow, checker=_elementwise_binary_checker) ex.register_implementation(prims.sub, sub, checker=_elementwise_binary_checker) ex.register_implementation(prims.div, div, checker=_elementwise_binary_checker) +ex.register_implementation(prims.shape, shape, checker=_always_executable) def _sink(*args, **kwargs): diff --git a/thunder/tests/opinfos.py b/thunder/tests/opinfos.py index c292073482..038997fd57 100644 --- a/thunder/tests/opinfos.py +++ b/thunder/tests/opinfos.py @@ -2246,7 +2246,11 @@ def pow_sample_input_generator(op, device, dtype, requires_grad, *, no_rhs_numbe ), # NOTE: PyTorch fails with RuntimeError: "reciprocal_cuda" not implemented for 'Long' occasionally when the exponent is CPU scalar tensor # e.g.: x=torch.tensor([[ 6, 5, 1, -8],], device='cuda:0');y=torch.tensor(-1);torch.pow(x,y) - DecorateInfo(pytest.mark.xfail, "test_core_vs_torch_consistency", dtypes=(datatypes.int32, datatypes.int64)), + DecorateInfo( + pytest.mark.xfail, + "test_core_vs_torch_consistency", + dtypes=(datatypes.int8, datatypes.int16, datatypes.int32, datatypes.int64), + ), ), ) elementwise_binary_ops.append(pow_opinfo) diff --git a/thunder/tests/test_jit_general.py b/thunder/tests/test_jit_general.py index 44ea9cd04c..3af1aea5c2 100644 --- a/thunder/tests/test_jit_general.py +++ b/thunder/tests/test_jit_general.py @@ -1116,7 +1116,7 @@ def forward(self, x): ("cpu", "cuda"), ) def test_cache_symbolic_values_reshape(device): - if not torch.cuda.is_available(): + if device == "cuda" and not torch.cuda.is_available(): pytest.skip("CUDA not available") a = torch.randn((4, 8, 6), device=device) @@ -1439,3 +1439,33 @@ def foo(a): assert_close(actual, expected) assert thunder.cache_misses(jfoo) == 1 assert thunder.cache_hits(jfoo) == 1 + + +def test_cache_symbolic_values_reshape_numel(): + def foo(a): + a = torch.reshape(a, [a.numel()]) + return a.relu() + + jfoo = thunder.jit(foo, cache="symbolic values") + + a = torch.randn(2, 3, 8, requires_grad=True, device="cpu") + + actual = jfoo(a) + expected = foo(a) + + assert_close(actual, expected) + + +def test_cache_symbolic_values_slice(): + def foo(a): + a = a[..., : a.shape[-1]] + return a.relu() + + jfoo = thunder.jit(foo, cache="symbolic values") + + a = torch.randn(2, 3, 8, requires_grad=True, device="cpu") + + actual = jfoo(a) + expected = foo(a) + + assert_close(actual, expected)