diff --git a/thunder/core/jit_ext.py b/thunder/core/jit_ext.py index 726353983a..e88419546a 100644 --- a/thunder/core/jit_ext.py +++ b/thunder/core/jit_ext.py @@ -63,6 +63,7 @@ StringProxy, TensorProxy, FutureTensorProxy, + SubclassTensorProxy, make_proxy_name, Variable, variableify, @@ -930,6 +931,42 @@ def thunder_function(*args, **kwargs): return interpreter_needs_wrap(checkpoint)(wrapped_thunder_function, *args, **kwargs) +@register_general_jit_lookaside(torch.Tensor._make_wrapper_subclass) +def _make_wrapper_subclass( + cls: torch._C._TensorMeta, + size: Sequence[int], + strides: Sequence[int] | None = None, + storage_offset: int | None = None, + memory_format: torch.memory_format | None = None, + dtype: torch.dtype | None = None, + layout: torch.layout | None = torch.strided, + device: torch.device | None = None, + pin_memory: bool = False, + requires_grad: bool = False, + dispatch_sizes_strides_policy: str | None = None, + dispatch_device: bool = False, + dispatch_layout: bool = False, + _extra_dispatch_keys: torch.DispatchKeySet | None = None, + storage_size: int | None = None, +): + ucls = unwrap(cls) + usize = unwrap(size) + udtype = unwrap(dtype) + udevice = unwrap(device) + urequires_grad = unwrap(requires_grad) + + subclass = SubclassTensorProxy( + None, + shape=usize, + device=udevice, + dtype=udtype, + requires_grad=urequires_grad, + history=ProvenanceRecord(PseudoInst.LOOKASIDE, [cls.provenance]), + subclass_type=ucls, + ) + return wrap(subclass, provenance=ProvenanceRecord(PseudoInst.LOOKASIDE, [cls.provenance])) + + # Adds proxy methods # NOTE These methods map to themselves, which prevents the interpreter from looking into them # This is OK because these methods are written in a tracing-safe manner, and trying to diff --git a/thunder/core/prims.py b/thunder/core/prims.py index da46147155..90b062eeb4 100644 --- a/thunder/core/prims.py +++ b/thunder/core/prims.py @@ -1,3 +1,4 @@ +from __future__ import annotations from enum import auto, Enum from numbers import Number from functools import reduce, wraps @@ -5,7 +6,7 @@ import builtins import math from types import NoneType -from typing import Union, Type, Any, List, Dict, Tuple, Optional +from typing import Union, Type, Any, List, Dict, Tuple, Optional, TYPE_CHECKING from collections.abc import Callable from collections.abc import Callable, Hashable, Sequence @@ -13,6 +14,10 @@ from thunder.core.langctxs import LanguageContext, register_langctx, Languages, langctx +if TYPE_CHECKING: + from collections.abc import Iterable + from thunder.core.codeutils import ContextObject + # # Creates and registers the torch language context # @@ -77,6 +82,7 @@ def register_method(method_name: str, method: Callable, /) -> None: TupleProxy, AnyProxy, IntegerProxy, + SubclassTensorProxy, ) import thunder.core.codeutils as codeutils from thunder.core.codeutils import Printable @@ -272,6 +278,8 @@ class PrimIDs(Enum): COPY_ = auto() # SINK = auto() + # Tensor Subclasses methods + TENSOR_SUBCLASS_CTOR = auto() class OpTags(Enum): @@ -3543,7 +3551,10 @@ def transpose_meta(a: TensorProxy, /, permutation: tuple[int, ...]) -> TensorPro view = make_prim(PrimIDs.VIEW, "view", meta=reshape_meta, tags=(OpTags.SHAPE_OP,)) -def shallow_copy_meta(a: TensorProxy, /) -> TensorProxy: +def shallow_copy_meta(a: TensorProxy | SubclassTensorProxy, /) -> TensorProxy: + if isinstance(a, SubclassTensorProxy): + # SubclassTensorProxy(like=...) would not copy some attrs such as `_tensors` while replace does. + return a.replace() return TensorProxy(like=a) @@ -4048,3 +4059,139 @@ def sink_meta(*args, **kwargs): # TODO do we want another tag to remove this after prologue is constructed? sink = make_prim(PrimIDs.SINK, "sink", meta=sink_meta, tags=(OpTags.DONT_DCE,)) + + +def tensor_subclass_ctor_meta( + cls, name, shape, device, dtype, requires_grad, tensors, non_tensors +) -> SubclassTensorProxy: + s = SubclassTensorProxy( + name, + subclass_type=cls, + shape=shape, + device=device, + dtype=dtype, + requires_grad=requires_grad, + tensors=tensors, + non_tensors=non_tensors, + history=[t.history for t in tensors], + ) + return s + + +def get_nested_types(collection): + collection = utils.sequencify(collection) + types_set = {type(t) for t in collection} + + def check_types(coll): + for item in coll: + types_set.add(type(item)) + # Check if the item is a nested collection + if baseutils.is_collection(item): + # If it's a dictionary, check its values + if isinstance(item, dict): + check_types(item.values()) + # Recursively check nested collections + else: + check_types(item) + + check_types(collection) + return tuple(types_set) + + +def filter_types(types: tuple[Any, ...]) -> tuple[Any, ...]: + return tuple( + filter( + lambda t: ( + t.__module__ != "builtins" + and t != Number + # note(crcrpar): maybe `thunder.core`? + and not t.__module__.startswith("thunder.") + and not t.__module__.startswith("torch.") + ), + types, + ) + ) + + +def printer_of_tensor_subclass_ctor( + bsym: BoundSymbol, + out_printables: Any, + arg_printables: Sequence[Printable], + kwarg_printables: dict[str, Printable], +) -> str | Iterable[str]: + from itertools import chain + + baseutils.check(not kwarg_printables, lambda: f"No kwargs are supported but {kwarg_printables = }") + + # NOTE(crcrpar): It's not a context but at the moment Tensor subclass is treated as `ContextObject`. + wrapped_cls: ContextObject | torch._C._TensorMeta = arg_printables[0] + if isinstance(wrapped_cls, torch._C._TensorMeta): + cls = wrapped_cls + else: + cls: torch._C._TensorMeta = wrapped_cls.obj + tensors, non_tensors = arg_printables[-2:] + new_non_tensors = [] + for a in non_tensors: + if isinstance(a, dtypes.dtype): + new_non_tensors.append(dtypes.to_torch_dtype(a)) + elif isinstance(a, devices.Device): + new_non_tensors.append(devices.to_torch_device(a)) + else: + new_non_tensors.append(a) + + arg_str = ", ".join(codeutils.prettyprint(x) for x in [*tensors, *new_non_tensors]) + kwarg_str = "" + + result_str: str + if bsym.output is None or (baseutils.is_collection(bsym.output) and len(bsym.output) == 0): + result_str = "" + else: + result_str = f"{codeutils.prettyprint(out_printables, literals_as_underscores=True)} = " + + # Creates a comment describing the output + comment_str: str + if isinstance(bsym.output, Proxy): + comment_str = f" # {codeutils.prettyprint(out_printables, with_type=True)}" + else: + comment_str = "" + + cls_with_module = f"{cls.__name__}" + s = f"{result_str}{cls_with_module}({arg_str}{', ' if (len(arg_str) > 0 and len(kwarg_str) > 0) else ''}{kwarg_str}){comment_str}" + + if bsym.header: + header_lines = ( + bsym.header + if isinstance(bsym.header, Sequence) and not isinstance(bsym.header, str) + else bsym.header.splitlines() + ) + header_lines = (f"# {line}" for line in header_lines) + return chain(header_lines, [s]) + + filtered_types = (cls,) + if non_tensors: + types = get_nested_types([t.obj if isinstance(t, codeutils.ContextObject) else t for t in non_tensors]) + filtered_types += filter_types(types) + new_imports = {t.__name__: t for t in filtered_types} + bsym._import_ctx.update(new_imports) + return s + + +def bind_postprocess_of_tensor_subclass_ctor(bsym: BoundSymbol) -> None: + cls = bsym.args[0] + non_tensors = bsym.args[-1] + + filtered_types: tuple[Any, ...] = (cls,) + if non_tensors: + types = get_nested_types(non_tensors) + filtered_types += filter_types(types) + new_imports = {t.__name__: t for t in filtered_types} + bsym._import_ctx.update(new_imports) + + +tensor_subclass_ctor = make_prim( + PrimIDs.TENSOR_SUBCLASS_CTOR, + "tensor_subclass_ctor", + meta=tensor_subclass_ctor_meta, + python_printer=printer_of_tensor_subclass_ctor, + _bind_postprocess=bind_postprocess_of_tensor_subclass_ctor, +) diff --git a/thunder/core/proxies.py b/thunder/core/proxies.py index 2f2eb1c665..c6c8865675 100644 --- a/thunder/core/proxies.py +++ b/thunder/core/proxies.py @@ -3,7 +3,7 @@ import copy from enum import auto, Enum from numbers import Number -from typing import Any +from typing import Any, ClassVar from collections.abc import Callable from collections.abc import Sequence @@ -1223,7 +1223,7 @@ def _infer_tensor_properties( _thunder_fsdp_padding_size = None if like is not None: - baseutils.check_type(like, (TensorProxy, FutureTensorProxy)) + baseutils.check_type(like, (TensorProxy, FutureTensorProxy, SubclassTensorProxy)) _shape = tuple(like.shape) _device = like.device _dtype = like.true_dtype @@ -1880,6 +1880,166 @@ def real(self): return method(self) +class SubclassTensorProxy(TensorProxy): + SUBCLASS_TYPE_ATTR: ClassVar[str] = "_subclass_type" + _tensors: list[TensorProxy] + _non_tensors: list[Any] + _subclass_type: torch._C._TensorMeta + + _tensor_attr_names: list[str] | None + _non_tensor_attr_names: list[str] | None + + def __init__(self, *args, **kwargs): + from thunder.core.pytree import tree_flatten + + kwarg_tensors = kwargs.pop("tensors", []) + kwarg_non_tensors = kwargs.pop("non_tensors", []) + subclass_type = kwargs.pop("subclass_type", None) + + has_name_before_init = hasattr(self, "_name") + # If tensors (and non_tensors) are not empty, then it should be the path of `_make_wrapper_subclass` + # where `self` should already have gotten its name. + flat_args, spec = tree_flatten((args, kwargs)) + tensors: list[TensorProxy] = [] + non_tensors: list[Any] = [] + for t in args + tuple(kwargs.values()): + if type(t) is SubclassTensorProxy: + continue + if type(t) is TensorProxy: + tensors.append(t) + else: + non_tensors.append(t) + + is_dunder_init_following_make_wrapper_subclass: bool = False + if tensors: + baseutils.check( + has_name_before_init + and not kwarg_tensors + and not kwarg_non_tensors + and self._subclass_type is not None, + lambda: ( + f"{flat_args=} indicates this instance is created by" + "`torch.Tensor._make_wrapper_subclass`'s lookaside but `name` is not set" + ), + ) + is_dunder_init_following_make_wrapper_subclass = True + + if not is_dunder_init_following_make_wrapper_subclass: + super().__init__(*args, **kwargs) + + self._tensors = kwarg_tensors + self._non_tensors = kwarg_non_tensors + self._subclass_type = subclass_type + else: + # TODO(crcrpar): Think about materializing `self` so that we can + # call `__tensor_init__` to know each attribute names. + from dataclasses import replace + import inspect + from thunder.core import prims + + bsym = prims.tensor_subclass_ctor.bind( + self._subclass_type, + self.name, + self.shape, + self.device, + self.dtype, + self.requires_grad, + self._tensors, + self._non_tensors, + output=self, + ) + cls_module = inspect.getmodule(self._subclass_type) + bsym.sym = replace(bsym.sym, _module=cls_module) + + # NOTE(crcrpar): A callable being `thunder.jit`ed can call `MySubclassTensor(...)` + # inside of it either directly or indirectly: indirect way is to call it through + # a custom `torch.autograd.Function` as in + # https://github.com/pytorch/ao/blob/000a490/torchao/float8/float8_tensor.py#L139-L209. + # If it's a direct call, `trace.bound_symbols` and `trace.scopes[-1]` are identical, + # but not, otherwise. As [the lookasdie of `torch.autograd.Function`]( + # https://github.com/Lightning-AI/lightning-thunder/blob/3d42c10/thunder/core/jit_ext.py#L655) + # puts the temporary scope to the current trace. + current_trace = get_tracectx() + if id(current_trace.bound_symbols) == id(cur_tail_scope := current_trace.scopes[-1]): + current_trace.add_bound_symbol(bsym) + else: + cur_tail_scope.append(bsym) + + if not self._tensors and not self._non_tensors: + for a in tensors: + self._tensors.append(a) + for a in non_tensors: + self._non_tensors.append(a) + baseutils.check(self._tensors, lambda: f"`{self._name}._tensors` must not be empty") + + def __tensor_flatten__(self) -> tuple[list[TensorProxy], dict[str, Any]]: + return self._tensor_attr_names, self.metadata + + @property + def metadata(self) -> dict[str, Any]: + return dict(zip(self._non_tensor_attr_names, self._non_tensors)) + + def replace(self, **changes): + r"""Return a copy of the SubclassTensorProxy object with new values for the specified fields as given to the constructor as arguments. + Valid keyword arguments are ``name``, ``history``, ``shape``, ``dtype``, ``device``, ``requires_grad``, ``distparallel_type``, ``thunder_fsdp_padding_size``. + ``like`` is also a valid keyword and will take metadata from the tensor proxy argument + in preference to the old values but overridable by keyword arguments. + Note that the copy will use the current (environment) tracectx.""" + + like = changes.get("like") + ( + shape, + device, + dtype, + true_dtype, + numel, + ndim, + requires_grad, + grad, + distparallel_type, + thunder_fsdp_padding_size, + ) = _infer_tensor_properties( + like, + changes.get("shape", self._shape if like is None else None), + changes.get("device", self._device if like is None else None), + changes.get("dtype", self._dtype if like is None else None), + changes.get("requires_grad", self._requires_grad if like is None else None), + changes.get("grad", self._grad if like is None else None), + changes.get("distparallel_type", self._distparallel_type if like is None else None), + changes.get("thunder_fsdp_padding_size", self._thunder_fsdp_padding_size if like is None else None), + ) + name = changes.get("name", self.name) + history = changes.get("history", self.history) + tags = changes.get("tags", self.tags) + p = SubclassTensorProxy( + name=name, + tags=tags, + shape=shape, + device=device, + dtype=dtype, + requires_grad=requires_grad, + distparallel_type=distparallel_type, + thunder_fsdp_padding_size=thunder_fsdp_padding_size, + history=history, + tensors=self._tensors, + non_tensors=self._non_tensors, + subclass_type=self._subclass_type, + ) + if hasattr(self, "_tensor_attr_names") and hasattr(self, "_non_tensor_attr_names"): + p._tensor_attr_names = self._tensor_attr_names + p._non_tensor_attr_names = self._non_tensor_attr_names + for name, value in zip(p._tensor_attr_names + p._non_tensor_attr_names, p._tensors + p._non_tensors): + setattr(p, name, value) + return p + + def __repr__(self): + tensors, metadata = {}, {} + if hasattr(self, "_tensor_attr_names"): + tensor_names, metadata = self.__tensor_flatten__() + tensors = {n: getattr(self, n) for n in tensor_names} + return f'<{type(self).__name__}(name="{self.name}", dtype={self.dtype}, shape={self._shape}, {tensors=}, {metadata=})>' + + class TorchAutogradFunctionCtxProxy(Proxy, TorchAutogradFunctionCtxProxyInterface): def __init__( self, diff --git a/thunder/executors/torchex.py b/thunder/executors/torchex.py index 6de644204d..826d9d8a01 100644 --- a/thunder/executors/torchex.py +++ b/thunder/executors/torchex.py @@ -2220,3 +2220,37 @@ def _shape_impl(t): shallow_copy = ex.register_operator("shallow_copy", meta=prims.shallow_copy, fn=lambda x: x) _register_implementation(prims.shallow_copy, shallow_copy, checker=_always_executable) + + +def _tensor_subclass_ctor(cls, name, shape, device, dtype, requires_grad, tensors, non_tensors): + new_non_tensors = [] + for a in non_tensors: + if isinstance(a, dtypes.dtype): + new_non_tensors.append(to_torch_dtype(a)) + elif isinstance(a, devices.Device): + new_non_tensors.append(to_torch_device(a)) + else: + new_non_tensors.append(a) + return cls(*tensors, *new_non_tensors) + + +def _bind_postprocess_of_tensor_subclass_ctor(bsym: BoundSymbol) -> None: + from thunder.core.prims import get_nested_types, filter_types + + cls, _name, _shape, _device, _dtype, _requires_grad, _tensors, non_tensors = bsym.args + filtered_types = (cls,) + if non_tensors: + types = get_nested_types(non_tensors) + filtered_types += filter_types(types) + new_imports = {t.__name__: t for t in filtered_types} + bsym._import_ctx.update(new_imports) + + +tensor_subclass_ctor = ex.register_operator( + "tensor_subclass_ctor", + meta=prims.tensor_subclass_ctor, + fn=_tensor_subclass_ctor, + bind_postprocess=_bind_postprocess_of_tensor_subclass_ctor, + python_printer=prims.printer_of_tensor_subclass_ctor, +) +_register_implementation(prims.tensor_subclass_ctor, tensor_subclass_ctor, checker=_always_executable) diff --git a/thunder/tests/test_tensor_subclass.py b/thunder/tests/test_tensor_subclass.py new file mode 100644 index 0000000000..7e3bad2460 --- /dev/null +++ b/thunder/tests/test_tensor_subclass.py @@ -0,0 +1,182 @@ +from __future__ import annotations +from typing import TYPE_CHECKING + +import pytest +import torch +from torch.utils import _pytree as pytree + +import thunder +from thunder.tests.framework import instantiate +from thunder.tests.make_tensor import make_tensor + +if TYPE_CHECKING: + from typing import Any + + +@torch._dynamo.allow_in_graph +class EncapsulateXandScale(torch.autograd.Function): + @staticmethod + def forward(ctx, x: torch.Tensor, scale: torch.Tensor): + return ScaleTensorSubclass(x, scale) + + @staticmethod + def backward(ctx, grad): + return grad, None + + +def encapsulate_x_and_scale(x, scale) -> ScaleTensorSubclass: + return EncapsulateXandScale.apply(x, scale) + + +@torch._dynamo.allow_in_graph +class ToScaleTensorSubclass(torch.autograd.Function): + @staticmethod + def forward(ctx, x: torch.Tensor): + return ScaleTensorSubclass.from_tensor(x) + + @staticmethod + def backward(ctx, grad): + return grad + + +def to_scale_tensor_subclass(x: torch.Tensor) -> ScaleTensorSubclass: + return ToScaleTensorSubclass.apply(x) + + +class ScaleTensorSubclass(torch.Tensor): + _x: torch.Tensor + _scale: torch.Tensor + __slots__ = ["_x", "_scale"] + + def __new__(cls, x: torch.Tensor, scale: torch.Tensor): + assert scale.numel() == 1, f"Invalid `scale`: {scale}" + dtype = x.dtype + device = x.device + self = torch.Tensor._make_wrapper_subclass( + cls, + x.size(), + dtype=dtype, + device=device, + # strides=x.stride(), + # storage_offset=x.storage_offset(), + # layout=x.layout, + requires_grad=x.requires_grad, + ) + self._x = x + self._scale = scale + + return self + + # ref: https://github.com/albanD/subclass_zoo/blob/ec47458/base_tensor.py#L22 + __torch_function__ = torch._C._disabled_torch_function_impl + + def __repr__(self): + return f"ScaleTensorSubclass(dtype={self._x.dtype}, device={self._x.device}, x={self._x}, scale={self._scale})" + + def __tensor_flatten__(self) -> tuple[list[str], dict[str, Any]]: + return ["_x", "_scale"], {} + + @staticmethod + def __tensor_unflatten__( + inner_tensors: dict[str, torch.Tensor], + metadata: dict[str, Any], + outer_size, + outer_stride, + ) -> ScaleTensorSubclass: + return ScaleTensorSubclass(inner_tensors["_x"], inner_tensors["_scale"]) + + @staticmethod + def from_tensor(x: torch.Tensor) -> ScaleTensorSubclass: + scale = x.abs().max() + return ScaleTensorSubclass(x, scale) + + @classmethod + def __torch_dispatch__(cls, aten_ir_op: torch._ops.OpOverload, types, args=(), kwargs=None): + + def allowed_subclass(typ): + return ( + issubclass(cls, typ) + or issubclass(torch._subclasses.FakeTensor, typ) + or issubclass(torch._subclasses.functional_tensor.FunctionalTensor, typ) + ) + + def maybe_unwrap_and_scale(t: ScaleTensorSubclass | Any): + if isinstance(t, ScaleTensorSubclass): + if t.is_floating_point(): + return t._x * t._scale + else: + return t._x + return t + + if not all(allowed_subclass(t) for t in types): + return NotImplementedError(f"Unsupported types are included: {types}") + + scales = tuple(t._scale for t in pytree.tree_flatten((args, kwargs))[0] if isinstance(t, ScaleTensorSubclass)) + unwrapped_args, unwrapped_kwargs = pytree.tree_map(maybe_unwrap_and_scale, (args, kwargs)) + out = aten_ir_op(*unwrapped_args, **unwrapped_kwargs) + return out + + +@instantiate( + dtypes=(thunder.core.dtypes.float32,), +) +def test_func_of_subclass_ctor_wrapper(executor, device, _): + + def f(x: torch.Tensor, scale: torch.Tensor) -> ScaleTensorSubclass: + y = ScaleTensorSubclass(x, scale) + return y + + jitted = executor.make_callable(f) + + dtype = torch.float32 + shape = (2, 2) + x = make_tensor(shape, device=device, dtype=dtype) + scale = make_tensor((), device=device, dtype=dtype) + + expected = f(x, scale) + actual = jitted(x, scale) + torch.testing.assert_close((expected._x, expected._scale), (actual._x, actual._scale)) + + def f(x: torch.Tensor, scale: torch.Tensor): + y = ScaleTensorSubclass(x, scale) + z = ScaleTensorSubclass(y._x, y._scale) + return z + + jitted = executor.make_callable(f) + + expected = f(x, scale) + actual = jitted(x, scale) + torch.testing.assert_close((expected._x, expected._scale), (actual._x, actual._scale)) + + +@instantiate( + dtypes=(thunder.core.dtypes.float32,), +) +def test_func_calling_converter(executor, device, _): + + def f(x: torch.Tensor, scale: torch.Tensor) -> ScaleTensorSubclass: + y = encapsulate_x_and_scale(x, scale) + return y + + jitted = executor.make_callable(f) + + dtype = torch.float32 + shape = (2, 2) + + x = make_tensor(shape, device=device, dtype=dtype) + scale = make_tensor((), device=device, dtype=dtype) + + expected = f(x, scale) + actual = jitted(x, scale) + torch.testing.assert_close((expected._x, expected._scale), (actual._x, actual._scale)) + + def g(x: torch.Tensor) -> ScaleTensorSubclass: + y = to_scale_tensor_subclass(x) + return y + + jitted = thunder.jit(g) + x = make_tensor(shape, device=device, dtype=dtype) + + expected = g(x) + actual = jitted(x) + torch.testing.assert_close((expected._x, expected._scale), (actual._x, actual._scale))