Skip to content

Commit

Permalink
trace transform of tensor wrapper subclass
Browse files Browse the repository at this point in the history
to support `__torch_dispatch__`.
Since it extends the behavior that is implemented in C++ level,
we'd need to apply the transform to split forward and backward traces
separately.

Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
  • Loading branch information
crcrpar committed Dec 30, 2024
1 parent 0998d81 commit db3a6c9
Show file tree
Hide file tree
Showing 9 changed files with 1,056 additions and 8 deletions.
1 change: 1 addition & 0 deletions docs/source/reference/transforms/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ thunder.transforms

MaterializationTransform
ConstantFolding
unroll_tensor_subclasses
9 changes: 8 additions & 1 deletion thunder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
from thunder.core.interpreter import print_interpreter_log, print_to_log
from thunder.core.jit_ext import thunder_general_jit
from thunder.executors.torch_autograd import split_forward_backward, ThunderFunction
from thunder.transforms.tensor_wrapper_subclass import unroll_tensor_subclasses

# NOTE This import is intentionally pytorch so that it thunder.torch doesn't import this
import torch as pytorch
Expand Down Expand Up @@ -369,7 +370,7 @@ def _alias_tensor_of_args_kwargs_dict(*args, **kwargs) -> dict[int, list[int]]:
data_ptr_to_tensor_group_index = {}
tensor_group_index_to_tensor_indices = defaultdict(list)
for idx, t in enumerate(flat_args):
if pytorch.is_tensor(t) and t.layout == pytorch.strided:
if type(t) in {pytorch.Tensor, pytorch.nn.Parameter} and t.layout == pytorch.strided:
data_ptr = t.untyped_storage().data_ptr()
if data_ptr not in data_ptr_to_tensor_group_index:
data_ptr_to_tensor_group_index[data_ptr] = len(data_ptr_to_tensor_group_index)
Expand Down Expand Up @@ -616,6 +617,7 @@ def get_computation_and_inputs(*args, **kwargs):
computation_trc = dce(computation_trc)
computation_traces.append(computation_trc)

_unroll_tensor_subclasses_applied = False
backward_trc = None
if not cd.disable_torch_autograd_support:
tensor_cls = (pytorch.Tensor, TensorProxy)
Expand All @@ -626,10 +628,15 @@ def get_computation_and_inputs(*args, **kwargs):
# transform_for_execution and various sorting of symbols,
# applying transform_for_execution after this would be
# breaking the order of operations
_unroll_tensor_subclasses_applied = True
computation_trc, backward_trc = split_forward_backward(computation_trc, cd, cs, *inps)
# Note computation_trc and backward_trc have been appended to cs.last_(backward_)traces
# by split_forward_backward

if not _unroll_tensor_subclasses_applied:
computation_trc = unroll_tensor_subclasses(computation_trc)
computation_traces.append(computation_trc)

if backward_trc is None:
from thunder.executors.passes import transform_for_execution as transform_for_execution_pass
from thunder.executors.passes import _transform_for_operator_executor_execution
Expand Down
168 changes: 165 additions & 3 deletions thunder/core/prims.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,8 @@ class PrimIDs(Enum):
SINK = auto()
# Tensor Subclasses methods
TENSOR_SUBCLASS_CTOR = auto()
FLATTEN_TENSOR_SUBCLASS = auto()
UNFLATTEN_TENSOR_SUBCLASS = auto()


class OpTags(Enum):
Expand Down Expand Up @@ -4098,7 +4100,7 @@ def check_types(coll):
return tuple(types_set)


