Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Tensor Subclasses] Trace transfom to interpret __torch_dispatch__ #1394

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion thunder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
from thunder.core.interpreter import print_interpreter_log, print_to_log
from thunder.core.jit_ext import thunder_general_jit
from thunder.executors.torch_autograd import split_forward_backward, ThunderFunction
from thunder.transforms.tensor_subclasses import flatten_tensor_subclasses

# NOTE This import is intentionally pytorch so that it thunder.torch doesn't import this
import torch as pytorch
Expand Down Expand Up @@ -370,7 +371,7 @@ def _alias_tensor_of_args_kwargs_dict(*args, **kwargs) -> dict[int, list[int]]:
data_ptr_to_tensor_group_index = {}
tensor_group_index_to_tensor_indices = defaultdict(list)
for idx, t in enumerate(flat_args):
if pytorch.is_tensor(t) and t.layout == pytorch.strided:
if type(t) in {pytorch.Tensor, pytorch.nn.Parameter} and t.layout == pytorch.strided:
data_ptr = t.untyped_storage().data_ptr()
if data_ptr not in data_ptr_to_tensor_group_index:
data_ptr_to_tensor_group_index[data_ptr] = len(data_ptr_to_tensor_group_index)
Expand Down
37 changes: 37 additions & 0 deletions thunder/core/jit_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
NumberProxy,
StringProxy,
TensorProxy,
SubclassTensorProxy,
FutureTensorProxy,
make_proxy_name,
Variable,
Expand Down Expand Up @@ -863,6 +864,42 @@ def grad_transform(*args, **kwargs):
return output


@register_general_jit_lookaside(torch.Tensor._make_wrapper_subclass)
def _make_wrapper_subclass(
cls: torch._C._TensorMeta,
size: Sequence[int],
strides: Sequence[int] | None = None,
storage_offset: int | None = None,
memory_format: torch.memory_format | None = None,
dtype: torch.dtype | None = None,
layout: torch.layout | None = torch.strided,
device: torch.device | None = None,
pin_memory: bool = False,
requires_grad: bool = False,
dispatch_sizes_strides_policy: str | None = None,
dispatch_device: bool = False,
dispatch_layout: bool = False,
_extra_dispatch_keys: torch.DispatchKeySet | None = None,
storage_size: int | None = None,
):
ucls = unwrap(cls)
usize = unwrap(size)
udtype = unwrap(dtype)
udevice = unwrap(device)
urequires_grad = unwrap(requires_grad)

subclass = SubclassTensorProxy(
None,
shape=usize,
device=udevice,
dtype=udtype,
requires_grad=urequires_grad,
history=ProvenanceRecord(PseudoInst.LOOKASIDE, [cls.provenance]),
subclass_type=ucls,
)
return wrap(subclass, provenance=ProvenanceRecord(PseudoInst.LOOKASIDE, [cls.provenance]))


@register_general_jit_lookaside(torch.autocast.__enter__)
def autocast_enter(autocast_obj):
unwrap_autocast_obj = unwrap(autocast_obj)
Expand Down
81 changes: 81 additions & 0 deletions thunder/core/prims.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from enum import auto, Enum
from numbers import Number
from functools import reduce, wraps
Expand Down Expand Up @@ -77,6 +79,7 @@ def register_method(method_name: str, method: Callable, /) -> None:
TupleProxy,
AnyProxy,
IntegerProxy,
SubclassTensorProxy,
)
import thunder.core.codeutils as codeutils
from thunder.core.codeutils import Printable
Expand Down Expand Up @@ -272,6 +275,10 @@ class PrimIDs(Enum):
COPY_ = auto()
#
SINK = auto()
# Tensor Subclasses methods
TENSOR_SUBCLASS_CTOR = auto()
FLATTEN_TENSOR_SUBCLASS = auto()
UNFLATTEN_TENSOR_SUBCLASS = auto()


