diff --git a/thunder/__init__.py b/thunder/__init__.py index eceacc4cad..6533ac8f44 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -73,6 +73,7 @@ from thunder.core.interpreter import print_interpreter_log, print_to_log from thunder.core.jit_ext import thunder_general_jit from thunder.executors.torch_autograd import split_forward_backward, ThunderFunction +from thunder.transforms.tensor_subclasses import flatten_tensor_subclasses # NOTE This import is intentionally pytorch so that it thunder.torch doesn't import this import torch as pytorch @@ -370,7 +371,7 @@ def _alias_tensor_of_args_kwargs_dict(*args, **kwargs) -> dict[int, list[int]]: data_ptr_to_tensor_group_index = {} tensor_group_index_to_tensor_indices = defaultdict(list) for idx, t in enumerate(flat_args): - if pytorch.is_tensor(t) and t.layout == pytorch.strided: + if type(t) in {pytorch.Tensor, pytorch.nn.Parameter} and t.layout == pytorch.strided: data_ptr = t.untyped_storage().data_ptr() if data_ptr not in data_ptr_to_tensor_group_index: data_ptr_to_tensor_group_index[data_ptr] = len(data_ptr_to_tensor_group_index) diff --git a/thunder/core/jit_ext.py b/thunder/core/jit_ext.py index 726353983a..b84cc904bb 100644 --- a/thunder/core/jit_ext.py +++ b/thunder/core/jit_ext.py @@ -62,6 +62,7 @@ NumberProxy, StringProxy, TensorProxy, + SubclassTensorProxy, FutureTensorProxy, make_proxy_name, Variable, @@ -863,6 +864,42 @@ def grad_transform(*args, **kwargs): return output +@register_general_jit_lookaside(torch.Tensor._make_wrapper_subclass) +def _make_wrapper_subclass( + cls: torch._C._TensorMeta, + size: Sequence[int], + strides: Sequence[int] | None = None, + storage_offset: int | None = None, + memory_format: torch.memory_format | None = None, + dtype: torch.dtype | None = None, + layout: torch.layout | None = torch.strided, + device: torch.device | None = None, + pin_memory: bool = False, + requires_grad: bool = False, + dispatch_sizes_strides_policy: str | None = None, + dispatch_device: bool = False, + dispatch_layout: bool = False, + _extra_dispatch_keys: torch.DispatchKeySet | None = None, + storage_size: int | None = None, +): + ucls = unwrap(cls) + usize = unwrap(size) + udtype = unwrap(dtype) + udevice = unwrap(device) + urequires_grad = unwrap(requires_grad) + + subclass = SubclassTensorProxy( + None, + shape=usize, + device=udevice, + dtype=udtype, + requires_grad=urequires_grad, + history=ProvenanceRecord(PseudoInst.LOOKASIDE, [cls.provenance]), + subclass_type=ucls, + ) + return wrap(subclass, provenance=ProvenanceRecord(PseudoInst.LOOKASIDE, [cls.provenance])) + + @register_general_jit_lookaside(torch.autocast.__enter__) def autocast_enter(autocast_obj): unwrap_autocast_obj = unwrap(autocast_obj) diff --git a/thunder/core/prims.py b/thunder/core/prims.py index da46147155..7a7d194796 100644 --- a/thunder/core/prims.py +++ b/thunder/core/prims.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from enum import auto, Enum from numbers import Number from functools import reduce, wraps @@ -77,6 +79,7 @@ def register_method(method_name: str, method: Callable, /) -> None: TupleProxy, AnyProxy, IntegerProxy, + SubclassTensorProxy, ) import thunder.core.codeutils as codeutils from thunder.core.codeutils import Printable @@ -272,6 +275,10 @@ class PrimIDs(Enum): COPY_ = auto() # SINK = auto() + # Tensor Subclasses methods + TENSOR_SUBCLASS_CTOR = auto() + FLATTEN_TENSOR_SUBCLASS = auto() + UNFLATTEN_TENSOR_SUBCLASS = auto() class OpTags(Enum): @@ -4048,3 +4055,77 @@ def sink_meta(*args, **kwargs): # TODO do we want another tag to remove this after prologue is constructed? sink = make_prim(PrimIDs.SINK, "sink", meta=sink_meta, tags=(OpTags.DONT_DCE,)) + + +def tensor_subclass_ctor_meta( + cls, name, shape, device, dtype, requires_grad, tensors, non_tensors +) -> SubclassTensorProxy: + s = SubclassTensorProxy( + name, + subclass_type=cls, + shape=shape, + device=device, + dtype=dtype, + requires_grad=requires_grad, + tensors=tensors, + non_tensors=non_tensors, + history=[t.history for t in tensors], + ) + return s + + +tensor_subclass_ctor = make_prim( + PrimIDs.TENSOR_SUBCLASS_CTOR, + "tensor_subclass_ctor", + meta=tensor_subclass_ctor_meta, +) + + +def flatten_tensor_subclass_meta(t: SubclassTensorProxy) -> tuple[TensorProxy, ...]: + tensor_attr_names, metadata = t.__tensor_flatten__() + tensors = tuple(getattr(t, name) for name in tensor_attr_names) + return tensors + + +flatten_tensor_subclass = make_prim( + PrimIDs.FLATTEN_TENSOR_SUBCLASS, + "flatten_tensor_subclass", + meta=flatten_tensor_subclass_meta, +) + + +def unflatten_tensor_subclass_meta( + tensor_subclass_type, + inner_tensors: dict[str, TensorProxy], + metadata: dict[str, Any], +) -> SubclassTensorProxy: + first_tensor: TensorProxy = list(inner_tensors.values())[0] + a = SubclassTensorProxy( + shape=first_tensor.shape, + device=first_tensor.device, + dtype=first_tensor.dtype, + requires_grad=first_tensor.requires_grad, + tensors=list(inner_tensors.values()), + non_tensors=list(metadata.values()), + subclass_type=tensor_subclass_type, + ) + for name, value in inner_tensors.items(): + setattr(a, name, value) + for name, value in metadata.items(): + setattr(a, name, value) + return a + + +def unflatten_tensor_subclass_python_impl( + tensor_subclass_type, + inner_tensors: dict[str, TensorProxy], + metadata: dict[str, Any], +) -> torch.Tensor: + return tensor_subclass_type.__tensor_unflatten__(inner_tensors, metadata, -1, -1) + + +unflatten_tensor_subclass = make_prim( + PrimIDs.UNFLATTEN_TENSOR_SUBCLASS, + "unflatten_tensor_subclass", + meta=unflatten_tensor_subclass_meta, +) diff --git a/thunder/core/proxies.py b/thunder/core/proxies.py index 2f2eb1c665..9799ea0561 100644 --- a/thunder/core/proxies.py +++ b/thunder/core/proxies.py @@ -3,7 +3,7 @@ import copy from enum import auto, Enum from numbers import Number -from typing import Any +from typing import Any, ClassVar from collections.abc import Callable from collections.abc import Sequence @@ -1880,6 +1880,154 @@ def real(self): return method(self) +class SubclassTensorProxy(TensorProxy): + SUBCLASS_TYPE_ATTR: ClassVar[str] = "_subclass_type" + _tensors: list[TensorProxy] + _non_tensors: list[Any] + _subclass_type: torch._C._TensorMeta + + _tensor_attr_names: list[str] | None + _non_tensor_attr_names: list[str] | None + + def __init__(self, *args, **kwargs): + from thunder.core.pytree import tree_flatten + + kwarg_tensors = kwargs.pop("tensors", []) + kwarg_non_tensors = kwargs.pop("non_tensors", []) + subclass_type = kwargs.pop("subclass_type", None) + + has_name_before_init = hasattr(self, "_name") + # If tensors (and non_tensors) are not empty, then it should be the path of `_make_wrapper_subclass` + # where `self` should already have gotten its name. + flat_args, spec = tree_flatten((args, kwargs)) + tensors: list[TensorProxy] = [] + non_tensors: list[Any] = [] + for t in args + tuple(kwargs.values()): + if type(t) is SubclassTensorProxy: + continue + if type(t) is TensorProxy: + tensors.append(t) + else: + non_tensors.append(t) + + is_dunder_init_following_make_wrapper_subclass: bool = False + if tensors: + baseutils.check( + has_name_before_init + and not kwarg_tensors + and not kwarg_non_tensors + and self._subclass_type is not None, + lambda: ( + f"{flat_args=} indicates this instance is created by" + "`torch.Tensor._make_wrapper_subclass`'s lookaside but `name` is not set" + ), + ) + is_dunder_init_following_make_wrapper_subclass = True + + if not is_dunder_init_following_make_wrapper_subclass: + super().__init__(*args, **kwargs) + + self._tensors = kwarg_tensors + self._non_tensors = kwarg_non_tensors + self._subclass_type = subclass_type + else: + # TODO(crcrpar): Think about materializing `self` so that we can + # call `__tensor_init__` to know each attribute names. + from thunder.core import prims + + bsym = prims.tensor_subclass_ctor.bind( + self._subclass_type, + self.name, + self.shape, + self.device, + self.dtype, + self.requires_grad, + self._tensors, + self._non_tensors, + output=self, + ) + # NOTE(crcrpar): A callable being `thunder.jit`ed can call `MySubclassTensor(...)` + # inside of it either directly or indirectly: indirect way is to call it through + # a custom `torch.autograd.Function` as in + # https://github.com/pytorch/ao/blob/000a490/torchao/float8/float8_tensor.py#L139-L209. + # If it's a direct call, `trace.bound_symbols` and `trace.scopes[-1]` are identical, + # but not, otherwise. As [the lookasdie of `torch.autograd.Function`]( + # https://github.com/Lightning-AI/lightning-thunder/blob/3d42c10/thunder/core/jit_ext.py#L655) + # puts the temporary scope to the current trace. + current_trace = get_tracectx() + if id(current_trace.bound_symbols) == id(cur_tail_scope := current_trace.scopes[-1]): + current_trace.add_bound_symbol(bsym) + else: + cur_tail_scope.append(bsym) + + if not self._tensors and not self._non_tensors: + for a in tensors: + self._tensors.append(a) + for a in non_tensors: + self._non_tensors.append(a) + baseutils.check(self._tensors, lambda: f"`{self._name}._tensors` must not be empty") + + def __tensor_flatten__(self) -> tuple[list[TensorProxy], dict[str, Any]]: + return self._tensor_attr_names, self.metadata + + @property + def metadata(self) -> dict[str, Any]: + return dict(zip(self._non_tensor_attr_names, self._non_tensors)) + + def replace(self, **changes): + r"""Return a copy of the SubclassTensorProxy object with new values for the specified fields as given to the constructor as arguments. + Valid keyword arguments are ``name``, ``history``, ``shape``, ``dtype``, ``device``, ``requires_grad``, ``distparallel_type``, ``thunder_fsdp_padding_size``. + ``like`` is also a valid keyword and will take metadata from the tensor proxy argument + in preference to the old values but overridable by keyword arguments. + Note that the copy will use the current (environment) tracectx.""" + + like = changes.get("like") + ( + shape, + device, + dtype, + true_dtype, + numel, + ndim, + requires_grad, + grad, + distparallel_type, + thunder_fsdp_padding_size, + ) = _infer_tensor_properties( + like, + changes.get("shape", self._shape if like is None else None), + changes.get("device", self._device if like is None else None), + changes.get("dtype", self._dtype if like is None else None), + changes.get("requires_grad", self._requires_grad if like is None else None), + changes.get("grad", self._grad if like is None else None), + changes.get("distparallel_type", self._distparallel_type if like is None else None), + changes.get("thunder_fsdp_padding_size", self._thunder_fsdp_padding_size if like is None else None), + ) + name = changes.get("name", self.name) + history = changes.get("history", self.history) + tags = changes.get("tags", self.tags) + p = SubclassTensorProxy( + name=name, + tags=tags, + shape=shape, + device=device, + dtype=dtype, + requires_grad=requires_grad, + distparallel_type=distparallel_type, + thunder_fsdp_padding_size=thunder_fsdp_padding_size, + history=history, + tensors=self._tensors, + non_tensors=self._non_tensors, + subclass_type=self._subclass_type, + ) + if hasattr(self, "_tensor_attr_names") and hasattr(self, "_non_tensor_attr_names"): + p._tensor_attr_names = self._tensor_attr_names + p._non_tensor_attr_names = self._non_tensor_attr_names + for name, value in zip(p._tensor_attr_names + p._non_tensor_attr_names, p._tensors + p._non_tensors): + setattr(p, name, value) + return p + + class TorchAutogradFunctionCtxProxy(Proxy, TorchAutogradFunctionCtxProxyInterface): def __init__( self, @@ -1951,6 +2099,7 @@ def __setattr__(self, name, value): # TODO: move this function to jit_ext.py def tensorproxy(t: torch.Tensor, /, *, name: None | str, history: None | tuple = None) -> TensorProxy: + from torch._subclasses.fake_tensor import FakeTensor from thunder.core.interpreter import ProvenanceRecord, PseudoInst, wrap_const if hasattr(t, "_thunder_device"): @@ -1985,8 +2134,8 @@ def tensorproxy(t: torch.Tensor, /, *, name: None | str, history: None | tuple = else: # NOTE Without tuple(t.shape) then the shape would be a torch.Size object shape = tuple(t.shape) - return TensorProxy( - name, + ctor_kwargs = dict( + name=name, shape=tuple(shape), device=device, dtype=dtype, @@ -1997,6 +2146,40 @@ def tensorproxy(t: torch.Tensor, /, *, name: None | str, history: None | tuple = thunder_fsdp_padding_size=_thunder_fsdp_padding_size, ) + # n.b.(crcrpar): :class:`thunder.dynamo.ThunderCompiler.__call__` takes torch.fx GraphModule + # where `FakeTensor` seems to be used, leading to failures observed in e.g. + # https://github.com/Lightning-AI/lightning-thunder/actions/runs/11689709564/job/32553053319#step:10:5747 + # https://dev.azure.com/Lightning-AI/lightning/_build/results?buildId=219328&view=logs&jobId=5b0799f7-725e-5b16-9b83-c0a5a25d03f0&j=5b0799f7-725e-5b16-9b83-c0a5a25d03f0 + if ( + isinstance(t, torch.Tensor) + and type(t) not in (torch.Tensor, torch.nn.Parameter, FakeTensor) + and hasattr(t, "__tensor_flatten__") + and hasattr(t, "__tensor_unflatten__") + ): + baseutils.check( + hasattr(t, "__tensor_flatten__") and hasattr(t, "__tensor_unflatten__"), + lambda: f"{t=} seems to be a tensor subclass but not traceable", + ) + tensor_attr_names, metadata = t.__tensor_flatten__() + tensors = [tensorproxy(getattr(t, name), name=None, history=history) for name in tensor_attr_names] + ctor_kwargs.update( + { + "tensors": tensors, + "non_tensors": list(metadata.values()), + "subclass_type": type(t), + } + ) + p = SubclassTensorProxy(**ctor_kwargs) + p._tensor_attr_names = tensor_attr_names + p._non_tensor_attr_names = list(metadata.keys()) + for name, tensor in zip(tensor_attr_names, tensors): + setattr(p, name, tensor) + for name, value in metadata.items(): + setattr(p, name, value) + return p + else: + return TensorProxy(**ctor_kwargs) + def futuretensorproxy( t: torch.Tensor | TensorProxy | FutureTensorProxy, /, *, name: None | str, history: None | tuple = None diff --git a/thunder/executors/torchex.py b/thunder/executors/torchex.py index 6de644204d..9ad2cf4107 100644 --- a/thunder/executors/torchex.py +++ b/thunder/executors/torchex.py @@ -2220,3 +2220,53 @@ def _shape_impl(t): shallow_copy = ex.register_operator("shallow_copy", meta=prims.shallow_copy, fn=lambda x: x) _register_implementation(prims.shallow_copy, shallow_copy, checker=_always_executable) + + +def _tensor_subclass_ctor(cls, name, shape, device, dtype, requires_grad, tensors, non_tensors): + return cls(*tensors, *non_tensors) + + +tensor_subclass_ctor = ex.register_operator( + "tensor_subclass_ctor", + meta=prims.tensor_subclass_ctor, + fn=_tensor_subclass_ctor, +) +_register_implementation(prims.tensor_subclass_ctor, tensor_subclass_ctor, checker=_always_executable) + + +def flatten_tensor_subclass_impl(t): + tensor_attr_names, metadata = t.__tensor_flatten__() + tensors = tuple(getattr(t, name) for name in tensor_attr_names) + return tensors + + +flatten_tensor_subclass = ex.register_operator( + "flatten_tensor_subclass", + meta=prims.flatten_tensor_subclass.meta, + fn=flatten_tensor_subclass_impl, +) +_register_implementation( + prims.flatten_tensor_subclass, + flatten_tensor_subclass, + checker=_always_executable, +) + + +def unflatten_tensor_subclass_impl( + tensor_subclass_type: torch._C._TensorMeta, + inner_tensors: dict[str, TensorLike], + metadata: dict, +): + return tensor_subclass_type.__tensor_unflatten__(inner_tensors, metadata, -1, -1) + + +unflatten_tensor_subclass = ex.register_operator( + "unflatten_tensor_subclass", + meta=prims.unflatten_tensor_subclass.meta, + fn=unflatten_tensor_subclass_impl, +) +_register_implementation( + prims.unflatten_tensor_subclass, + unflatten_tensor_subclass, + checker=_always_executable, +) diff --git a/thunder/tests/test_tensor_subclass.py b/thunder/tests/test_tensor_subclass.py new file mode 100644 index 0000000000..17f3cf0e07 --- /dev/null +++ b/thunder/tests/test_tensor_subclass.py @@ -0,0 +1,218 @@ +from __future__ import annotations +from typing import TYPE_CHECKING + +import torch +from torch.utils import _pytree as pytree + +import thunder +from thunder.core.proxies import SubclassTensorProxy +from thunder.tests.framework import instantiate +from thunder.tests.make_tensor import make_tensor + +if TYPE_CHECKING: + from typing import Any + from thunder.core.symbol import BoundSymbol + + +@torch._dynamo.allow_in_graph +class EncapsulateXandScale(torch.autograd.Function): + @staticmethod + def forward(ctx, x: torch.Tensor, scale: torch.Tensor): + return ScaleTensorSubclass(x, scale) + + @staticmethod + def backward(ctx, grad): + return grad, None + + +def encapsulate_x_and_scale(x, scale) -> ScaleTensorSubclass: + return EncapsulateXandScale.apply(x, scale) + + +@torch._dynamo.allow_in_graph +class ToScaleTensorSubclass(torch.autograd.Function): + @staticmethod + def forward(ctx, x: torch.Tensor): + return ScaleTensorSubclass.from_tensor(x) + + @staticmethod + def backward(ctx, grad): + return grad + + +def to_scale_tensor_subclass(x: torch.Tensor) -> ScaleTensorSubclass: + return ToScaleTensorSubclass.apply(x) + + +class ScaleTensorSubclass(torch.Tensor): + _x: torch.Tensor + _scale: torch.Tensor + __slots__ = ["_x", "_scale"] + + def __new__(cls, x: torch.Tensor, scale: torch.Tensor): + assert scale.numel() == 1, f"Invalid `scale`: {scale}" + dtype = x.dtype + device = x.device + self = torch.Tensor._make_wrapper_subclass( + cls, + x.size(), + dtype=dtype, + device=device, + # strides=x.stride(), + # storage_offset=x.storage_offset(), + # layout=x.layout, + # requires_grad=x.requires_grad, + ) + self._x = x + self._scale = scale + + return self + + # ref: https://github.com/albanD/subclass_zoo/blob/ec47458/base_tensor.py#L22 + __torch_function__ = torch._C._disabled_torch_function_impl + + def __repr__(self): + return f"ScaleTensorSubclass(dtype={self._x.dtype}, device={self._x.device}, x={self._x}, scale={self._scale})" + + def __tensor_flatten__(self) -> tuple[list[str], dict[str, Any]]: + return ["_x", "_scale"], {} + + @staticmethod + def __tensor_unflatten__( + inner_tensors: dict[str, torch.Tensor], + metadata: dict[str, Any], + outer_size, + outer_stride, + ) -> ScaleTensorSubclass: + return ScaleTensorSubclass(inner_tensors["_x"], inner_tensors["_scale"]) + + @staticmethod + def from_tensor(x: torch.Tensor) -> ScaleTensorSubclass: + scale = x.abs().max() + return ScaleTensorSubclass(x, scale) + + @classmethod + def __torch_dispatch__(cls, aten_ir_op: torch._ops.OpOverload, types, args=(), kwargs=None): + + def allowed_subclass(typ): + return ( + issubclass(cls, typ) + or issubclass(torch._subclasses.FakeTensor, typ) + or issubclass(torch._subclasses.functional_tensor.FunctionalTensor, typ) + ) + + def maybe_unwrap_and_scale(t: ScaleTensorSubclass | Any): + if isinstance(t, ScaleTensorSubclass): + if t.is_floating_point(): + return t._x * t._scale + else: + return t._x + return t + + if not all(allowed_subclass(t) for t in types): + return NotImplementedError(f"Unsupported types are included: {types}") + + scales = tuple(t._scale for t in pytree.tree_flatten((args, kwargs))[0] if isinstance(t, ScaleTensorSubclass)) + unwrapped_args, unwrapped_kwargs = pytree.tree_map(maybe_unwrap_and_scale, (args, kwargs)) + out = aten_ir_op(*unwrapped_args, **unwrapped_kwargs) + if not isinstance(out, torch.Tensor): + return out + else: + return ScaleTensorSubclass(out, scales[0]) + + +@instantiate( + dtypes=(thunder.core.dtypes.float32,), +) +def test_func_of_subclass_ctor_wrapper(executor, device, _): + + def f(x: torch.Tensor, scale: torch.Tensor) -> ScaleTensorSubclass: + y = ScaleTensorSubclass(x, scale) + return y + + jitted = executor.make_callable(f) + + dtype = torch.float32 + shape = (2, 2) + x = make_tensor(shape, device=device, dtype=dtype) + scale = make_tensor((), device=device, dtype=dtype) + + expected = f(x, scale) + actual = jitted(x, scale) + torch.testing.assert_close((expected._x, expected._scale), (actual._x, actual._scale)) + + def f(x: torch.Tensor, scale: torch.Tensor): + y = ScaleTensorSubclass(x, scale) + z = ScaleTensorSubclass(y._x, y._scale) + return z + + jitted = executor.make_callable(f) + + expected = f(x, scale) + actual = jitted(x, scale) + torch.testing.assert_close((expected._x, expected._scale), (actual._x, actual._scale)) + + +@instantiate( + dtypes=(thunder.core.dtypes.float32,), +) +def test_func_calling_converter(executor, device, _): + + def f(x: torch.Tensor, scale: torch.Tensor) -> ScaleTensorSubclass: + y = encapsulate_x_and_scale(x, scale) + return y + + jitted = executor.make_callable(f) + + dtype = torch.float32 + shape = (2, 2) + + x = make_tensor(shape, device=device, dtype=dtype) + scale = make_tensor((), device=device, dtype=dtype) + + expected = f(x, scale) + actual = jitted(x, scale) + torch.testing.assert_close((expected._x, expected._scale), (actual._x, actual._scale)) + + def g(x: torch.Tensor) -> ScaleTensorSubclass: + y = to_scale_tensor_subclass(x) + return y + + jitted = thunder.jit(g) + x = make_tensor(shape, device=device, dtype=dtype) + + expected = g(x) + actual = jitted(x) + torch.testing.assert_close((expected._x, expected._scale), (actual._x, actual._scale)) + + +@instantiate( + dtypes=(thunder.core.dtypes.float32,), +) +def test_func_of_subclass_simple_math(executor, device, _): + + def f(x: ScaleTensorSubclass, data: torch.Tensor, scale: torch.Tensor) -> ScaleTensorSubclass: + + y = ScaleTensorSubclass(data, scale) + out = x + y + return out + + jitted = executor.make_callable(f) + + dtype = torch.float32 + shape = (2, 2) + x = ScaleTensorSubclass( + make_tensor(shape, device=device, dtype=dtype), + make_tensor((), device=device, dtype=dtype), + ) + data = make_tensor(shape, device=device, dtype=dtype) + scale = make_tensor((), device=device, dtype=dtype) + + expected = f(x, data, scale) + actual = jitted(x, data, scale) + assert type(expected) is type(actual) + torch.testing.assert_close((expected._x, expected._scale), (actual._x, actual._scale)) + + return_bsym: BoundSymbol = thunder.last_traces(jitted)[-1].bound_symbols[-1] + return_proxy = return_bsym.flat_args[0] + assert isinstance(return_proxy, SubclassTensorProxy) diff --git a/thunder/transforms/tensor_subclasses.py b/thunder/transforms/tensor_subclasses.py new file mode 100644 index 0000000000..9d3a4adfa8 --- /dev/null +++ b/thunder/transforms/tensor_subclasses.py @@ -0,0 +1,514 @@ +from __future__ import annotations +from dataclasses import dataclass +from dataclasses import field +from numbers import Number +from typing import TYPE_CHECKING, NamedTuple + +import torch +from torch.fx.experimental.proxy_tensor import make_fx +from torch._dispatch.python import enable_python_dispatcher +from torch._subclasses.fake_tensor import FakeTensor +from torch._subclasses.fake_tensor import FakeTensorMode +from torch._subclasses.functional_tensor import FunctionalTensorMode +from torch.utils._python_dispatch import is_traceable_wrapper_subclass + +from thunder.core.codeutils import SigInfo +from thunder.core import devices +from thunder.core import dtypes +from thunder.core import prims +from thunder.core import utils +from thunder.core.proxies import ProxyInterface +from thunder.core.proxies import SubclassTensorProxy +from thunder.core.proxies import TensorProxy +from thunder.core.proxies import Variable +from thunder.core.proxies import variableify +from thunder.core.pytree import tree_flatten +from thunder.core.pytree import tree_map +from thunder.core.pytree import tree_unflatten +from thunder.core.trace import TraceCtx +from thunder.core.trace import TraceProvenance +from thunder.core.trace import from_trace +from thunder.core.trace import tracectx +from thunder.executors.passes import transform_for_execution +from thunder.extend import get_executor + +if TYPE_CHECKING: + from collections.abc import Sequence + from typing import Any + from optree import PyTreeSpec + from torch.fx import GraphModule + from torch.fx import Node + from torch._ops import OpOverload + from thunder.core.symbol import Symbol, BoundSymbol + from torch._C import _TensorMeta + + +__all__ = [ + "flatten_tensor_subclasses", +] + + +PLACEHOLDER: str = "placeholder" +CALL_FUNCTION: str = "call_function" +OUTPUT: str = "output" + + +class OutputWrapperForFxTracing(NamedTuple): + inner_tensors: dict[str, torch.Tensor] | torch.Tensor + metadata: dict[str, Any] | None + + +def _materialize_tensor_proxy(t: TensorProxy, fake_tensor_mode: FakeTensorMode | None) -> torch.Tensor: + shape = t.shape + device = devices.to_torch_device(t.device) + dtype = dtypes.to_torch_dtype(t.dtype) + requires_grad = t.requires_grad + + with torch.device("meta"): + t = torch.empty(shape, dtype=dtype, requires_grad=requires_grad) + if fake_tensor_mode is None: + return t + fakified_empty_tensor = fake_tensor_mode.fake_tensor_converter.from_meta_and_device( + fake_mode=fake_tensor_mode, t=t, device=device + ) + return fakified_empty_tensor + + +def _make_fake_subclass_tensor_from_subclass_tensor_proxy( + tensor_proxy: SubclassTensorProxy, + fake_tensor_mode: FakeTensorMode, +) -> torch.Tensor: + utils.check( + (subclass_type := getattr(tensor_proxy, SubclassTensorProxy.SUBCLASS_TYPE_ATTR, None)) is not None, + lambda: f"{tensor_proxy} does not have `{SubclassTensorProxy.SUBCLASS_TYPE_ATTR}`", + ) + utils.check( + tensor_proxy._tensors, + lambda: f"{tensor_proxy} has an empty `{tensor_proxy._tensors=}`", + ) + tensor_attr_names = tensor_proxy._tensor_attr_names + non_tensor_attr_names = tensor_proxy._non_tensor_attr_names + inner_tensors = dict( + zip( + tensor_attr_names, + [_materialize_tensor_proxy(t, fake_tensor_mode=fake_tensor_mode) for t in tensor_proxy._tensors], + ) + ) + metadata = dict(zip(non_tensor_attr_names, tensor_proxy._non_tensors)) + subclass_tensor = subclass_type.__tensor_unflatten__( + inner_tensors, + metadata, + outer_size=-1, + outer_stride=-1, + ) + fakified = fake_tensor_mode.from_tensor(subclass_tensor, static_shapes=True) + return fakified + + +def materialize_tensor_proxy( + t: TensorProxy | SubclassTensorProxy, + fake_tensor_mode: FakeTensorMode, +) -> torch.Tensor: + if isinstance(t, SubclassTensorProxy): + return _make_fake_subclass_tensor_from_subclass_tensor_proxy(t, fake_tensor_mode) + return _materialize_tensor_proxy(t, fake_tensor_mode) + + +def maybe_materialize_tensor( + t: ProxyInterface, + fake_tensor_mode: FakeTensorMode, +) -> ProxyInterface | torch.Tensor: + if isinstance(t, (TensorProxy, SubclassTensorProxy)): + return materialize_tensor_proxy(t, fake_tensor_mode) + if isinstance(t, (Number, str)): + return t + return t.value + + +def proxy_fake_tensor(t: torch.Tensor | FakeTensor) -> ProxyInterface: + if isinstance(t, FakeTensor) or (isinstance(t, torch.Tensor) and not issubclass(type(t), torch.Tensor)): + return TensorProxy( + None, + shape=list(t.shape), + dtype=dtypes.to_dtype(t.dtype), + device=devices.to_device(t.device), + requires_grad=t.requires_grad, + ) + if torch.utils._python_dispatch.is_traceable_wrapper_subclass(t): + tensor_attr_names, metadata = t.__tensor_flatten__() + tensor_proxies = [proxy_fake_tensor(getattr(t, name)) for name in tensor_attr_names] + non_tensor_attr_names = list(metadata.keys()) + non_tensors = list(metadata.values()) + p = SubclassTensorProxy( + None, + shape=list(t.shape), + dtype=dtypes.to_dtype(t.dtype), + device=devices.to_device(t.device), + requires_grad=t.requires_grad, + tensors=tensor_proxies, + non_tensors=non_tensors, + subclass_type=type(t), + ) + p._tensor_attr_names = tensor_attr_names + p._non_tensor_attr_names = non_tensor_attr_names + for name, value in zip(tensor_attr_names + non_tensor_attr_names, tensor_proxies + non_tensors): + setattr(p, name, value) + return p + return t + + +def trace_from_bsym_or_bsyms(bsym_or_bsyms: BoundSymbol | Sequence[BoundSymbol]) -> TraceCtx: + bsyms = utils.sequencify(bsym_or_bsyms) + + trace = TraceCtx() + trace.bound_symbols.extend(bsyms) + trace.args = bsyms[0].flat_proxy_args + with tracectx(trace): + prims.python_return(bsyms[-1].output) + with tracectx(trace): + # note(crcrpar): Give prefix `tmp` to avoid infinite recursion due to the same name + trace._siginfo = SigInfo.from_name_and_args(f"tmp_{bsyms[0].sym.name}", trace.args) + return trace + + +def aten_core_ir_op_to_ltorch_op(aten_op: OpOverload) -> Symbol: + import thunder.torch as ltorch + + op_name_without_overload = aten_op._opname + utils.check( + hasattr(ltorch, op_name_without_overload), + lambda: f"{aten_op=} cannot find an appropriate ltorch op. Query: {op_name_without_overload}", + ) + return getattr(ltorch, op_name_without_overload) + + +@dataclass +class DesugarTensorSubclass: + computation_trace: TraceCtx + swap_map: dict[Variable, ProxyInterface] = field(init=False, default_factory=dict) + fake_tensor_mode: FakeTensorMode = field(init=False, default_factory=FakeTensorMode) + flat_trace_args: Sequence[ProxyInterface] = field(init=False, default=None) + flat_trace_args_spec: Any = field(init=False, default=None) + subclass_proxy_to_flatten: set[Variable] = field(init=False, default_factory=set) + bsym_to_new_outputs: dict[BoundSymbol, list[TensorProxy]] = field(init=False, default_factory=dict) + + def __post_init__(self) -> None: + self.flat_trace_args, self.flat_trace_args_spec = tree_flatten( + (self.computation_trace.args, self.computation_trace.kwargs) + ) + for arg in self.flat_trace_args: + if isinstance(arg, SubclassTensorProxy): + self.subclass_proxy_to_flatten.add(variableify(arg)) + + def _get_tensor_attr_names(self, p: SubclassTensorProxy) -> list[str]: + return p._tensor_attr_names + + def _get_non_tensor_attr_names(self, p: SubclassTensorProxy) -> list[str]: + return p._non_tensor_attr_names + + def translate_fx_graph_into_bsym( + self, + bsym: BoundSymbol, + fx_graph: GraphModule, + ) -> BoundSymbol | tuple[BoundSymbol, ...]: + import thunder.torch as ltorch + + unwrapped_bsym_args: dict[int, ProxyInterface] = {} + list_of_unflatten_bsym: list[BoundSymbol] = [] + for a in bsym.flat_args: + if isinstance(a, SubclassTensorProxy): + if variableify(a) in self.subclass_proxy_to_flatten: + self.computation_trace.push_scope([]) + with tracectx(self.computation_trace): + prims.flatten_tensor_subclass(a) + unflatten_bsym = self.computation_trace.pop_scope()[0] + list_of_unflatten_bsym.append(unflatten_bsym) + tensor_attr_names = self._get_tensor_attr_names(a) + tensors = a._tensors + non_tensor_attr_names = self._get_non_tensor_attr_names(a) + non_tensors = a._non_tensors + metadata = dict(zip(non_tensor_attr_names, non_tensors)) + for name, t in zip(tensor_attr_names, tensors): + utils.check( + isinstance(t, TensorProxy), + lambda: f"{a=}, {tensor_attr_names = }, {tensors=}", + ) + unwrapped_bsym_args[len(unwrapped_bsym_args)] = t + utils.check( + not metadata, + lambda: f"Tensor Subclasses with nonempty metadata are not supported.", + exception_type=NotImplementedError, + ) + else: + if not isinstance(a, ProxyInterface): + from thunder.core.proxies import proxy + + with tracectx(self.computation_trace): + a = proxy(a) + unwrapped_bsym_args[len(unwrapped_bsym_args)] = a + + node: Node + list_of_placeholder_node: list[Node] = [] + list_of_function_call_node: list[Node] = [] + node_of_output: Node + for node in fx_graph.graph.nodes: + if node.op == PLACEHOLDER: + list_of_placeholder_node.append(node) + if node.op == CALL_FUNCTION: + list_of_function_call_node.append(node) + if node.op == OUTPUT: + node_of_output = node + args = [n.target for n in list_of_placeholder_node] + arg_name_to_index = {a: i for i, a in enumerate(args)} + ltorch_ops_for_node_of_ops = [getattr(ltorch, node.target._opname) for node in list_of_function_call_node] + + bsyms: list[BoundSymbol] = [] + if list_of_unflatten_bsym: + bsyms.extend(list_of_unflatten_bsym) + fxnode_output_name_to_tensor_proxy: dict[str, OpOverload] = {} + for node, ltorch_op in zip(list_of_function_call_node, ltorch_ops_for_node_of_ops): + args: list[Node] = node.args + + arg_proxies: list[ProxyInterface] = [] + for a in args: + if isinstance(a.target, str): + arg_proxies.append(unwrapped_bsym_args[arg_name_to_index[a.target]]) + else: + arg_proxies.append(fxnode_output_name_to_tensor_proxy[str(a)]) + + self.computation_trace.push_scope([]) + + with tracectx(self.computation_trace): + out = ltorch_op(*arg_proxies) + fxnode_output_name_to_tensor_proxy[str(node)] = out + bsyms.extend(self.computation_trace.pop_scope()) + if len(bsyms) == 0: + return [bsym] + + orig_output = bsym.flat_outs[0] + if is_subclass_ctor_bsym := bsym.sym.id == prims.PrimIDs.TENSOR_SUBCLASS_CTOR: + utils.check_type(orig_output, SubclassTensorProxy) + if isinstance(orig_output, SubclassTensorProxy): + # note(crcrpar): args[0] would be list of tensors, and args[1] could be list of non-tensors. + args: list[Node] = node_of_output.args[0] + new_tensor_proxies = [] + for a in args: + if isinstance(a.target, str): + new_tensor_proxies.append(unwrapped_bsym_args[arg_name_to_index[a.target]]) + else: + new_tensor_proxies.append(fxnode_output_name_to_tensor_proxy[str(a)]) + utils.check( + len(orig_output._tensors) == len(new_tensor_proxies), + lambda: ( + f"The number of new tensor proxies for {orig_output=} does not match: " + f"{len(new_tensor_proxies)=} != {len(orig_output._tensors)=}" + ), + ) + with tracectx(self.computation_trace): + new_subclass = orig_output.replace() + for name, value in zip(new_subclass._tensor_attr_names, new_tensor_proxies): + setattr(new_subclass, name, value) + bsyms.append( + prims.unflatten_tensor_subclass.bind( + new_subclass._subclass_type, + dict(zip(new_subclass._tensor_attr_names, new_tensor_proxies)), + dict(zip(new_subclass._non_tensor_attr_names, new_subclass._non_tensors)), + output=new_subclass, + ) + ) + + self.swap_map[variableify(orig_output)] = new_subclass + return bsyms + + def convert_trace_to_fx_graph_and_get_fake_result( + self, + trace: TraceCtx, + ) -> tuple[GraphModule, tuple[OutputWrapperForFxTracing, ...], tuple[torch.Tensor, ...], PyTreeSpec]: + + def create_ctor(unflatten_method, tensor_names): + + def ctor(tensors, metadata): + inner_tensors = dict(zip(tensor_names, tensors)) + return unflatten_method(inner_tensors, metadata, -1, -1) + + return ctor + + args = tree_map( + lambda t: maybe_materialize_tensor( + t, + self.fake_tensor_mode, + ), + trace.args, + ) + desugared_args = [] + arg_idx_to_sugar: dict[int, tuple[int, Any]] = {} + for a in args: + if is_traceable_wrapper_subclass(a): + start_idx = len(desugared_args) + attrs, metadta = a.__tensor_flatten__() + desugared_args.extend([getattr(a, name) for name in attrs]) + desugared_args.append(metadta) + end_idx = len(desugared_args) + arg_idx_to_sugar[start_idx] = end_idx, create_ctor(type(a).__tensor_unflatten__, attrs) + else: + desugared_args.append(a) + + out_specs: list[Any] = [] + orig_output: list[torch.Tensor] = [] + + def transform_out(out: torch.Tensor) -> OutputWrapperForFxTracing: + orig_output.append(out) + if is_traceable_wrapper_subclass(out): + attrs, metadata = out.__tensor_flatten__() + tensors = [getattr(out, name) for name in attrs] + output = OutputWrapperForFxTracing(dict(zip(attrs, tensors)), metadata) + else: + output = OutputWrapperForFxTracing(out, None) + return output + + extrace = transform_for_execution(trace, [get_executor("torch")]) + f = extrace.python_callable(include_decorators=False) + + def f_with_wrap_and_unwrap(*desugared_args) -> tuple[OutputWrapperForFxTracing, ...]: + args = [] + cur_idx = 0 + while cur_idx < len(desugared_args): + if cur_idx in arg_idx_to_sugar: + end_idx, construct_subclass = arg_idx_to_sugar[cur_idx] + args_of_subclass = desugared_args[cur_idx:end_idx] + tensors = args_of_subclass[:-1] + metadata = args_of_subclass[-1] + subclass = construct_subclass(tensors, metadata) + args.append(subclass) + + cur_idx = end_idx + else: + args.append(desugared_args[cur_idx]) + cur_idx += 1 + + out = f(*args) + # Specialcasing the output of initial computation trace + if isinstance(out, dict) and len(out) == 2 and ("output", "flat_args") == tuple(out.keys()): + sequencified_out = out + else: + sequencified_out = utils.sequencify(out) + flat_out, out_spec = tree_flatten(sequencified_out) + out_specs.append(out_spec) + flat_cosmeticized_out = tree_map(transform_out, flat_out) + return tree_unflatten(flat_cosmeticized_out, out_spec) + + with ( + enable_python_dispatcher(), + FunctionalTensorMode( + pre_dispatch=False, + export=False, + _allow_token_discovery=True, + ), + ): + fx: GraphModule = make_fx(f_with_wrap_and_unwrap)(*desugared_args) + + return fx, fx(*desugared_args), tuple(orig_output), out_specs[0] + + def __call__(self, bsym: BoundSymbol) -> list[BoundSymbol]: + updated_bsym: BoundSymbol = bsym.from_bsym_swap_proxies(self.swap_map) + if updated_bsym.sym.id == prims.PrimIDs.RETURN: + if not self.subclass_proxy_to_flatten or True: + return [updated_bsym] + + is_subclass_ctor = bsym.sym.id == prims.PrimIDs.TENSOR_SUBCLASS_CTOR + if not is_subclass_ctor and not any(isinstance(a, SubclassTensorProxy) for a in updated_bsym.flat_proxy_args): + return [updated_bsym] + + utils.check( + len(updated_bsym.flat_outs) < 2, + lambda: f"bsym has {len(updated_bsym.flat_outs)} outputs", + exception_type=NotImplementedError, + ) + + trace = trace_from_bsym_or_bsyms(updated_bsym) + fx, sequencified_cosmeticized_out, orig_output, _ = self.convert_trace_to_fx_graph_and_get_fake_result(trace) + utils.check( + len(sequencified_cosmeticized_out) == len(orig_output), + lambda: f"{len(sequencified_cosmeticized_out)=}, {len(orig_output)=}", + ) + if is_subclass_ctor: + utils.check(len(sequencified_cosmeticized_out) == 1 and len(orig_output) == 1, lambda: "") + fake_tensor_subclass = orig_output[0] + subclass_proxy = updated_bsym.flat_outs[0] + tensor_attr_names, metadata = fake_tensor_subclass.__tensor_flatten__() + subclass_proxy._tensor_attr_names = tensor_attr_names + subclass_proxy._non_tensor_attr_names = list(metadata.keys()) + self.subclass_proxy_to_flatten.add(variableify(subclass_proxy)) + for name, value in zip( + tensor_attr_names + subclass_proxy._non_tensor_attr_names, + subclass_proxy._tensors + subclass_proxy._non_tensor_attr_names, + ): + setattr(subclass_proxy, name, value) + return [updated_bsym] + + out = [] + for i, (cosmeticized_out, orig_out) in enumerate(zip(sequencified_cosmeticized_out, orig_output)): + if isinstance(cosmeticized_out.inner_tensors, dict): + utils.check( + is_traceable_wrapper_subclass(orig_out), lambda: f"{cosmeticized_out=} don't match {orig_out=}" + ) + out.append(orig_out) + else: + out.append(orig_out.tensors) + + with tracectx(self.computation_trace): + out_proxy = tree_map(proxy_fake_tensor, out) + + utils.check( + len(updated_bsym.flat_outs) == len(out_proxy), + lambda: f"{len(bsym.flat_outs)=}, {len(out_proxy)=}, {out_proxy=}, {bsym.flat_outs=}", + ) + sequence_out = [variableify(a) for a in updated_bsym.flat_outs] + self.swap_map.update(dict(zip(sequence_out, utils.sequencify(out_proxy)))) + + bsym_with_modified_output = updated_bsym.from_bsym_swap_proxies(self.swap_map) + + self.bsym_to_new_outputs[bsym_with_modified_output] = bsym_with_modified_output + return self.translate_fx_graph_into_bsym(bsym_with_modified_output, fx) + + +def flatten_tensor_subclasses(computation_trace: TraceCtx) -> TraceCtx: + """Flatten tensor subclasses in ``computation_trace``. + + Two things are happening inside of this function: + * Reevaluate every single bsym of ``computation_trace.bound_symbols``. + * Flatten tensor subclasses + + Each :class:`thunder.core.symbol.BoundSymbol` is reevaluated with torch.fx tracing and + ``FakeTensorMode``. This is necessary because Thunder's initial trace cannot correctly infer the output + type of an op with tensor subclasses. By translating each bsym into a callable and tracing it with + ``torch.fx`` and ``FakeTensorMode``, we can tell the output type and the exact behavior of the bsym + which is extended by subclass's ``__torch_dispatch__`` (note that the sequence of observed operations + are free from tensor subclasses, everything is flattened). + The output type information is then reflected to the output :class:`thunder.core.proxies.Proxy`. + + With this function applied, the :class:`thunder.core.trace.TraceCtx` is free from tensor subclasses. + Exceptions are prologue (meaning the first few lines of the trace, before any math) and epilogue (meaning + the last few lines of the trace, right before return statement). + + Args: + computation_trace: + + Returns: + TraceCtx: transformed trace that is free from tensor subclasses, every ``__torch_dispatch__`` + behavior is spelled out. + """ + desugar_tensor_subclass = DesugarTensorSubclass(computation_trace=computation_trace) + updated_bsyms: list[BoundSymbol] = [] + bsym: BoundSymbol + for bsym in computation_trace.bound_symbols: + maybe_desugared_bsyms = desugar_tensor_subclass(bsym) + updated_bsyms.extend(maybe_desugared_bsyms) + + if not desugar_tensor_subclass.subclass_proxy_to_flatten: + return computation_trace + + computation_trace_with_subclass_tensor_proxy_output = from_trace(computation_trace) + computation_trace_with_subclass_tensor_proxy_output.bound_symbols.extend(updated_bsyms) + computation_trace_with_subclass_tensor_proxy_output.set_provenance(TraceProvenance("tensor subclasses desugared")) + return computation_trace_with_subclass_tensor_proxy_output