From 8435406378ddca6933d29cafcd6055c1f913c1e4 Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Thu, 26 Dec 2024 03:19:00 +0900 Subject: [PATCH] updates for MLP with `torchao.float8` - Add `scaled_mm` - Change how the lookaside of `torch.autograd.Function.apply` applies dce taking the failure of apex fused rms norm into consideration. ```python @torch.no_grad() @no_autocast def FusedRMSNormAffineMixedDtypesFunction(t_0, t_1, tup11, f12, b13): # /usr/local/lib/python3.12/dist-packages/apex/normalization/fused_layer_norm.py:128: weight_ = weight.contiguous() # t_0: "cuda:0 f32[4, 5, 3, 2]" # t_1: "cuda:0 f32[3, 2]" # /usr/local/lib/python3.12/dist-packages/apex/normalization/fused_layer_norm.py:127: input_ = input.contiguous() t5 = ltorch.contiguous(t_0, memory_format=_torch_memory_format_0) # t5: "cuda:0 f32[4, 5, 3, 2]" # t5 = prims.stride_order(t_0, (3, 2, 1, 0)) # t5: "cuda:0 f32[4, 5, 3, 2]" # /usr/local/lib/python3.12/dist-packages/apex/normalization/fused_layer_norm.py:128: weight_ = weight.contiguous() t6 = ltorch.contiguous(t_1, memory_format=_torch_memory_format_0) # t6: "cuda:0 f32[3, 2]" # t6 = prims.stride_order(t_1, (1, 0)) # t6: "cuda:0 f32[3, 2]" (t10, t9) = apex_fused_rms_norm_forward_affine_mixed_dtypes(t5, (3, 2), t6, 1e-05) return t10 ``` For this trace, `thunder.core.transforms.dce` replaces `t9` with `_` then the augmented forward trace would lose the access to it. So by reusing the augmented forward trace in the basic forward trace, `dce` would not do so. Signed-off-by: Masaki Kozuki --- thunder/core/jit_ext.py | 66 ++++++++++++------- thunder/core/proxies.py | 20 ++++++ thunder/core/pytree.py | 2 + thunder/core/trace_interpreter.py | 7 ++- thunder/core/transform_common.py | 2 +- thunder/executors/torch_compile.py | 6 ++ thunder/executors/torchex.py | 31 +++++++++ thunder/tests/test_tensor_subclass.py | 81 +++++++++++++++++++++++- thunder/torch/__init__.py | 91 ++++++++++++++++++++++++++- 9 files changed, 279 insertions(+), 27 deletions(-) diff --git a/thunder/core/jit_ext.py b/thunder/core/jit_ext.py index e88419546a..f260b5ebc6 100644 --- a/thunder/core/jit_ext.py +++ b/thunder/core/jit_ext.py @@ -667,6 +667,9 @@ def _general_jit_torch_autograd_function_apply_lookaside(obj: Any, *args, **kwar So far, non-tensor ``ctx`` attributes seem to be folded into a trace. """ from thunder.core.baseutils import check, sequencify + from thunder.core.trace_interpreter import interpret_trace + from thunder.core.transforms import dce + from thunder.core.pytree import tree_flatten, tree_unflatten custom_autograd_function_cls = unwrap(obj) custom_forward = custom_autograd_function_cls.forward @@ -679,25 +682,36 @@ def _general_jit_torch_autograd_function_apply_lookaside(obj: Any, *args, **kwar if trace_of_fwd is INTERPRETER_SIGNALS.EXCEPTION_RAISED: return trace_of_fwd - # Forward. + # augmented forward trace. unwrapped_custom_forward_args = tree_map(lambda a: unwrap(a), args) - trace_of_fwd._siginfo = SigInfo.from_name_and_args( - custom_autograd_function_cls.__name__, - unwrapped_custom_forward_args, - ) - trace_of_fwd.args = unwrapped_custom_forward_args unpack_bsyms = [ prims.unpack_trivial.bind(a, name=a.name, output=a) - for a in filter(lambda a: isinstance(a, Proxy), trace_of_fwd.args) + for a in filter(lambda a: isinstance(a, Proxy), unwrapped_custom_forward_args) ] - trace_of_fwd.bound_symbols = unpack_bsyms + trace_of_fwd.bound_symbols - @wraps(trace_of_fwd.python_callable()) + augmented_bsym_output: tuple[tuple[TensorProxy, ...], tuple[TensorProxy, ...]] = ( + tuple(sequencify(trace_of_fwd.output)), + ctx_proxy.saved_tensors, + ) + trace_of_augmented_fwd = TraceCtx() + trace_of_augmented_fwd.bound_symbols.extend((unpack_bsyms + trace_of_fwd.bound_symbols)[:-1]) + with tracectx(trace_of_augmented_fwd): + prims.python_return(augmented_bsym_output) + trace_of_augmented_fwd._siginfo = SigInfo.from_name_and_args( + custom_autograd_function_cls.__name__, unwrapped_custom_forward_args + ) + trace_of_augmented_fwd.args = unwrapped_custom_forward_args + trace_of_augmented_fwd = dce(trace_of_augmented_fwd) + _, spec_of_fwd_output = tree_flatten(trace_of_fwd.output) + + @wraps(trace_of_augmented_fwd.python_callable()) def core_of_forward(*args, **kwargs): - return thunder.core.trace_interpreter.interpret_trace(trace_of_fwd, *args, **kwargs) + output, _ = interpret_trace(trace_of_augmented_fwd, *args, **kwargs) + flat_output, _ = tree_flatten(output) + return tree_unflatten(flat_output, spec_of_fwd_output) custom_fwd_sym = get_jit_ctx().ad_hoc_executor.register_operator( - trace_of_fwd._siginfo.name, + custom_autograd_function_cls.__name__, like=core_of_forward, ) unwrapped_forward_result = custom_fwd_sym(*unwrapped_custom_forward_args) @@ -706,17 +720,6 @@ def core_of_forward(*args, **kwargs): provenance=ProvenanceRecord(PseudoInst.LOOKASIDE, inputs=[obj.provenance, fwd_output_provenance]), ) - augmented_bsym_output: tuple[tuple[TensorProxy, ...], tuple[TensorProxy, ...]] = ( - tuple(sequencify(trace_of_fwd.output)), - ctx_proxy.saved_tensors, - ) - trace_of_augmented_fwd = TraceCtx() - trace_of_augmented_fwd.bound_symbols.extend(trace_of_fwd.bound_symbols[:-1]) - with tracectx(trace_of_augmented_fwd): - prims.python_return(augmented_bsym_output) - trace_of_augmented_fwd._siginfo = SigInfo.from_name_and_args(custom_fwd_sym.name, unwrapped_custom_forward_args) - trace_of_augmented_fwd.args = unwrapped_custom_forward_args - # Backward definition custom_backward = custom_autograd_function_cls.backward grads = tree_map( @@ -745,6 +748,7 @@ def core_of_forward(*args, **kwargs): ctx_proxy.saved_consts + ctx_proxy.saved_tensors + grads, ) bwd_trace_impl.args = tuple(ctx_proxy.saved_consts + ctx_proxy.saved_tensors + grads) + bwd_trace_impl = dce(bwd_trace_impl) @wraps(bwd_trace_impl.python_callable()) def bwd_impl_callable(*args, **kwargs): @@ -770,6 +774,24 @@ def grad_transform(*args, **kwargs): execution_transform=core_of_forward, grad_transform=grad_transform, ) + + added_bsym: BoundSymbol = get_jit_ctx().computation_trace.scopes[-1][-1] + import_ctx, call_ctx, object_ctx = {}, {}, {} + for bsym in trace_of_fwd.bound_symbols: + cur_import_ctx, cur_call_ctx, cur_object_ctx = bsym.gather_ctxs() + import_ctx.update(cur_import_ctx) + call_ctx.update(cur_call_ctx) + object_ctx.update(cur_object_ctx) + + if import_ctx: + added_bsym._import_ctx.update(import_ctx) + if call_ctx: + if added_bsym._call_ctx is not None: + added_bsym._call_ctx.update(call_ctx) + else: + added_bsym._call_ctx = call_ctx + if object_ctx: + added_bsym._object_ctx.update(object_ctx) return forward_result diff --git a/thunder/core/proxies.py b/thunder/core/proxies.py index 439c9f95b6..a772e80b66 100644 --- a/thunder/core/proxies.py +++ b/thunder/core/proxies.py @@ -1502,6 +1502,26 @@ def distparallel_type(self): def thunder_fsdp_padding_size(self): return self._thunder_fsdp_padding_size + # n.b.(crcrpar): just returning contiguous for `_make_wrapper_subclasses` + def stride(self) -> Sequence[int]: + shape = self.shape + if len(shape) == 1: + return (1,) + elif len(shape) == 0: + return tuple() + else: + import numpy + + _stride = reversed(numpy.cumprod([1] + list(shape[1:])).tolist()) + return tuple(_stride) + + def storage_offset(self) -> int: + return -1 + + @property + def layout(self) -> torch.layout: + return torch.strided + # We need to implement `__len__` as # > In addition to bypassing any instance attributes in the # > interest of correctness, implicit special method lookup diff --git a/thunder/core/pytree.py b/thunder/core/pytree.py index 8c92a38555..f50861d878 100644 --- a/thunder/core/pytree.py +++ b/thunder/core/pytree.py @@ -1,3 +1,4 @@ +from enum import Enum from functools import partial from types import FunctionType import dataclasses @@ -64,6 +65,7 @@ def tree_flatten(args, namespace=OPTREE_NAMESPACE): and not is_likely_from_collections_namedtuple(args) and not dataclasses.is_dataclass(args) and not type(args).__module__.startswith("torch.return_types") + and not issubclass(type(args), Enum) ): raise TypeError(f"tree_flatten of type {type(args)} is not supported.") return optree.tree_flatten(args, none_is_leaf=True, namespace=namespace) diff --git a/thunder/core/trace_interpreter.py b/thunder/core/trace_interpreter.py index edff5937a9..95ca6ebd20 100644 --- a/thunder/core/trace_interpreter.py +++ b/thunder/core/trace_interpreter.py @@ -128,8 +128,11 @@ def add_to_swap_map(old, new): old = old.replace(shape=new._shape) if isinstance(new, VJPDual): - swap_map[variableify(new.primal)] = old - new.primal = old + # note(crcrpar): Without this sanity check, `subclass.__tensor_flatten__`, + # seems to cause `new.primal` == `old`, leading to a cycle in swapping. + if (key := variableify(new.primal)) != variableify(old): + swap_map[variableify(new.primal)] = old + new.primal = old else: assert isinstance(new, ProxyInterface), (old, new) swap_map[variableify(new)] = old diff --git a/thunder/core/transform_common.py b/thunder/core/transform_common.py index bfe4123dc6..fd1dcd325f 100644 --- a/thunder/core/transform_common.py +++ b/thunder/core/transform_common.py @@ -165,7 +165,7 @@ def dce(trace: Trace, needed_proxies: None | set[Variable] = None) -> Trace: # may mark some of the operation's outputs as unused some_unused = False for out in bsym.flat_proxy_outs: - if variableify(out) in needed_proxies and producer_map[out] == bsym: + if variableify(out) in needed_proxies and producer_map.get(out, None) == bsym: needed = True else: some_unused = True diff --git a/thunder/executors/torch_compile.py b/thunder/executors/torch_compile.py index ce95b91bf8..24cb7a7e02 100644 --- a/thunder/executors/torch_compile.py +++ b/thunder/executors/torch_compile.py @@ -56,6 +56,12 @@ def _to_torch(*args, **kwargs) -> Any: if torch_op is None: raise RuntimeError("op not found for {bsym.sym.name}") + # NOTE(crcrpar): Currently `ltorch.t` is mapped to `torchex.transpose` + # thus `args` needs to be updated to have dim0 and dim1 + if bsym.sym.id == "torch.t": + utils.check(len(args) == 1, lambda: f"{bsym.sym.id} takes only one argument but {args=}") + args = args + (0, 1) + return torch_op(*args, **kwargs) return _to_torch diff --git a/thunder/executors/torchex.py b/thunder/executors/torchex.py index e17c6f5841..cc08935e0e 100644 --- a/thunder/executors/torchex.py +++ b/thunder/executors/torchex.py @@ -1403,6 +1403,7 @@ def _copy_with_setitem_impl(a, key, value): # matmul = _register_torch_operation("matmul") +_scaled_mm = _register_torch_operation("_scaled_mm") outer = _register_torch_operation("outer") _register_implementation(prims.matmul, matmul, checker=_always_executable) @@ -1410,6 +1411,36 @@ def _copy_with_setitem_impl(a, key, value): _register_implementation(ltorch.matmul, matmul, checker=_always_executable) _register_implementation(ltorch.outer, outer, checker=_always_executable) + +def _scaled_mm_transform( + a: TensorLike, + b: TensorLike, + scale_a: TensorLike, + scale_b: TensorLike, + bias: TensorLike | None = None, + scale_result: TensorLike | None = None, + out_dtype: dtypeLike | None = None, + use_fast_accum: bool = False, +): + + def is_column_major(mat: TensorLike) -> bool: + return mat.stride()[0] == 1 and mat.stride()[0] > 1 + + result_dtype: torch.dtype = to_torch_dtype(a.dtype if out_dtype is None else out_dtype) + if not is_column_major(b): + b = b.t().contiguous().t() + + return _scaled_mm(a, b, scale_a, scale_b, bias, scale_result, result_dtype, use_fast_accum) + + +_register_implementation( + ltorch._scaled_mm, _scaled_mm, checker=_always_executable, execution_transform=_scaled_mm_transform +) +_register_implementation( + ltorch.core_aten_scaled_mm, _scaled_mm, checker=_always_executable, execution_transform=_scaled_mm_transform +) + + # # Normalization operations # diff --git a/thunder/tests/test_tensor_subclass.py b/thunder/tests/test_tensor_subclass.py index da5cf1b8ac..fa92132065 100644 --- a/thunder/tests/test_tensor_subclass.py +++ b/thunder/tests/test_tensor_subclass.py @@ -1,18 +1,29 @@ from __future__ import annotations from typing import TYPE_CHECKING +from lightning_utilities.core.imports import package_available import pytest import torch +import torch.nn as nn from torch.utils import _pytree as pytree import thunder -from thunder.tests.framework import instantiate +from thunder.dynamo.compiler import ThunderCompiler +from thunder.tests.framework import ( + DynamoThunderExecutor, + TorchExecutor, + instantiate, + nvFuserExecutor, +) from thunder.tests.make_tensor import make_tensor if TYPE_CHECKING: from typing import Any +TORCHAO_AVAILABLE = package_available("torchao") + + @torch._dynamo.allow_in_graph class EncapsulateXandScale(torch.autograd.Function): @staticmethod @@ -232,3 +243,71 @@ def g(x: ScaleTensorSubclass, data: torch.Tensor, scale: torch.Tensor) -> torch. torch.testing.assert_close(expected, actual) if requires_grad: actual.mean().backward() + + +@instantiate( + dtypes=(thunder.core.dtypes.float32, thunder.core.dtypes.bfloat16), + devicetypes=(thunder.core.devices.DeviceType.CUDA,), + executors=(TorchExecutor, nvFuserExecutor, DynamoThunderExecutor), + decorators=( + pytest.mark.skipif( + not (TORCHAO_AVAILABLE and torch.cuda.get_device_capability() >= (8, 9)), + reason="Requires capability >= 8.9 and torchao", + ), + pytest.mark.parametrize("bias", (True, False)), + ), +) +def test_torchao_float8_linear(executor, device, dtype, bias): + from torchao.float8 import convert_to_float8_training + + batch_size, in_features, out_features = 16, 32, 64 + device = torch.device("cuda") + torch_dtype = thunder.core.dtypes.to_torch_dtype(dtype) + + model = nn.Sequential( + nn.Linear(in_features, out_features, bias=bias), + nn.GELU(approximate="tanh"), + nn.Linear(out_features, out_features, bias=bias), + ).to(device=device, dtype=torch_dtype) + fp8_model = convert_to_float8_training(model) + x = make_tensor((batch_size, in_features), device=device, dtype=torch_dtype) + + expected: torch.Tensor + jitted: nn.Module + backend: ThunderCompiler | None = None + + if is_thunderfx := executor == DynamoThunderExecutor: + torch._dynamo.reset() + expected = torch.compile(fp8_model)(x) + backend = ThunderCompiler() + jitted = torch.compile(fp8_model, backend=backend) + else: + expected = fp8_model(x) + jitted = executor.make_callable(fp8_model) + + if bias and dtype == thunder.core.dtypes.bfloat16 and executor == nvFuserExecutor: + with pytest.raises( + RuntimeError, match="Failed to compute the min-cut on the graph due to a path with infinite capacity" + ): + jitted(x) + return + actual = jitted(x) + if bias and dtype == thunder.core.dtypes.bfloat16 and executor == DynamoThunderExecutor: + with pytest.raises(AssertionError, match="Tensor-likes are not close"): + torch.testing.assert_close(actual, expected) + return + + if (dtype == thunder.core.dtypes.bfloat16 and executor != DynamoThunderExecutor) or ( + not bias and dtype == thunder.core.dtypes.bfloat16 and executor == DynamoThunderExecutor + ): + pytest.xfail("numerical error") + torch.testing.assert_close(actual, expected) + + # TODO(crcrpar): Think of how to push tensor subclasses to `thunder.jit`. + # Currently no subgraphs go to thunder.jit. + if is_thunderfx: + for subgraph in backend.subgraph_infos: + if not bias and dtype == thunder.core.dtypes.bfloat16: + assert not subgraph.thunder_compiled_fns + else: + assert subgraph.thunder_compiled_fns diff --git a/thunder/torch/__init__.py b/thunder/torch/__init__.py index 34245270b8..6de3deb20b 100644 --- a/thunder/torch/__init__.py +++ b/thunder/torch/__init__.py @@ -1406,7 +1406,9 @@ def t(a: TensorLike, /) -> TensorLike: lambda: f"t() expects a tensor with <= 2 dimensions, but self is {a.ndim}D", RuntimeError, ) - return prims.transpose(a, (1, 0)) if a.ndim == 2 else a + if a.ndim != 2: + return a + return transpose(a, 0, 1) @torchsymbol(torch.ops.aten.t.default, id="torch.ops.aten.t.default") @@ -1480,6 +1482,17 @@ def core_aten_transpose(a: TensorProxy, dim0: int, dim1: int) -> TensorProxy: return _transpose_impl(a, dim0, dim1) +def _transpose_grad(a: TensorLike, /, dim0: int, dim1: int) -> TensorLike: + fwd = transpose(a, dim0, dim1) + g = get_grad(fwd) + a_grad = transpose(g, dim0, dim1) + put_grad(a, a_grad) + return fwd + + +register_grad(transpose, _transpose_grad) + + @torchsymbol(torch.unbind, is_method=True) def unbind(a: TensorLike, /, dim: int = 0) -> tuple[TensorLike, ...]: utils.check( @@ -4282,6 +4295,82 @@ def outer(a: TensorLike, b: TensorLike, /) -> TensorLike: return a[:, None] * b[None, :] +# TODO(crcrpar): Add nvfuser support of `matmul(a.float() * scale_a, b.float() * scale_b) + bias` +# So far I haven't managed to get a nice result from nvfuser region as I left +# https://github.com/Lightning-AI/lightning-thunder/pull/1415/files#r1892875183 +# reference: https://github.com/pytorch/pytorch/blob/6d4cd3e/torch/_meta_registrations.py#L5566 +def _scaled_mm_impl( + a: TensorLike, + b: TensorLike, + scale_a: TensorLike, + scale_b: TensorLike, + bias: TensorLike | None = None, + scale_result: TensorLike | None = None, + out_dtype: dtypeLike | None = None, + use_fast_accum: bool = False, +) -> TensorLike: + fp8_dtypes = {dtypes.float8_e4m3fn, dtypes.float8_e4m3fnuz, dtypes.float8_e5m2, dtypes.float8_e5m2fnuz} + # TODO(crcrpar): Devise a way to make sure `a` is row-major and `b` is column-major. + utils.check( + ( + (a.ndim == 2 and b.ndim == 2) + and (a.shape[1] == b.shape[0]) + and (a.shape[1] % 16 == 0 and b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0) + and (to_dtype(a.dtype) in fp8_dtypes and to_dtype(b.dtype) in fp8_dtypes) + and not (a.dtype == dtypes.float8_e5m2 and b.dtype == dtypes.float8_e5m2) + and to_device(a.device).type == "cuda" + ), + lambda: f"data matrices of {a=} and {b=} do not satisfy the condition.", + ) + args = [a, b, scale_a, scale_b] + if bias is not None: + args.append(bias) + utils.check_same_device(args) + utils.check( + ( + (scale_a.numel() == 1 and scale_b.numel() == 1) + and (scale_a.dtype == dtypes.float32 and scale_b.dtype == dtypes.float32) + ), + lambda: f"Only tensor-wise scaling is supported but {scaled_a.shape = } and {scaled_b.shape = }", + exception_type=NotImplementedError, + ) + result_dtype = a.dtype if out_dtype is None else to_dtype(out_dtype) + return TensorProxy( + like=a, + shape=(a.shape[0], b.shape[1]), + device=a.device, + dtype=result_dtype, + ) + + +@torchsymbol(torch._scaled_mm) +def _scaled_mm( + a: TensorLike, + b: TensorLike, + scale_a: TensorLike, + scale_b: TensorLike, + bias: TensorLike | None = None, + scale_result: TensorLike | None = None, + out_dtype: dtypeLike | None = None, + use_fast_accum: bool = False, +) -> TensorLike: + return _scaled_mm_impl(a, b, scale_a, scale_b, bias, scale_result, out_dtype, use_fast_accum) + + +@torchsymbol(torch.ops.aten._scaled_mm.default, id="torch.ops.aten._scaled_mm") +def core_aten_scaled_mm( + a: TensorLike, + b: TensorLike, + scale_a: TensorLike, + scale_b: TensorLike, + bias: TensorLike | None = None, + scale_result: TensorLike | None = None, + out_dtype: dtypeLike | None = None, + use_fast_accum: bool = False, +) -> TensorLike: + return _scaled_mm_impl(a, b, scale_a, scale_b, bias, scale_result, out_dtype, use_fast_accum) + + # # Normalization operations #