class OpTags(Enum):
Expand Down Expand Up @@ -4048,3 +4055,77 @@ def sink_meta(*args, **kwargs):

# TODO do we want another tag to remove this after prologue is constructed?
sink = make_prim(PrimIDs.SINK, "sink", meta=sink_meta, tags=(OpTags.DONT_DCE,))


def tensor_subclass_ctor_meta(
cls, name, shape, device, dtype, requires_grad, tensors, non_tensors
) -> SubclassTensorProxy:
s = SubclassTensorProxy(
name,
subclass_type=cls,
shape=shape,
device=device,
dtype=dtype,
requires_grad=requires_grad,
tensors=tensors,
non_tensors=non_tensors,
history=[t.history for t in tensors],
)
return s


tensor_subclass_ctor = make_prim(
PrimIDs.TENSOR_SUBCLASS_CTOR,
"tensor_subclass_ctor",
meta=tensor_subclass_ctor_meta,
)


def flatten_tensor_subclass_meta(t: SubclassTensorProxy) -> tuple[TensorProxy, ...]:
tensor_attr_names, metadata = t.__tensor_flatten__()
tensors = tuple(getattr(t, name) for name in tensor_attr_names)
return tensors


flatten_tensor_subclass = make_prim(
PrimIDs.FLATTEN_TENSOR_SUBCLASS,
"flatten_tensor_subclass",
meta=flatten_tensor_subclass_meta,
)


def unflatten_tensor_subclass_meta(
tensor_subclass_type,
inner_tensors: dict[str, TensorProxy],
metadata: dict[str, Any],
) -> SubclassTensorProxy:
first_tensor: TensorProxy = list(inner_tensors.values())[0]
a = SubclassTensorProxy(
shape=first_tensor.shape,
device=first_tensor.device,
dtype=first_tensor.dtype,
requires_grad=first_tensor.requires_grad,
tensors=list(inner_tensors.values()),
non_tensors=list(metadata.values()),
subclass_type=tensor_subclass_type,
)
for name, value in inner_tensors.items():
setattr(a, name, value)
for name, value in metadata.items():
setattr(a, name, value)
return a


def unflatten_tensor_subclass_python_impl(
tensor_subclass_type,
inner_tensors: dict[str, TensorProxy],
metadata: dict[str, Any],
) -> torch.Tensor:
return tensor_subclass_type.__tensor_unflatten__(inner_tensors, metadata, -1, -1)


unflatten_tensor_subclass = make_prim(
PrimIDs.UNFLATTEN_TENSOR_SUBCLASS,
"unflatten_tensor_subclass",
meta=unflatten_tensor_subclass_meta,
)
189 changes: 186 additions & 3 deletions thunder/core/proxies.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import copy
from enum import auto, Enum
from numbers import Number
from typing import Any
from typing import Any, ClassVar
from collections.abc import Callable
from collections.abc import Sequence

Expand Down Expand Up @@ -1880,6 +1880,154 @@ def real(self):
return method(self)


class SubclassTensorProxy(TensorProxy):
SUBCLASS_TYPE_ATTR: ClassVar[str] = "_subclass_type"
_tensors: list[TensorProxy]
_non_tensors: list[Any]
_subclass_type: torch._C._TensorMeta

_tensor_attr_names: list[str] | None
_non_tensor_attr_names: list[str] | None

def __init__(self, *args, **kwargs):
from thunder.core.pytree import tree_flatten

kwarg_tensors = kwargs.pop("tensors", [])
kwarg_non_tensors = kwargs.pop("non_tensors", [])
subclass_type = kwargs.pop("subclass_type", None)

has_name_before_init = hasattr(self, "_name")
# If tensors (and non_tensors) are not empty, then it should be the path of `_make_wrapper_subclass`
# where `self` should already have gotten its name.
flat_args, spec = tree_flatten((args, kwargs))
tensors: list[TensorProxy] = []
non_tensors: list[Any] = []
for t in args + tuple(kwargs.values()):
if type(t) is SubclassTensorProxy:
continue
if type(t) is TensorProxy:
tensors.append(t)
else:
non_tensors.append(t)