def filter_types(types: tuple[Any, ...]) -> tuple[Any, ...]:
def filter_types_for_tensor_wrapper_subclass(types: tuple[Any, ...]) -> tuple[Any, ...]:
return tuple(
filter(
lambda t: (
Expand Down Expand Up @@ -4170,7 +4172,7 @@ def printer_of_tensor_subclass_ctor(
filtered_types = (cls,)
if non_tensors:
types = get_nested_types([t.obj if isinstance(t, codeutils.ContextObject) else t for t in non_tensors])
filtered_types += filter_types(types)
filtered_types += filter_types_for_tensor_wrapper_subclass(types)
new_imports = {t.__name__: t for t in filtered_types}
bsym._import_ctx.update(new_imports)
return s
Expand All @@ -4183,7 +4185,7 @@ def bind_postprocess_of_tensor_subclass_ctor(bsym: BoundSymbol) -> None:
filtered_types: tuple[Any, ...] = (cls,)
if non_tensors:
types = get_nested_types(non_tensors)
filtered_types += filter_types(types)
filtered_types += filter_types_for_tensor_wrapper_subclass(types)
new_imports = {t.__name__: t for t in filtered_types}
bsym._import_ctx.update(new_imports)

Expand All @@ -4195,3 +4197,163 @@ def bind_postprocess_of_tensor_subclass_ctor(bsym: BoundSymbol) -> None:
python_printer=printer_of_tensor_subclass_ctor,
_bind_postprocess=bind_postprocess_of_tensor_subclass_ctor,
)


def printer_of_tensor_subclass_flatten(
bsym: BoundSymbol,
out_printables: Any,
arg_printables: Sequence[Printable],
kwarg_printables: dict[str, Printable],
) -> str | Iterable[str]:
from itertools import chain

arg_str = (
""
if (arg_printables is None or len(arg_printables) == 0)
else ", ".join(codeutils.prettyprint(x) for x in arg_printables)
)

result_str: str
if bsym.output is None or (baseutils.is_collection(bsym.output) and len(bsym.output) == 0):
result_str = ""
else:
result_str = f"{codeutils.prettyprint(out_printables, literals_as_underscores=True)} = "

# Creates a comment describing the output
comment_str = ""
if isinstance(bsym.output, Proxy):
comment_str = f" # {codeutils.prettyprint(out_printables, with_type=True)}"

s = f"{result_str}{arg_str}.__tensor_flatten__(){comment_str}"

if bsym.header:
header_lines = (
bsym.header
if isinstance(bsym.header, Sequence) and not isinstance(bsym.header, str)
else bsym.header.splitlines()
)
header_lines = (f"# {line}" for line in header_lines)
return chain(header_lines, [s])

return s


# NOTE(crcrpar): The behavior is different from PyTorch `subclass_tensor.__tensor_flatten__()`
# that returns a list of tensor attr names and a dict of const metadata. In Thunder traces,
# const values could be obviated and actual tensor proxies would be more useful
# than tensor attr names.
def flatten_tensor_subclass_meta(t: SubclassTensorProxy) -> tuple[TensorProxy, ...]:
tensor_attr_names, metadata = t.__tensor_flatten__()
tensors = tuple(getattr(t, name) for name in tensor_attr_names)
return tensors


flatten_tensor_subclass = make_prim(
PrimIDs.FLATTEN_TENSOR_SUBCLASS,
"flatten_tensor_subclass",
meta=flatten_tensor_subclass_meta,
python_printer=printer_of_tensor_subclass_flatten,
)


def printer_of_unflatten_tensor_subclass(
bsym: BoundSymbol,
out_printables: Any,
arg_printables: Sequence[Printable],
kwarg_printables: dict[str, Printable],
) -> str | Iterable[str]:
from itertools import chain

wrapped_cls: ContextObject | torch._C._TensorMeta = arg_printables[0]
if isinstance(wrapped_cls, torch._C._TensorMeta):
cls = wrapped_cls
else:
cls: torch._C._TensorMeta = wrapped_cls.obj

arg_str = (
""
if (arg_printables is None or len(arg_printables) == 0)
else ", ".join(codeutils.prettyprint(x) for x in arg_printables[1:])
)
kwarg_str: str

if len(kwarg_printables) == 0:
kwarg_str = ""
else:
kwarg_str = ", ".join(f"{k}={codeutils.prettyprint(v)}" for k, v in kwarg_printables.items())

result_str: str
if bsym.output is None or (baseutils.is_collection(bsym.output) and len(bsym.output) == 0):
result_str = ""
else:
result_str = f"{codeutils.prettyprint(out_printables, literals_as_underscores=True)} = "

# Creates a comment describing the output
comment_str = ""
if isinstance(bsym.output, Proxy):
comment_str = f" # {codeutils.prettyprint(out_printables, with_type=True)}"

s = f"{result_str}{cls.__name__}.__tensor_unflatten__({arg_str}{', ' if (len(arg_str) > 0 and len(kwarg_str) > 0) else ''}{kwarg_str}){comment_str}"

if bsym.header:
header_lines = (
bsym.header
if isinstance(bsym.header, Sequence) and not isinstance(bsym.header, str)
else bsym.header.splitlines()
)
header_lines = (f"# {line}" for line in header_lines)
return chain(header_lines, [s])

return s


def bind_postprocess_of_unflatten_tensor_subclass(bsym: BoundSymbol) -> None:
cls = bsym.args[0]
inner_tensors = bsym.args[1]
metadata = bsym.args[2]

filtered_types: tuple[Any, ...] = (cls,)
if metadata:
types = get_nested_types(list(metadata.values()))
filtered_types += filter_types_for_tensor_wrapper_subclass(types)
new_imports = {t.__name__: t for t in filtered_types}
bsym._import_ctx.update(new_imports)


def unflatten_tensor_subclass_meta(
tensor_subclass_type,
inner_tensors: dict[str, TensorProxy],
metadata: dict[str, Any],
) -> SubclassTensorProxy:
first_tensor: TensorProxy = list(inner_tensors.values())[0]
a = SubclassTensorProxy(
shape=first_tensor.shape,
device=first_tensor.device,
dtype=first_tensor.dtype,
requires_grad=first_tensor.requires_grad,
tensors=list(inner_tensors.values()),
non_tensors=list(metadata.values()),
subclass_type=tensor_subclass_type,
)
for name, value in inner_tensors.items():
setattr(a, name, value)
for name, value in metadata.items():
setattr(a, name, value)
return a


def unflatten_tensor_subclass_python_impl(
tensor_subclass_type,
inner_tensors: dict[str, TensorProxy],
metadata: dict[str, Any],
) -> torch.Tensor:
return tensor_subclass_type.__tensor_unflatten__(inner_tensors, metadata, -1, -1)


unflatten_tensor_subclass = make_prim(
PrimIDs.UNFLATTEN_TENSOR_SUBCLASS,
"unflatten_tensor_subclass",
meta=unflatten_tensor_subclass_meta,
python_printer=printer_of_unflatten_tensor_subclass,
_bind_postprocess=bind_postprocess_of_unflatten_tensor_subclass,
)
38 changes: 36 additions & 2 deletions thunder/core/proxies.py
Original file line number Diff line number Diff line change
Expand Up @@ -2111,6 +2111,7 @@ def __setattr__(self, name, value):

# TODO: move this function to jit_ext.py
def tensorproxy(t: torch.Tensor, /, *, name: None | str, history: None | tuple = None) -> TensorProxy:
from torch._subclasses.fake_tensor import FakeTensor
from thunder.core.interpreter import ProvenanceRecord, PseudoInst, wrap_const

if hasattr(t, "_thunder_device"):
Expand Down Expand Up @@ -2145,8 +2146,8 @@ def tensorproxy(t: torch.Tensor, /, *, name: None | str, history: None | tuple =
else:
# NOTE Without tuple(t.shape) then the shape would be a torch.Size object
shape = tuple(t.shape)
return TensorProxy(
name,
ctor_kwargs = dict(
name=name,
shape=tuple(shape),
device=device,
dtype=dtype,
Expand All @@ -2156,6 +2157,39 @@ def tensorproxy(t: torch.Tensor, /, *, name: None | str, history: None | tuple =
history=history,
thunder_fsdp_padding_size=_thunder_fsdp_padding_size,
)
# n.b.(crcrpar): :class:`thunder.dynamo.ThunderCompiler.__call__` takes torch.fx GraphModule
# where `FakeTensor` seems to be used, leading to failures observed in e.g.
# https://github.com/Lightning-AI/lightning-thunder/actions/runs/11689709564/job/32553053319#step:10:5747
# https://dev.azure.com/Lightning-AI/lightning/_build/results?buildId=219328&view=logs&jobId=5b0799f7-725e-5b16-9b83-c0a5a25d03f0&j=5b0799f7-725e-5b16-9b83-c0a5a25d03f0
if (
isinstance(t, torch.Tensor)
and type(t) not in (torch.Tensor, torch.nn.Parameter, FakeTensor)
and hasattr(t, "__tensor_flatten__")
and hasattr(t, "__tensor_unflatten__")
):
baseutils.check(
hasattr(t, "__tensor_flatten__") and hasattr(t, "__tensor_unflatten__"),
lambda: f"{t=} seems to be a tensor subclass but not traceable",
)
tensor_attr_names, metadata = t.__tensor_flatten__()
tensors = [tensorproxy(getattr(t, name), name=None, history=history) for name in tensor_attr_names]
ctor_kwargs.update(
{
"tensors": tensors,
"non_tensors": list(metadata.values()),
"subclass_type": type(t),
}
)
p = SubclassTensorProxy(**ctor_kwargs)
p._tensor_attr_names = tensor_attr_names
p._non_tensor_attr_names = list(metadata.keys())
for name, tensor in zip(tensor_attr_names, tensors):
setattr(p, name, tensor)
for name, value in metadata.items():
setattr(p, name, value)
return p
else:
return TensorProxy(**ctor_kwargs)


def futuretensorproxy(
Expand Down
7 changes: 7 additions & 0 deletions thunder/executors/torch_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stat
from thunder.distributed.transforms import FSDPCommBucketing
from thunder.distributed.utils import sort_data_parallel_syncs, sort_waits, sort_communication_ops
from thunder.executors.passes import del_last_used, transform_for_execution
from thunder.transforms.tensor_wrapper_subclass import unroll_tensor_subclasses

utils.check(compile_data is not None, lambda: "`compile_data` is required")
# NOTE: This function is rather slow, so it's intended to be used
Expand All @@ -158,6 +159,9 @@ def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stat
fw_traces = [fw_trace]
bw_traces = [bw_trace]

fw_trace = unroll_tensor_subclasses(fw_trace)
fw_traces.append(fw_trace)

from thunder.distributed import FSDPType

# only enable rematerialize_params_in_backward when using FSDP ZeRO3
Expand Down Expand Up @@ -262,6 +266,9 @@ def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stat
if getattr(compile_data.fn, "use_fsdp", False):
bw_trace = _fsdp_comm_bucketing.apply_bucketing_to_backward_trace(bw_trace)

bw_trace = unroll_tensor_subclasses(bw_trace)
bw_traces.append(bw_trace)

# Now we can run the optimization passes on the backward trace
# TODO Restore request for no rematerialization
bw_extrace = transform_for_execution(
Expand Down
48 changes: 46 additions & 2 deletions thunder/executors/torchex.py
Original file line number Diff line number Diff line change
Expand Up @@ -2235,13 +2235,13 @@ def _tensor_subclass_ctor(cls, name, shape, device, dtype, requires_grad, tensor


def _bind_postprocess_of_tensor_subclass_ctor(bsym: BoundSymbol) -> None:
from thunder.core.prims import get_nested_types, filter_types
from thunder.core.prims import get_nested_types, filter_types_for_tensor_wrapper_subclass

cls, _name, _shape, _device, _dtype, _requires_grad, _tensors, non_tensors = bsym.args
filtered_types = (cls,)
if non_tensors:
types = get_nested_types(non_tensors)
filtered_types += filter_types(types)
filtered_types += filter_types_for_tensor_wrapper_subclass(types)
new_imports = {t.__name__: t for t in filtered_types}
bsym._import_ctx.update(new_imports)

Expand All @@ -2254,3 +2254,47 @@ def _bind_postprocess_of_tensor_subclass_ctor(bsym: BoundSymbol) -> None:
python_printer=prims.printer_of_tensor_subclass_ctor,
)
_register_implementation(prims.tensor_subclass_ctor, tensor_subclass_ctor, checker=_always_executable)


def flatten_tensor_subclass_impl(t):
tensor_attr_names, metadata = t.__tensor_flatten__()
tensors = tuple(getattr(t, name) for name in tensor_attr_names)
return tensors


flatten_tensor_subclass = ex.register_operator(
"flatten_tensor_subclass",
meta=prims.flatten_tensor_subclass.meta,
fn=flatten_tensor_subclass_impl,
)
_register_implementation(
prims.flatten_tensor_subclass,
flatten_tensor_subclass,
checker=_always_executable,
)


def unflatten_tensor_subclass_impl(
tensor_subclass_type: torch._C._TensorMeta,
inner_tensors: dict[str, TensorLike],
metadata: dict,
):
for key in metadata:
v = metadata[key]
if isinstance(v, dtypes.dtype):
metadata[key] = to_torch_dtype(v)
elif isinstance(v, devices.Device):
metadata[key] = to_torch_device(v)
return tensor_subclass_type.__tensor_unflatten__(inner_tensors, metadata, -1, -1)


unflatten_tensor_subclass = ex.register_operator(
"unflatten_tensor_subclass",
meta=prims.unflatten_tensor_subclass.meta,
fn=unflatten_tensor_subclass_impl,
)
_register_implementation(
prims.unflatten_tensor_subclass,
unflatten_tensor_subclass,
checker=_always_executable,
)
Loading

0 comments on commit db3a6c9

Please sign in to comment.