From f9fede76ab649527a33bd74022cbe8ab909b6bc1 Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Sat, 28 Dec 2024 13:57:31 +0900 Subject: [PATCH] check tensor attrs of tensor wrapper subclasses in prologue also use `pytorch_executor` in the `transform_for_execution` of `prologue_trace` as it could have the prim of tensor subclass flattening whose definition is only available in pytorch executor. Signed-off-by: Masaki Kozuki --- thunder/__init__.py | 2 +- thunder/clang/__init__.py | 13 +++++++------ thunder/core/jit_ext.py | 14 +++++++++++--- 3 files changed, 19 insertions(+), 10 deletions(-) diff --git a/thunder/__init__.py b/thunder/__init__.py index 5c915cd714..5524b796f3 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -592,7 +592,7 @@ def get_computation_and_inputs(*args, **kwargs): prologue_traces += transform_for_execution( prologue_trc, - executors_list=(pythonex,), + executors_list=(pythonex, pytorch_executor), use_del_last_used=False, ) prologue_trc = prologue_traces[-1] diff --git a/thunder/clang/__init__.py b/thunder/clang/__init__.py index 74b1456149..270b543898 100644 --- a/thunder/clang/__init__.py +++ b/thunder/clang/__init__.py @@ -20,15 +20,16 @@ from thunder.core import utils import thunder.core.prims as prims from thunder.core.proxies import ( + AnyProxy, IntegerProxy, - NumberProxy, NumberLike, + NumberProxy, + Proxy, + SubclassTensorProxy, TensorProxy, - pyval, - pytype, proxy, - AnyProxy, - Proxy, + pytype, + pyval, ) import thunder.core.devices as devices @@ -67,7 +68,7 @@ def __call__(self, fn: Callable) -> Callable: # Checks a tensor's shape and metadata (for use with cache check) @clangop() -def check_tensor_shape_and_metadata(t: TensorProxy, /) -> None: +def check_tensor_shape_and_metadata(t: TensorProxy | SubclassTensorProxy, /) -> None: return prims.check_tensor_shape_and_metadata( t, # replace Proxy entries with `-1`s as wild card, as we any value is diff --git a/thunder/core/jit_ext.py b/thunder/core/jit_ext.py index e9f41b3fc3..aceb30d102 100644 --- a/thunder/core/jit_ext.py +++ b/thunder/core/jit_ext.py @@ -146,9 +146,7 @@ class JITSharpEdgeError(RuntimeError): def _general_jit_sharp_edge(desc: str, value: Any, /) -> Any | INTERPRETER_SIGNALS: sharp_edges: SHARP_EDGES_OPTIONS = get_jit_ctx().sharp_edges - s: str = ( - f"{desc} This is currently considered a sharp edge even with interpretation=INTERPRETATION_OPTIONS.TRANSLATE_PYTHON. For cases in which we are overly strict, please file an issue. Thank you!" - ) + s: str = f"{desc} This is currently considered a sharp edge even with interpretation=INTERPRETATION_OPTIONS.TRANSLATE_PYTHON. For cases in which we are overly strict, please file an issue. Thank you!" if sharp_edges is SHARP_EDGES_OPTIONS.ERROR: return do_raise(JITSharpEdgeError(s)) @@ -1719,9 +1717,12 @@ def is_variableified_tensorproxy(v: Variable | Proxy) -> Proxy: with tracectx(prologue_trace): for prim, *args in ctx._constraints: + subclass_tensor: SubclassTensorProxy | None = None for a in args: if isinstance(a, Proxy): unpack(a) + if isinstance(a, SubclassTensorProxy): + subclass_tensor = a # unpacking Proxy in TensorProxy.shape which is used in `check_tensor_shape_and_metadata` if prim == clang.check_tensor_shape_and_metadata: for s in a.shape: @@ -1730,6 +1731,13 @@ def is_variableified_tensorproxy(v: Variable | Proxy) -> Proxy: prim(*args) + if isinstance(subclass_tensor, SubclassTensorProxy): + for t in prims.flatten_tensor_subclass(subclass_tensor): + for s in t.shape: + if isinstance(s, Proxy): + unpack(s) + prim(t) + cache_info = thunder._get_cache_info() # assert len of cache info to ensure that we're not missing anything? if cache_info: