Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[torchao float8tensor] #1415

Draft
wants to merge 101 commits into
base: main
Choose a base branch
from
Draft

[torchao float8tensor] #1415

wants to merge 101 commits into from

Conversation

crcrpar
Copy link
Collaborator

@crcrpar crcrpar commented Nov 8, 2024

What does this PR do?

Improve the tensor subclass support of #1394 for TorchAo float8.

note: pytorch/ao#1339 is needed

my environment

  • torch: 2.6.0a0+git45ed7c1
  • nvfuser: 0.2.23+git21e3617
  • torchao: 0.8.0+git33d57af1
  • CUDA Version: 12.6

Traces

Model def

batch_size, in_features, out_features = 16, 32, 64
device = torch.device("cuda")
dtype = torch.float32
bias = True
model = convert_to_float8_training(nn.Sequential(
    nn.Linear(in_features, out_features, bias=bias),
    nn.GELU(approximate="tanh"),
    nn.Linear(out_features, out_features, bias=bias),
).to(device=device, dtype=dtype))
fp32, GELU -> working

Commit: 0893fbe
Model:

batch_size, in_features, out_features = 16, 32, 64
device = torch.device("cuda")
dtype = torch.float32
bias = True
model = convert_to_float8_training(nn.Sequential(
    nn.Linear(in_features, out_features, bias=bias),
    nn.GELU(approximate="tanh"),
    nn.Linear(out_features, out_features, bias=bias),
).to(device=device, dtype=dtype))
# forward
# Constructed by Delete Last Used (took 0 milliseconds)
import thunder.core.devices as devices
import thunder.core.dtypes as dtypes
import torch
from torchao.float8.float8_tensor import Float8Tensor
from torchao.float8.float8_tensor import GemmInputRole
from torchao.float8.float8_tensor import LinearMMConfig
from torchao.float8.float8_tensor import ScaledMMConfig
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def computation(input, t_0_bias, weight, t_2_bias, t_2_weight):
  # input: "cuda:0 f32[16, 32]"
  # t_0_bias: "cuda:0 f32[64]"
  # weight: "cuda:0 f32[64, 32]"
  # t_2_bias: "cuda:0 f32[64]"
  # t_2_weight: "cuda:0 f32[64, 64]"
  [scale, t340, t438, t441, t444, t445, t225, t452, t462, t466] = nvFusion0(input, weight, t_2_weight)
    # t3 = prims.abs(input)  # t3: "cuda:0 f32[16, 32]"
    # amax = prims.amax(t3, (0, 1))  # amax: "cuda:0 f32[]"
    # t5 = prims.convert_element_type(amax, dtypes.float64)  # t5: "cuda:0 f64[]"
    # t470 = prims.ne(t5, t5)  # t470: "cuda:0 b8[]"
    # t471 = prims.gt(t5, 1e-12)  # t471: "cuda:0 b8[]"
    # t472 = prims.where(t471, t5, 1e-12)  # t472: "cuda:0 f64[]"
    # t10 = prims.where(t470, t5, t472)  # t10: "cuda:0 f64[]"
    # res = prims.div(448.0, t10)  # res: "cuda:0 f64[]"
    # scale = prims.convert_element_type(res, dtypes.float32)  # scale: "cuda:0 f32[]"
    # t476 = prims.broadcast_in_dim(scale, (16, 32), ())  # t476: "cuda:0 f32[16, 32]"
    # t331 = prims.mul(input, t476)  # t331: "cuda:0 f32[16, 32]"
    # t478 = prims.ne(t331, t331)  # t478: "cuda:0 b8[16, 32]"
    # t479 = prims.gt(t331, -448.0)  # t479: "cuda:0 b8[16, 32]"
    # t480 = prims.where(t479, t331, -448.0)  # t480: "cuda:0 f32[16, 32]"
    # t481 = prims.where(t478, t331, t480)  # t481: "cuda:0 f32[16, 32]"
    # t482 = prims.ne(t481, t481)  # t482: "cuda:0 b8[16, 32]"
    # t483 = prims.lt(t481, 448.0)  # t483: "cuda:0 b8[16, 32]"
    # t484 = prims.where(t483, t481, 448.0)  # t484: "cuda:0 f32[16, 32]"
    # t339 = prims.where(t482, t481, t484)  # t339: "cuda:0 f32[16, 32]"
    # t340 = prims.convert_element_type(t339, dtypes.float8_e4m3fn)  # t340: "cuda:0 f8_e4m3fn[16, 32]"
    # t52 = prims.abs(weight)  # t52: "cuda:0 f32[64, 32]"
    # t53 = prims.amax(t52, (0, 1))  # t53: "cuda:0 f32[]"
    # t54 = prims.convert_element_type(t53, dtypes.float64)  # t54: "cuda:0 f64[]"
    # t490 = prims.ne(t54, t54)  # t490: "cuda:0 b8[]"
    # t491 = prims.gt(t54, 1e-12)  # t491: "cuda:0 b8[]"
    # t492 = prims.where(t491, t54, 1e-12)  # t492: "cuda:0 f64[]"
    # t58 = prims.where(t490, t54, t492)  # t58: "cuda:0 f64[]"
    # t59 = prims.div(448.0, t58)  # t59: "cuda:0 f64[]"
    # weight_scale = prims.convert_element_type(t59, dtypes.float32)  # weight_scale: "cuda:0 f32[]"
    # t496 = prims.broadcast_in_dim(weight_scale, (64, 32), ())  # t496: "cuda:0 f32[64, 32]"
    # t352 = prims.mul(weight, t496)  # t352: "cuda:0 f32[64, 32]"
    # t498 = prims.ne(t352, t352)  # t498: "cuda:0 b8[64, 32]"
    # t499 = prims.gt(t352, -448.0)  # t499: "cuda:0 b8[64, 32]"
    # t500 = prims.where(t499, t352, -448.0)  # t500: "cuda:0 f32[64, 32]"
    # t501 = prims.where(t498, t352, t500)  # t501: "cuda:0 f32[64, 32]"
    # t502 = prims.ne(t501, t501)  # t502: "cuda:0 b8[64, 32]"
    # t503 = prims.lt(t501, 448.0)  # t503: "cuda:0 b8[64, 32]"
    # t504 = prims.where(t503, t501, 448.0)  # t504: "cuda:0 f32[64, 32]"
    # t360 = prims.where(t502, t501, t504)  # t360: "cuda:0 f32[64, 32]"
    # t361 = prims.convert_element_type(t360, dtypes.float8_e4m3fn)  # t361: "cuda:0 f8_e4m3fn[64, 32]"
    # t431 = prims.transpose(t361, (1, 0))  # t431: "cuda:0 f8_e4m3fn[32, 64]"
    # t438 = prims.reshape(t340, (16, 32))  # t438: "cuda:0 f8_e4m3fn[16, 32]"
    # t441 = prims.transpose(t431, (1, 0))  # t441: "cuda:0 f8_e4m3fn[64, 32]"
    # t444 = prims.reciprocal(scale)  # t444: "cuda:0 f32[]"
    # t445 = prims.reciprocal(weight_scale)  # t445: "cuda:0 f32[]"
    # t217 = prims.abs(t_2_weight)  # t217: "cuda:0 f32[64, 64]"
    # t218 = prims.amax(t217, (0, 1))  # t218: "cuda:0 f32[]"
    # t219 = prims.convert_element_type(t218, dtypes.float64)  # t219: "cuda:0 f64[]"
    # t552 = prims.ne(t219, t219)  # t552: "cuda:0 b8[]"
    # t553 = prims.gt(t219, 1e-12)  # t553: "cuda:0 b8[]"
    # t554 = prims.where(t553, t219, 1e-12)  # t554: "cuda:0 f64[]"
    # t223 = prims.where(t552, t219, t554)  # t223: "cuda:0 f64[]"
    # t224 = prims.div(448.0, t223)  # t224: "cuda:0 f64[]"
    # t225 = prims.convert_element_type(t224, dtypes.float32)  # t225: "cuda:0 f32[]"
    # t558 = prims.broadcast_in_dim(t225, (64, 64), ())  # t558: "cuda:0 f32[64, 64]"
    # t409 = prims.mul(t_2_weight, t558)  # t409: "cuda:0 f32[64, 64]"
    # t560 = prims.ne(t409, t409)  # t560: "cuda:0 b8[64, 64]"
    # t561 = prims.gt(t409, -448.0)  # t561: "cuda:0 b8[64, 64]"
    # t562 = prims.where(t561, t409, -448.0)  # t562: "cuda:0 f32[64, 64]"
    # t563 = prims.where(t560, t409, t562)  # t563: "cuda:0 f32[64, 64]"
    # t564 = prims.ne(t563, t563)  # t564: "cuda:0 b8[64, 64]"
    # t565 = prims.lt(t563, 448.0)  # t565: "cuda:0 b8[64, 64]"
    # t566 = prims.where(t565, t563, 448.0)  # t566: "cuda:0 f32[64, 64]"
    # t417 = prims.where(t564, t563, t566)  # t417: "cuda:0 f32[64, 64]"
    # t418 = prims.convert_element_type(t417, dtypes.float8_e4m3fn)  # t418: "cuda:0 f8_e4m3fn[64, 64]"
    # t452 = prims.transpose(t418, (1, 0))  # t452: "cuda:0 f8_e4m3fn[64, 64]"
    # t462 = prims.transpose(t452, (1, 0))  # t462: "cuda:0 f8_e4m3fn[64, 64]"
    # t466 = prims.reciprocal(t225)  # t466: "cuda:0 f32[]"

  # /opt/pytorch/lightning-thunder/thunder/core/proxies.py:1966: 	                self.requires_grad,
  input_fp8 = Float8Tensor(t340, scale, torch.float32, LinearMMConfig(output=ScaledMMConfig(emulate=False, use_fast_accum=True, fp8_output=False, pad_inner_dim=False), grad_input=ScaledMMConfig(emulate=False, use_fast_accum=False, fp8_output=False, pad_inner_dim=False), grad_weight=ScaledMMConfig(emulate=False, use_fast_accum=False, fp8_output=False, pad_inner_dim=False)), _GemmInputRole_5, None)  # input_fp8: "cuda:0 f32[16, 32]"
  del t340, scale
  t442 = torch.clone(t441)  # t442: "cuda:0 f8_e4m3fn[64, 32]"
    # t442 = ltorch.clone(t441, memory_format=_torch_memory_format_7)  # t442: "cuda:0 f8_e4m3fn[64, 32]"
      # t442 = prims.clone(t441)  # t442: "cuda:0 f8_e4m3fn[64, 32]"
  del t441
  t453 = unflatten_tensor_subclass(_torch__C__TensorMeta_4, {'_data': t452, '_scale': t225}, {'_orig_dtype': torch.float32, '_linear_mm_config': LinearMMConfig(output=ScaledMMConfig(emulate=False, use_fast_accum=True, fp8_output=False, pad_inner_dim=False), grad_input=ScaledMMConfig(emulate=False, use_fast_accum=False, fp8_output=False, pad_inner_dim=False), grad_weight=ScaledMMConfig(emulate=False, use_fast_accum=False, fp8_output=False, pad_inner_dim=False)), '_gemm_input_role': _GemmInputRole_6, '_axiswise_dim': None})  # t453: "cuda:0 f32[64, 64]"
  del t452, t225
  t463 = torch.clone(t462)  # t463: "cuda:0 f8_e4m3fn[64, 64]"
    # t463 = ltorch.clone(t462, memory_format=_torch_memory_format_7)  # t463: "cuda:0 f8_e4m3fn[64, 64]"
      # t463 = prims.clone(t462)  # t463: "cuda:0 f8_e4m3fn[64, 64]"
  del t462
  [t515, t577] = nvFusion1(t442, t463)
    # t443 = prims.transpose(t442, (1, 0))  # t443: "cuda:0 f8_e4m3fn[32, 64]"
    # t513 = prims.transpose(t443, (1, 0))  # t513: "cuda:0 f8_e4m3fn[64, 32]"
    # t514 = prims.stride_order(t513, (1, 0))  # t514: "cuda:0 f8_e4m3fn[64, 32]"
    # t515 = prims.transpose(t514, (1, 0))  # t515: "cuda:0 f8_e4m3fn[32, 64]"
    # t464 = prims.transpose(t463, (1, 0))  # t464: "cuda:0 f8_e4m3fn[64, 64]"
    # t575 = prims.transpose(t464, (1, 0))  # t575: "cuda:0 f8_e4m3fn[64, 64]"
    # t576 = prims.stride_order(t575, (1, 0))  # t576: "cuda:0 f8_e4m3fn[64, 64]"
    # t577 = prims.transpose(t576, (1, 0))  # t577: "cuda:0 f8_e4m3fn[64, 64]"
  del t442, t463
  t446 = torch._scaled_mm(t438, t515, t444, t445, None, None, torch.float32, True)  # t446: "cuda:0 f32[16, 64]"
  del t438, t515, t444, t445

  # /usr/local/lib/python3.12/dist-packages/torchao/float8/float8_linear.py:104: 	        return grad_input, grad_weight.t()
  t110 = shallow_copy(t446)  # t110: "cuda:0 f32[16, 64]"
  del t446
  [t156, s, t397, t459, t465] = nvFusion2(t_0_bias, t110)
    # t518 = prims.broadcast_in_dim(t_0_bias, [1, 64], [1])  # t518: "cuda:0 f32[1, 64]"
    # t367 = prims.broadcast_in_dim(t518, (16, 64), (0, 1))  # t367: "cuda:0 f32[16, 64]"
    # t156 = prims.add(t110, t367)  # t156: "cuda:0 f32[16, 64]"
    # t369 = prims.mul(t156, t156)  # t369: "cuda:0 f32[16, 64]"
    # t370 = prims.mul(t369, t156)  # t370: "cuda:0 f32[16, 64]"
    # t371 = prims.mul(0.5, t156)  # t371: "cuda:0 f32[16, 64]"
    # t372 = prims.mul(0.044715, t370)  # t372: "cuda:0 f32[16, 64]"
    # t373 = prims.add(t156, t372)  # t373: "cuda:0 f32[16, 64]"
    # t374 = prims.mul(0.7978845608028654, t373)  # t374: "cuda:0 f32[16, 64]"
    # t375 = prims.tanh(t374)  # t375: "cuda:0 f32[16, 64]"
    # t376 = prims.add(1.0, t375)  # t376: "cuda:0 f32[16, 64]"
    # hp_tensor = prims.mul(t371, t376)  # hp_tensor: "cuda:0 f32[16, 64]"
    # t169 = prims.abs(hp_tensor)  # t169: "cuda:0 f32[16, 64]"
    # t170 = prims.amax(t169, (0, 1))  # t170: "cuda:0 f32[]"
    # t171 = prims.convert_element_type(t170, dtypes.float64)  # t171: "cuda:0 f64[]"
    # t532 = prims.ne(t171, t171)  # t532: "cuda:0 b8[]"
    # t533 = prims.gt(t171, 1e-12)  # t533: "cuda:0 b8[]"
    # t534 = prims.where(t533, t171, 1e-12)  # t534: "cuda:0 f64[]"
    # t175 = prims.where(t532, t171, t534)  # t175: "cuda:0 f64[]"
    # t176 = prims.div(448.0, t175)  # t176: "cuda:0 f64[]"
    # s = prims.convert_element_type(t176, dtypes.float32)  # s: "cuda:0 f32[]"
    # t538 = prims.broadcast_in_dim(s, (16, 64), ())  # t538: "cuda:0 f32[16, 64]"
    # t388 = prims.mul(hp_tensor, t538)  # t388: "cuda:0 f32[16, 64]"
    # t540 = prims.ne(t388, t388)  # t540: "cuda:0 b8[16, 64]"
    # t541 = prims.gt(t388, -448.0)  # t541: "cuda:0 b8[16, 64]"
    # t542 = prims.where(t541, t388, -448.0)  # t542: "cuda:0 f32[16, 64]"
    # t543 = prims.where(t540, t388, t542)  # t543: "cuda:0 f32[16, 64]"
    # t544 = prims.ne(t543, t543)  # t544: "cuda:0 b8[16, 64]"
    # t545 = prims.lt(t543, 448.0)  # t545: "cuda:0 b8[16, 64]"
    # t546 = prims.where(t545, t543, 448.0)  # t546: "cuda:0 f32[16, 64]"
    # t396 = prims.where(t544, t543, t546)  # t396: "cuda:0 f32[16, 64]"
    # t397 = prims.convert_element_type(t396, dtypes.float8_e4m3fn)  # t397: "cuda:0 f8_e4m3fn[16, 64]"
    # t459 = prims.reshape(t397, (16, 64))  # t459: "cuda:0 f8_e4m3fn[16, 64]"
    # t465 = prims.reciprocal(s)  # t465: "cuda:0 f32[]"
  del t110

  # /opt/pytorch/lightning-thunder/thunder/core/proxies.py:1966: 	                self.requires_grad,
  t207 = Float8Tensor(t397, s, torch.float32, LinearMMConfig(output=ScaledMMConfig(emulate=False, use_fast_accum=True, fp8_output=False, pad_inner_dim=False), grad_input=ScaledMMConfig(emulate=False, use_fast_accum=False, fp8_output=False, pad_inner_dim=False), grad_weight=ScaledMMConfig(emulate=False, use_fast_accum=False, fp8_output=False, pad_inner_dim=False)), _GemmInputRole_5, None)  # t207: "cuda:0 f32[16, 64]"
  del t397, s
  t467 = torch._scaled_mm(t459, t577, t465, t466, None, None, torch.float32, True)  # t467: "cuda:0 f32[16, 64]"
  del t459, t577, t465, t466

  # /usr/local/lib/python3.12/dist-packages/torchao/float8/float8_linear.py:104: 	        return grad_input, grad_weight.t()
  t274 = shallow_copy(t467)  # t274: "cuda:0 f32[16, 64]"
  del t467
  [t320] = nvFusion3(t_2_bias, t274)
    # t580 = prims.broadcast_in_dim(t_2_bias, [1, 64], [1])  # t580: "cuda:0 f32[1, 64]"
    # t424 = prims.broadcast_in_dim(t580, (16, 64), (0, 1))  # t424: "cuda:0 f32[16, 64]"
    # t320 = prims.add(t274, t424)  # t320: "cuda:0 f32[16, 64]"
  del t274
  return {'output': (t320,), 'flat_args': [input, t_0_bias, weight, t_2_bias, t_2_weight], 'flat_output': (t320,)}, ((input_fp8, t156, t207, t453), ())

# backward
# Constructed by Delete Last Used (took 0 milliseconds)
import thunder.core.devices as devices
import thunder.core.dtypes as dtypes
import torch
from torchao.float8.float8_tensor import Float8Tensor
from torchao.float8.float8_tensor import GemmInputRole
from torchao.float8.float8_tensor import LinearMMConfig
from torchao.float8.float8_tensor import ScaledMMConfig
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def backward_fn(saved_for_backward, cotangents):
  # saved_for_backward: "Collection"
  # cotangents: "Collection"
  C0, _, = saved_for_backward
  clear_mutable_collection(saved_for_backward)
  del saved_for_backward
  t0, = cotangents
  clear_mutable_collection(cotangents)
  del cotangents
  input_fp8, t156, t207, t453, = C0
  clear_mutable_collection(C0)
  del C0
  [t111, t153, t676, t597] = nvFusion0(t0)
    # t111 = prims.sum(t0, (0,))  # t111: "cuda:0 f32[64]"
    # t112 = prims.abs(t0)  # t112: "cuda:0 f32[16, 64]"
    # t122 = prims.amax(t112, (0, 1))  # t122: "cuda:0 f32[]"
    # t135 = prims.convert_element_type(t122, dtypes.float64)  # t135: "cuda:0 f64[]"
    # t660 = prims.ne(t135, t135)  # t660: "cuda:0 b8[]"
    # t661 = prims.gt(t135, 1e-12)  # t661: "cuda:0 b8[]"
    # t662 = prims.where(t661, t135, 1e-12)  # t662: "cuda:0 f64[]"
    # t151 = prims.where(t660, t135, t662)  # t151: "cuda:0 f64[]"
    # t152 = prims.div(57344.0, t151)  # t152: "cuda:0 f64[]"
    # t153 = prims.convert_element_type(t152, dtypes.float32)  # t153: "cuda:0 f32[]"
    # t666 = prims.broadcast_in_dim(t153, (16, 64), ())  # t666: "cuda:0 f32[16, 64]"
    # t667 = prims.mul(t0, t666)  # t667: "cuda:0 f32[16, 64]"
    # t668 = prims.ne(t667, t667)  # t668: "cuda:0 b8[16, 64]"
    # t669 = prims.gt(t667, -57344.0)  # t669: "cuda:0 b8[16, 64]"
    # t670 = prims.where(t669, t667, -57344.0)  # t670: "cuda:0 f32[16, 64]"
    # t671 = prims.where(t668, t667, t670)  # t671: "cuda:0 f32[16, 64]"
    # t672 = prims.ne(t671, t671)  # t672: "cuda:0 b8[16, 64]"
    # t673 = prims.lt(t671, 57344.0)  # t673: "cuda:0 b8[16, 64]"
    # t674 = prims.where(t673, t671, 57344.0)  # t674: "cuda:0 f32[16, 64]"
    # t675 = prims.where(t672, t671, t674)  # t675: "cuda:0 f32[16, 64]"
    # t676 = prims.convert_element_type(t675, dtypes.float8_e5m2)  # t676: "cuda:0 f8_e5m2[16, 64]"
    # t597 = prims.reciprocal(t153)  # t597: "cuda:0 f32[]"
  del t0
  (t41, t12) = flatten_tensor_subclass(input_fp8)
  del input_fp8
  (t452, t225) = flatten_tensor_subclass(t453)
  del t453
  (t206, t177) = flatten_tensor_subclass(t207)
  del t207
  t209 = Float8Tensor(t676, t153, torch.float32, LinearMMConfig(output=ScaledMMConfig(emulate=False, use_fast_accum=True, fp8_output=False, pad_inner_dim=False), grad_input=ScaledMMConfig(emulate=False, use_fast_accum=False, fp8_output=False, pad_inner_dim=False), grad_weight=ScaledMMConfig(emulate=False, use_fast_accum=False, fp8_output=False, pad_inner_dim=False)), _GemmInputRole_4, None)  # t209: "cuda:0 f32[16, 64]"
  del t676, t153
  [t594, t598, t613, t617, t651, t655] = nvFusion1(t452, t225, t206, t177, t41, t12)
    # t591 = prims.transpose(t452, (1, 0))  # t591: "cuda:0 f8_e4m3fn[64, 64]"
    # t594 = prims.transpose(t591, (1, 0))  # t594: "cuda:0 f8_e4m3fn[64, 64]"
    # t598 = prims.reciprocal(t225)  # t598: "cuda:0 f32[]"
    # t605 = prims.reshape(t206, (16, 64))  # t605: "cuda:0 f8_e4m3fn[16, 64]"
    # t613 = prims.transpose(t605, (1, 0))  # t613: "cuda:0 f8_e4m3fn[64, 16]"
    # t617 = prims.reciprocal(t177)  # t617: "cuda:0 f32[]"
    # t643 = prims.reshape(t41, (16, 32))  # t643: "cuda:0 f8_e4m3fn[16, 32]"
    # t651 = prims.transpose(t643, (1, 0))  # t651: "cuda:0 f8_e4m3fn[32, 16]"
    # t655 = prims.reciprocal(t12)  # t655: "cuda:0 f32[]"
  del t452, t225, t206, t177, t41, t12
  (t208, _) = flatten_tensor_subclass(t209)
  del t209
  t652 = torch.clone(t651)  # t652: "cuda:0 f8_e4m3fn[32, 16]"
    # t652 = ltorch.clone(t651, memory_format=_torch_memory_format_5)  # t652: "cuda:0 f8_e4m3fn[32, 16]"
      # t652 = prims.clone(t651)  # t652: "cuda:0 f8_e4m3fn[32, 16]"
  del t651
  t595 = torch.clone(t594)  # t595: "cuda:0 f8_e4m3fn[64, 64]"
    # t595 = ltorch.clone(t594, memory_format=_torch_memory_format_5)  # t595: "cuda:0 f8_e4m3fn[64, 64]"
      # t595 = prims.clone(t594)  # t595: "cuda:0 f8_e4m3fn[64, 64]"
  del t594
  t614 = torch.clone(t613)  # t614: "cuda:0 f8_e4m3fn[64, 16]"
    # t614 = ltorch.clone(t613, memory_format=_torch_memory_format_5)  # t614: "cuda:0 f8_e4m3fn[64, 16]"
      # t614 = prims.clone(t613)  # t614: "cuda:0 f8_e4m3fn[64, 16]"
  del t613
  [t586, t685, t610, t694, t742] = nvFusion2(t208, t595, t614, t652)
    # t586 = prims.reshape(t208, (16, 64))  # t586: "cuda:0 f8_e5m2[16, 64]"
    # t596 = prims.transpose(t595, (1, 0))  # t596: "cuda:0 f8_e4m3fn[64, 64]"
    # t683 = prims.transpose(t596, (1, 0))  # t683: "cuda:0 f8_e4m3fn[64, 64]"
    # t684 = prims.stride_order(t683, (1, 0))  # t684: "cuda:0 f8_e4m3fn[64, 64]"
    # t685 = prims.transpose(t684, (1, 0))  # t685: "cuda:0 f8_e4m3fn[64, 64]"
    # t610 = prims.transpose(t586, (1, 0))  # t610: "cuda:0 f8_e5m2[64, 16]"
    # t615 = prims.transpose(t614, (1, 0))  # t615: "cuda:0 f8_e4m3fn[16, 64]"
    # t692 = prims.transpose(t615, (1, 0))  # t692: "cuda:0 f8_e4m3fn[64, 16]"
    # t693 = prims.stride_order(t692, (1, 0))  # t693: "cuda:0 f8_e4m3fn[64, 16]"
    # t694 = prims.transpose(t693, (1, 0))  # t694: "cuda:0 f8_e4m3fn[16, 64]"
    # t653 = prims.transpose(t652, (1, 0))  # t653: "cuda:0 f8_e4m3fn[16, 32]"
    # t740 = prims.transpose(t653, (1, 0))  # t740: "cuda:0 f8_e4m3fn[32, 16]"
    # t741 = prims.stride_order(t740, (1, 0))  # t741: "cuda:0 f8_e4m3fn[32, 16]"
    # t742 = prims.transpose(t741, (1, 0))  # t742: "cuda:0 f8_e4m3fn[16, 32]"
  del t208, t595, t614, t652
  t599 = torch._scaled_mm(t586, t685, t597, t598, None, None, torch.float32, False)  # t599: "cuda:0 f32[16, 64]"
  del t586, t685, t598
  t618 = torch._scaled_mm(t610, t694, t597, t617, None, None, torch.float32, False)  # t618: "cuda:0 f32[64, 64]"
  del t610, t694, t597, t617
  [t226, t315, t431, t732, t654] = nvFusion3(t156, t599, t618)
    # t369 = prims.mul(t156, t156)  # t369: "cuda:0 f32[16, 64]"
    # t370 = prims.mul(t369, t156)  # t370: "cuda:0 f32[16, 64]"
    # t372 = prims.mul(0.044715, t370)  # t372: "cuda:0 f32[16, 64]"
    # t373 = prims.add(t156, t372)  # t373: "cuda:0 f32[16, 64]"
    # t374 = prims.mul(0.7978845608028654, t373)  # t374: "cuda:0 f32[16, 64]"
    # t375 = prims.tanh(t374)  # t375: "cuda:0 f32[16, 64]"
    # t254 = prims.mul(t375, t375)  # t254: "cuda:0 f32[16, 64]"
    # t371 = prims.mul(0.5, t156)  # t371: "cuda:0 f32[16, 64]"
    # t255 = prims.sub(1.0, t254)  # t255: "cuda:0 f32[16, 64]"
    # t240 = prims.mul(t371, t599)  # t240: "cuda:0 f32[16, 64]"
    # t256 = prims.mul(t240, t255)  # t256: "cuda:0 f32[16, 64]"
    # t376 = prims.add(1.0, t375)  # t376: "cuda:0 f32[16, 64]"
    # t258 = prims.mul(0.7978845608028654, t256)  # t258: "cuda:0 f32[16, 64]"
    # t239 = prims.mul(t376, t599)  # t239: "cuda:0 f32[16, 64]"
    # t270 = prims.mul(0.044715, t258)  # t270: "cuda:0 f32[16, 64]"
    # t272 = prims.mul(0.5, t239)  # t272: "cuda:0 f32[16, 64]"
    # t275 = prims.mul(t156, t270)  # t275: "cuda:0 f32[16, 64]"
    # t276 = prims.mul(t369, t270)  # t276: "cuda:0 f32[16, 64]"
    # t273 = prims.add(t258, t272)  # t273: "cuda:0 f32[16, 64]"
    # t299 = prims.mul(t156, t275)  # t299: "cuda:0 f32[16, 64]"
    # t286 = prims.add(t273, t276)  # t286: "cuda:0 f32[16, 64]"
    # t301 = prims.add(t286, t299)  # t301: "cuda:0 f32[16, 64]"
    # t314 = prims.add(t301, t299)  # t314: "cuda:0 f32[16, 64]"
    # t316 = prims.abs(t314)  # t316: "cuda:0 f32[16, 64]"
    # t317 = prims.amax(t316, (0, 1))  # t317: "cuda:0 f32[]"
    # t318 = prims.convert_element_type(t317, dtypes.float64)  # t318: "cuda:0 f64[]"
    # t717 = prims.gt(t318, 1e-12)  # t717: "cuda:0 b8[]"
    # t718 = prims.where(t717, t318, 1e-12)  # t718: "cuda:0 f64[]"
    # t716 = prims.ne(t318, t318)  # t716: "cuda:0 b8[]"
    # t429 = prims.where(t716, t318, t718)  # t429: "cuda:0 f64[]"
    # t430 = prims.div(57344.0, t429)  # t430: "cuda:0 f64[]"
    # t431 = prims.convert_element_type(t430, dtypes.float32)  # t431: "cuda:0 f32[]"
    # t722 = prims.broadcast_in_dim(t431, (16, 64), ())  # t722: "cuda:0 f32[16, 64]"
    # t723 = prims.mul(t314, t722)  # t723: "cuda:0 f32[16, 64]"
    # t725 = prims.gt(t723, -57344.0)  # t725: "cuda:0 b8[16, 64]"
    # t726 = prims.where(t725, t723, -57344.0)  # t726: "cuda:0 f32[16, 64]"
    # t724 = prims.ne(t723, t723)  # t724: "cuda:0 b8[16, 64]"
    # t727 = prims.where(t724, t723, t726)  # t727: "cuda:0 f32[16, 64]"
    # t729 = prims.lt(t727, 57344.0)  # t729: "cuda:0 b8[16, 64]"
    # t730 = prims.where(t729, t727, 57344.0)  # t730: "cuda:0 f32[16, 64]"
    # t728 = prims.ne(t727, t727)  # t728: "cuda:0 b8[16, 64]"
    # t731 = prims.where(t728, t727, t730)  # t731: "cuda:0 f32[16, 64]"
    # t216 = prims.transpose(t618, (1, 0))  # t216: "cuda:0 f32[64, 64]"
    # t654 = prims.reciprocal(t431)  # t654: "cuda:0 f32[]"
    # t732 = prims.convert_element_type(t731, dtypes.float8_e5m2)  # t732: "cuda:0 f8_e5m2[16, 64]"
    # t315 = prims.sum(t314, (0,))  # t315: "cuda:0 f32[64]"
    # t226 = prims.transpose(t216, (1, 0))  # t226: "cuda:0 f32[64, 64]"
  del t156, t599, t618
  t443 = Float8Tensor(t732, t431, torch.float32, LinearMMConfig(output=ScaledMMConfig(emulate=False, use_fast_accum=True, fp8_output=False, pad_inner_dim=False), grad_input=ScaledMMConfig(emulate=False, use_fast_accum=False, fp8_output=False, pad_inner_dim=False), grad_weight=ScaledMMConfig(emulate=False, use_fast_accum=False, fp8_output=False, pad_inner_dim=False)), _GemmInputRole_4, None)  # t443: "cuda:0 f32[16, 64]"
  del t732, t431
  (t442, _) = flatten_tensor_subclass(t443)
  del t443
  [t648] = nvFusion4(t442)
    # t624 = prims.reshape(t442, (16, 64))  # t624: "cuda:0 f8_e5m2[16, 64]"
    # t648 = prims.transpose(t624, (1, 0))  # t648: "cuda:0 f8_e5m2[64, 16]"
  del t442
  t656 = torch._scaled_mm(t648, t742, t654, t655, None, None, torch.float32, False)  # t656: "cuda:0 f32[64, 32]"
  del t648, t742, t654, t655
  [t451] = nvFusion5(t656)
    # t450 = prims.transpose(t656, (1, 0))  # t450: "cuda:0 f32[32, 64]"
    # t451 = prims.transpose(t450, (1, 0))  # t451: "cuda:0 f32[64, 32]"
  del t656
  return (None, t315, t451, t111, t226)
bf16, ReLU / GELU -> not working

ReLU -> NVIDIA/Fuser#3609
GELU -> #1567

Commit: 0893fbe

Model:

batch_size, in_features, out_features = 16, 32, 64
device = torch.device("cuda")
dtype = torch.bfloat16
bias = True
model = convert_to_float8_training(nn.Sequential(
    nn.Linear(in_features, out_features, bias=bias),
    nn.ReLU(),
    nn.Linear(out_features, out_features, bias=bias),
).to(device=device, dtype=dtype))
# forward
# Constructed by Delete Last Used (took 0 milliseconds)
import thunder.core.devices as devices
import thunder.core.dtypes as dtypes
import torch
from torchao.float8.float8_tensor import Float8Tensor
from torchao.float8.float8_tensor import ScaledMMConfig
from torchao.float8.float8_tensor import GemmInputRole
from torchao.float8.float8_tensor import LinearMMConfig
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def computation(input, t_0_bias, weight, t_2_bias, t_2_weight):
  # input: "cuda:0 bf16[16, 32]"
  # t_0_bias: "cuda:0 bf16[64]"
  # weight: "cuda:0 bf16[64, 32]"
  # t_2_bias: "cuda:0 bf16[64]"
  # t_2_weight: "cuda:0 bf16[64, 64]"
  [scale, t380, t492, t495, t498, t499, t249, t506, t516, t520] = nvFusion0(input, weight, t_2_weight)
    # t522 = prims.convert_element_type(input, dtypes.float32)  # t522: "cuda:0 f32[16, 32]"
    # t523 = prims.abs(t522)  # t523: "cuda:0 f32[16, 32]"
    # t360 = prims.amax(t523, (0, 1))  # t360: "cuda:0 f32[]"
    # t9 = prims.convert_element_type(t360, dtypes.float64)  # t9: "cuda:0 f64[]"
    # t526 = prims.ne(t9, t9)  # t526: "cuda:0 b8[]"
    # t527 = prims.gt(t9, 1e-12)  # t527: "cuda:0 b8[]"
    # t528 = prims.where(t527, t9, 1e-12)  # t528: "cuda:0 f64[]"
    # t14 = prims.where(t526, t9, t528)  # t14: "cuda:0 f64[]"
    # res = prims.div(448.0, t14)  # res: "cuda:0 f64[]"
    # scale = prims.convert_element_type(res, dtypes.float32)  # scale: "cuda:0 f32[]"
    # t533 = prims.broadcast_in_dim(scale, (16, 32), ())  # t533: "cuda:0 f32[16, 32]"
    # t371 = prims.mul(t522, t533)  # t371: "cuda:0 f32[16, 32]"
    # t535 = prims.ne(t371, t371)  # t535: "cuda:0 b8[16, 32]"
    # t536 = prims.gt(t371, -448.0)  # t536: "cuda:0 b8[16, 32]"
    # t537 = prims.where(t536, t371, -448.0)  # t537: "cuda:0 f32[16, 32]"
    # t538 = prims.where(t535, t371, t537)  # t538: "cuda:0 f32[16, 32]"
    # t539 = prims.ne(t538, t538)  # t539: "cuda:0 b8[16, 32]"
    # t540 = prims.lt(t538, 448.0)  # t540: "cuda:0 b8[16, 32]"
    # t541 = prims.where(t540, t538, 448.0)  # t541: "cuda:0 f32[16, 32]"
    # t379 = prims.where(t539, t538, t541)  # t379: "cuda:0 f32[16, 32]"
    # t380 = prims.convert_element_type(t379, dtypes.float8_e4m3fn)  # t380: "cuda:0 f8_e4m3fn[16, 32]"
    # t545 = prims.convert_element_type(weight, dtypes.float32)  # t545: "cuda:0 f32[64, 32]"
    # t546 = prims.abs(t545)  # t546: "cuda:0 f32[64, 32]"
    # t386 = prims.amax(t546, (0, 1))  # t386: "cuda:0 f32[]"
    # t64 = prims.convert_element_type(t386, dtypes.float64)  # t64: "cuda:0 f64[]"
    # t549 = prims.ne(t64, t64)  # t549: "cuda:0 b8[]"
    # t550 = prims.gt(t64, 1e-12)  # t550: "cuda:0 b8[]"
    # t551 = prims.where(t550, t64, 1e-12)  # t551: "cuda:0 f64[]"
    # t68 = prims.where(t549, t64, t551)  # t68: "cuda:0 f64[]"
    # t69 = prims.div(448.0, t68)  # t69: "cuda:0 f64[]"
    # weight_scale = prims.convert_element_type(t69, dtypes.float32)  # weight_scale: "cuda:0 f32[]"
    # t556 = prims.broadcast_in_dim(weight_scale, (64, 32), ())  # t556: "cuda:0 f32[64, 32]"
    # t397 = prims.mul(t545, t556)  # t397: "cuda:0 f32[64, 32]"
    # t558 = prims.ne(t397, t397)  # t558: "cuda:0 b8[64, 32]"
    # t559 = prims.gt(t397, -448.0)  # t559: "cuda:0 b8[64, 32]"
    # t560 = prims.where(t559, t397, -448.0)  # t560: "cuda:0 f32[64, 32]"
    # t561 = prims.where(t558, t397, t560)  # t561: "cuda:0 f32[64, 32]"
    # t562 = prims.ne(t561, t561)  # t562: "cuda:0 b8[64, 32]"
    # t563 = prims.lt(t561, 448.0)  # t563: "cuda:0 b8[64, 32]"
    # t564 = prims.where(t563, t561, 448.0)  # t564: "cuda:0 f32[64, 32]"
    # t405 = prims.where(t562, t561, t564)  # t405: "cuda:0 f32[64, 32]"
    # t406 = prims.convert_element_type(t405, dtypes.float8_e4m3fn)  # t406: "cuda:0 f8_e4m3fn[64, 32]"
    # t485 = prims.transpose(t406, (1, 0))  # t485: "cuda:0 f8_e4m3fn[32, 64]"
    # t492 = prims.reshape(t380, (16, 32))  # t492: "cuda:0 f8_e4m3fn[16, 32]"
    # t495 = prims.transpose(t485, (1, 0))  # t495: "cuda:0 f8_e4m3fn[64, 32]"
    # t498 = prims.reciprocal(scale)  # t498: "cuda:0 f32[]"
    # t499 = prims.reciprocal(weight_scale)  # t499: "cuda:0 f32[]"
    # t609 = prims.convert_element_type(t_2_weight, dtypes.float32)  # t609: "cuda:0 f32[64, 64]"
    # t610 = prims.abs(t609)  # t610: "cuda:0 f32[64, 64]"
    # t449 = prims.amax(t610, (0, 1))  # t449: "cuda:0 f32[]"
    # t243 = prims.convert_element_type(t449, dtypes.float64)  # t243: "cuda:0 f64[]"
    # t613 = prims.ne(t243, t243)  # t613: "cuda:0 b8[]"
    # t614 = prims.gt(t243, 1e-12)  # t614: "cuda:0 b8[]"
    # t615 = prims.where(t614, t243, 1e-12)  # t615: "cuda:0 f64[]"
    # t247 = prims.where(t613, t243, t615)  # t247: "cuda:0 f64[]"
    # t248 = prims.div(448.0, t247)  # t248: "cuda:0 f64[]"
    # t249 = prims.convert_element_type(t248, dtypes.float32)  # t249: "cuda:0 f32[]"
    # t620 = prims.broadcast_in_dim(t249, (64, 64), ())  # t620: "cuda:0 f32[64, 64]"
    # t460 = prims.mul(t609, t620)  # t460: "cuda:0 f32[64, 64]"
    # t622 = prims.ne(t460, t460)  # t622: "cuda:0 b8[64, 64]"
    # t623 = prims.gt(t460, -448.0)  # t623: "cuda:0 b8[64, 64]"
    # t624 = prims.where(t623, t460, -448.0)  # t624: "cuda:0 f32[64, 64]"
    # t625 = prims.where(t622, t460, t624)  # t625: "cuda:0 f32[64, 64]"
    # t626 = prims.ne(t625, t625)  # t626: "cuda:0 b8[64, 64]"
    # t627 = prims.lt(t625, 448.0)  # t627: "cuda:0 b8[64, 64]"
    # t628 = prims.where(t627, t625, 448.0)  # t628: "cuda:0 f32[64, 64]"
    # t468 = prims.where(t626, t625, t628)  # t468: "cuda:0 f32[64, 64]"
    # t469 = prims.convert_element_type(t468, dtypes.float8_e4m3fn)  # t469: "cuda:0 f8_e4m3fn[64, 64]"
    # t506 = prims.transpose(t469, (1, 0))  # t506: "cuda:0 f8_e4m3fn[64, 64]"
    # t516 = prims.transpose(t506, (1, 0))  # t516: "cuda:0 f8_e4m3fn[64, 64]"
    # t520 = prims.reciprocal(t249)  # t520: "cuda:0 f32[]"

  # /opt/pytorch/lightning-thunder/thunder/core/proxies.py:1966: 	                self.requires_grad,
  input_fp8 = Float8Tensor(t380, scale, torch.bfloat16, LinearMMConfig(output=ScaledMMConfig(emulate=False, use_fast_accum=True, fp8_output=False, pad_inner_dim=False), grad_input=ScaledMMConfig(emulate=False, use_fast_accum=False, fp8_output=False, pad_inner_dim=False), grad_weight=ScaledMMConfig(emulate=False, use_fast_accum=False, fp8_output=False, pad_inner_dim=False)), _GemmInputRole_5, None)  # input_fp8: "cuda:0 bf16[16, 32]"
  del t380, scale
  t496 = torch.clone(t495)  # t496: "cuda:0 f8_e4m3fn[64, 32]"
    # t496 = ltorch.clone(t495, memory_format=_torch_memory_format_7)  # t496: "cuda:0 f8_e4m3fn[64, 32]"
      # t496 = prims.clone(t495)  # t496: "cuda:0 f8_e4m3fn[64, 32]"
  del t495
  t507 = unflatten_tensor_subclass(_torch__C__TensorMeta_4, {'_data': t506, '_scale': t249}, {'_orig_dtype': torch.bfloat16, '_linear_mm_config': LinearMMConfig(output=ScaledMMConfig(emulate=False, use_fast_accum=True, fp8_output=False, pad_inner_dim=False), grad_input=ScaledMMConfig(emulate=False, use_fast_accum=False, fp8_output=False, pad_inner_dim=False), grad_weight=ScaledMMConfig(emulate=False, use_fast_accum=False, fp8_output=False, pad_inner_dim=False)), '_gemm_input_role': _GemmInputRole_6, '_axiswise_dim': None})  # t507: "cuda:0 bf16[64, 64]"
  del t506, t249
  t517 = torch.clone(t516)  # t517: "cuda:0 f8_e4m3fn[64, 64]"
    # t517 = ltorch.clone(t516, memory_format=_torch_memory_format_7)  # t517: "cuda:0 f8_e4m3fn[64, 64]"
      # t517 = prims.clone(t516)  # t517: "cuda:0 f8_e4m3fn[64, 64]"
  del t516
  [t575, t639] = nvFusion1(t496, t517)
    # t497 = prims.transpose(t496, (1, 0))  # t497: "cuda:0 f8_e4m3fn[32, 64]"
    # t573 = prims.transpose(t497, (1, 0))  # t573: "cuda:0 f8_e4m3fn[64, 32]"
    # t574 = prims.stride_order(t573, (1, 0))  # t574: "cuda:0 f8_e4m3fn[64, 32]"
    # t575 = prims.transpose(t574, (1, 0))  # t575: "cuda:0 f8_e4m3fn[32, 64]"
    # t518 = prims.transpose(t517, (1, 0))  # t518: "cuda:0 f8_e4m3fn[64, 64]"
    # t637 = prims.transpose(t518, (1, 0))  # t637: "cuda:0 f8_e4m3fn[64, 64]"
    # t638 = prims.stride_order(t637, (1, 0))  # t638: "cuda:0 f8_e4m3fn[64, 64]"
    # t639 = prims.transpose(t638, (1, 0))  # t639: "cuda:0 f8_e4m3fn[64, 64]"
  del t496, t517
  t500 = torch._scaled_mm(t492, t575, t498, t499, None, None, torch.bfloat16, True)  # t500: "cuda:0 bf16[16, 64]"
  del t492, t575, t498, t499

  # /usr/local/lib/python3.12/dist-packages/torchao/float8/float8_linear.py:104: 	        return grad_input, grad_weight.t()
  t122 = shallow_copy(t500)  # t122: "cuda:0 bf16[16, 64]"
  del t500
  [t417, s, t443, t513, t519] = nvFusion2(t_0_bias, t122)
    # t578 = prims.broadcast_in_dim(t_0_bias, [1, 64], [1])  # t578: "cuda:0 bf16[1, 64]"
    # t412 = prims.broadcast_in_dim(t578, (16, 64), (0, 1))  # t412: "cuda:0 bf16[16, 64]"
    # t413 = prims.convert_element_type(t122, dtypes.float32)  # t413: "cuda:0 f32[16, 64]"
    # t414 = prims.convert_element_type(t412, dtypes.float32)  # t414: "cuda:0 f32[16, 64]"
    # t415 = prims.add(t413, t414)  # t415: "cuda:0 f32[16, 64]"
    # t177 = prims.convert_element_type(t415, dtypes.bfloat16)  # t177: "cuda:0 bf16[16, 64]"
    # t417 = prims.gt(t177, 0.0)  # t417: "cuda:0 b8[16, 64]"
    # hp_tensor = prims.where(t417, t177, 0.0)  # hp_tensor: "cuda:0 bf16[16, 64]"
    # t586 = prims.convert_element_type(hp_tensor, dtypes.float32)  # t586: "cuda:0 f32[16, 64]"
    # t587 = prims.abs(t586)  # t587: "cuda:0 f32[16, 64]"
    # t423 = prims.amax(t587, (0, 1))  # t423: "cuda:0 f32[]"
    # t189 = prims.convert_element_type(t423, dtypes.float64)  # t189: "cuda:0 f64[]"
    # t590 = prims.ne(t189, t189)  # t590: "cuda:0 b8[]"
    # t591 = prims.gt(t189, 1e-12)  # t591: "cuda:0 b8[]"
    # t592 = prims.where(t591, t189, 1e-12)  # t592: "cuda:0 f64[]"
    # t193 = prims.where(t590, t189, t592)  # t193: "cuda:0 f64[]"
    # t194 = prims.div(448.0, t193)  # t194: "cuda:0 f64[]"
    # s = prims.convert_element_type(t194, dtypes.float32)  # s: "cuda:0 f32[]"
    # t597 = prims.broadcast_in_dim(s, (16, 64), ())  # t597: "cuda:0 f32[16, 64]"
    # t434 = prims.mul(t586, t597)  # t434: "cuda:0 f32[16, 64]"
    # t599 = prims.ne(t434, t434)  # t599: "cuda:0 b8[16, 64]"
    # t600 = prims.gt(t434, -448.0)  # t600: "cuda:0 b8[16, 64]"
    # t601 = prims.where(t600, t434, -448.0)  # t601: "cuda:0 f32[16, 64]"
    # t602 = prims.where(t599, t434, t601)  # t602: "cuda:0 f32[16, 64]"
    # t603 = prims.ne(t602, t602)  # t603: "cuda:0 b8[16, 64]"
    # t604 = prims.lt(t602, 448.0)  # t604: "cuda:0 b8[16, 64]"
    # t605 = prims.where(t604, t602, 448.0)  # t605: "cuda:0 f32[16, 64]"
    # t442 = prims.where(t603, t602, t605)  # t442: "cuda:0 f32[16, 64]"
    # t443 = prims.convert_element_type(t442, dtypes.float8_e4m3fn)  # t443: "cuda:0 f8_e4m3fn[16, 64]"
    # t513 = prims.reshape(t443, (16, 64))  # t513: "cuda:0 f8_e4m3fn[16, 64]"
    # t519 = prims.reciprocal(s)  # t519: "cuda:0 f32[]"
  del t122

  # /opt/pytorch/lightning-thunder/thunder/core/proxies.py:1966: 	                self.requires_grad,
  t227 = Float8Tensor(t443, s, torch.bfloat16, LinearMMConfig(output=ScaledMMConfig(emulate=False, use_fast_accum=True, fp8_output=False, pad_inner_dim=False), grad_input=ScaledMMConfig(emulate=False, use_fast_accum=False, fp8_output=False, pad_inner_dim=False), grad_weight=ScaledMMConfig(emulate=False, use_fast_accum=False, fp8_output=False, pad_inner_dim=False)), _GemmInputRole_5, None)  # t227: "cuda:0 bf16[16, 64]"
  del t443, s
  t521 = torch._scaled_mm(t513, t639, t519, t520, None, None, torch.bfloat16, True)  # t521: "cuda:0 bf16[16, 64]"
  del t513, t639, t519, t520

  # /usr/local/lib/python3.12/dist-packages/torchao/float8/float8_linear.py:104: 	        return grad_input, grad_weight.t()
  t300 = shallow_copy(t521)  # t300: "cuda:0 bf16[16, 64]"
  del t521
  [t355] = nvFusion3(t_2_bias, t300)
    # t642 = prims.broadcast_in_dim(t_2_bias, [1, 64], [1])  # t642: "cuda:0 bf16[1, 64]"
    # t475 = prims.broadcast_in_dim(t642, (16, 64), (0, 1))  # t475: "cuda:0 bf16[16, 64]"
    # t476 = prims.convert_element_type(t300, dtypes.float32)  # t476: "cuda:0 f32[16, 64]"
    # t477 = prims.convert_element_type(t475, dtypes.float32)  # t477: "cuda:0 f32[16, 64]"
    # t478 = prims.add(t476, t477)  # t478: "cuda:0 f32[16, 64]"
    # t355 = prims.convert_element_type(t478, dtypes.bfloat16)  # t355: "cuda:0 bf16[16, 64]"
  del t300
  return {'output': (t355,), 'flat_args': [input, t_0_bias, weight, t_2_bias, t_2_weight], 'flat_output': (t355,)}, ((input_fp8, t227, t417, t507), ())

