From 14ccf6bb9321d1cf79d0a92b81dcb70b0cc518da Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Thu, 28 Nov 2024 21:12:39 +0900 Subject: [PATCH] Unrolling tensor subclasses in fwd/bwd split (#1489) Signed-off-by: Masaki Kozuki Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- thunder/__init__.py | 6 ++-- thunder/core/jit_ext.py | 2 +- thunder/core/pytree.py | 2 ++ thunder/executors/torch_autograd.py | 4 +++ thunder/executors/torchex.py | 3 -- thunder/tests/test_tensor_subclass.py | 20 +++++++---- thunder/torch/__init__.py | 15 +++++++- thunder/transforms/tensor_subclasses.py | 46 +++++++++++++++++-------- 8 files changed, 71 insertions(+), 27 deletions(-) diff --git a/thunder/__init__.py b/thunder/__init__.py index 661ca2232b..acf59c8f70 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -587,8 +587,6 @@ def get_computation_and_inputs(*args, **kwargs): if len(tensor_args_consumed_by_inplace_grouped_by_numel) > 1: vanilla_tensor_args = set(tensor_indices) - computation_trc = flatten_tensor_subclasses(computation_trc) - if epilogue_trc is not None: epilogue_traces = [epilogue_trc] else: @@ -649,6 +647,7 @@ def get_computation_and_inputs(*args, **kwargs): computation_trc = dce(computation_trc) computation_traces.append(computation_trc) + _tensor_subclass_transform_applied = False backward_trc = None if not cd.disable_torch_autograd_support: tensor_cls = (pytorch.Tensor, TensorProxy) @@ -662,6 +661,9 @@ def get_computation_and_inputs(*args, **kwargs): computation_trc, backward_trc = split_forward_backward(computation_trc, cd, cs, *inps) # Note computation_trc and backward_trc have been appended to cs.last_(backward_)traces # by split_forward_backward + _tensor_subclass_transform_applied = True + if not _tensor_subclass_transform_applied: + computation_trc, _ = flatten_tensor_subclasses(computation_trc) if backward_trc is None: from thunder.executors.passes import transform_for_execution as transform_for_execution_pass diff --git a/thunder/core/jit_ext.py b/thunder/core/jit_ext.py index 7cb060c2b3..5441053f71 100644 --- a/thunder/core/jit_ext.py +++ b/thunder/core/jit_ext.py @@ -637,7 +637,7 @@ def _convert_pytorchfunc_to_thundertrace( trace = TraceCtx() trace.bound_symbols.extend(active_jit_ctx.computation_trace.pop_scope()) func_result = unwrap(wrapped_func_result) - if shallow_copy_output: + if shallow_copy_output and not trace.bound_symbols: from thunder.core.baseutils import sequencify out_to_shallow_copy: dict[Variable, TensorProxy] = {} diff --git a/thunder/core/pytree.py b/thunder/core/pytree.py index 8c92a38555..262e547750 100644 --- a/thunder/core/pytree.py +++ b/thunder/core/pytree.py @@ -1,6 +1,7 @@ from functools import partial from types import FunctionType import dataclasses +from enum import Enum import optree import torch @@ -64,6 +65,7 @@ def tree_flatten(args, namespace=OPTREE_NAMESPACE): and not is_likely_from_collections_namedtuple(args) and not dataclasses.is_dataclass(args) and not type(args).__module__.startswith("torch.return_types") + and not issubclass(type(args), Enum) ): raise TypeError(f"tree_flatten of type {type(args)} is not supported.") return optree.tree_flatten(args, none_is_leaf=True, namespace=namespace) diff --git a/thunder/executors/torch_autograd.py b/thunder/executors/torch_autograd.py index 5374b23afe..3ed3b20ec5 100644 --- a/thunder/executors/torch_autograd.py +++ b/thunder/executors/torch_autograd.py @@ -132,6 +132,7 @@ def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stat from thunder.distributed.transforms import FSDPCommBucketing from thunder.distributed.utils import sort_data_parallel_syncs, sort_waits, sort_communication_ops from thunder.executors.passes import del_last_used, transform_for_execution + from thunder.transforms.tensor_subclasses import flatten_tensor_subclasses, DesugarTensorSubclass utils.check(compile_data is not None, lambda: "`compile_data` is required") # NOTE: This function is rather slow, so it's intended to be used @@ -154,6 +155,7 @@ def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stat # not any other container type. So we need to flatten the outputs of # the forward trace and inputs of the backward trace. fw_trace, bw_trace = forward_and_backward_from_trace(primal_trace, torch_autograd=True) + fw_trace, fw_tensor_subclass_desugar = flatten_tensor_subclasses(fw_trace) fw_traces = [fw_trace] bw_traces = [bw_trace] @@ -245,6 +247,8 @@ def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stat if getattr(compile_data.fn, "use_fsdp", False): bw_trace = _fsdp_comm_bucketing.apply_bucketing_to_backward_trace(bw_trace) + bw_trace, bw_tensor_subclass_desugar = flatten_tensor_subclasses(bw_trace) + # Now we can run the optimization passes on the backward trace # TODO Restore request for no rematerialization bw_extrace = transform_for_execution( diff --git a/thunder/executors/torchex.py b/thunder/executors/torchex.py index cfcc86b96d..eff0c93cbd 100644 --- a/thunder/executors/torchex.py +++ b/thunder/executors/torchex.py @@ -1410,9 +1410,6 @@ def _scaled_mm_transform( if b.stride()[0] != 1 and b.stride()[1] > 1: b = b.t().contiguous().t() - print( - f"{type(a)=}, {type(b)=}, {type(scale_a)=}, {type(scale_b)=}, {type(bias)=}, {type(scale_result)=}, {type(result_dtype)=}, {type(use_fast_accum)=}" - ) return _scaled_mm(a, b, scale_a, scale_b, bias, scale_result, result_dtype, use_fast_accum) diff --git a/thunder/tests/test_tensor_subclass.py b/thunder/tests/test_tensor_subclass.py index f98fe3bc7f..03243bda31 100644 --- a/thunder/tests/test_tensor_subclass.py +++ b/thunder/tests/test_tensor_subclass.py @@ -8,15 +8,19 @@ from torch.utils import _pytree as pytree import thunder -from thunder.core.proxies import SubclassTensorProxy -from thunder.tests.framework import instantiate +from thunder.tests.framework import ( + instantiate, + TorchExecutor, + TorchCompileCatExecutor, + nvFuserExecutor, + DynamoThunderExecutor, +) from thunder.tests.make_tensor import make_tensor TORCHAO_AVAILABLE = package_available("torchao") if TYPE_CHECKING: from typing import Any - from thunder.core.symbol import BoundSymbol @torch._dynamo.allow_in_graph @@ -243,14 +247,12 @@ def g(x: ScaleTensorSubclass, data: torch.Tensor, scale: torch.Tensor) -> torch. @instantiate( dtypes=(thunder.core.dtypes.float32,), devicetypes=(thunder.core.devices.DeviceType.CUDA,), + executors=(TorchExecutor, TorchCompileCatExecutor, nvFuserExecutor, DynamoThunderExecutor), decorators=( pytest.mark.skipif( not (TORCHAO_AVAILABLE and torch.cuda.get_device_capability() >= (8, 9)), reason="Requires capability >= 8.9 and torchao", ), - # forward-backward split is failing. - # TypeError: tree_flatten of type is not supported. - pytest.mark.xfail(), ), ) def test_torchao_float8_linear(executor, device, _): @@ -269,3 +271,9 @@ def test_torchao_float8_linear(executor, device, _): jitted = executor.make_callable(fp8_model) actual = jitted(x) + + if executor == DynamoThunderExecutor: + with pytest.raises(AssertionError): + torch.testing.assert_close(actual, expected) + else: + torch.testing.assert_close(actual, expected) diff --git a/thunder/torch/__init__.py b/thunder/torch/__init__.py index a9398ffa77..e03de2f54e 100644 --- a/thunder/torch/__init__.py +++ b/thunder/torch/__init__.py @@ -1259,7 +1259,9 @@ def t(a: TensorLike, /) -> TensorLike: lambda: f"t() expects a tensor with <= 2 dimensions, but self is {a.ndim}D", RuntimeError, ) - return transpose(a, 0, 1) if a.ndim == 2 else a + if a.ndim == 2: + return transpose(a, 0, 1) + return a @run_once @@ -1312,6 +1314,17 @@ def transpose(a: TensorLike, /, dim0: int, dim1: int) -> TensorLike: return clang.transpose(a, permutation) +def _transpose_grad(a: TensorLike, /, dim0: int, dim1: int) -> TensorLike: + fwd = transpose(a, dim0, dim1) + g = get_grad(fwd) + a_grad = transpose(g, dim0, dim1) + put_grad(a, a_grad) + return fwd + + +register_grad(transpose, _transpose_grad) + + @torchsymbol(torch.unbind, is_method=True) def unbind(a: TensorLike, /, dim: int = 0) -> tuple[TensorLike, ...]: utils.check( diff --git a/thunder/transforms/tensor_subclasses.py b/thunder/transforms/tensor_subclasses.py index 0538f5bc06..e9ab709a34 100644 --- a/thunder/transforms/tensor_subclasses.py +++ b/thunder/transforms/tensor_subclasses.py @@ -41,10 +41,10 @@ from torch.fx import GraphModule from torch._ops import OpOverload from thunder.core.symbol import Symbol, BoundSymbol - from torch._C import _TensorMeta __all__ = [ + "DesugarTensorSubclass", "flatten_tensor_subclasses", ] @@ -249,17 +249,18 @@ def translate_fx_graph_into_bsym( import thunder.torch as ltorch unwrapped_bsym_args: dict[int, ProxyInterface] = {} - list_of_unflatten_bsym: list[BoundSymbol] = [] + list_of_flattening_bsyms: list[BoundSymbol] = [] for a in bsym.flat_args: if isinstance(a, SubclassTensorProxy): if variableify(a) in self.subclass_proxy_to_flatten: self.computation_trace.push_scope([]) with tracectx(self.computation_trace): prims.flatten_tensor_subclass(a) - unflatten_bsym = self.computation_trace.pop_scope()[0] - list_of_unflatten_bsym.append(unflatten_bsym) + flattening_bsym = self.computation_trace.pop_scope()[0] + list_of_flattening_bsyms.append(flattening_bsym) tensor_attr_names = self._get_tensor_attr_names(a) tensors = a._tensors + non_tensor_attr_names = self._get_non_tensor_attr_names(a) non_tensors = a._non_tensors metadata = dict(zip(non_tensor_attr_names, non_tensors)) @@ -307,8 +308,8 @@ def translate_fx_graph_into_bsym( ltorch_ops_for_node_of_ops.append(getattr(ltorch, node.target._opname)) bsyms: list[BoundSymbol] = [] - if list_of_unflatten_bsym: - bsyms.extend(list_of_unflatten_bsym) + if list_of_flattening_bsyms: + bsyms.extend(list_of_flattening_bsyms) fxnode_output_name_to_tensor_proxy: dict[str, OpOverload] = {} for node, ltorch_op in zip(list_of_function_call_node, ltorch_ops_for_node_of_ops): args: list[Node] = node.args @@ -379,10 +380,22 @@ def translate_fx_graph_into_bsym( f"{len(new_tensor_proxies)=} != {len(orig_output._tensors)=}" ), ) - if [variableify(t) for t in orig_output._tensors] != [variableify(t) for t in new_tensor_proxies]: - orig_output._tensors = new_tensor_proxies - for name, tensor in zip(orig_output._tensor_attr_names, new_tensor_proxies): - setattr(orig_output, name, tensor) + with tracectx(self.computation_trace): + new_subclass = orig_output.replace() + new_subclass._tensors = new_tensor_proxies + for name, value in zip(new_subclass._tensor_attr_names, new_tensor_proxies): + setattr(new_subclass, name, value) + bsyms.append( + prims.unflatten_tensor_subclass.bind( + new_subclass._subclass_type, + dict(zip(new_subclass._tensor_attr_names, new_tensor_proxies)), + dict(zip(new_subclass._non_tensor_attr_names, new_subclass._non_tensors)), + output=new_subclass, + ) + ) + + self.swap_map[variableify(orig_output)] = new_subclass + self.subclass_proxy_to_flatten.add(variableify(new_subclass)) else: non_none_args = [n for n in node_of_output.args[0] if n is not None] @@ -502,7 +515,12 @@ def f_with_wrap_and_unwrap(*desugared_args) -> tuple[OutputWrapperForFxTracing, def __call__(self, bsym: BoundSymbol) -> list[BoundSymbol]: updated_bsym: BoundSymbol = bsym.from_bsym_swap_proxies(self.swap_map) - if updated_bsym.sym.id == prims.PrimIDs.RETURN: + if bsym.sym.id == prims.PrimIDs.RETURN: + new_swap_map = {} + for k, v in self.swap_map.items(): + if isinstance(v, SubclassTensorProxy): + continue + new_swap_map[k] = v if not self.subclass_proxy_to_flatten or True: return [updated_bsym] @@ -567,7 +585,7 @@ def __call__(self, bsym: BoundSymbol) -> list[BoundSymbol]: return self.translate_fx_graph_into_bsym(bsym_with_modified_output, fx) -def flatten_tensor_subclasses(computation_trace: TraceCtx) -> TraceCtx: +def flatten_tensor_subclasses(computation_trace: TraceCtx) -> tuple[TraceCtx, DesugarTensorSubclass]: """Flatten tensor subclasses in ``computation_trace``. Two things are happening inside of this function: @@ -601,9 +619,9 @@ def flatten_tensor_subclasses(computation_trace: TraceCtx) -> TraceCtx: updated_bsyms.extend(maybe_desugared_bsyms) if not desugar_tensor_subclass.subclass_proxy_to_flatten: - return computation_trace + return computation_trace, None computation_trace_with_subclass_tensor_proxy_output = from_trace(computation_trace) computation_trace_with_subclass_tensor_proxy_output.bound_symbols.extend(updated_bsyms) computation_trace_with_subclass_tensor_proxy_output.set_provenance(TraceProvenance("tensor subclasses desugared")) - return computation_trace_with_subclass_tensor_proxy_output + return computation_trace_with_subclass_tensor_proxy_output, desugar_tensor_subclass