Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[tensor wrapper subclass] Add support for torchao.float8 mlp #1585

Draft
wants to merge 1 commit into
base: tensor_subclass_2
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 44 additions & 22 deletions thunder/core/jit_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,6 +667,9 @@ def _general_jit_torch_autograd_function_apply_lookaside(obj: Any, *args, **kwar
So far, non-tensor ``ctx`` attributes seem to be folded into a trace.
"""
from thunder.core.baseutils import check, sequencify
from thunder.core.trace_interpreter import interpret_trace
from thunder.core.transforms import dce
from thunder.core.pytree import tree_flatten, tree_unflatten

custom_autograd_function_cls = unwrap(obj)
custom_forward = custom_autograd_function_cls.forward
Expand All @@ -679,25 +682,36 @@ def _general_jit_torch_autograd_function_apply_lookaside(obj: Any, *args, **kwar
if trace_of_fwd is INTERPRETER_SIGNALS.EXCEPTION_RAISED:
return trace_of_fwd

# Forward.
# augmented forward trace.
unwrapped_custom_forward_args = tree_map(lambda a: unwrap(a), args)
trace_of_fwd._siginfo = SigInfo.from_name_and_args(
custom_autograd_function_cls.__name__,
unwrapped_custom_forward_args,
)
trace_of_fwd.args = unwrapped_custom_forward_args
unpack_bsyms = [
prims.unpack_trivial.bind(a, name=a.name, output=a)
for a in filter(lambda a: isinstance(a, Proxy), trace_of_fwd.args)
for a in filter(lambda a: isinstance(a, Proxy), unwrapped_custom_forward_args)
]
trace_of_fwd.bound_symbols = unpack_bsyms + trace_of_fwd.bound_symbols

@wraps(trace_of_fwd.python_callable())
augmented_bsym_output: tuple[tuple[TensorProxy, ...], tuple[TensorProxy, ...]] = (
tuple(sequencify(trace_of_fwd.output)),
ctx_proxy.saved_tensors,
)
trace_of_augmented_fwd = TraceCtx()
trace_of_augmented_fwd.bound_symbols.extend((unpack_bsyms + trace_of_fwd.bound_symbols)[:-1])
with tracectx(trace_of_augmented_fwd):
prims.python_return(augmented_bsym_output)
trace_of_augmented_fwd._siginfo = SigInfo.from_name_and_args(
custom_autograd_function_cls.__name__, unwrapped_custom_forward_args
)
trace_of_augmented_fwd.args = unwrapped_custom_forward_args
trace_of_augmented_fwd = dce(trace_of_augmented_fwd)
_, spec_of_fwd_output = tree_flatten(trace_of_fwd.output)

@wraps(trace_of_augmented_fwd.python_callable())
def core_of_forward(*args, **kwargs):
return thunder.core.trace_interpreter.interpret_trace(trace_of_fwd, *args, **kwargs)
output, _ = interpret_trace(trace_of_augmented_fwd, *args, **kwargs)
flat_output, _ = tree_flatten(output)
return tree_unflatten(flat_output, spec_of_fwd_output)

custom_fwd_sym = get_jit_ctx().ad_hoc_executor.register_operator(
trace_of_fwd._siginfo.name,
custom_autograd_function_cls.__name__,
like=core_of_forward,
)
unwrapped_forward_result = custom_fwd_sym(*unwrapped_custom_forward_args)
Expand All @@ -706,17 +720,6 @@ def core_of_forward(*args, **kwargs):
provenance=ProvenanceRecord(PseudoInst.LOOKASIDE, inputs=[obj.provenance, fwd_output_provenance]),
)

augmented_bsym_output: tuple[tuple[TensorProxy, ...], tuple[TensorProxy, ...]] = (
tuple(sequencify(trace_of_fwd.output)),
ctx_proxy.saved_tensors,
)
trace_of_augmented_fwd = TraceCtx()
trace_of_augmented_fwd.bound_symbols.extend(trace_of_fwd.bound_symbols[:-1])
with tracectx(trace_of_augmented_fwd):
prims.python_return(augmented_bsym_output)
trace_of_augmented_fwd._siginfo = SigInfo.from_name_and_args(custom_fwd_sym.name, unwrapped_custom_forward_args)
trace_of_augmented_fwd.args = unwrapped_custom_forward_args

# Backward definition
custom_backward = custom_autograd_function_cls.backward
grads = tree_map(
Expand Down Expand Up @@ -745,6 +748,7 @@ def core_of_forward(*args, **kwargs):
ctx_proxy.saved_consts + ctx_proxy.saved_tensors + grads,
)
bwd_trace_impl.args = tuple(ctx_proxy.saved_consts + ctx_proxy.saved_tensors + grads)
bwd_trace_impl = dce(bwd_trace_impl)

@wraps(bwd_trace_impl.python_callable())
def bwd_impl_callable(*args, **kwargs):
Expand All @@ -770,6 +774,24 @@ def grad_transform(*args, **kwargs):
execution_transform=core_of_forward,
grad_transform=grad_transform,
)

added_bsym: BoundSymbol = get_jit_ctx().computation_trace.scopes[-1][-1]
import_ctx, call_ctx, object_ctx = {}, {}, {}
for bsym in trace_of_fwd.bound_symbols:
cur_import_ctx, cur_call_ctx, cur_object_ctx = bsym.gather_ctxs()
import_ctx.update(cur_import_ctx)
call_ctx.update(cur_call_ctx)
object_ctx.update(cur_object_ctx)

if import_ctx:
added_bsym._import_ctx.update(import_ctx)
if call_ctx:
if added_bsym._call_ctx is not None:
added_bsym._call_ctx.update(call_ctx)
else:
added_bsym._call_ctx = call_ctx
if object_ctx:
added_bsym._object_ctx.update(object_ctx)
return forward_result


Expand Down
20 changes: 20 additions & 0 deletions thunder/core/proxies.py
Original file line number Diff line number Diff line change
Expand Up @@ -1502,6 +1502,26 @@ def distparallel_type(self):
def thunder_fsdp_padding_size(self):
return self._thunder_fsdp_padding_size

# n.b.(crcrpar): just returning contiguous for `_make_wrapper_subclasses`
def stride(self) -> Sequence[int]:
shape = self.shape
if len(shape) == 1:
return (1,)
elif len(shape) == 0:
return tuple()
else:
import numpy

_stride = reversed(numpy.cumprod([1] + list(shape[1:])).tolist())
return tuple(_stride)

def storage_offset(self) -> int:
return -1

@property
def layout(self) -> torch.layout:
return torch.strided

# We need to implement `__len__` as
# > In addition to bypassing any instance attributes in the
# > interest of correctness, implicit special method lookup
Expand Down
2 changes: 2 additions & 0 deletions thunder/core/pytree.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from enum import Enum
from functools import partial
from types import FunctionType
import dataclasses
Expand Down Expand Up @@ -64,6 +65,7 @@ def tree_flatten(args, namespace=OPTREE_NAMESPACE):
and not is_likely_from_collections_namedtuple(args)
and not dataclasses.is_dataclass(args)
and not type(args).__module__.startswith("torch.return_types")
and not issubclass(type(args), Enum)
):
raise TypeError(f"tree_flatten of type {type(args)} is not supported.")
return optree.tree_flatten(args, none_is_leaf=True, namespace=namespace)
Expand Down
7 changes: 5 additions & 2 deletions thunder/core/trace_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,11 @@ def add_to_swap_map(old, new):
old = old.replace(shape=new._shape)

if isinstance(new, VJPDual):
swap_map[variableify(new.primal)] = old
new.primal = old
# note(crcrpar): Without this sanity check, `subclass.__tensor_flatten__`,
# seems to cause `new.primal` == `old`, leading to a cycle in swapping.
if (key := variableify(new.primal)) != variableify(old):
swap_map[variableify(new.primal)] = old
new.primal = old
else:
assert isinstance(new, ProxyInterface), (old, new)
swap_map[variableify(new)] = old
Expand Down
2 changes: 1 addition & 1 deletion thunder/core/transform_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def dce(trace: Trace, needed_proxies: None | set[Variable] = None) -> Trace:
# may mark some of the operation's outputs as unused
some_unused = False
for out in bsym.flat_proxy_outs:
if variableify(out) in needed_proxies and producer_map[out] == bsym:
if variableify(out) in needed_proxies and producer_map.get(out, None) == bsym:
needed = True
else:
some_unused = True
Expand Down
6 changes: 6 additions & 0 deletions thunder/executors/torch_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,12 @@ def _to_torch(*args, **kwargs) -> Any:
if torch_op is None:
raise RuntimeError("op not found for {bsym.sym.name}")

# NOTE(crcrpar): Currently `ltorch.t` is mapped to `torchex.transpose`
# thus `args` needs to be updated to have dim0 and dim1
if bsym.sym.id == "torch.t":
utils.check(len(args) == 1, lambda: f"{bsym.sym.id} takes only one argument but {args=}")
args = args + (0, 1)

return torch_op(*args, **kwargs)

return _to_torch
Expand Down
31 changes: 31 additions & 0 deletions thunder/executors/torchex.py
Original file line number Diff line number Diff line change
Expand Up @@ -1403,13 +1403,44 @@ def _copy_with_setitem_impl(a, key, value):
#

matmul = _register_torch_operation("matmul")
_scaled_mm = _register_torch_operation("_scaled_mm")
outer = _register_torch_operation("outer")

_register_implementation(prims.matmul, matmul, checker=_always_executable)

_register_implementation(ltorch.matmul, matmul, checker=_always_executable)
_register_implementation(ltorch.outer, outer, checker=_always_executable)


def _scaled_mm_transform(
a: TensorLike,
b: TensorLike,
scale_a: TensorLike,
scale_b: TensorLike,
bias: TensorLike | None = None,
scale_result: TensorLike | None = None,
out_dtype: dtypeLike | None = None,
use_fast_accum: bool = False,
):

def is_column_major(mat: TensorLike) -> bool:
return mat.stride()[0] == 1 and mat.stride()[0] > 1

result_dtype: torch.dtype = to_torch_dtype(a.dtype if out_dtype is None else out_dtype)
if not is_column_major(b):
b = b.t().contiguous().t()

return _scaled_mm(a, b, scale_a, scale_b, bias, scale_result, result_dtype, use_fast_accum)


_register_implementation(
ltorch._scaled_mm, _scaled_mm, checker=_always_executable, execution_transform=_scaled_mm_transform
)
_register_implementation(
ltorch.core_aten_scaled_mm, _scaled_mm, checker=_always_executable, execution_transform=_scaled_mm_transform
)


#
# Normalization operations
#
Expand Down
81 changes: 80 additions & 1 deletion thunder/tests/test_tensor_subclass.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,29 @@
from __future__ import annotations
from typing import TYPE_CHECKING

from lightning_utilities.core.imports import package_available
import pytest
import torch
import torch.nn as nn
from torch.utils import _pytree as pytree

import thunder
from thunder.tests.framework import instantiate
from thunder.dynamo.compiler import ThunderCompiler
from thunder.tests.framework import (
DynamoThunderExecutor,
TorchExecutor,
instantiate,
nvFuserExecutor,
)
from thunder.tests.make_tensor import make_tensor

if TYPE_CHECKING:
from typing import Any


TORCHAO_AVAILABLE = package_available("torchao")


@torch._dynamo.allow_in_graph
class EncapsulateXandScale(torch.autograd.Function):
@staticmethod
Expand Down Expand Up @@ -232,3 +243,71 @@ def g(x: ScaleTensorSubclass, data: torch.Tensor, scale: torch.Tensor) -> torch.
torch.testing.assert_close(expected, actual)
if requires_grad:
actual.mean().backward()


@instantiate(
dtypes=(thunder.core.dtypes.float32, thunder.core.dtypes.bfloat16),
devicetypes=(thunder.core.devices.DeviceType.CUDA,),
executors=(TorchExecutor, nvFuserExecutor, DynamoThunderExecutor),
decorators=(
pytest.mark.skipif(
not (TORCHAO_AVAILABLE and torch.cuda.get_device_capability() >= (8, 9)),
reason="Requires capability >= 8.9 and torchao",
),
pytest.mark.parametrize("bias", (True, False)),
),
)
def test_torchao_float8_linear(executor, device, dtype, bias):
from torchao.float8 import convert_to_float8_training

batch_size, in_features, out_features = 16, 32, 64
device = torch.device("cuda")
torch_dtype = thunder.core.dtypes.to_torch_dtype(dtype)

model = nn.Sequential(
nn.Linear(in_features, out_features, bias=bias),
nn.GELU(approximate="tanh"),
nn.Linear(out_features, out_features, bias=bias),
).to(device=device, dtype=torch_dtype)
fp8_model = convert_to_float8_training(model)
x = make_tensor((batch_size, in_features), device=device, dtype=torch_dtype)

expected: torch.Tensor
jitted: nn.Module
backend: ThunderCompiler | None = None

if is_thunderfx := executor == DynamoThunderExecutor:
torch._dynamo.reset()
expected = torch.compile(fp8_model)(x)
backend = ThunderCompiler()
jitted = torch.compile(fp8_model, backend=backend)
else:
expected = fp8_model(x)
jitted = executor.make_callable(fp8_model)

if bias and dtype == thunder.core.dtypes.bfloat16 and executor == nvFuserExecutor:
with pytest.raises(
RuntimeError, match="Failed to compute the min-cut on the graph due to a path with infinite capacity"
):
jitted(x)
return
actual = jitted(x)
if bias and dtype == thunder.core.dtypes.bfloat16 and executor == DynamoThunderExecutor:
with pytest.raises(AssertionError, match="Tensor-likes are not close"):
torch.testing.assert_close(actual, expected)
return

if (dtype == thunder.core.dtypes.bfloat16 and executor != DynamoThunderExecutor) or (
not bias and dtype == thunder.core.dtypes.bfloat16 and executor == DynamoThunderExecutor
):
pytest.xfail("numerical error")
torch.testing.assert_close(actual, expected)

# TODO(crcrpar): Think of how to push tensor subclasses to `thunder.jit`.
# Currently no subgraphs go to thunder.jit.
if is_thunderfx:
for subgraph in backend.subgraph_infos:
if not bias and dtype == thunder.core.dtypes.bfloat16:
assert not subgraph.thunder_compiled_fns
else:
assert subgraph.thunder_compiled_fns
Loading
Loading