# backward
# Constructed by Delete Last Used (took 0 milliseconds)
import thunder.core.devices as devices
import thunder.core.dtypes as dtypes
import torch
from torchao.float8.float8_tensor import Float8Tensor
from torchao.float8.float8_tensor import ScaledMMConfig
from torchao.float8.float8_tensor import GemmInputRole
from torchao.float8.float8_tensor import LinearMMConfig
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def backward_fn(saved_for_backward, cotangents):
  # saved_for_backward: "Collection"
  # cotangents: "Collection"
  C0, _, = saved_for_backward
  clear_mutable_collection(saved_for_backward)
  del saved_for_backward
  t0, = cotangents
  clear_mutable_collection(cotangents)
  del cotangents
  input_fp8, t227, t417, t507, = C0
  clear_mutable_collection(C0)
  del C0
  [t104, t170, t746, t659] = nvFusion0(t0)
    # t86 = prims.convert_element_type(t0, dtypes.float32)  # t86: "cuda:0 f32[16, 64]"
    # t723 = prims.sum(t86, (0,))  # t723: "cuda:0 f32[64]"
    # t104 = prims.convert_element_type(t723, dtypes.bfloat16)  # t104: "cuda:0 bf16[64]"
    # t726 = prims.abs(t86)  # t726: "cuda:0 f32[16, 64]"
    # t123 = prims.amax(t726, (0, 1))  # t123: "cuda:0 f32[]"
    # t138 = prims.convert_element_type(t123, dtypes.float64)  # t138: "cuda:0 f64[]"
    # t729 = prims.ne(t138, t138)  # t729: "cuda:0 b8[]"
    # t730 = prims.gt(t138, 1e-12)  # t730: "cuda:0 b8[]"
    # t731 = prims.where(t730, t138, 1e-12)  # t731: "cuda:0 f64[]"
    # t168 = prims.where(t729, t138, t731)  # t168: "cuda:0 f64[]"
    # t169 = prims.div(57344.0, t168)  # t169: "cuda:0 f64[]"
    # t170 = prims.convert_element_type(t169, dtypes.float32)  # t170: "cuda:0 f32[]"
    # t736 = prims.broadcast_in_dim(t170, (16, 64), ())  # t736: "cuda:0 f32[16, 64]"
    # t737 = prims.mul(t86, t736)  # t737: "cuda:0 f32[16, 64]"
    # t738 = prims.ne(t737, t737)  # t738: "cuda:0 b8[16, 64]"
    # t739 = prims.gt(t737, -57344.0)  # t739: "cuda:0 b8[16, 64]"
    # t740 = prims.where(t739, t737, -57344.0)  # t740: "cuda:0 f32[16, 64]"
    # t741 = prims.where(t738, t737, t740)  # t741: "cuda:0 f32[16, 64]"
    # t742 = prims.ne(t741, t741)  # t742: "cuda:0 b8[16, 64]"
    # t743 = prims.lt(t741, 57344.0)  # t743: "cuda:0 b8[16, 64]"
    # t744 = prims.where(t743, t741, 57344.0)  # t744: "cuda:0 f32[16, 64]"
    # t745 = prims.where(t742, t741, t744)  # t745: "cuda:0 f32[16, 64]"
    # t746 = prims.convert_element_type(t745, dtypes.float8_e5m2)  # t746: "cuda:0 f8_e5m2[16, 64]"
    # t659 = prims.reciprocal(t170)  # t659: "cuda:0 f32[]"
  del t0
  (t47, t16) = flatten_tensor_subclass(input_fp8)
  del input_fp8
  (t506, t249) = flatten_tensor_subclass(t507)
  del t507
  (t226, t195) = flatten_tensor_subclass(t227)
  del t227
  t229 = Float8Tensor(t746, t170, torch.bfloat16, LinearMMConfig(output=ScaledMMConfig(emulate=False, use_fast_accum=True, fp8_output=False, pad_inner_dim=False), grad_input=ScaledMMConfig(emulate=False, use_fast_accum=False, fp8_output=False, pad_inner_dim=False), grad_weight=ScaledMMConfig(emulate=False, use_fast_accum=False, fp8_output=False, pad_inner_dim=False)), _GemmInputRole_4, None)  # t229: "cuda:0 bf16[16, 64]"
  del t746, t170
  [t656, t660, t675, t679, t713, t717] = nvFusion1(t506, t249, t226, t195, t47, t16)
    # t653 = prims.transpose(t506, (1, 0))  # t653: "cuda:0 f8_e4m3fn[64, 64]"
    # t656 = prims.transpose(t653, (1, 0))  # t656: "cuda:0 f8_e4m3fn[64, 64]"
    # t660 = prims.reciprocal(t249)  # t660: "cuda:0 f32[]"
    # t667 = prims.reshape(t226, (16, 64))  # t667: "cuda:0 f8_e4m3fn[16, 64]"
    # t675 = prims.transpose(t667, (1, 0))  # t675: "cuda:0 f8_e4m3fn[64, 16]"
    # t679 = prims.reciprocal(t195)  # t679: "cuda:0 f32[]"
    # t705 = prims.reshape(t47, (16, 32))  # t705: "cuda:0 f8_e4m3fn[16, 32]"
    # t713 = prims.transpose(t705, (1, 0))  # t713: "cuda:0 f8_e4m3fn[32, 16]"
    # t717 = prims.reciprocal(t16)  # t717: "cuda:0 f32[]"
  del t506, t249, t226, t195, t47, t16
  (t228, _) = flatten_tensor_subclass(t229)
  del t229
  t714 = torch.clone(t713)  # t714: "cuda:0 f8_e4m3fn[32, 16]"
    # t714 = ltorch.clone(t713, memory_format=_torch_memory_format_5)  # t714: "cuda:0 f8_e4m3fn[32, 16]"
      # t714 = prims.clone(t713)  # t714: "cuda:0 f8_e4m3fn[32, 16]"
  del t713
  t657 = torch.clone(t656)  # t657: "cuda:0 f8_e4m3fn[64, 64]"
    # t657 = ltorch.clone(t656, memory_format=_torch_memory_format_5)  # t657: "cuda:0 f8_e4m3fn[64, 64]"
      # t657 = prims.clone(t656)  # t657: "cuda:0 f8_e4m3fn[64, 64]"
  del t656
  t676 = torch.clone(t675)  # t676: "cuda:0 f8_e4m3fn[64, 16]"
    # t676 = ltorch.clone(t675, memory_format=_torch_memory_format_5)  # t676: "cuda:0 f8_e4m3fn[64, 16]"
      # t676 = prims.clone(t675)  # t676: "cuda:0 f8_e4m3fn[64, 16]"
  del t675
  [t648, t755, t672, t764, t805] = nvFusion2(t228, t657, t676, t714)
    # t648 = prims.reshape(t228, (16, 64))  # t648: "cuda:0 f8_e5m2[16, 64]"
    # t658 = prims.transpose(t657, (1, 0))  # t658: "cuda:0 f8_e4m3fn[64, 64]"
    # t753 = prims.transpose(t658, (1, 0))  # t753: "cuda:0 f8_e4m3fn[64, 64]"
    # t754 = prims.stride_order(t753, (1, 0))  # t754: "cuda:0 f8_e4m3fn[64, 64]"
    # t755 = prims.transpose(t754, (1, 0))  # t755: "cuda:0 f8_e4m3fn[64, 64]"
    # t672 = prims.transpose(t648, (1, 0))  # t672: "cuda:0 f8_e5m2[64, 16]"
    # t677 = prims.transpose(t676, (1, 0))  # t677: "cuda:0 f8_e4m3fn[16, 64]"
    # t762 = prims.transpose(t677, (1, 0))  # t762: "cuda:0 f8_e4m3fn[64, 16]"
    # t763 = prims.stride_order(t762, (1, 0))  # t763: "cuda:0 f8_e4m3fn[64, 16]"
    # t764 = prims.transpose(t763, (1, 0))  # t764: "cuda:0 f8_e4m3fn[16, 64]"
    # t715 = prims.transpose(t714, (1, 0))  # t715: "cuda:0 f8_e4m3fn[16, 32]"
    # t803 = prims.transpose(t715, (1, 0))  # t803: "cuda:0 f8_e4m3fn[32, 16]"
    # t804 = prims.stride_order(t803, (1, 0))  # t804: "cuda:0 f8_e4m3fn[32, 16]"
    # t805 = prims.transpose(t804, (1, 0))  # t805: "cuda:0 f8_e4m3fn[16, 32]"
  del t228, t657, t676, t714
  t661 = torch._scaled_mm(t648, t755, t659, t660, None, None, torch.bfloat16, False)  # t661: "cuda:0 bf16[16, 64]"
  del t648, t755, t660
  t680 = torch._scaled_mm(t672, t764, t659, t679, None, None, torch.bfloat16, False)  # t680: "cuda:0 bf16[64, 64]"
  del t672, t764, t659, t679
  [t250, t284, t349, t795, t716] = nvFusion3(t680, t417, t661)
    # t236 = prims.transpose(t680, (1, 0))  # t236: "cuda:0 bf16[64, 64]"
    # t250 = prims.transpose(t236, (1, 0))  # t250: "cuda:0 bf16[64, 64]"
    # t264 = prims.where(t417, t661, 0.0)  # t264: "cuda:0 bf16[16, 64]"
    # t265 = prims.convert_element_type(t264, dtypes.float32)  # t265: "cuda:0 f32[16, 64]"
    # t772 = prims.sum(t265, (0,))  # t772: "cuda:0 f32[64]"
    # t284 = prims.convert_element_type(t772, dtypes.bfloat16)  # t284: "cuda:0 bf16[64]"
    # t775 = prims.abs(t265)  # t775: "cuda:0 f32[16, 64]"
    # t302 = prims.amax(t775, (0, 1))  # t302: "cuda:0 f32[]"
    # t330 = prims.convert_element_type(t302, dtypes.float64)  # t330: "cuda:0 f64[]"
    # t778 = prims.ne(t330, t330)  # t778: "cuda:0 b8[]"
    # t779 = prims.gt(t330, 1e-12)  # t779: "cuda:0 b8[]"
    # t780 = prims.where(t779, t330, 1e-12)  # t780: "cuda:0 f64[]"
    # t347 = prims.where(t778, t330, t780)  # t347: "cuda:0 f64[]"
    # t348 = prims.div(57344.0, t347)  # t348: "cuda:0 f64[]"
    # t349 = prims.convert_element_type(t348, dtypes.float32)  # t349: "cuda:0 f32[]"
    # t785 = prims.broadcast_in_dim(t349, (16, 64), ())  # t785: "cuda:0 f32[16, 64]"
    # t786 = prims.mul(t265, t785)  # t786: "cuda:0 f32[16, 64]"
    # t787 = prims.ne(t786, t786)  # t787: "cuda:0 b8[16, 64]"
    # t788 = prims.gt(t786, -57344.0)  # t788: "cuda:0 b8[16, 64]"
    # t789 = prims.where(t788, t786, -57344.0)  # t789: "cuda:0 f32[16, 64]"
    # t790 = prims.where(t787, t786, t789)  # t790: "cuda:0 f32[16, 64]"
    # t791 = prims.ne(t790, t790)  # t791: "cuda:0 b8[16, 64]"
    # t792 = prims.lt(t790, 57344.0)  # t792: "cuda:0 b8[16, 64]"
    # t793 = prims.where(t792, t790, 57344.0)  # t793: "cuda:0 f32[16, 64]"
    # t794 = prims.where(t791, t790, t793)  # t794: "cuda:0 f32[16, 64]"
    # t795 = prims.convert_element_type(t794, dtypes.float8_e5m2)  # t795: "cuda:0 f8_e5m2[16, 64]"
    # t716 = prims.reciprocal(t349)  # t716: "cuda:0 f32[]"
  del t680, t417, t661
  t491 = Float8Tensor(t795, t349, torch.bfloat16, LinearMMConfig(output=ScaledMMConfig(emulate=False, use_fast_accum=True, fp8_output=False, pad_inner_dim=False), grad_input=ScaledMMConfig(emulate=False, use_fast_accum=False, fp8_output=False, pad_inner_dim=False), grad_weight=ScaledMMConfig(emulate=False, use_fast_accum=False, fp8_output=False, pad_inner_dim=False)), _GemmInputRole_4, None)  # t491: "cuda:0 bf16[16, 64]"
  del t795, t349
  (t490, _) = flatten_tensor_subclass(t491)
  del t491
  [t710] = nvFusion4(t490)
    # t686 = prims.reshape(t490, (16, 64))  # t686: "cuda:0 f8_e5m2[16, 64]"
    # t710 = prims.transpose(t686, (1, 0))  # t710: "cuda:0 f8_e5m2[64, 16]"
  del t490
  t718 = torch._scaled_mm(t710, t805, t716, t717, None, None, torch.bfloat16, False)  # t718: "cuda:0 bf16[64, 32]"
  del t710, t805, t716, t717
  [t499] = nvFusion5(t718)
    # t498 = prims.transpose(t718, (1, 0))  # t498: "cuda:0 bf16[32, 64]"
    # t499 = prims.transpose(t498, (1, 0))  # t499: "cuda:0 bf16[64, 32]"
  del t718
  return (None, t284, t499, t104, t250)

