From 3fb1302fdc4aeee506048e4560cc3715709f0770 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Fri, 4 Oct 2024 08:15:54 +0000 Subject: [PATCH 01/15] quick fix on reshape --- thunder/core/proxies.py | 9 +++++++-- thunder/executors/pythonex.py | 2 ++ 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/thunder/core/proxies.py b/thunder/core/proxies.py index ae7744dda3..01f9baadc6 100644 --- a/thunder/core/proxies.py +++ b/thunder/core/proxies.py @@ -1251,7 +1251,7 @@ def _infer_tensor_properties( else: # deferred computation of numel # TODO: similar to how `shape` is handled, this should be CSE or lifted for efficiency - _numel = lambda tp: reduce(operator.mul, tp.shape, 1) + _numel = lambda *args: reduce(operator.mul, _shape, 1) # TODO Alias rank to ndim? _ndim = len(_shape) @@ -1459,7 +1459,12 @@ def __init__( # outside of a trace or language context @property def shape(self): - return self._shape + trace: None | TraceCtx = get_tracectx() + if trace is None: + return self._shape + else: + from thunder.core.prims import shape + return shape(self) @property def ndim(self): diff --git a/thunder/executors/pythonex.py b/thunder/executors/pythonex.py index 942e9c9b0b..de2bae4d0b 100644 --- a/thunder/executors/pythonex.py +++ b/thunder/executors/pythonex.py @@ -344,6 +344,7 @@ def _elementwise_binary_checker(a: NumberLike | TensorProxy, b: NumberLike | Ten pythonex_pow = ex.register_operator("pow", like=prims.pow, module=operator) sub = ex.register_operator("sub", like=prims.sub, module=operator) div = ex.register_operator("div", like=prims.div, fn=_div_prim_impl) +shape = ex.register_operator("shape", like=prims.shape, fn=lambda x: x.shape) # TODO: Restore truediv once we find it... # truediv = ex.register_operator("truediv", like=prims.truediv, module=operator) @@ -367,6 +368,7 @@ def _elementwise_binary_checker(a: NumberLike | TensorProxy, b: NumberLike | Ten ex.register_implementation(prims.pow, pythonex_pow, checker=_elementwise_binary_checker) ex.register_implementation(prims.sub, sub, checker=_elementwise_binary_checker) ex.register_implementation(prims.div, div, checker=_elementwise_binary_checker) +ex.register_implementation(prims.shape, shape, checker=_always_executable) def _sink(*args, **kwargs): From 7c5677ccd64aca18b6e5e55b98f622a76d00d1d5 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Fri, 4 Oct 2024 13:03:42 +0000 Subject: [PATCH 02/15] patching tracectx for printing --- thunder/core/trace.py | 210 +++++++++++++++++++++--------------------- 1 file changed, 105 insertions(+), 105 deletions(-) diff --git a/thunder/core/trace.py b/thunder/core/trace.py index b8670212aa..27062a61c5 100644 --- a/thunder/core/trace.py +++ b/thunder/core/trace.py @@ -361,120 +361,120 @@ def python(self, *, print_depth: int = 1, include_decorators: bool = True) -> st try: # Acquires ctx and imports from the BoundSymbols... import_ctx, call_ctx, object_ctx = self._gather_ctxs() + finally: + reset_tracectx(token) - # ... and from the signature - if self._siginfo is None and self.fn is None: - signature_str = f"# No signature available" - else: - si = self.siginfo() - signature_str = si.prettyprint(trace=self, import_ctx=import_ctx, object_ctx=object_ctx) - - # Constructs program strings - program = [] + # ... and from the signature + if self._siginfo is None and self.fn is None: + signature_str = f"# No signature available" + else: + si = self.siginfo() + signature_str = si.prettyprint(trace=self, import_ctx=import_ctx, object_ctx=object_ctx) - # Prints provenance (if any) first - if self._provenance is not None: - provenance_str = f"{str(self._provenance)}" - program.append(provenance_str) + # Constructs program strings + program = [] - # NOTE torch is explicitly imported because we always run in the no_grad() ctx (see below) - import torch + # Prints provenance (if any) first + if self._provenance is not None: + provenance_str = f"{str(self._provenance)}" + program.append(provenance_str) - import_ctx["torch"] = torch + # NOTE torch is explicitly imported because we always run in the no_grad() ctx (see below) + import torch - # Prints imports, sorted by name + import_ctx["torch"] = torch - def keyfn(class_or_module: type | ModuleType) -> str: - if isinstance(class_or_module, ModuleType): - return class_or_module.__name__ - return class_or_module.__module__ + # Prints imports, sorted by name - name: str - class_or_module: type | ModuleType - for name, class_or_module in sorted(import_ctx.items(), key=lambda x: keyfn(x[1])): - import_str: str + def keyfn(class_or_module: type | ModuleType) -> str: + if isinstance(class_or_module, ModuleType): + return class_or_module.__name__ + return class_or_module.__module__ - # Handles class imports - if not isinstance(class_or_module, ModuleType): - cls: type = class_or_module - import_str = f"from {cls.__module__} import {cls.__name__}" + name: str + class_or_module: type | ModuleType + for name, class_or_module in sorted(import_ctx.items(), key=lambda x: keyfn(x[1])): + import_str: str + + # Handles class imports + if not isinstance(class_or_module, ModuleType): + cls: type = class_or_module + import_str = f"from {cls.__module__} import {cls.__name__}" + else: + # class_or_module is a module + module: ModuleType = class_or_module + if module.__name__ == name: + import_str = f"import {module.__name__}" else: - # class_or_module is a module - module: ModuleType = class_or_module - if module.__name__ == name: - import_str = f"import {module.__name__}" - else: - import_str = f"import {module.__name__} as {name}" - program.append(import_str) - - if include_decorators: - program.append("from thunder.executors.torchex import no_autocast") - - # Separates imports from the function for readability - if len(import_ctx) > 0: - program.append("") - - if include_decorators: - # NOTE: For TransformerEngine executor, we want to wrap the generated - # forward function in fp8_autocast ctx manager. - # In the future, if other executor has similar requirements, we should - # add a new extension point for executors - # NOTE: For TE v1.6 onwards, `fp8_autocast` checks if `torch.is_grad_enabled` for updating - # the FP8 scales/inverses. So this decorator should be applied before `torch.no_grad` (so that - # it is in grad enabled part). - from thunder.executors.transformer_engineex import _is_te_linear_enabled, _get_te_wrapper_string - - if self._include_te_fp8_autocast and _is_te_linear_enabled(import_ctx, object_ctx): - program.append(_get_te_wrapper_string()) - - # Disable gradients since Thunder takes care of this (for when calling torch operations) - program.append("@torch.no_grad()") - # Disable autocast since we already generated the trace with it in consideration (for when calling torch - # operations) - program.append("@no_autocast") - - # Prints the signature - program.append(signature_str) - - # TODO Print objects from context - # Prints constants (if any) upfront - # constants = tuple(om for om in self._object_meta_map.values() if om.is_constant) - # if len(constants) > 0: - # const_comment_str = f"{indent}# Initializes constants" - # program.append(const_comment_str) - # for c in constants: - # constant_python = c.python(indent=1) - # program.extend(constant_python) - - # Separates constants from operations - # if len(constants) > 0: - # program.append("") - - # Prints operations - - filename = None - lineno = None - for i, bsym in enumerate(self.bound_symbols): - if ( - bsym.source_filename is not None - and bsym.source_positions is not None - and bsym.source_positions.lineno is not None - ) and (filename != bsym.source_filename or lineno != bsym.source_positions.lineno): - if i > 0: - program.append("") - src_line = get_source_line(bsym.source_filename, bsym.source_positions.lineno) - program.append(f""" # {bsym.source_filename}:{bsym.source_positions.lineno}: \t{src_line}""") - filename = bsym.source_filename - lineno = bsym.source_positions and bsym.source_positions.lineno - - lines = bsym.python(indent=1, print_depth=print_depth) - program.extend(lines) - - python = "\n".join(program) - - return python - finally: - reset_tracectx(token) + import_str = f"import {module.__name__} as {name}" + program.append(import_str) + + if include_decorators: + program.append("from thunder.executors.torchex import no_autocast") + + # Separates imports from the function for readability + if len(import_ctx) > 0: + program.append("") + + if include_decorators: + # NOTE: For TransformerEngine executor, we want to wrap the generated + # forward function in fp8_autocast ctx manager. + # In the future, if other executor has similar requirements, we should + # add a new extension point for executors + # NOTE: For TE v1.6 onwards, `fp8_autocast` checks if `torch.is_grad_enabled` for updating + # the FP8 scales/inverses. So this decorator should be applied before `torch.no_grad` (so that + # it is in grad enabled part). + from thunder.executors.transformer_engineex import _is_te_linear_enabled, _get_te_wrapper_string + + if self._include_te_fp8_autocast and _is_te_linear_enabled(import_ctx, object_ctx): + program.append(_get_te_wrapper_string()) + + # Disable gradients since Thunder takes care of this (for when calling torch operations) + program.append("@torch.no_grad()") + # Disable autocast since we already generated the trace with it in consideration (for when calling torch + # operations) + program.append("@no_autocast") + + # Prints the signature + program.append(signature_str) + + # TODO Print objects from context + # Prints constants (if any) upfront + # constants = tuple(om for om in self._object_meta_map.values() if om.is_constant) + # if len(constants) > 0: + # const_comment_str = f"{indent}# Initializes constants" + # program.append(const_comment_str) + # for c in constants: + # constant_python = c.python(indent=1) + # program.extend(constant_python) + + # Separates constants from operations + # if len(constants) > 0: + # program.append("") + + # Prints operations + + filename = None + lineno = None + for i, bsym in enumerate(self.bound_symbols): + if ( + bsym.source_filename is not None + and bsym.source_positions is not None + and bsym.source_positions.lineno is not None + ) and (filename != bsym.source_filename or lineno != bsym.source_positions.lineno): + if i > 0: + program.append("") + src_line = get_source_line(bsym.source_filename, bsym.source_positions.lineno) + program.append(f""" # {bsym.source_filename}:{bsym.source_positions.lineno}: \t{src_line}""") + filename = bsym.source_filename + lineno = bsym.source_positions and bsym.source_positions.lineno + + lines = bsym.python(indent=1, print_depth=print_depth) + program.extend(lines) + + python = "\n".join(program) + + return python # Returns a Python callable that executes the trace # TODO issue "Create a mechanism for freezing TraceCtx objects" From 60f9f1787a8993b1aec1b386e8fff5c6d3f9c398 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 4 Oct 2024 13:08:18 +0000 Subject: [PATCH 03/15] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- thunder/core/proxies.py | 1 + 1 file changed, 1 insertion(+) diff --git a/thunder/core/proxies.py b/thunder/core/proxies.py index 01f9baadc6..6f5d55cb16 100644 --- a/thunder/core/proxies.py +++ b/thunder/core/proxies.py @@ -1464,6 +1464,7 @@ def shape(self): return self._shape else: from thunder.core.prims import shape + return shape(self) @property From 3d3700d373dc99073588eb6a3d42179ea426319e Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Fri, 4 Oct 2024 13:47:00 +0000 Subject: [PATCH 04/15] switch to _shape to avoid trace --- thunder/common.py | 2 +- thunder/core/trace_interpreter.py | 4 ++-- thunder/core/transforms.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/thunder/common.py b/thunder/common.py index 92ebeac8e5..d44bc73453 100644 --- a/thunder/common.py +++ b/thunder/common.py @@ -293,7 +293,7 @@ def translate(x: Any, *, name: str | None = None) -> Any: if isinstance(x, Proxy): # register proxy name used by NumberProxies in TensorProxy.shape if isinstance(x, TensorProxy): - for s_p in filter(lambda s: isinstance(s, Proxy), x.shape): + for s_p in filter(lambda s: isinstance(s, Proxy), x._shape): # TODO need to avoid name conflict here, since s_p.name # could have conflicted with something defined earlier in # the trace. diff --git a/thunder/core/trace_interpreter.py b/thunder/core/trace_interpreter.py index a7d7e598e2..ed0e110191 100644 --- a/thunder/core/trace_interpreter.py +++ b/thunder/core/trace_interpreter.py @@ -122,9 +122,9 @@ def add_to_swap_map(old, new): # (FSDP, tensor parallel) transforms do "break" shape metadata new_trace.names.remove(old.name) # taken by the .replace proxy if isinstance(new, VJPDual): - old = old.replace(shape=new.primal.shape) + old = old.replace(shape=new.primal._shape) else: - old = old.replace(shape=new.shape) + old = old.replace(shape=new._shape) if isinstance(new, VJPDual): swap_map[variableify(new.primal)] = old diff --git a/thunder/core/transforms.py b/thunder/core/transforms.py index 71d6872a66..0979551a5e 100644 --- a/thunder/core/transforms.py +++ b/thunder/core/transforms.py @@ -2811,7 +2811,7 @@ def _vjp(primals, cotangents, **kwargs): # If the argument is a CPU scalar tensor, its gradient needs to be summed into a scalar tensor. vjp_result = tuple( ( - sum_to(grad, arg.shape) + sum_to(grad, arg._shape) if (grad is not None and isinstance(arg, TensorProxy) and arg.device.type == "cpu") else grad ) From 27a713005d005cb4e4f1a16721355bf5aa6abd40 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Fri, 4 Oct 2024 21:29:37 +0000 Subject: [PATCH 05/15] reverting this first and follow up to check the print again --- thunder/core/trace.py | 210 +++++++++++++++++++++--------------------- 1 file changed, 105 insertions(+), 105 deletions(-) diff --git a/thunder/core/trace.py b/thunder/core/trace.py index 27062a61c5..b8670212aa 100644 --- a/thunder/core/trace.py +++ b/thunder/core/trace.py @@ -361,120 +361,120 @@ def python(self, *, print_depth: int = 1, include_decorators: bool = True) -> st try: # Acquires ctx and imports from the BoundSymbols... import_ctx, call_ctx, object_ctx = self._gather_ctxs() - finally: - reset_tracectx(token) - # ... and from the signature - if self._siginfo is None and self.fn is None: - signature_str = f"# No signature available" - else: - si = self.siginfo() - signature_str = si.prettyprint(trace=self, import_ctx=import_ctx, object_ctx=object_ctx) + # ... and from the signature + if self._siginfo is None and self.fn is None: + signature_str = f"# No signature available" + else: + si = self.siginfo() + signature_str = si.prettyprint(trace=self, import_ctx=import_ctx, object_ctx=object_ctx) - # Constructs program strings - program = [] + # Constructs program strings + program = [] - # Prints provenance (if any) first - if self._provenance is not None: - provenance_str = f"{str(self._provenance)}" - program.append(provenance_str) + # Prints provenance (if any) first + if self._provenance is not None: + provenance_str = f"{str(self._provenance)}" + program.append(provenance_str) - # NOTE torch is explicitly imported because we always run in the no_grad() ctx (see below) - import torch + # NOTE torch is explicitly imported because we always run in the no_grad() ctx (see below) + import torch - import_ctx["torch"] = torch + import_ctx["torch"] = torch - # Prints imports, sorted by name + # Prints imports, sorted by name - def keyfn(class_or_module: type | ModuleType) -> str: - if isinstance(class_or_module, ModuleType): - return class_or_module.__name__ - return class_or_module.__module__ + def keyfn(class_or_module: type | ModuleType) -> str: + if isinstance(class_or_module, ModuleType): + return class_or_module.__name__ + return class_or_module.__module__ - name: str - class_or_module: type | ModuleType - for name, class_or_module in sorted(import_ctx.items(), key=lambda x: keyfn(x[1])): - import_str: str - - # Handles class imports - if not isinstance(class_or_module, ModuleType): - cls: type = class_or_module - import_str = f"from {cls.__module__} import {cls.__name__}" - else: - # class_or_module is a module - module: ModuleType = class_or_module - if module.__name__ == name: - import_str = f"import {module.__name__}" + name: str + class_or_module: type | ModuleType + for name, class_or_module in sorted(import_ctx.items(), key=lambda x: keyfn(x[1])): + import_str: str + + # Handles class imports + if not isinstance(class_or_module, ModuleType): + cls: type = class_or_module + import_str = f"from {cls.__module__} import {cls.__name__}" else: - import_str = f"import {module.__name__} as {name}" - program.append(import_str) - - if include_decorators: - program.append("from thunder.executors.torchex import no_autocast") - - # Separates imports from the function for readability - if len(import_ctx) > 0: - program.append("") - - if include_decorators: - # NOTE: For TransformerEngine executor, we want to wrap the generated - # forward function in fp8_autocast ctx manager. - # In the future, if other executor has similar requirements, we should - # add a new extension point for executors - # NOTE: For TE v1.6 onwards, `fp8_autocast` checks if `torch.is_grad_enabled` for updating - # the FP8 scales/inverses. So this decorator should be applied before `torch.no_grad` (so that - # it is in grad enabled part). - from thunder.executors.transformer_engineex import _is_te_linear_enabled, _get_te_wrapper_string - - if self._include_te_fp8_autocast and _is_te_linear_enabled(import_ctx, object_ctx): - program.append(_get_te_wrapper_string()) - - # Disable gradients since Thunder takes care of this (for when calling torch operations) - program.append("@torch.no_grad()") - # Disable autocast since we already generated the trace with it in consideration (for when calling torch - # operations) - program.append("@no_autocast") - - # Prints the signature - program.append(signature_str) - - # TODO Print objects from context - # Prints constants (if any) upfront - # constants = tuple(om for om in self._object_meta_map.values() if om.is_constant) - # if len(constants) > 0: - # const_comment_str = f"{indent}# Initializes constants" - # program.append(const_comment_str) - # for c in constants: - # constant_python = c.python(indent=1) - # program.extend(constant_python) - - # Separates constants from operations - # if len(constants) > 0: - # program.append("") - - # Prints operations - - filename = None - lineno = None - for i, bsym in enumerate(self.bound_symbols): - if ( - bsym.source_filename is not None - and bsym.source_positions is not None - and bsym.source_positions.lineno is not None - ) and (filename != bsym.source_filename or lineno != bsym.source_positions.lineno): - if i > 0: - program.append("") - src_line = get_source_line(bsym.source_filename, bsym.source_positions.lineno) - program.append(f""" # {bsym.source_filename}:{bsym.source_positions.lineno}: \t{src_line}""") - filename = bsym.source_filename - lineno = bsym.source_positions and bsym.source_positions.lineno - - lines = bsym.python(indent=1, print_depth=print_depth) - program.extend(lines) - - python = "\n".join(program) - - return python + # class_or_module is a module + module: ModuleType = class_or_module + if module.__name__ == name: + import_str = f"import {module.__name__}" + else: + import_str = f"import {module.__name__} as {name}" + program.append(import_str) + + if include_decorators: + program.append("from thunder.executors.torchex import no_autocast") + + # Separates imports from the function for readability + if len(import_ctx) > 0: + program.append("") + + if include_decorators: + # NOTE: For TransformerEngine executor, we want to wrap the generated + # forward function in fp8_autocast ctx manager. + # In the future, if other executor has similar requirements, we should + # add a new extension point for executors + # NOTE: For TE v1.6 onwards, `fp8_autocast` checks if `torch.is_grad_enabled` for updating + # the FP8 scales/inverses. So this decorator should be applied before `torch.no_grad` (so that + # it is in grad enabled part). + from thunder.executors.transformer_engineex import _is_te_linear_enabled, _get_te_wrapper_string + + if self._include_te_fp8_autocast and _is_te_linear_enabled(import_ctx, object_ctx): + program.append(_get_te_wrapper_string()) + + # Disable gradients since Thunder takes care of this (for when calling torch operations) + program.append("@torch.no_grad()") + # Disable autocast since we already generated the trace with it in consideration (for when calling torch + # operations) + program.append("@no_autocast") + + # Prints the signature + program.append(signature_str) + + # TODO Print objects from context + # Prints constants (if any) upfront + # constants = tuple(om for om in self._object_meta_map.values() if om.is_constant) + # if len(constants) > 0: + # const_comment_str = f"{indent}# Initializes constants" + # program.append(const_comment_str) + # for c in constants: + # constant_python = c.python(indent=1) + # program.extend(constant_python) + + # Separates constants from operations + # if len(constants) > 0: + # program.append("") + + # Prints operations + + filename = None + lineno = None + for i, bsym in enumerate(self.bound_symbols): + if ( + bsym.source_filename is not None + and bsym.source_positions is not None + and bsym.source_positions.lineno is not None + ) and (filename != bsym.source_filename or lineno != bsym.source_positions.lineno): + if i > 0: + program.append("") + src_line = get_source_line(bsym.source_filename, bsym.source_positions.lineno) + program.append(f""" # {bsym.source_filename}:{bsym.source_positions.lineno}: \t{src_line}""") + filename = bsym.source_filename + lineno = bsym.source_positions and bsym.source_positions.lineno + + lines = bsym.python(indent=1, print_depth=print_depth) + program.extend(lines) + + python = "\n".join(program) + + return python + finally: + reset_tracectx(token) # Returns a Python callable that executes the trace # TODO issue "Create a mechanism for freezing TraceCtx objects" From db34394d829a40267f82a017e24db0bcc9448a0a Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Fri, 4 Oct 2024 21:35:50 +0000 Subject: [PATCH 06/15] TensorProxy string print shouldn't trace its shape access --- thunder/core/proxies.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/thunder/core/proxies.py b/thunder/core/proxies.py index 6f5d55cb16..0035ba850a 100644 --- a/thunder/core/proxies.py +++ b/thunder/core/proxies.py @@ -1554,10 +1554,10 @@ def replace(self, **changes): ) def __repr__(self): - return f'<{type(self).__name__}(name="{self.name}", dtype={self.dtype}, shape={self.shape})>' + return f'<{type(self).__name__}(name="{self.name}", dtype={self.dtype}, shape={self._shape})>' def type_string(self): - return f"{self.device.device_str()} {self.dtype.shortname()}{list(self.shape)}" + return f"{self.device.device_str()} {self.dtype.shortname()}{list(self._shape)}" # NOTE __getattr__ is overridden to support language-specific methods def __getattr__(self, attr: str, /): From cb791336d68f4c29f086226d995eb6e33331d8aa Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Fri, 4 Oct 2024 22:02:18 +0000 Subject: [PATCH 07/15] fixing tests --- thunder/tests/test_jit_general.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/thunder/tests/test_jit_general.py b/thunder/tests/test_jit_general.py index 44ea9cd04c..18f74f1951 100644 --- a/thunder/tests/test_jit_general.py +++ b/thunder/tests/test_jit_general.py @@ -1116,7 +1116,7 @@ def forward(self, x): ("cpu", "cuda"), ) def test_cache_symbolic_values_reshape(device): - if not torch.cuda.is_available(): + if device == "cuda" and not torch.cuda.is_available(): pytest.skip("CUDA not available") a = torch.randn((4, 8, 6), device=device) From d8e2d971cd2260251b7191d4a77d90d484d0feb5 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Sat, 5 Oct 2024 08:16:30 +0000 Subject: [PATCH 08/15] fixing dce to handle multiple producers; adding tests --- thunder/core/utils.py | 9 +++++-- thunder/tests/test_jit_general.py | 44 +++++++++++++++++++++++++++++++ 2 files changed, 51 insertions(+), 2 deletions(-) diff --git a/thunder/core/utils.py b/thunder/core/utils.py index edce139cc3..bb91172b46 100644 --- a/thunder/core/utils.py +++ b/thunder/core/utils.py @@ -1000,8 +1000,9 @@ def __repr__(self) -> str: return str(self._dict) -# NOTE That this pass does not assume that the bound symbols are in a reasonable order, -# but it does assume that each proxy is uniquely constructed once +# NOTE That this pass does not assume that the bound symbols are in a reasonable order. +# For bound symbols with multiple producers, this pass returns the first producer of +# in order of the presented bound symbols # Returns a proxy -> producer mapping # If _map_to_numbers is True then producers are represented by their position in the trace (their "line number") def producers(trace_or_bsyms: TraceCtx | list[BoundSymbolInterface], *, _map_to_numbers: bool = False) -> ProxyDict: @@ -1021,6 +1022,10 @@ def producers(trace_or_bsyms: TraceCtx | list[BoundSymbolInterface], *, _map_to_ continue for out in bsym.flat_proxy_outs: + # if a producer has already been traversed, skip + if producers.get(out, None) != None: + continue + vout = variableify(out) # Checks if the proxy was also an input (in which case this is not its producers) diff --git a/thunder/tests/test_jit_general.py b/thunder/tests/test_jit_general.py index 18f74f1951..3bd5dfb268 100644 --- a/thunder/tests/test_jit_general.py +++ b/thunder/tests/test_jit_general.py @@ -1439,3 +1439,47 @@ def foo(a): assert_close(actual, expected) assert thunder.cache_misses(jfoo) == 1 assert thunder.cache_hits(jfoo) == 1 + + +@pytest.mark.parametrize( + "device", + ("cpu", "cuda"), +) +def test_cache_symbolic_values_reshape_numel(device): + if device == "cuda" and not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + def foo(a): + a = torch.reshape(a, [a.numel()]) + return a.relu() + + jfoo = thunder.jit(foo, cache="symbolic values") + + a = torch.randn(2, 3, 8, requires_grad=True, device=device) + + actual = jfoo(a) + expected = foo(a) + + assert_close(actual, expected) + + +@pytest.mark.parametrize( + "device", + ("cpu", "cuda"), +) +def test_cache_symbolic_values_slice(device): + if device == "cuda" and not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + def foo(a): + a = a[..., : a.shape[-1]] + return a.relu() + + jfoo = thunder.jit(foo, cache="symbolic values") + + a = torch.randn(2, 3, 8, requires_grad=True, device=device) + + actual = jfoo(a) + expected = foo(a) + + assert_close(actual, expected) From 100c9f8d65f66f2546ca67cc2109e093f06be17a Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Sat, 5 Oct 2024 09:17:43 +0000 Subject: [PATCH 09/15] let's not set the whole world on fire --- thunder/core/proxies.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/thunder/core/proxies.py b/thunder/core/proxies.py index 0035ba850a..0270e64bdf 100644 --- a/thunder/core/proxies.py +++ b/thunder/core/proxies.py @@ -16,7 +16,12 @@ from thunder.core.compile_data import using_symbolic_values, using_jit from thunder.core.interpreter import is_jitting, ProvenanceRecord, PseudoInst -from thunder.core.trace import VariableInterface, get_tracectx, TraceCtx +from thunder.core.trace import ( + VariableInterface, + get_tracectx, + is_tracing, + TraceCtx, +) from thunder.core.baseutils import ( ProxyInterface, NumberProxyInterface, @@ -1459,8 +1464,7 @@ def __init__( # outside of a trace or language context @property def shape(self): - trace: None | TraceCtx = get_tracectx() - if trace is None: + if not using_symbolic_values() or not is_tracing(): return self._shape else: from thunder.core.prims import shape From 851ad773b2ee2d166b7139505ba04c412a34664f Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Mon, 7 Oct 2024 19:42:58 +0000 Subject: [PATCH 10/15] quick patch on tests --- thunder/tests/opinfos.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/thunder/tests/opinfos.py b/thunder/tests/opinfos.py index c292073482..3628c57db5 100644 --- a/thunder/tests/opinfos.py +++ b/thunder/tests/opinfos.py @@ -2207,7 +2207,13 @@ def polygamma_sample_input_generator(op, device, dtype, requires_grad, *, no_rhs def pow_sample_input_generator(op, device, dtype, requires_grad, *, no_rhs_numbers: bool = False, **kwargs): - default_generator = partial(elementwise_binary_generator, no_rhs_numbers=True) + # exclude_zero avoids having + # t_0 = tensor([...], device="cuda", dtype=torch.int8) + # t_1 = tensor(-1, dtype=torch.int8) + # torch.pow(t_0, t_1) + # which raise an issue with + # RuntimeError: "reciprocal_cuda" not implemented for 'Char' + default_generator = partial(elementwise_binary_generator, no_rhs_numbers=True, exclude_zero=True) yield from default_generator(op, device, dtype, requires_grad, **kwargs) # For backward of pow, we need to make sure that when the base is zero, the From 8995ed5fdb9966ae139f5198e7b382037ad0b88d Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Mon, 7 Oct 2024 19:59:47 +0000 Subject: [PATCH 11/15] updating torch issue --- thunder/tests/opinfos.py | 1 + 1 file changed, 1 insertion(+) diff --git a/thunder/tests/opinfos.py b/thunder/tests/opinfos.py index 3628c57db5..4659c4faad 100644 --- a/thunder/tests/opinfos.py +++ b/thunder/tests/opinfos.py @@ -2213,6 +2213,7 @@ def pow_sample_input_generator(op, device, dtype, requires_grad, *, no_rhs_numbe # torch.pow(t_0, t_1) # which raise an issue with # RuntimeError: "reciprocal_cuda" not implemented for 'Char' + # see issue: github.com/pytorch/pytorch/issues/137440 default_generator = partial(elementwise_binary_generator, no_rhs_numbers=True, exclude_zero=True) yield from default_generator(op, device, dtype, requires_grad, **kwargs) From 98c5f929563c3cff0b53772186ad2471f19deb9c Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Mon, 7 Oct 2024 21:03:06 +0000 Subject: [PATCH 12/15] refactor the fix --- thunder/tests/opinfos.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/thunder/tests/opinfos.py b/thunder/tests/opinfos.py index 4659c4faad..696f34f66c 100644 --- a/thunder/tests/opinfos.py +++ b/thunder/tests/opinfos.py @@ -2207,14 +2207,7 @@ def polygamma_sample_input_generator(op, device, dtype, requires_grad, *, no_rhs def pow_sample_input_generator(op, device, dtype, requires_grad, *, no_rhs_numbers: bool = False, **kwargs): - # exclude_zero avoids having - # t_0 = tensor([...], device="cuda", dtype=torch.int8) - # t_1 = tensor(-1, dtype=torch.int8) - # torch.pow(t_0, t_1) - # which raise an issue with - # RuntimeError: "reciprocal_cuda" not implemented for 'Char' - # see issue: github.com/pytorch/pytorch/issues/137440 - default_generator = partial(elementwise_binary_generator, no_rhs_numbers=True, exclude_zero=True) + default_generator = partial(elementwise_binary_generator, no_rhs_numbers=True) yield from default_generator(op, device, dtype, requires_grad, **kwargs) # For backward of pow, we need to make sure that when the base is zero, the @@ -2253,7 +2246,7 @@ def pow_sample_input_generator(op, device, dtype, requires_grad, *, no_rhs_numbe ), # NOTE: PyTorch fails with RuntimeError: "reciprocal_cuda" not implemented for 'Long' occasionally when the exponent is CPU scalar tensor # e.g.: x=torch.tensor([[ 6, 5, 1, -8],], device='cuda:0');y=torch.tensor(-1);torch.pow(x,y) - DecorateInfo(pytest.mark.xfail, "test_core_vs_torch_consistency", dtypes=(datatypes.int32, datatypes.int64)), + DecorateInfo(pytest.mark.xfail, "test_core_vs_torch_consistency", dtypes=(datatypes.int8, datatype.int16, datatypes.int32, datatypes.int64)), ), ) elementwise_binary_ops.append(pow_opinfo) From 0520e42e373d70677c46e78e652fd6de9d4b88aa Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 7 Oct 2024 21:04:02 +0000 Subject: [PATCH 13/15] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- thunder/tests/opinfos.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/thunder/tests/opinfos.py b/thunder/tests/opinfos.py index 696f34f66c..5b5476118e 100644 --- a/thunder/tests/opinfos.py +++ b/thunder/tests/opinfos.py @@ -2246,7 +2246,11 @@ def pow_sample_input_generator(op, device, dtype, requires_grad, *, no_rhs_numbe ), # NOTE: PyTorch fails with RuntimeError: "reciprocal_cuda" not implemented for 'Long' occasionally when the exponent is CPU scalar tensor # e.g.: x=torch.tensor([[ 6, 5, 1, -8],], device='cuda:0');y=torch.tensor(-1);torch.pow(x,y) - DecorateInfo(pytest.mark.xfail, "test_core_vs_torch_consistency", dtypes=(datatypes.int8, datatype.int16, datatypes.int32, datatypes.int64)), + DecorateInfo( + pytest.mark.xfail, + "test_core_vs_torch_consistency", + dtypes=(datatypes.int8, datatype.int16, datatypes.int32, datatypes.int64), + ), ), ) elementwise_binary_ops.append(pow_opinfo) From 30fa4b4a356c00db5cdf43bbcc65c03c659c4d59 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Mon, 7 Oct 2024 21:10:04 +0000 Subject: [PATCH 14/15] typo --- thunder/tests/opinfos.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/thunder/tests/opinfos.py b/thunder/tests/opinfos.py index 696f34f66c..32cb7f1e61 100644 --- a/thunder/tests/opinfos.py +++ b/thunder/tests/opinfos.py @@ -2246,7 +2246,7 @@ def pow_sample_input_generator(op, device, dtype, requires_grad, *, no_rhs_numbe ), # NOTE: PyTorch fails with RuntimeError: "reciprocal_cuda" not implemented for 'Long' occasionally when the exponent is CPU scalar tensor # e.g.: x=torch.tensor([[ 6, 5, 1, -8],], device='cuda:0');y=torch.tensor(-1);torch.pow(x,y) - DecorateInfo(pytest.mark.xfail, "test_core_vs_torch_consistency", dtypes=(datatypes.int8, datatype.int16, datatypes.int32, datatypes.int64)), + DecorateInfo(pytest.mark.xfail, "test_core_vs_torch_consistency", dtypes=(datatypes.int8, datatypes.int16, datatypes.int32, datatypes.int64)), ), ) elementwise_binary_ops.append(pow_opinfo) From 0843c06f3861d97cf1c8b0e927574567ecc3d4c5 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Wed, 9 Oct 2024 02:04:01 +0000 Subject: [PATCH 15/15] addressing review comments --- thunder/core/proxies.py | 3 +-- thunder/tests/test_jit_general.py | 22 ++++------------------ 2 files changed, 5 insertions(+), 20 deletions(-) diff --git a/thunder/core/proxies.py b/thunder/core/proxies.py index 0270e64bdf..181b325230 100644 --- a/thunder/core/proxies.py +++ b/thunder/core/proxies.py @@ -1247,8 +1247,7 @@ def _infer_tensor_properties( thunder_fsdp_padding_size if thunder_fsdp_padding_size is not None else _thunder_fsdp_padding_size ) - # dynamic shape not yet enabled, otherwise, the bake in should be guarded with if not using_symbolic_values(): - # dynamic shape support is currently block by #471 https://github.com/Lightning-AI/lightning-thunder/issues/471 + baseutils.check(_shape is not None, lambda: f"_shape cannot be None when creating TensorProxy") if not using_symbolic_values(): _shape = tuple(pyval(x) for x in _shape) # Computes derived properties diff --git a/thunder/tests/test_jit_general.py b/thunder/tests/test_jit_general.py index 3bd5dfb268..3af1aea5c2 100644 --- a/thunder/tests/test_jit_general.py +++ b/thunder/tests/test_jit_general.py @@ -1441,21 +1441,14 @@ def foo(a): assert thunder.cache_hits(jfoo) == 1 -@pytest.mark.parametrize( - "device", - ("cpu", "cuda"), -) -def test_cache_symbolic_values_reshape_numel(device): - if device == "cuda" and not torch.cuda.is_available(): - pytest.skip("CUDA not available") - +def test_cache_symbolic_values_reshape_numel(): def foo(a): a = torch.reshape(a, [a.numel()]) return a.relu() jfoo = thunder.jit(foo, cache="symbolic values") - a = torch.randn(2, 3, 8, requires_grad=True, device=device) + a = torch.randn(2, 3, 8, requires_grad=True, device="cpu") actual = jfoo(a) expected = foo(a) @@ -1463,21 +1456,14 @@ def foo(a): assert_close(actual, expected) -@pytest.mark.parametrize( - "device", - ("cpu", "cuda"), -) -def test_cache_symbolic_values_slice(device): - if device == "cuda" and not torch.cuda.is_available(): - pytest.skip("CUDA not available") - +def test_cache_symbolic_values_slice(): def foo(a): a = a[..., : a.shape[-1]] return a.relu() jfoo = thunder.jit(foo, cache="symbolic values") - a = torch.randn(2, 3, 8, requires_grad=True, device=device) + a = torch.randn(2, 3, 8, requires_grad=True, device="cpu") actual = jfoo(a) expected = foo(a)