-
Notifications
You must be signed in to change notification settings - Fork 85
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ThunderFX fails with FP8 and Activation Checkpointing #1424
Comments
This seems to be happening due to interaction of TransformerEngine and checkpointing. Minimal Repro import torch
import torch.utils.checkpoint
def checkpointed_fn(x):
y = x.cos()
return torch.nn.functional.linear(x, y)
def fn(x):
return torch.utils.checkpoint.checkpoint(checkpointed_fn, x)
from thunder.dynamo import ThunderCompiler
from thunder.executors.transformer_engineex import transformer_engine_ex
import thunder
backend = ThunderCompiler(executors=[transformer_engine_ex,])
x = torch.randn(16, 16, device='cuda', requires_grad=True)
o = torch.compile(fn, backend=backend)(x)
assert len(backend.subgraph_infos) == 1
subgraph_info = backend.subgraph_infos[0]
tfn = subgraph_info.thunder_compiled_fns[0]
print(thunder.last_traces(tfn)[-1])
print(thunder.last_backward_traces(tfn)[-1])
o.sum().backward() # KeyError: 'scaling_fwd' This happens because in the forward we are calling Forward Graph def computation(l_x_):
# l_x_: "cuda:0 f32[16, 16]"
t4 = torch.cos(l_x_) # t4: "cuda:0 f32[16, 16]"
# t4 = ltorch.cos(l_x_) # t4: "cuda:0 f32[16, 16]"
# t4 = prims.cos(l_x_) # t4: "cuda:0 f32[16, 16]"
getitem = torch.nn.functional.linear(l_x_, t4, None) # getitem: "cuda:0 f32[16, 16]"
# getitem = ltorch.linear(l_x_, t4, None) # getitem: "cuda:0 f32[16, 16]"
# getitem = prims.linear(l_x_, t4, None) # getitem: "cuda:0 f32[16, 16]"
del t4
return {'output': getitem, 'flat_args': [l_x_], 'flat_output': (getitem,)}, ((l_x_,), ()) Backward Graph def backward_fn(saved_for_backward, cotangents):
# saved_for_backward: "Collection"
# cotangents: "Collection"
C0, _, = saved_for_backward
clear_mutable_collection(saved_for_backward)
del saved_for_backward
t0, = cotangents
clear_mutable_collection(cotangents)
del cotangents
l_x_, = C0
clear_mutable_collection(C0)
del C0
t6 = torch.cos(l_x_) # t6: "cuda:0 f32[16, 16]"
# t6 = ltorch.cos(l_x_) # t6: "cuda:0 f32[16, 16]"
# t6 = prims.cos(l_x_) # t6: "cuda:0 f32[16, 16]"
(_, (t10, t11, t12, t13, t14, _), ctx_te_1) = te_linear_0(l_x_, t6, None)
del t6
(t19, t20, _) = te_functional_linear_backward((16, 16), (16, 16), None, ctx_te_1, (t10, t11, t12, t13, t14, None), t0)
del ctx_te_1, t10, t11, t12, t13, t14, t0
t21 = torch.sin(l_x_) # t21: "cuda:0 f32[16, 16]"
# t21 = ltorch.sin(l_x_) # t21: "cuda:0 f32[16, 16]"
# t21 = prims.sin(l_x_) # t21: "cuda:0 f32[16, 16]"
del l_x_
t22 = torch.neg(t21) # t22: "cuda:0 f32[16, 16]"
# t22 = ltorch.neg(t21) # t22: "cuda:0 f32[16, 16]"
# t22 = prims.neg(t21) # t22: "cuda:0 f32[16, 16]"
del t21
t23 = torch.mul(t20, t22) # t23: "cuda:0 f32[16, 16]"
# t23 = ltorch.mul(t20, t22) # t23: "cuda:0 f32[16, 16]"
# t23 = prims.mul(t20, t22) # t23: "cuda:0 f32[16, 16]"
del t20, t22
t24 = torch.add(t19, t23) # t24: "cuda:0 f32[16, 16]"
# t24 = ltorch.add(t19, t23, alpha=1) # t24: "cuda:0 f32[16, 16]"
# t24 = prims.add(t19, t23) # t24: "cuda:0 f32[16, 16]"
del t19, t23
te_sync_fp8_meta_bwd()
return (t24,) @kiya00 do you know why this could be happening? Thanks! |
lightning-thunder/thunder/torch/__init__.py Lines 5319 to 5331 in 60f3ee1
checkpointing uses vjp, is the te_linear_0 in the backward trace the original torch.nn.functional.linear ?
|
lightning-thunder/thunder/core/transforms.py Line 2819 in 60f3ee1
the input trace is:
and after L2819, it seems the linear becomes
|
🐛 Bug
When training models: 'vicuna-7b-v1.5-16k', 'longchat-13b-16k', 'Mistral-7B-v0.2', 'falcon-180B', 'Llama-3-70B', 'CodeLlama-34b-hf' with FSDP and FP8 we get KeyError: 'scaling_fwd'. This might be also issue with Transformer Engine,, so I'm happy to move this issue to TE if needed.
Full traceback:
To Reproduce
Please use:
1 node(s), each with 8 GPUs.
Image "INTERNAL_IMAGE:pjnl_20241107"
Training script:
python /opt/pytorch/lightning-thunder/thunder/benchmarks/benchmark_litgpt.py
--model_name Mistral-7B-v0.2
--distributed_mode fsdp
--shard_mode zero2
--compile dynamo_thunder
--checkpoint_activations True
--low_precision_mode fp8-delayed-te
--micro_batch_size 1
Environment
system.device_product_name DGXH100
system.gpu_driver_version 535.129.03
libraries.cuda 12.6.98.001
libraries.pip.lightning 2.4.0.dev20240728
libraries.pip.lightning-thunder 0.2.0.dev0
libraries.pip.lightning-utilities 0.11.8
libraries.pip.litgpt 0.4.11
libraries.pip.nvfuser 0.2.22+gitba4f7d4
libraries.pip.pytorch-lightning 2.4.0
libraries.pip.torch 2.6.0a0+gita9b4989
libraries.pip.torchao 0.6.1
libraries.pip.torchmetrics 1.5.1
libraries.pip.torchvision 0.19.0a0+d23a6e1
The text was updated successfully, but these errors were encountered: