Skip to content

Commit

Permalink
Unrolling tensor subclasses in fwd/bwd split (#1489)
Browse files Browse the repository at this point in the history
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
crcrpar and pre-commit-ci[bot] committed Dec 21, 2024
1 parent 5c197cf commit 4b9e67f
Show file tree
Hide file tree
Showing 7 changed files with 70 additions and 24 deletions.
4 changes: 4 additions & 0 deletions thunder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,6 +618,7 @@ def get_computation_and_inputs(*args, **kwargs):
computation_trc = dce(computation_trc)
computation_traces.append(computation_trc)

_tensor_subclass_transform_applied = False
backward_trc = None
if not cd.disable_torch_autograd_support:
tensor_cls = (pytorch.Tensor, TensorProxy)
Expand All @@ -631,6 +632,9 @@ def get_computation_and_inputs(*args, **kwargs):
computation_trc, backward_trc = split_forward_backward(computation_trc, cd, cs, *inps)
# Note computation_trc and backward_trc have been appended to cs.last_(backward_)traces
# by split_forward_backward
_tensor_subclass_transform_applied = True
if not _tensor_subclass_transform_applied:
computation_trc, _ = flatten_tensor_subclasses(computation_trc)

if backward_trc is None:
from thunder.executors.passes import transform_for_execution as transform_for_execution_pass
Expand Down
2 changes: 2 additions & 0 deletions thunder/core/pytree.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from functools import partial
from types import FunctionType
import dataclasses
from enum import Enum

import optree
import torch
Expand Down Expand Up @@ -64,6 +65,7 @@ def tree_flatten(args, namespace=OPTREE_NAMESPACE):
and not is_likely_from_collections_namedtuple(args)
and not dataclasses.is_dataclass(args)
and not type(args).__module__.startswith("torch.return_types")
and not issubclass(type(args), Enum)
):
raise TypeError(f"tree_flatten of type {type(args)} is not supported.")
return optree.tree_flatten(args, none_is_leaf=True, namespace=namespace)
Expand Down
4 changes: 4 additions & 0 deletions thunder/executors/torch_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stat
from thunder.distributed.transforms import FSDPCommBucketing
from thunder.distributed.utils import sort_data_parallel_syncs, sort_waits, sort_communication_ops
from thunder.executors.passes import del_last_used, transform_for_execution
from thunder.transforms.tensor_subclasses import flatten_tensor_subclasses, DesugarTensorSubclass

utils.check(compile_data is not None, lambda: "`compile_data` is required")
# NOTE: This function is rather slow, so it's intended to be used
Expand All @@ -154,6 +155,7 @@ def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stat
# not any other container type. So we need to flatten the outputs of
# the forward trace and inputs of the backward trace.
fw_trace, bw_trace = forward_and_backward_from_trace(primal_trace, torch_autograd=True)
fw_trace, fw_tensor_subclass_desugar = flatten_tensor_subclasses(fw_trace)

fw_traces = [fw_trace]
bw_traces = [bw_trace]
Expand Down Expand Up @@ -262,6 +264,8 @@ def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stat
if getattr(compile_data.fn, "use_fsdp", False):
bw_trace = _fsdp_comm_bucketing.apply_bucketing_to_backward_trace(bw_trace)

bw_trace, bw_tensor_subclass_desugar = flatten_tensor_subclasses(bw_trace)

# Now we can run the optimization passes on the backward trace
# TODO Restore request for no rematerialization
bw_extrace = transform_for_execution(
Expand Down
3 changes: 0 additions & 3 deletions thunder/executors/torchex.py
Original file line number Diff line number Diff line change
Expand Up @@ -1421,9 +1421,6 @@ def _scaled_mm_transform(
if b.stride()[0] != 1 and b.stride()[1] > 1:
b = b.t().contiguous().t()

print(
f"{type(a)=}, {type(b)=}, {type(scale_a)=}, {type(scale_b)=}, {type(bias)=}, {type(scale_result)=}, {type(result_dtype)=}, {type(use_fast_accum)=}"
)
return _scaled_mm(a, b, scale_a, scale_b, bias, scale_result, result_dtype, use_fast_accum)


Expand Down
20 changes: 14 additions & 6 deletions thunder/tests/test_tensor_subclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,19 @@
from torch.utils import _pytree as pytree

import thunder
from thunder.core.proxies import SubclassTensorProxy
from thunder.tests.framework import instantiate
from thunder.tests.framework import (
instantiate,
TorchExecutor,
TorchCompileCatExecutor,
nvFuserExecutor,
DynamoThunderExecutor,
)
from thunder.tests.make_tensor import make_tensor

TORCHAO_AVAILABLE = package_available("torchao")

if TYPE_CHECKING:
from typing import Any
from thunder.core.symbol import BoundSymbol


@torch._dynamo.allow_in_graph
Expand Down Expand Up @@ -243,14 +247,12 @@ def g(x: ScaleTensorSubclass, data: torch.Tensor, scale: torch.Tensor) -> torch.
@instantiate(
dtypes=(thunder.core.dtypes.float32,),
devicetypes=(thunder.core.devices.DeviceType.CUDA,),
executors=(TorchExecutor, TorchCompileCatExecutor, nvFuserExecutor, DynamoThunderExecutor),
decorators=(
pytest.mark.skipif(
not (TORCHAO_AVAILABLE and torch.cuda.get_device_capability() >= (8, 9)),
reason="Requires capability >= 8.9 and torchao",
),
# forward-backward split is failing.
# TypeError: tree_flatten of type <enum 'GemmInputRole'> is not supported.
pytest.mark.xfail(),
),
)
def test_torchao_float8_linear(executor, device, _):
Expand All @@ -269,3 +271,9 @@ def test_torchao_float8_linear(executor, device, _):

jitted = executor.make_callable(fp8_model)
actual = jitted(x)

if executor == DynamoThunderExecutor:
with pytest.raises(AssertionError):
torch.testing.assert_close(actual, expected)
else:
torch.testing.assert_close(actual, expected)
15 changes: 14 additions & 1 deletion thunder/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1260,7 +1260,9 @@ def t(a: TensorLike, /) -> TensorLike:
lambda: f"t() expects a tensor with <= 2 dimensions, but self is {a.ndim}D",
RuntimeError,
)
return transpose(a, 0, 1) if a.ndim == 2 else a
if a.ndim == 2:
return transpose(a, 0, 1)
return a


@run_once
Expand Down Expand Up @@ -1313,6 +1315,17 @@ def transpose(a: TensorLike, /, dim0: int, dim1: int) -> TensorLike:
return clang.transpose(a, permutation)


def _transpose_grad(a: TensorLike, /, dim0: int, dim1: int) -> TensorLike:
fwd = transpose(a, dim0, dim1)
g = get_grad(fwd)
a_grad = transpose(g, dim0, dim1)
put_grad(a, a_grad)
return fwd


register_grad(transpose, _transpose_grad)


@torchsymbol(torch.unbind, is_method=True)
def unbind(a: TensorLike, /, dim: int = 0) -> tuple[TensorLike, ...]:
utils.check(
Expand Down
46 changes: 32 additions & 14 deletions thunder/transforms/tensor_subclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,10 @@
from torch.fx import GraphModule
from torch._ops import OpOverload
from thunder.core.symbol import Symbol, BoundSymbol
from torch._C import _TensorMeta


__all__ = [
"DesugarTensorSubclass",
"flatten_tensor_subclasses",
]

Expand Down Expand Up @@ -249,17 +249,18 @@ def translate_fx_graph_into_bsym(
import thunder.torch as ltorch

unwrapped_bsym_args: dict[int, ProxyInterface] = {}
list_of_unflatten_bsym: list[BoundSymbol] = []
list_of_flattening_bsyms: list[BoundSymbol] = []
for a in bsym.flat_args:
if isinstance(a, SubclassTensorProxy):
if variableify(a) in self.subclass_proxy_to_flatten:
self.computation_trace.push_scope([])
with tracectx(self.computation_trace):
prims.flatten_tensor_subclass(a)
unflatten_bsym = self.computation_trace.pop_scope()[0]
list_of_unflatten_bsym.append(unflatten_bsym)
flattening_bsym = self.computation_trace.pop_scope()[0]
list_of_flattening_bsyms.append(flattening_bsym)
tensor_attr_names = self._get_tensor_attr_names(a)
tensors = a._tensors

non_tensor_attr_names = self._get_non_tensor_attr_names(a)
non_tensors = a._non_tensors
metadata = dict(zip(non_tensor_attr_names, non_tensors))
Expand Down Expand Up @@ -307,8 +308,8 @@ def translate_fx_graph_into_bsym(
ltorch_ops_for_node_of_ops.append(getattr(ltorch, node.target._opname))

bsyms: list[BoundSymbol] = []
if list_of_unflatten_bsym:
bsyms.extend(list_of_unflatten_bsym)
if list_of_flattening_bsyms:
bsyms.extend(list_of_flattening_bsyms)
fxnode_output_name_to_tensor_proxy: dict[str, OpOverload] = {}
for node, ltorch_op in zip(list_of_function_call_node, ltorch_ops_for_node_of_ops):
args: list[Node] = node.args
Expand Down Expand Up @@ -379,10 +380,22 @@ def translate_fx_graph_into_bsym(
f"{len(new_tensor_proxies)=} != {len(orig_output._tensors)=}"
),
)
if [variableify(t) for t in orig_output._tensors] != [variableify(t) for t in new_tensor_proxies]:
orig_output._tensors = new_tensor_proxies
for name, tensor in zip(orig_output._tensor_attr_names, new_tensor_proxies):
setattr(orig_output, name, tensor)
with tracectx(self.computation_trace):
new_subclass = orig_output.replace()
new_subclass._tensors = new_tensor_proxies
for name, value in zip(new_subclass._tensor_attr_names, new_tensor_proxies):
setattr(new_subclass, name, value)
bsyms.append(
prims.unflatten_tensor_subclass.bind(
new_subclass._subclass_type,
dict(zip(new_subclass._tensor_attr_names, new_tensor_proxies)),
dict(zip(new_subclass._non_tensor_attr_names, new_subclass._non_tensors)),
output=new_subclass,
)
)

self.swap_map[variableify(orig_output)] = new_subclass
self.subclass_proxy_to_flatten.add(variableify(new_subclass))

else:
non_none_args = [n for n in node_of_output.args[0] if n is not None]
Expand Down Expand Up @@ -502,7 +515,12 @@ def f_with_wrap_and_unwrap(*desugared_args) -> tuple[OutputWrapperForFxTracing,

def __call__(self, bsym: BoundSymbol) -> list[BoundSymbol]:
updated_bsym: BoundSymbol = bsym.from_bsym_swap_proxies(self.swap_map)
if updated_bsym.sym.id == prims.PrimIDs.RETURN:
if bsym.sym.id == prims.PrimIDs.RETURN:
new_swap_map = {}
for k, v in self.swap_map.items():
if isinstance(v, SubclassTensorProxy):
continue
new_swap_map[k] = v
if not self.subclass_proxy_to_flatten or True:
return [updated_bsym]

Expand Down Expand Up @@ -567,7 +585,7 @@ def __call__(self, bsym: BoundSymbol) -> list[BoundSymbol]:
return self.translate_fx_graph_into_bsym(bsym_with_modified_output, fx)


def flatten_tensor_subclasses(computation_trace: TraceCtx) -> TraceCtx:
def flatten_tensor_subclasses(computation_trace: TraceCtx) -> tuple[TraceCtx, DesugarTensorSubclass]:
"""Flatten tensor subclasses in ``computation_trace``.
Two things are happening inside of this function:
Expand Down Expand Up @@ -601,9 +619,9 @@ def flatten_tensor_subclasses(computation_trace: TraceCtx) -> TraceCtx:
updated_bsyms.extend(maybe_desugared_bsyms)

if not desugar_tensor_subclass.subclass_proxy_to_flatten:
return computation_trace
return computation_trace, None

computation_trace_with_subclass_tensor_proxy_output = from_trace(computation_trace)
computation_trace_with_subclass_tensor_proxy_output.bound_symbols.extend(updated_bsyms)
computation_trace_with_subclass_tensor_proxy_output.set_provenance(TraceProvenance("tensor subclasses desugared"))
return computation_trace_with_subclass_tensor_proxy_output
return computation_trace_with_subclass_tensor_proxy_output, desugar_tensor_subclass

0 comments on commit 4b9e67f

Please sign in to comment.