is_dunder_init_following_make_wrapper_subclass: bool = False
if tensors:
baseutils.check(
has_name_before_init
and not kwarg_tensors
and not kwarg_non_tensors
and self._subclass_type is not None,
lambda: (
f"{flat_args=} indicates this instance is created by"
"`torch.Tensor._make_wrapper_subclass`'s lookaside but `name` is not set"
),
)
is_dunder_init_following_make_wrapper_subclass = True

if not is_dunder_init_following_make_wrapper_subclass:
super().__init__(*args, **kwargs)

self._tensors = kwarg_tensors
self._non_tensors = kwarg_non_tensors
self._subclass_type = subclass_type
else:
# TODO(crcrpar): Think about materializing `self` so that we can
# call `__tensor_init__` to know each attribute names.
from thunder.core import prims

bsym = prims.tensor_subclass_ctor.bind(
self._subclass_type,
self.name,
self.shape,
self.device,
self.dtype,
self.requires_grad,
self._tensors,
self._non_tensors,
output=self,
)
# NOTE(crcrpar): A callable being `thunder.jit`ed can call `MySubclassTensor(...)`
# inside of it either directly or indirectly: indirect way is to call it through
# a custom `torch.autograd.Function` as in
# https://github.com/pytorch/ao/blob/000a490/torchao/float8/float8_tensor.py#L139-L209.
# If it's a direct call, `trace.bound_symbols` and `trace.scopes[-1]` are identical,
# but not, otherwise. As [the lookasdie of `torch.autograd.Function`](
# https://github.com/Lightning-AI/lightning-thunder/blob/3d42c10/thunder/core/jit_ext.py#L655)
# puts the temporary scope to the current trace.
current_trace = get_tracectx()
if id(current_trace.bound_symbols) == id(cur_tail_scope := current_trace.scopes[-1]):
current_trace.add_bound_symbol(bsym)
else:
cur_tail_scope.append(bsym)

if not self._tensors and not self._non_tensors:
for a in tensors:
self._tensors.append(a)
for a in non_tensors:
self._non_tensors.append(a)
baseutils.check(self._tensors, lambda: f"`{self._name}._tensors` must not be empty")

def __tensor_flatten__(self) -> tuple[list[TensorProxy], dict[str, Any]]:
return self._tensor_attr_names, self.metadata

@property
def metadata(self) -> dict[str, Any]:
return dict(zip(self._non_tensor_attr_names, self._non_tensors))

def replace(self, **changes):
r"""Return a copy of the SubclassTensorProxy object with new values for the specified fields as given to the constructor as arguments.
Valid keyword arguments are ``name``, ``history``, ``shape``, ``dtype``, ``device``, ``requires_grad``, ``distparallel_type``, ``thunder_fsdp_padding_size``.
``like`` is also a valid keyword and will take metadata from the tensor proxy argument
in preference to the old values but overridable by keyword arguments.
Note that the copy will use the current (environment) tracectx."""

like = changes.get("like")
(
shape,
device,
dtype,
true_dtype,
numel,
ndim,
requires_grad,
grad,
distparallel_type,
thunder_fsdp_padding_size,
) = _infer_tensor_properties(
like,
changes.get("shape", self._shape if like is None else None),
changes.get("device", self._device if like is None else None),
changes.get("dtype", self._dtype if like is None else None),
changes.get("requires_grad", self._requires_grad if like is None else None),
changes.get("grad", self._grad if like is None else None),
changes.get("distparallel_type", self._distparallel_type if like is None else None),
changes.get("thunder_fsdp_padding_size", self._thunder_fsdp_padding_size if like is None else None),
)
name = changes.get("name", self.name)
history = changes.get("history", self.history)
tags = changes.get("tags", self.tags)
p = SubclassTensorProxy(
name=name,
tags=tags,
shape=shape,
device=device,
dtype=dtype,
requires_grad=requires_grad,
distparallel_type=distparallel_type,
thunder_fsdp_padding_size=thunder_fsdp_padding_size,
history=history,
tensors=self._tensors,
non_tensors=self._non_tensors,
subclass_type=self._subclass_type,
)
if hasattr(self, "_tensor_attr_names") and hasattr(self, "_non_tensor_attr_names"):
p._tensor_attr_names = self._tensor_attr_names
p._non_tensor_attr_names = self._non_tensor_attr_names
for name, value in zip(p._tensor_attr_names + p._non_tensor_attr_names, p._tensors + p._non_tensors):
setattr(p, name, value)
return p