@crcrpar

This comment was marked as outdated.

@crcrpar

This comment was marked as outdated.

@crcrpar

This comment was marked as outdated.

@crcrpar crcrpar force-pushed the crpa/subclass-torchao_float8tensor branch 2 times, most recently from 896b631 to 316327f Compare November 24, 2024 16:13
@t-vi
Copy link
Collaborator

t-vi commented Nov 25, 2024

@crcrpar if you merge main, the pt nightly distributed ci tests should be fixed.

thunder/__init__.py Outdated Show resolved Hide resolved
thunder/core/jit_ext.py Outdated Show resolved Hide resolved
thunder/core/jit_ext.py Show resolved Hide resolved
thunder/core/prims.py Outdated Show resolved Hide resolved
thunder/executors/torch_autograd.py Outdated Show resolved Hide resolved
@crcrpar crcrpar force-pushed the crpa/subclass-tensor-ops branch from 15c8d12 to 70dc6ba Compare November 28, 2024 12:31
@crcrpar crcrpar force-pushed the crpa/subclass-torchao_float8tensor branch from 04d528a to 804bc99 Compare November 28, 2024 12:32
@crcrpar crcrpar force-pushed the crpa/subclass-torchao_float8tensor branch from 7c1fea6 to 8475ff7 Compare November 30, 2024 07:02
@crcrpar crcrpar force-pushed the crpa/subclass-tensor-ops branch from 70dc6ba to fc6d8a9 Compare December 7, 2024 07:22
@crcrpar crcrpar force-pushed the crpa/subclass-torchao_float8tensor branch 2 times, most recently from ca3b5f7 to 2b30049 Compare December 9, 2024 09:28
@crcrpar crcrpar force-pushed the crpa/subclass-tensor-ops branch from fc6d8a9 to ce3edbc Compare December 12, 2024 23:23
@crcrpar crcrpar force-pushed the crpa/subclass-torchao_float8tensor branch from 2b30049 to 6b73636 Compare December 12, 2024 23:25
@github-actions github-actions bot added the documentation Improvements or additions to documentation label Dec 16, 2024
Comment on lines 2605 to 2652
scale_a: TensorLike,
scale_b: TensorLike,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thunder/executors/nvfuserex_impl.py Outdated Show resolved Hide resolved
@crcrpar crcrpar force-pushed the crpa/subclass-torchao_float8tensor branch from 5fc2dad to 06d4546 Compare December 16, 2024 21:48
@crcrpar crcrpar changed the base branch from crpa/subclass-tensor-ops to main December 17, 2024 04:52
crcrpar and others added 26 commits December 21, 2024 16:13
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
between torch and thunder proxy

Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
since the outputs of subclass flattening would be replaceable with the
args of ctor/unflatten of that subclass tensors.

Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
@crcrpar crcrpar force-pushed the crpa/subclass-torchao_float8tensor branch from 33d6a1a to fce281b Compare December 21, 2024 07:14
due to the inaccurate results compared to `torch._scaled_mm`

Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
of tensor subclasses from `nondifferentiable_vjp_symbols`
since the trace transform of tensor subclasses comes after VJP

Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants