Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[tensor wrapper subclass] Add trace transform for tensor subclasses #1584

Draft
wants to merge 3 commits into
base: tensor_subclass_1
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/reference/transforms/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ thunder.transforms

MaterializationTransform
ConstantFolding
unroll_tensor_subclasses
9 changes: 8 additions & 1 deletion thunder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
from thunder.core.interpreter import print_interpreter_log, print_to_log
from thunder.core.jit_ext import thunder_general_jit
from thunder.executors.torch_autograd import split_forward_backward, ThunderFunction
from thunder.transforms.tensor_wrapper_subclass import unroll_tensor_subclasses

# NOTE This import is intentionally pytorch so that it thunder.torch doesn't import this
import torch as pytorch
Expand Down Expand Up @@ -369,7 +370,7 @@ def _alias_tensor_of_args_kwargs_dict(*args, **kwargs) -> dict[int, list[int]]:
data_ptr_to_tensor_group_index = {}
tensor_group_index_to_tensor_indices = defaultdict(list)
for idx, t in enumerate(flat_args):
if pytorch.is_tensor(t) and t.layout == pytorch.strided:
if type(t) in {pytorch.Tensor, pytorch.nn.Parameter} and t.layout == pytorch.strided:
data_ptr = t.untyped_storage().data_ptr()
if data_ptr not in data_ptr_to_tensor_group_index:
data_ptr_to_tensor_group_index[data_ptr] = len(data_ptr_to_tensor_group_index)
Expand Down Expand Up @@ -616,6 +617,7 @@ def get_computation_and_inputs(*args, **kwargs):
computation_trc = dce(computation_trc)
computation_traces.append(computation_trc)

_unroll_tensor_subclasses_applied = False
backward_trc = None
if not cd.disable_torch_autograd_support:
tensor_cls = (pytorch.Tensor, TensorProxy)
Expand All @@ -626,10 +628,15 @@ def get_computation_and_inputs(*args, **kwargs):
# transform_for_execution and various sorting of symbols,
# applying transform_for_execution after this would be
# breaking the order of operations
_unroll_tensor_subclasses_applied = True
computation_trc, backward_trc = split_forward_backward(computation_trc, cd, cs, *inps)
# Note computation_trc and backward_trc have been appended to cs.last_(backward_)traces
# by split_forward_backward

if not _unroll_tensor_subclasses_applied:
computation_trc = unroll_tensor_subclasses(computation_trc)
computation_traces.append(computation_trc)

if backward_trc is None:
from thunder.executors.passes import transform_for_execution as transform_for_execution_pass
from thunder.executors.passes import _transform_for_operator_executor_execution
Expand Down
168 changes: 165 additions & 3 deletions thunder/core/prims.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,8 @@ class PrimIDs(Enum):
SINK = auto()
# Tensor Subclasses methods
TENSOR_SUBCLASS_CTOR = auto()
FLATTEN_TENSOR_SUBCLASS = auto()
UNFLATTEN_TENSOR_SUBCLASS = auto()


class OpTags(Enum):
Expand Down Expand Up @@ -4098,7 +4100,7 @@ def check_types(coll):
return tuple(types_set)


def filter_types(types: tuple[Any, ...]) -> tuple[Any, ...]:
def filter_types_for_tensor_wrapper_subclass(types: tuple[Any, ...]) -> tuple[Any, ...]:
return tuple(
filter(
lambda t: (
Expand Down Expand Up @@ -4170,7 +4172,7 @@ def printer_of_tensor_subclass_ctor(
filtered_types = (cls,)
if non_tensors:
types = get_nested_types([t.obj if isinstance(t, codeutils.ContextObject) else t for t in non_tensors])
filtered_types += filter_types(types)
filtered_types += filter_types_for_tensor_wrapper_subclass(types)
new_imports = {t.__name__: t for t in filtered_types}
bsym._import_ctx.update(new_imports)
return s
Expand All @@ -4183,7 +4185,7 @@ def bind_postprocess_of_tensor_subclass_ctor(bsym: BoundSymbol) -> None:
filtered_types: tuple[Any, ...] = (cls,)
if non_tensors:
types = get_nested_types(non_tensors)
filtered_types += filter_types(types)
filtered_types += filter_types_for_tensor_wrapper_subclass(types)
new_imports = {t.__name__: t for t in filtered_types}
bsym._import_ctx.update(new_imports)

Expand All @@ -4195,3 +4197,163 @@ def bind_postprocess_of_tensor_subclass_ctor(bsym: BoundSymbol) -> None:
python_printer=printer_of_tensor_subclass_ctor,
_bind_postprocess=bind_postprocess_of_tensor_subclass_ctor,
)


def printer_of_tensor_subclass_flatten(
bsym: BoundSymbol,
out_printables: Any,
arg_printables: Sequence[Printable],
kwarg_printables: dict[str, Printable],
) -> str | Iterable[str]:
from itertools import chain

arg_str = (
""
if (arg_printables is None or len(arg_printables) == 0)
else ", ".join(codeutils.prettyprint(x) for x in arg_printables)
)

result_str: str
if bsym.output is None or (baseutils.is_collection(bsym.output) and len(bsym.output) == 0):
result_str = ""
else:
result_str = f"{codeutils.prettyprint(out_printables, literals_as_underscores=True)} = "

# Creates a comment describing the output
comment_str = ""
if isinstance(bsym.output, Proxy):
comment_str = f" # {codeutils.prettyprint(out_printables, with_type=True)}"

s = f"{result_str}{arg_str}.__tensor_flatten__(){comment_str}"

if bsym.header:
header_lines = (
bsym.header
if isinstance(bsym.header, Sequence) and not isinstance(bsym.header, str)
else bsym.header.splitlines()
)
header_lines = (f"# {line}" for line in header_lines)
return chain(header_lines, [s])

return s


# NOTE(crcrpar): The behavior is different from PyTorch `subclass_tensor.__tensor_flatten__()`
# that returns a list of tensor attr names and a dict of const metadata. In Thunder traces,
# const values could be obviated and actual tensor proxies would be more useful
# than tensor attr names.
def flatten_tensor_subclass_meta(t: SubclassTensorProxy) -> tuple[TensorProxy, ...]:
tensor_attr_names, metadata = t.__tensor_flatten__()
tensors = tuple(getattr(t, name) for name in tensor_attr_names)
return tensors


flatten_tensor_subclass = make_prim(
PrimIDs.FLATTEN_TENSOR_SUBCLASS,
"flatten_tensor_subclass",
meta=flatten_tensor_subclass_meta,
python_printer=printer_of_tensor_subclass_flatten,
)


def printer_of_unflatten_tensor_subclass(
bsym: BoundSymbol,
out_printables: Any,
arg_printables: Sequence[Printable],
kwarg_printables: dict[str, Printable],
) -> str | Iterable[str]:
from itertools import chain

wrapped_cls: ContextObject | torch._C._TensorMeta = arg_printables[0]
if isinstance(wrapped_cls, torch._C._TensorMeta):
cls = wrapped_cls
else:
cls: torch._C._TensorMeta = wrapped_cls.obj

arg_str = (
""
if (arg_printables is None or len(arg_printables) == 0)
else ", ".join(codeutils.prettyprint(x) for x in arg_printables[1:])
)
kwarg_str: str

if len(kwarg_printables) == 0:
kwarg_str = ""
else:
kwarg_str = ", ".join(f"{k}={codeutils.prettyprint(v)}" for k, v in kwarg_printables.items())

result_str: str
if bsym.output is None or (baseutils.is_collection(bsym.output) and len(bsym.output) == 0):
result_str = ""
else:
result_str = f"{codeutils.prettyprint(out_printables, literals_as_underscores=True)} = "

# Creates a comment describing the output
comment_str = ""
if isinstance(bsym.output, Proxy):
comment_str = f" # {codeutils.prettyprint(out_printables, with_type=True)}"

s = f"{result_str}{cls.__name__}.__tensor_unflatten__({arg_str}{', ' if (len(arg_str) > 0 and len(kwarg_str) > 0) else ''}{kwarg_str}){comment_str}"

if bsym.header:
header_lines = (
bsym.header
if isinstance(bsym.header, Sequence) and not isinstance(bsym.header, str)
else bsym.header.splitlines()
)
header_lines = (f"# {line}" for line in header_lines)
return chain(header_lines, [s])

return s


def bind_postprocess_of_unflatten_tensor_subclass(bsym: BoundSymbol) -> None:
cls = bsym.args[0]
inner_tensors = bsym.args[1]
metadata = bsym.args[2]

filtered_types: tuple[Any, ...] = (cls,)
if metadata:
types = get_nested_types(list(metadata.values()))
filtered_types += filter_types_for_tensor_wrapper_subclass(types)
new_imports = {t.__name__: t for t in filtered_types}
bsym._import_ctx.update(new_imports)


def unflatten_tensor_subclass_meta(
tensor_subclass_type,
inner_tensors: dict[str, TensorProxy],
metadata: dict[str, Any],
) -> SubclassTensorProxy:
first_tensor: TensorProxy = list(inner_tensors.values())[0]
a = SubclassTensorProxy(
shape=first_tensor.shape,
device=first_tensor.device,
dtype=first_tensor.dtype,
requires_grad=first_tensor.requires_grad,
tensors=list(inner_tensors.values()),
non_tensors=list(metadata.values()),
subclass_type=tensor_subclass_type,
)
for name, value in inner_tensors.items():
setattr(a, name, value)
for name, value in metadata.items():
setattr(a, name, value)
return a


def unflatten_tensor_subclass_python_impl(
tensor_subclass_type,
inner_tensors: dict[str, TensorProxy],
metadata: dict[str, Any],
) -> torch.Tensor:
return tensor_subclass_type.__tensor_unflatten__(inner_tensors, metadata, -1, -1)


unflatten_tensor_subclass = make_prim(
PrimIDs.UNFLATTEN_TENSOR_SUBCLASS,
"unflatten_tensor_subclass",
meta=unflatten_tensor_subclass_meta,
python_printer=printer_of_unflatten_tensor_subclass,
_bind_postprocess=bind_postprocess_of_unflatten_tensor_subclass,
)
38 changes: 36 additions & 2 deletions thunder/core/proxies.py
Original file line number Diff line number Diff line change
Expand Up @@ -2111,6 +2111,7 @@ def __setattr__(self, name, value):

# TODO: move this function to jit_ext.py
def tensorproxy(t: torch.Tensor, /, *, name: None | str, history: None | tuple = None) -> TensorProxy:
from torch._subclasses.fake_tensor import FakeTensor
from thunder.core.interpreter import ProvenanceRecord, PseudoInst, wrap_const

if hasattr(t, "_thunder_device"):
Expand Down Expand Up @@ -2145,8 +2146,8 @@ def tensorproxy(t: torch.Tensor, /, *, name: None | str, history: None | tuple =
else:
# NOTE Without tuple(t.shape) then the shape would be a torch.Size object
shape = tuple(t.shape)
return TensorProxy(
name,
ctor_kwargs = dict(
name=name,
shape=tuple(shape),
device=device,
dtype=dtype,
Expand All @@ -2156,6 +2157,39 @@ def tensorproxy(t: torch.Tensor, /, *, name: None | str, history: None | tuple =
history=history,
thunder_fsdp_padding_size=_thunder_fsdp_padding_size,
)
# n.b.(crcrpar): :class:`thunder.dynamo.ThunderCompiler.__call__` takes torch.fx GraphModule
# where `FakeTensor` seems to be used, leading to failures observed in e.g.
# https://github.com/Lightning-AI/lightning-thunder/actions/runs/11689709564/job/32553053319#step:10:5747
# https://dev.azure.com/Lightning-AI/lightning/_build/results?buildId=219328&view=logs&jobId=5b0799f7-725e-5b16-9b83-c0a5a25d03f0&j=5b0799f7-725e-5b16-9b83-c0a5a25d03f0
if (
isinstance(t, torch.Tensor)
and type(t) not in (torch.Tensor, torch.nn.Parameter, FakeTensor)
and hasattr(t, "__tensor_flatten__")
and hasattr(t, "__tensor_unflatten__")
):
baseutils.check(
hasattr(t, "__tensor_flatten__") and hasattr(t, "__tensor_unflatten__"),
lambda: f"{t=} seems to be a tensor subclass but not traceable",
)
tensor_attr_names, metadata = t.__tensor_flatten__()
tensors = [tensorproxy(getattr(t, name), name=None, history=history) for name in tensor_attr_names]
ctor_kwargs.update(
{
"tensors": tensors,
"non_tensors": list(metadata.values()),
"subclass_type": type(t),
}
)
p = SubclassTensorProxy(**ctor_kwargs)
p._tensor_attr_names = tensor_attr_names
p._non_tensor_attr_names = list(metadata.keys())
for name, tensor in zip(tensor_attr_names, tensors):
setattr(p, name, tensor)
for name, value in metadata.items():
setattr(p, name, value)
return p
else:
return TensorProxy(**ctor_kwargs)


def futuretensorproxy(
Expand Down
7 changes: 7 additions & 0 deletions thunder/executors/torch_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stat
from thunder.distributed.transforms import FSDPCommBucketing
from thunder.distributed.utils import sort_data_parallel_syncs, sort_waits, sort_communication_ops
from thunder.executors.passes import del_last_used, transform_for_execution
from thunder.transforms.tensor_wrapper_subclass import unroll_tensor_subclasses

utils.check(compile_data is not None, lambda: "`compile_data` is required")
# NOTE: This function is rather slow, so it's intended to be used
Expand All @@ -158,6 +159,9 @@ def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stat
fw_traces = [fw_trace]
bw_traces = [bw_trace]

fw_trace = unroll_tensor_subclasses(fw_trace)
fw_traces.append(fw_trace)

from thunder.distributed import FSDPType

# only enable rematerialize_params_in_backward when using FSDP ZeRO3
Expand Down Expand Up @@ -262,6 +266,9 @@ def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stat
if getattr(compile_data.fn, "use_fsdp", False):
bw_trace = _fsdp_comm_bucketing.apply_bucketing_to_backward_trace(bw_trace)

bw_trace = unroll_tensor_subclasses(bw_trace)
bw_traces.append(bw_trace)

# Now we can run the optimization passes on the backward trace
# TODO Restore request for no rematerialization
bw_extrace = transform_for_execution(
Expand Down
48 changes: 46 additions & 2 deletions thunder/executors/torchex.py
Original file line number Diff line number Diff line change
Expand Up @@ -2235,13 +2235,13 @@ def _tensor_subclass_ctor(cls, name, shape, device, dtype, requires_grad, tensor


def _bind_postprocess_of_tensor_subclass_ctor(bsym: BoundSymbol) -> None:
from thunder.core.prims import get_nested_types, filter_types
from thunder.core.prims import get_nested_types, filter_types_for_tensor_wrapper_subclass

cls, _name, _shape, _device, _dtype, _requires_grad, _tensors, non_tensors = bsym.args
filtered_types = (cls,)
if non_tensors:
types = get_nested_types(non_tensors)
filtered_types += filter_types(types)
filtered_types += filter_types_for_tensor_wrapper_subclass(types)
new_imports = {t.__name__: t for t in filtered_types}
bsym._import_ctx.update(new_imports)

Expand All @@ -2254,3 +2254,47 @@ def _bind_postprocess_of_tensor_subclass_ctor(bsym: BoundSymbol) -> None:
python_printer=prims.printer_of_tensor_subclass_ctor,
)
_register_implementation(prims.tensor_subclass_ctor, tensor_subclass_ctor, checker=_always_executable)


def flatten_tensor_subclass_impl(t):
tensor_attr_names, metadata = t.__tensor_flatten__()
tensors = tuple(getattr(t, name) for name in tensor_attr_names)
return tensors


flatten_tensor_subclass = ex.register_operator(
"flatten_tensor_subclass",
meta=prims.flatten_tensor_subclass.meta,
fn=flatten_tensor_subclass_impl,
)
_register_implementation(
prims.flatten_tensor_subclass,
flatten_tensor_subclass,
checker=_always_executable,
)


def unflatten_tensor_subclass_impl(
tensor_subclass_type: torch._C._TensorMeta,
inner_tensors: dict[str, TensorLike],
metadata: dict,
):
for key in metadata:
v = metadata[key]
if isinstance(v, dtypes.dtype):
metadata[key] = to_torch_dtype(v)
elif isinstance(v, devices.Device):
metadata[key] = to_torch_device(v)
return tensor_subclass_type.__tensor_unflatten__(inner_tensors, metadata, -1, -1)


unflatten_tensor_subclass = ex.register_operator(
"unflatten_tensor_subclass",
meta=prims.unflatten_tensor_subclass.meta,
fn=unflatten_tensor_subclass_impl,
)
_register_implementation(
prims.unflatten_tensor_subclass,
unflatten_tensor_subclass,
checker=_always_executable,
)
Loading
Loading