class TorchAutogradFunctionCtxProxy(Proxy, TorchAutogradFunctionCtxProxyInterface):
def __init__(
self,
Expand Down Expand Up @@ -1951,6 +2099,7 @@ def __setattr__(self, name, value):

# TODO: move this function to jit_ext.py
def tensorproxy(t: torch.Tensor, /, *, name: None | str, history: None | tuple = None) -> TensorProxy:
from torch._subclasses.fake_tensor import FakeTensor
from thunder.core.interpreter import ProvenanceRecord, PseudoInst, wrap_const

if hasattr(t, "_thunder_device"):
Expand Down Expand Up @@ -1985,8 +2134,8 @@ def tensorproxy(t: torch.Tensor, /, *, name: None | str, history: None | tuple =
else:
# NOTE Without tuple(t.shape) then the shape would be a torch.Size object
shape = tuple(t.shape)
return TensorProxy(
name,
ctor_kwargs = dict(
name=name,
shape=tuple(shape),
device=device,
dtype=dtype,
Expand All @@ -1997,6 +2146,40 @@ def tensorproxy(t: torch.Tensor, /, *, name: None | str, history: None | tuple =
thunder_fsdp_padding_size=_thunder_fsdp_padding_size,
)

# n.b.(crcrpar): :class:`thunder.dynamo.ThunderCompiler.__call__` takes torch.fx GraphModule
# where `FakeTensor` seems to be used, leading to failures observed in e.g.
# https://github.com/Lightning-AI/lightning-thunder/actions/runs/11689709564/job/32553053319#step:10:5747
# https://dev.azure.com/Lightning-AI/lightning/_build/results?buildId=219328&view=logs&jobId=5b0799f7-725e-5b16-9b83-c0a5a25d03f0&j=5b0799f7-725e-5b16-9b83-c0a5a25d03f0
if (
isinstance(t, torch.Tensor)
and type(t) not in (torch.Tensor, torch.nn.Parameter, FakeTensor)
and hasattr(t, "__tensor_flatten__")
and hasattr(t, "__tensor_unflatten__")
):
baseutils.check(
hasattr(t, "__tensor_flatten__") and hasattr(t, "__tensor_unflatten__"),
lambda: f"{t=} seems to be a tensor subclass but not traceable",
)
tensor_attr_names, metadata = t.__tensor_flatten__()
tensors = [tensorproxy(getattr(t, name), name=None, history=history) for name in tensor_attr_names]
ctor_kwargs.update(
{
"tensors": tensors,
"non_tensors": list(metadata.values()),
"subclass_type": type(t),
}
)
p = SubclassTensorProxy(**ctor_kwargs)
p._tensor_attr_names = tensor_attr_names
p._non_tensor_attr_names = list(metadata.keys())
for name, tensor in zip(tensor_attr_names, tensors):
setattr(p, name, tensor)
for name, value in metadata.items():
setattr(p, name, value)
return p
else:
return TensorProxy(**ctor_kwargs)


def futuretensorproxy(
t: torch.Tensor | TensorProxy | FutureTensorProxy, /, *, name: None | str, history: None | tuple = None
Expand Down
Loading
Loading