Skip to content

Commit

Permalink
check tensor attrs of tensor wrapper subclasses in prologue
Browse files Browse the repository at this point in the history
also use `pytorch_executor` in the `transform_for_execution` of
`prologue_trace` as it could have the prim of tensor subclass flattening
whose definition is only available in pytorch executor.

Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
  • Loading branch information
crcrpar committed Dec 28, 2024
1 parent 6e6b077 commit f9fede7
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 10 deletions.
2 changes: 1 addition & 1 deletion thunder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,7 +592,7 @@ def get_computation_and_inputs(*args, **kwargs):

prologue_traces += transform_for_execution(
prologue_trc,
executors_list=(pythonex,),
executors_list=(pythonex, pytorch_executor),
use_del_last_used=False,
)
prologue_trc = prologue_traces[-1]
Expand Down
13 changes: 7 additions & 6 deletions thunder/clang/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,16 @@
from thunder.core import utils
import thunder.core.prims as prims
from thunder.core.proxies import (
AnyProxy,
IntegerProxy,
NumberProxy,
NumberLike,
NumberProxy,
Proxy,
SubclassTensorProxy,
TensorProxy,
pyval,
pytype,
proxy,
AnyProxy,
Proxy,
pytype,
pyval,
)
import thunder.core.devices as devices

Expand Down Expand Up @@ -67,7 +68,7 @@ def __call__(self, fn: Callable) -> Callable:

# Checks a tensor's shape and metadata (for use with cache check)
@clangop()
def check_tensor_shape_and_metadata(t: TensorProxy, /) -> None:
def check_tensor_shape_and_metadata(t: TensorProxy | SubclassTensorProxy, /) -> None:
return prims.check_tensor_shape_and_metadata(
t,
# replace Proxy entries with `-1`s as wild card, as we any value is
Expand Down
14 changes: 11 additions & 3 deletions thunder/core/jit_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,7 @@ class JITSharpEdgeError(RuntimeError):
def _general_jit_sharp_edge(desc: str, value: Any, /) -> Any | INTERPRETER_SIGNALS:
sharp_edges: SHARP_EDGES_OPTIONS = get_jit_ctx().sharp_edges

s: str = (
f"{desc} This is currently considered a sharp edge even with interpretation=INTERPRETATION_OPTIONS.TRANSLATE_PYTHON. For cases in which we are overly strict, please file an issue. Thank you!"
)
s: str = f"{desc} This is currently considered a sharp edge even with interpretation=INTERPRETATION_OPTIONS.TRANSLATE_PYTHON. For cases in which we are overly strict, please file an issue. Thank you!"

if sharp_edges is SHARP_EDGES_OPTIONS.ERROR:
return do_raise(JITSharpEdgeError(s))
Expand Down Expand Up @@ -1719,9 +1717,12 @@ def is_variableified_tensorproxy(v: Variable | Proxy) -> Proxy:

with tracectx(prologue_trace):
for prim, *args in ctx._constraints:
subclass_tensor: SubclassTensorProxy | None = None
for a in args:
if isinstance(a, Proxy):
unpack(a)
if isinstance(a, SubclassTensorProxy):
subclass_tensor = a
# unpacking Proxy in TensorProxy.shape which is used in `check_tensor_shape_and_metadata`
if prim == clang.check_tensor_shape_and_metadata:
for s in a.shape:
Expand All @@ -1730,6 +1731,13 @@ def is_variableified_tensorproxy(v: Variable | Proxy) -> Proxy:

prim(*args)

if isinstance(subclass_tensor, SubclassTensorProxy):
for t in prims.flatten_tensor_subclass(subclass_tensor):
for s in t.shape:
if isinstance(s, Proxy):
unpack(s)
prim(t)

cache_info = thunder._get_cache_info()
# assert len of cache info to ensure that we're not missing anything?
if cache_info:
Expand Down

0 comments on commit f9fede7

Please sign in to comment.