-
Notifications
You must be signed in to change notification settings - Fork 73
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FP8 training support #184
base: main
Are you sure you want to change the base?
FP8 training support #184
Conversation
Seem to be getting an error when running multi-GPU training:
Possibly helpful references:
Will try to investigate more when I find time. Accelerate seems to be able to handle fp8 mixed precision just fine (atleast on Hopper), so will try to poke around |
Working fine for all (Hunyuan, LTX, Cog) on a single GPU. Tested with 49x480x720 resolution. BF16 Trace with precomputation (about 45 GB required without validation): FP8 Trace with precomputation (about 32 GB required without validation): Notes:
|
The above numbers were with precomputation. This is what we have without precomputation at
For all cases, the numbers are without performing validation at all. If we do validation, memory blows up further |
Training with single images works completely under 20 GB unless validation is performed (in which case it peaks around 57 GB. HunyuanVideo can be finetuned for styles/characters and a lot more with just images, so this is a great win! logs
|
Multi-GPU layerwise upcasting training seems to work without errors for Hopper (tested on 8xH100) and Ada (tested on 2x RTX 4090). Seems like Ampere doesn't have the relevant bits implemented for fp8 DDP as mentioned in the linked issue, which is where the above error stack trace comes from. |
Oh no, it failed on Hopper too (8x DDP)... Training works, but it failed when validation started. This is because I did not account for some cases in the pipeline like guidance preparation.
|
I've done a brief test with LTX-Video with this branch, and I'm seeing a ~13% VRAM reduction when using float8_e4m3fn. Speed wise it seems to be about the same. Edit: Oh, and keep up the good work! ;) Edit 2: Actually, during the short time I ran it, it seems to be about 5-10% faster per step as well. |
@neph1 I think that is expected because from what I understand we're not launching native FP8 kernels |
Just for future self-reference for debugging, dtypes of different layers: dtypesx_embedder.proj torch.float8_e4m3fn
context_embedder.time_text_embed.timestep_embedder.linear_1 torch.float8_e4m3fn
context_embedder.time_text_embed.timestep_embedder.linear_2 torch.float8_e4m3fn
context_embedder.time_text_embed.text_embedder.linear_1 torch.float8_e4m3fn
context_embedder.time_text_embed.text_embedder.linear_2 torch.float8_e4m3fn
context_embedder.proj_in torch.float8_e4m3fn
context_embedder.token_refiner.refiner_blocks.0.norm1 torch.bfloat16
context_embedder.token_refiner.refiner_blocks.0.attn.to_q torch.float8_e4m3fn
context_embedder.token_refiner.refiner_blocks.0.attn.to_q.base_layer torch.float8_e4m3fn
context_embedder.token_refiner.refiner_blocks.0.attn.to_q.lora_A.default torch.float32
context_embedder.token_refiner.refiner_blocks.0.attn.to_q.lora_B.default torch.float32
context_embedder.token_refiner.refiner_blocks.0.attn.to_k torch.float8_e4m3fn
context_embedder.token_refiner.refiner_blocks.0.attn.to_k.base_layer torch.float8_e4m3fn
context_embedder.token_refiner.refiner_blocks.0.attn.to_k.lora_A.default torch.float32
context_embedder.token_refiner.refiner_blocks.0.attn.to_k.lora_B.default torch.float32
context_embedder.token_refiner.refiner_blocks.0.attn.to_v torch.float8_e4m3fn
context_embedder.token_refiner.refiner_blocks.0.attn.to_v.base_layer torch.float8_e4m3fn
context_embedder.token_refiner.refiner_blocks.0.attn.to_v.lora_A.default torch.float32
context_embedder.token_refiner.refiner_blocks.0.attn.to_v.lora_B.default torch.float32
context_embedder.token_refiner.refiner_blocks.0.attn.to_out.0 torch.float8_e4m3fn
context_embedder.token_refiner.refiner_blocks.0.attn.to_out.0.base_layer torch.float8_e4m3fn
context_embedder.token_refiner.refiner_blocks.0.attn.to_out.0.lora_A.default torch.float32
context_embedder.token_refiner.refiner_blocks.0.attn.to_out.0.lora_B.default torch.float32
context_embedder.token_refiner.refiner_blocks.0.norm2 torch.bfloat16
context_embedder.token_refiner.refiner_blocks.0.ff.net.0.proj torch.float8_e4m3fn
context_embedder.token_refiner.refiner_blocks.0.ff.net.2 torch.float8_e4m3fn
context_embedder.token_refiner.refiner_blocks.0.norm_out.linear torch.bfloat16
context_embedder.token_refiner.refiner_blocks.1.norm1 torch.bfloat16
context_embedder.token_refiner.refiner_blocks.1.attn.to_q torch.float8_e4m3fn
context_embedder.token_refiner.refiner_blocks.1.attn.to_q.base_layer torch.float8_e4m3fn
context_embedder.token_refiner.refiner_blocks.1.attn.to_q.lora_A.default torch.float32
context_embedder.token_refiner.refiner_blocks.1.attn.to_q.lora_B.default torch.float32
context_embedder.token_refiner.refiner_blocks.1.attn.to_k torch.float8_e4m3fn
context_embedder.token_refiner.refiner_blocks.1.attn.to_k.base_layer torch.float8_e4m3fn
context_embedder.token_refiner.refiner_blocks.1.attn.to_k.lora_A.default torch.float32
context_embedder.token_refiner.refiner_blocks.1.attn.to_k.lora_B.default torch.float32
context_embedder.token_refiner.refiner_blocks.1.attn.to_v torch.float8_e4m3fn
context_embedder.token_refiner.refiner_blocks.1.attn.to_v.base_layer torch.float8_e4m3fn
context_embedder.token_refiner.refiner_blocks.1.attn.to_v.lora_A.default torch.float32
context_embedder.token_refiner.refiner_blocks.1.attn.to_v.lora_B.default torch.float32
context_embedder.token_refiner.refiner_blocks.1.attn.to_out.0 torch.float8_e4m3fn
context_embedder.token_refiner.refiner_blocks.1.attn.to_out.0.base_layer torch.float8_e4m3fn
context_embedder.token_refiner.refiner_blocks.1.attn.to_out.0.lora_A.default torch.float32
context_embedder.token_refiner.refiner_blocks.1.attn.to_out.0.lora_B.default torch.float32
context_embedder.token_refiner.refiner_blocks.1.norm2 torch.bfloat16
context_embedder.token_refiner.refiner_blocks.1.ff.net.0.proj torch.float8_e4m3fn
context_embedder.token_refiner.refiner_blocks.1.ff.net.2 torch.float8_e4m3fn
context_embedder.token_refiner.refiner_blocks.1.norm_out.linear torch.bfloat16
time_text_embed.timestep_embedder.linear_1 torch.float8_e4m3fn
time_text_embed.timestep_embedder.linear_2 torch.float8_e4m3fn
time_text_embed.guidance_embedder.linear_1 torch.float8_e4m3fn
time_text_embed.guidance_embedder.linear_2 torch.float8_e4m3fn
time_text_embed.text_embedder.linear_1 torch.float8_e4m3fn
time_text_embed.text_embedder.linear_2 torch.float8_e4m3fn
transformer_blocks.0.norm1.linear torch.bfloat16
transformer_blocks.0.norm1_context.linear torch.bfloat16
transformer_blocks.0.attn.norm_q torch.bfloat16
transformer_blocks.0.attn.norm_k torch.bfloat16
transformer_blocks.0.attn.to_q torch.float8_e4m3fn
transformer_blocks.0.attn.to_q.base_layer torch.float8_e4m3fn
transformer_blocks.0.attn.to_q.lora_A.default torch.float32
transformer_blocks.0.attn.to_q.lora_B.default torch.float32
transformer_blocks.0.attn.to_k torch.float8_e4m3fn
transformer_blocks.0.attn.to_k.base_layer torch.float8_e4m3fn
transformer_blocks.0.attn.to_k.lora_A.default torch.float32
transformer_blocks.0.attn.to_k.lora_B.default torch.float32
transformer_blocks.0.attn.to_v torch.float8_e4m3fn
transformer_blocks.0.attn.to_v.base_layer torch.float8_e4m3fn
transformer_blocks.0.attn.to_v.lora_A.default torch.float32
transformer_blocks.0.attn.to_v.lora_B.default torch.float32
transformer_blocks.0.attn.add_k_proj torch.float8_e4m3fn
transformer_blocks.0.attn.add_k_proj.base_layer torch.float8_e4m3fn
transformer_blocks.0.attn.add_k_proj.lora_A.default torch.float32
transformer_blocks.0.attn.add_k_proj.lora_B.default torch.float32
transformer_blocks.0.attn.add_v_proj torch.float8_e4m3fn
transformer_blocks.0.attn.add_v_proj.base_layer torch.float8_e4m3fn
transformer_blocks.0.attn.add_v_proj.lora_A.default torch.float32
transformer_blocks.0.attn.add_v_proj.lora_B.default torch.float32
transformer_blocks.0.attn.add_q_proj torch.float8_e4m3fn
transformer_blocks.0.attn.add_q_proj.base_layer torch.float8_e4m3fn
transformer_blocks.0.attn.add_q_proj.lora_A.default torch.float32
transformer_blocks.0.attn.add_q_proj.lora_B.default torch.float32
transformer_blocks.0.attn.to_out.0 torch.float8_e4m3fn
transformer_blocks.0.attn.to_out.0.base_layer torch.float8_e4m3fn
transformer_blocks.0.attn.to_out.0.lora_A.default torch.float32
transformer_blocks.0.attn.to_out.0.lora_B.default torch.float32
transformer_blocks.0.attn.to_add_out torch.float8_e4m3fn
transformer_blocks.0.attn.to_add_out.base_layer torch.float8_e4m3fn
transformer_blocks.0.attn.to_add_out.lora_A.default torch.float32
transformer_blocks.0.attn.to_add_out.lora_B.default torch.float32
transformer_blocks.0.attn.norm_added_q torch.bfloat16
transformer_blocks.0.attn.norm_added_k torch.bfloat16
transformer_blocks.0.ff.net.0.proj torch.float8_e4m3fn
transformer_blocks.0.ff.net.2 torch.float8_e4m3fn
transformer_blocks.0.ff_context.net.0.proj torch.float8_e4m3fn
transformer_blocks.0.ff_context.net.2 torch.float8_e4m3fn
transformer_blocks.1.norm1.linear torch.bfloat16
transformer_blocks.1.norm1_context.linear torch.bfloat16
transformer_blocks.1.attn.norm_q torch.bfloat16
transformer_blocks.1.attn.norm_k torch.bfloat16
transformer_blocks.1.attn.to_q torch.float8_e4m3fn
transformer_blocks.1.attn.to_q.base_layer torch.float8_e4m3fn
transformer_blocks.1.attn.to_q.lora_A.default torch.float32
transformer_blocks.1.attn.to_q.lora_B.default torch.float32
transformer_blocks.1.attn.to_k torch.float8_e4m3fn
transformer_blocks.1.attn.to_k.base_layer torch.float8_e4m3fn
transformer_blocks.1.attn.to_k.lora_A.default torch.float32
transformer_blocks.1.attn.to_k.lora_B.default torch.float32
transformer_blocks.1.attn.to_v torch.float8_e4m3fn
transformer_blocks.1.attn.to_v.base_layer torch.float8_e4m3fn
transformer_blocks.1.attn.to_v.lora_A.default torch.float32
transformer_blocks.1.attn.to_v.lora_B.default torch.float32
transformer_blocks.1.attn.add_k_proj torch.float8_e4m3fn
transformer_blocks.1.attn.add_k_proj.base_layer torch.float8_e4m3fn
transformer_blocks.1.attn.add_k_proj.lora_A.default torch.float32
transformer_blocks.1.attn.add_k_proj.lora_B.default torch.float32
transformer_blocks.1.attn.add_v_proj torch.float8_e4m3fn
transformer_blocks.1.attn.add_v_proj.base_layer torch.float8_e4m3fn
transformer_blocks.1.attn.add_v_proj.lora_A.default torch.float32
transformer_blocks.1.attn.add_v_proj.lora_B.default torch.float32
transformer_blocks.1.attn.add_q_proj torch.float8_e4m3fn
transformer_blocks.1.attn.add_q_proj.base_layer torch.float8_e4m3fn
transformer_blocks.1.attn.add_q_proj.lora_A.default torch.float32
transformer_blocks.1.attn.add_q_proj.lora_B.default torch.float32
transformer_blocks.1.attn.to_out.0 torch.float8_e4m3fn
transformer_blocks.1.attn.to_out.0.base_layer torch.float8_e4m3fn
transformer_blocks.1.attn.to_out.0.lora_A.default torch.float32
transformer_blocks.1.attn.to_out.0.lora_B.default torch.float32
transformer_blocks.1.attn.to_add_out torch.float8_e4m3fn
transformer_blocks.1.attn.to_add_out.base_layer torch.float8_e4m3fn
transformer_blocks.1.attn.to_add_out.lora_A.default torch.float32
transformer_blocks.1.attn.to_add_out.lora_B.default torch.float32
transformer_blocks.1.attn.norm_added_q torch.bfloat16
transformer_blocks.1.attn.norm_added_k torch.bfloat16
transformer_blocks.1.ff.net.0.proj torch.float8_e4m3fn
transformer_blocks.1.ff.net.2 torch.float8_e4m3fn
transformer_blocks.1.ff_context.net.0.proj torch.float8_e4m3fn
transformer_blocks.1.ff_context.net.2 torch.float8_e4m3fn
transformer_blocks.2.norm1.linear torch.bfloat16
transformer_blocks.2.norm1_context.linear torch.bfloat16
transformer_blocks.2.attn.norm_q torch.bfloat16
transformer_blocks.2.attn.norm_k torch.bfloat16
transformer_blocks.2.attn.to_q torch.float8_e4m3fn
transformer_blocks.2.attn.to_q.base_layer torch.float8_e4m3fn
transformer_blocks.2.attn.to_q.lora_A.default torch.float32
transformer_blocks.2.attn.to_q.lora_B.default torch.float32
transformer_blocks.2.attn.to_k torch.float8_e4m3fn
transformer_blocks.2.attn.to_k.base_layer torch.float8_e4m3fn
transformer_blocks.2.attn.to_k.lora_A.default torch.float32
transformer_blocks.2.attn.to_k.lora_B.default torch.float32
transformer_blocks.2.attn.to_v torch.float8_e4m3fn
transformer_blocks.2.attn.to_v.base_layer torch.float8_e4m3fn
transformer_blocks.2.attn.to_v.lora_A.default torch.float32
transformer_blocks.2.attn.to_v.lora_B.default torch.float32
transformer_blocks.2.attn.add_k_proj torch.float8_e4m3fn
transformer_blocks.2.attn.add_k_proj.base_layer torch.float8_e4m3fn
transformer_blocks.2.attn.add_k_proj.lora_A.default torch.float32
transformer_blocks.2.attn.add_k_proj.lora_B.default torch.float32
transformer_blocks.2.attn.add_v_proj torch.float8_e4m3fn
transformer_blocks.2.attn.add_v_proj.base_layer torch.float8_e4m3fn
transformer_blocks.2.attn.add_v_proj.lora_A.default torch.float32
transformer_blocks.2.attn.add_v_proj.lora_B.default torch.float32
transformer_blocks.2.attn.add_q_proj torch.float8_e4m3fn
transformer_blocks.2.attn.add_q_proj.base_layer torch.float8_e4m3fn
transformer_blocks.2.attn.add_q_proj.lora_A.default torch.float32
transformer_blocks.2.attn.add_q_proj.lora_B.default torch.float32
transformer_blocks.2.attn.to_out.0 torch.float8_e4m3fn
transformer_blocks.2.attn.to_out.0.base_layer torch.float8_e4m3fn
transformer_blocks.2.attn.to_out.0.lora_A.default torch.float32
transformer_blocks.2.attn.to_out.0.lora_B.default torch.float32
transformer_blocks.2.attn.to_add_out torch.float8_e4m3fn
transformer_blocks.2.attn.to_add_out.base_layer torch.float8_e4m3fn
transformer_blocks.2.attn.to_add_out.lora_A.default torch.float32
transformer_blocks.2.attn.to_add_out.lora_B.default torch.float32
transformer_blocks.2.attn.norm_added_q torch.bfloat16
transformer_blocks.2.attn.norm_added_k torch.bfloat16
transformer_blocks.2.ff.net.0.proj torch.float8_e4m3fn
transformer_blocks.2.ff.net.2 torch.float8_e4m3fn
transformer_blocks.2.ff_context.net.0.proj torch.float8_e4m3fn
transformer_blocks.2.ff_context.net.2 torch.float8_e4m3fn
transformer_blocks.3.norm1.linear torch.bfloat16
transformer_blocks.3.norm1_context.linear torch.bfloat16
transformer_blocks.3.attn.norm_q torch.bfloat16
transformer_blocks.3.attn.norm_k torch.bfloat16
transformer_blocks.3.attn.to_q torch.float8_e4m3fn
transformer_blocks.3.attn.to_q.base_layer torch.float8_e4m3fn
transformer_blocks.3.attn.to_q.lora_A.default torch.float32
transformer_blocks.3.attn.to_q.lora_B.default torch.float32
transformer_blocks.3.attn.to_k torch.float8_e4m3fn
transformer_blocks.3.attn.to_k.base_layer torch.float8_e4m3fn
transformer_blocks.3.attn.to_k.lora_A.default torch.float32
transformer_blocks.3.attn.to_k.lora_B.default torch.float32
transformer_blocks.3.attn.to_v torch.float8_e4m3fn
transformer_blocks.3.attn.to_v.base_layer torch.float8_e4m3fn
transformer_blocks.3.attn.to_v.lora_A.default torch.float32
transformer_blocks.3.attn.to_v.lora_B.default torch.float32
transformer_blocks.3.attn.add_k_proj torch.float8_e4m3fn
transformer_blocks.3.attn.add_k_proj.base_layer torch.float8_e4m3fn
transformer_blocks.3.attn.add_k_proj.lora_A.default torch.float32
transformer_blocks.3.attn.add_k_proj.lora_B.default torch.float32
transformer_blocks.3.attn.add_v_proj torch.float8_e4m3fn
transformer_blocks.3.attn.add_v_proj.base_layer torch.float8_e4m3fn
transformer_blocks.3.attn.add_v_proj.lora_A.default torch.float32
transformer_blocks.3.attn.add_v_proj.lora_B.default torch.float32
transformer_blocks.3.attn.add_q_proj torch.float8_e4m3fn
transformer_blocks.3.attn.add_q_proj.base_layer torch.float8_e4m3fn
transformer_blocks.3.attn.add_q_proj.lora_A.default torch.float32
transformer_blocks.3.attn.add_q_proj.lora_B.default torch.float32
transformer_blocks.3.attn.to_out.0 torch.float8_e4m3fn
transformer_blocks.3.attn.to_out.0.base_layer torch.float8_e4m3fn
transformer_blocks.3.attn.to_out.0.lora_A.default torch.float32
transformer_blocks.3.attn.to_out.0.lora_B.default torch.float32
transformer_blocks.3.attn.to_add_out torch.float8_e4m3fn
transformer_blocks.3.attn.to_add_out.base_layer torch.float8_e4m3fn
transformer_blocks.3.attn.to_add_out.lora_A.default torch.float32
transformer_blocks.3.attn.to_add_out.lora_B.default torch.float32
transformer_blocks.3.attn.norm_added_q torch.bfloat16
transformer_blocks.3.attn.norm_added_k torch.bfloat16
transformer_blocks.3.ff.net.0.proj torch.float8_e4m3fn
transformer_blocks.3.ff.net.2 torch.float8_e4m3fn
transformer_blocks.3.ff_context.net.0.proj torch.float8_e4m3fn
transformer_blocks.3.ff_context.net.2 torch.float8_e4m3fn
transformer_blocks.4.norm1.linear torch.bfloat16
transformer_blocks.4.norm1_context.linear torch.bfloat16
transformer_blocks.4.attn.norm_q torch.bfloat16
transformer_blocks.4.attn.norm_k torch.bfloat16
transformer_blocks.4.attn.to_q torch.float8_e4m3fn
transformer_blocks.4.attn.to_q.base_layer torch.float8_e4m3fn
transformer_blocks.4.attn.to_q.lora_A.default torch.float32
transformer_blocks.4.attn.to_q.lora_B.default torch.float32
transformer_blocks.4.attn.to_k torch.float8_e4m3fn
transformer_blocks.4.attn.to_k.base_layer torch.float8_e4m3fn
transformer_blocks.4.attn.to_k.lora_A.default torch.float32
transformer_blocks.4.attn.to_k.lora_B.default torch.float32
transformer_blocks.4.attn.to_v torch.float8_e4m3fn
transformer_blocks.4.attn.to_v.base_layer torch.float8_e4m3fn
transformer_blocks.4.attn.to_v.lora_A.default torch.float32
transformer_blocks.4.attn.to_v.lora_B.default torch.float32
transformer_blocks.4.attn.add_k_proj torch.float8_e4m3fn
transformer_blocks.4.attn.add_k_proj.base_layer torch.float8_e4m3fn
transformer_blocks.4.attn.add_k_proj.lora_A.default torch.float32
transformer_blocks.4.attn.add_k_proj.lora_B.default torch.float32
transformer_blocks.4.attn.add_v_proj torch.float8_e4m3fn
transformer_blocks.4.attn.add_v_proj.base_layer torch.float8_e4m3fn
transformer_blocks.4.attn.add_v_proj.lora_A.default torch.float32
transformer_blocks.4.attn.add_v_proj.lora_B.default torch.float32
transformer_blocks.4.attn.add_q_proj torch.float8_e4m3fn
transformer_blocks.4.attn.add_q_proj.base_layer torch.float8_e4m3fn
transformer_blocks.4.attn.add_q_proj.lora_A.default torch.float32
transformer_blocks.4.attn.add_q_proj.lora_B.default torch.float32
transformer_blocks.4.attn.to_out.0 torch.float8_e4m3fn
transformer_blocks.4.attn.to_out.0.base_layer torch.float8_e4m3fn
transformer_blocks.4.attn.to_out.0.lora_A.default torch.float32
transformer_blocks.4.attn.to_out.0.lora_B.default torch.float32
transformer_blocks.4.attn.to_add_out torch.float8_e4m3fn
transformer_blocks.4.attn.to_add_out.base_layer torch.float8_e4m3fn
transformer_blocks.4.attn.to_add_out.lora_A.default torch.float32
transformer_blocks.4.attn.to_add_out.lora_B.default torch.float32
transformer_blocks.4.attn.norm_added_q torch.bfloat16
transformer_blocks.4.attn.norm_added_k torch.bfloat16
transformer_blocks.4.ff.net.0.proj torch.float8_e4m3fn
transformer_blocks.4.ff.net.2 torch.float8_e4m3fn
transformer_blocks.4.ff_context.net.0.proj torch.float8_e4m3fn
transformer_blocks.4.ff_context.net.2 torch.float8_e4m3fn
transformer_blocks.5.norm1.linear torch.bfloat16
transformer_blocks.5.norm1_context.linear torch.bfloat16
transformer_blocks.5.attn.norm_q torch.bfloat16
transformer_blocks.5.attn.norm_k torch.bfloat16
transformer_blocks.5.attn.to_q torch.float8_e4m3fn
transformer_blocks.5.attn.to_q.base_layer torch.float8_e4m3fn
transformer_blocks.5.attn.to_q.lora_A.default torch.float32
transformer_blocks.5.attn.to_q.lora_B.default torch.float32
transformer_blocks.5.attn.to_k torch.float8_e4m3fn
transformer_blocks.5.attn.to_k.base_layer torch.float8_e4m3fn
transformer_blocks.5.attn.to_k.lora_A.default torch.float32
transformer_blocks.5.attn.to_k.lora_B.default torch.float32
transformer_blocks.5.attn.to_v torch.float8_e4m3fn
transformer_blocks.5.attn.to_v.base_layer torch.float8_e4m3fn
transformer_blocks.5.attn.to_v.lora_A.default torch.float32
transformer_blocks.5.attn.to_v.lora_B.default torch.float32
transformer_blocks.5.attn.add_k_proj torch.float8_e4m3fn
transformer_blocks.5.attn.add_k_proj.base_layer torch.float8_e4m3fn
transformer_blocks.5.attn.add_k_proj.lora_A.default torch.float32
transformer_blocks.5.attn.add_k_proj.lora_B.default torch.float32
transformer_blocks.5.attn.add_v_proj torch.float8_e4m3fn
transformer_blocks.5.attn.add_v_proj.base_layer torch.float8_e4m3fn
transformer_blocks.5.attn.add_v_proj.lora_A.default torch.float32
transformer_blocks.5.attn.add_v_proj.lora_B.default torch.float32
transformer_blocks.5.attn.add_q_proj torch.float8_e4m3fn
transformer_blocks.5.attn.add_q_proj.base_layer torch.float8_e4m3fn
transformer_blocks.5.attn.add_q_proj.lora_A.default torch.float32
transformer_blocks.5.attn.add_q_proj.lora_B.default torch.float32
transformer_blocks.5.attn.to_out.0 torch.float8_e4m3fn
transformer_blocks.5.attn.to_out.0.base_layer torch.float8_e4m3fn
transformer_blocks.5.attn.to_out.0.lora_A.default torch.float32
transformer_blocks.5.attn.to_out.0.lora_B.default torch.float32
transformer_blocks.5.attn.to_add_out torch.float8_e4m3fn
transformer_blocks.5.attn.to_add_out.base_layer torch.float8_e4m3fn
transformer_blocks.5.attn.to_add_out.lora_A.default torch.float32
transformer_blocks.5.attn.to_add_out.lora_B.default torch.float32
transformer_blocks.5.attn.norm_added_q torch.bfloat16
transformer_blocks.5.attn.norm_added_k torch.bfloat16
transformer_blocks.5.ff.net.0.proj torch.float8_e4m3fn
transformer_blocks.5.ff.net.2 torch.float8_e4m3fn
transformer_blocks.5.ff_context.net.0.proj torch.float8_e4m3fn
transformer_blocks.5.ff_context.net.2 torch.float8_e4m3fn
transformer_blocks.6.norm1.linear torch.bfloat16
transformer_blocks.6.norm1_context.linear torch.bfloat16
transformer_blocks.6.attn.norm_q torch.bfloat16
transformer_blocks.6.attn.norm_k torch.bfloat16
transformer_blocks.6.attn.to_q torch.float8_e4m3fn
transformer_blocks.6.attn.to_q.base_layer torch.float8_e4m3fn
transformer_blocks.6.attn.to_q.lora_A.default torch.float32
transformer_blocks.6.attn.to_q.lora_B.default torch.float32
transformer_blocks.6.attn.to_k torch.float8_e4m3fn
transformer_blocks.6.attn.to_k.base_layer torch.float8_e4m3fn
transformer_blocks.6.attn.to_k.lora_A.default torch.float32
transformer_blocks.6.attn.to_k.lora_B.default torch.float32
transformer_blocks.6.attn.to_v torch.float8_e4m3fn
transformer_blocks.6.attn.to_v.base_layer torch.float8_e4m3fn
transformer_blocks.6.attn.to_v.lora_A.default torch.float32
transformer_blocks.6.attn.to_v.lora_B.default torch.float32
transformer_blocks.6.attn.add_k_proj torch.float8_e4m3fn
transformer_blocks.6.attn.add_k_proj.base_layer torch.float8_e4m3fn
transformer_blocks.6.attn.add_k_proj.lora_A.default torch.float32
transformer_blocks.6.attn.add_k_proj.lora_B.default torch.float32
transformer_blocks.6.attn.add_v_proj torch.float8_e4m3fn
transformer_blocks.6.attn.add_v_proj.base_layer torch.float8_e4m3fn
transformer_blocks.6.attn.add_v_proj.lora_A.default torch.float32
transformer_blocks.6.attn.add_v_proj.lora_B.default torch.float32
transformer_blocks.6.attn.add_q_proj torch.float8_e4m3fn
transformer_blocks.6.attn.add_q_proj.base_layer torch.float8_e4m3fn
transformer_blocks.6.attn.add_q_proj.lora_A.default torch.float32
transformer_blocks.6.attn.add_q_proj.lora_B.default torch.float32
transformer_blocks.6.attn.to_out.0 torch.float8_e4m3fn
transformer_blocks.6.attn.to_out.0.base_layer torch.float8_e4m3fn
transformer_blocks.6.attn.to_out.0.lora_A.default torch.float32
transformer_blocks.6.attn.to_out.0.lora_B.default torch.float32
transformer_blocks.6.attn.to_add_out torch.float8_e4m3fn
transformer_blocks.6.attn.to_add_out.base_layer torch.float8_e4m3fn
transformer_blocks.6.attn.to_add_out.lora_A.default torch.float32
transformer_blocks.6.attn.to_add_out.lora_B.default torch.float32
transformer_blocks.6.attn.norm_added_q torch.bfloat16
transformer_blocks.6.attn.norm_added_k torch.bfloat16
transformer_blocks.6.ff.net.0.proj torch.float8_e4m3fn
transformer_blocks.6.ff.net.2 torch.float8_e4m3fn
transformer_blocks.6.ff_context.net.0.proj torch.float8_e4m3fn
transformer_blocks.6.ff_context.net.2 torch.float8_e4m3fn
transformer_blocks.7.norm1.linear torch.bfloat16
transformer_blocks.7.norm1_context.linear torch.bfloat16
transformer_blocks.7.attn.norm_q torch.bfloat16
transformer_blocks.7.attn.norm_k torch.bfloat16
transformer_blocks.7.attn.to_q torch.float8_e4m3fn
transformer_blocks.7.attn.to_q.base_layer torch.float8_e4m3fn
transformer_blocks.7.attn.to_q.lora_A.default torch.float32
transformer_blocks.7.attn.to_q.lora_B.default torch.float32
transformer_blocks.7.attn.to_k torch.float8_e4m3fn
transformer_blocks.7.attn.to_k.base_layer torch.float8_e4m3fn
transformer_blocks.7.attn.to_k.lora_A.default torch.float32
transformer_blocks.7.attn.to_k.lora_B.default torch.float32
transformer_blocks.7.attn.to_v torch.float8_e4m3fn
transformer_blocks.7.attn.to_v.base_layer torch.float8_e4m3fn
transformer_blocks.7.attn.to_v.lora_A.default torch.float32
transformer_blocks.7.attn.to_v.lora_B.default torch.float32
transformer_blocks.7.attn.add_k_proj torch.float8_e4m3fn
transformer_blocks.7.attn.add_k_proj.base_layer torch.float8_e4m3fn
transformer_blocks.7.attn.add_k_proj.lora_A.default torch.float32
transformer_blocks.7.attn.add_k_proj.lora_B.default torch.float32
transformer_blocks.7.attn.add_v_proj torch.float8_e4m3fn
transformer_blocks.7.attn.add_v_proj.base_layer torch.float8_e4m3fn
transformer_blocks.7.attn.add_v_proj.lora_A.default torch.float32
transformer_blocks.7.attn.add_v_proj.lora_B.default torch.float32
transformer_blocks.7.attn.add_q_proj torch.float8_e4m3fn
transformer_blocks.7.attn.add_q_proj.base_layer torch.float8_e4m3fn
transformer_blocks.7.attn.add_q_proj.lora_A.default torch.float32
transformer_blocks.7.attn.add_q_proj.lora_B.default torch.float32
transformer_blocks.7.attn.to_out.0 torch.float8_e4m3fn
transformer_blocks.7.attn.to_out.0.base_layer torch.float8_e4m3fn
transformer_blocks.7.attn.to_out.0.lora_A.default torch.float32
transformer_blocks.7.attn.to_out.0.lora_B.default torch.float32
transformer_blocks.7.attn.to_add_out torch.float8_e4m3fn
transformer_blocks.7.attn.to_add_out.base_layer torch.float8_e4m3fn
transformer_blocks.7.attn.to_add_out.lora_A.default torch.float32
transformer_blocks.7.attn.to_add_out.lora_B.default torch.float32
transformer_blocks.7.attn.norm_added_q torch.bfloat16
transformer_blocks.7.attn.norm_added_k torch.bfloat16
transformer_blocks.7.ff.net.0.proj torch.float8_e4m3fn
transformer_blocks.7.ff.net.2 torch.float8_e4m3fn
transformer_blocks.7.ff_context.net.0.proj torch.float8_e4m3fn
transformer_blocks.7.ff_context.net.2 torch.float8_e4m3fn
transformer_blocks.8.norm1.linear torch.bfloat16
transformer_blocks.8.norm1_context.linear torch.bfloat16
transformer_blocks.8.attn.norm_q torch.bfloat16
transformer_blocks.8.attn.norm_k torch.bfloat16
transformer_blocks.8.attn.to_q torch.float8_e4m3fn
transformer_blocks.8.attn.to_q.base_layer torch.float8_e4m3fn
transformer_blocks.8.attn.to_q.lora_A.default torch.float32
transformer_blocks.8.attn.to_q.lora_B.default torch.float32
transformer_blocks.8.attn.to_k torch.float8_e4m3fn
transformer_blocks.8.attn.to_k.base_layer torch.float8_e4m3fn
transformer_blocks.8.attn.to_k.lora_A.default torch.float32
transformer_blocks.8.attn.to_k.lora_B.default torch.float32
transformer_blocks.8.attn.to_v torch.float8_e4m3fn
transformer_blocks.8.attn.to_v.base_layer torch.float8_e4m3fn
transformer_blocks.8.attn.to_v.lora_A.default torch.float32
transformer_blocks.8.attn.to_v.lora_B.default torch.float32
transformer_blocks.8.attn.add_k_proj torch.float8_e4m3fn
transformer_blocks.8.attn.add_k_proj.base_layer torch.float8_e4m3fn
transformer_blocks.8.attn.add_k_proj.lora_A.default torch.float32
transformer_blocks.8.attn.add_k_proj.lora_B.default torch.float32
transformer_blocks.8.attn.add_v_proj torch.float8_e4m3fn
transformer_blocks.8.attn.add_v_proj.base_layer torch.float8_e4m3fn
transformer_blocks.8.attn.add_v_proj.lora_A.default torch.float32
transformer_blocks.8.attn.add_v_proj.lora_B.default torch.float32
transformer_blocks.8.attn.add_q_proj torch.float8_e4m3fn
transformer_blocks.8.attn.add_q_proj.base_layer torch.float8_e4m3fn
transformer_blocks.8.attn.add_q_proj.lora_A.default torch.float32
transformer_blocks.8.attn.add_q_proj.lora_B.default torch.float32
transformer_blocks.8.attn.to_out.0 torch.float8_e4m3fn
transformer_blocks.8.attn.to_out.0.base_layer torch.float8_e4m3fn
transformer_blocks.8.attn.to_out.0.lora_A.default torch.float32
transformer_blocks.8.attn.to_out.0.lora_B.default torch.float32
transformer_blocks.8.attn.to_add_out torch.float8_e4m3fn
transformer_blocks.8.attn.to_add_out.base_layer torch.float8_e4m3fn
transformer_blocks.8.attn.to_add_out.lora_A.default torch.float32
transformer_blocks.8.attn.to_add_out.lora_B.default torch.float32
transformer_blocks.8.attn.norm_added_q torch.bfloat16
transformer_blocks.8.attn.norm_added_k torch.bfloat16
transformer_blocks.8.ff.net.0.proj torch.float8_e4m3fn
transformer_blocks.8.ff.net.2 torch.float8_e4m3fn
transformer_blocks.8.ff_context.net.0.proj torch.float8_e4m3fn
transformer_blocks.8.ff_context.net.2 torch.float8_e4m3fn
transformer_blocks.9.norm1.linear torch.bfloat16
transformer_blocks.9.norm1_context.linear torch.bfloat16
transformer_blocks.9.attn.norm_q torch.bfloat16
transformer_blocks.9.attn.norm_k torch.bfloat16
transformer_blocks.9.attn.to_q torch.float8_e4m3fn
transformer_blocks.9.attn.to_q.base_layer torch.float8_e4m3fn
transformer_blocks.9.attn.to_q.lora_A.default torch.float32
transformer_blocks.9.attn.to_q.lora_B.default torch.float32
transformer_blocks.9.attn.to_k torch.float8_e4m3fn
transformer_blocks.9.attn.to_k.base_layer torch.float8_e4m3fn
transformer_blocks.9.attn.to_k.lora_A.default torch.float32
transformer_blocks.9.attn.to_k.lora_B.default torch.float32
transformer_blocks.9.attn.to_v torch.float8_e4m3fn
transformer_blocks.9.attn.to_v.base_layer torch.float8_e4m3fn
transformer_blocks.9.attn.to_v.lora_A.default torch.float32
transformer_blocks.9.attn.to_v.lora_B.default torch.float32
transformer_blocks.9.attn.add_k_proj torch.float8_e4m3fn
transformer_blocks.9.attn.add_k_proj.base_layer torch.float8_e4m3fn
transformer_blocks.9.attn.add_k_proj.lora_A.default torch.float32
transformer_blocks.9.attn.add_k_proj.lora_B.default torch.float32
transformer_blocks.9.attn.add_v_proj torch.float8_e4m3fn
transformer_blocks.9.attn.add_v_proj.base_layer torch.float8_e4m3fn
transformer_blocks.9.attn.add_v_proj.lora_A.default torch.float32
transformer_blocks.9.attn.add_v_proj.lora_B.default torch.float32
transformer_blocks.9.attn.add_q_proj torch.float8_e4m3fn
transformer_blocks.9.attn.add_q_proj.base_layer torch.float8_e4m3fn
transformer_blocks.9.attn.add_q_proj.lora_A.default torch.float32
transformer_blocks.9.attn.add_q_proj.lora_B.default torch.float32
transformer_blocks.9.attn.to_out.0 torch.float8_e4m3fn
transformer_blocks.9.attn.to_out.0.base_layer torch.float8_e4m3fn
transformer_blocks.9.attn.to_out.0.lora_A.default torch.float32
transformer_blocks.9.attn.to_out.0.lora_B.default torch.float32
transformer_blocks.9.attn.to_add_out torch.float8_e4m3fn
transformer_blocks.9.attn.to_add_out.base_layer torch.float8_e4m3fn
transformer_blocks.9.attn.to_add_out.lora_A.default torch.float32
transformer_blocks.9.attn.to_add_out.lora_B.default torch.float32
transformer_blocks.9.attn.norm_added_q torch.bfloat16
transformer_blocks.9.attn.norm_added_k torch.bfloat16
transformer_blocks.9.ff.net.0.proj torch.float8_e4m3fn
transformer_blocks.9.ff.net.2 torch.float8_e4m3fn
transformer_blocks.9.ff_context.net.0.proj torch.float8_e4m3fn
transformer_blocks.9.ff_context.net.2 torch.float8_e4m3fn
transformer_blocks.10.norm1.linear torch.bfloat16
transformer_blocks.10.norm1_context.linear torch.bfloat16
transformer_blocks.10.attn.norm_q torch.bfloat16
transformer_blocks.10.attn.norm_k torch.bfloat16
transformer_blocks.10.attn.to_q torch.float8_e4m3fn
transformer_blocks.10.attn.to_q.base_layer torch.float8_e4m3fn
transformer_blocks.10.attn.to_q.lora_A.default torch.float32
transformer_blocks.10.attn.to_q.lora_B.default torch.float32
transformer_blocks.10.attn.to_k torch.float8_e4m3fn
transformer_blocks.10.attn.to_k.base_layer torch.float8_e4m3fn
transformer_blocks.10.attn.to_k.lora_A.default torch.float32
transformer_blocks.10.attn.to_k.lora_B.default torch.float32
transformer_blocks.10.attn.to_v torch.float8_e4m3fn
transformer_blocks.10.attn.to_v.base_layer torch.float8_e4m3fn
transformer_blocks.10.attn.to_v.lora_A.default torch.float32
transformer_blocks.10.attn.to_v.lora_B.default torch.float32
transformer_blocks.10.attn.add_k_proj torch.float8_e4m3fn
transformer_blocks.10.attn.add_k_proj.base_layer torch.float8_e4m3fn
transformer_blocks.10.attn.add_k_proj.lora_A.default torch.float32
transformer_blocks.10.attn.add_k_proj.lora_B.default torch.float32
transformer_blocks.10.attn.add_v_proj torch.float8_e4m3fn
transformer_blocks.10.attn.add_v_proj.base_layer torch.float8_e4m3fn
transformer_blocks.10.attn.add_v_proj.lora_A.default torch.float32
transformer_blocks.10.attn.add_v_proj.lora_B.default torch.float32
transformer_blocks.10.attn.add_q_proj torch.float8_e4m3fn
transformer_blocks.10.attn.add_q_proj.base_layer torch.float8_e4m3fn
transformer_blocks.10.attn.add_q_proj.lora_A.default torch.float32
transformer_blocks.10.attn.add_q_proj.lora_B.default torch.float32
transformer_blocks.10.attn.to_out.0 torch.float8_e4m3fn
transformer_blocks.10.attn.to_out.0.base_layer torch.float8_e4m3fn
transformer_blocks.10.attn.to_out.0.lora_A.default torch.float32
transformer_blocks.10.attn.to_out.0.lora_B.default torch.float32
transformer_blocks.10.attn.to_add_out torch.float8_e4m3fn
transformer_blocks.10.attn.to_add_out.base_layer torch.float8_e4m3fn
transformer_blocks.10.attn.to_add_out.lora_A.default torch.float32
transformer_blocks.10.attn.to_add_out.lora_B.default torch.float32
transformer_blocks.10.attn.norm_added_q torch.bfloat16
transformer_blocks.10.attn.norm_added_k torch.bfloat16
transformer_blocks.10.ff.net.0.proj torch.float8_e4m3fn
transformer_blocks.10.ff.net.2 torch.float8_e4m3fn
transformer_blocks.10.ff_context.net.0.proj torch.float8_e4m3fn
transformer_blocks.10.ff_context.net.2 torch.float8_e4m3fn
transformer_blocks.11.norm1.linear torch.bfloat16
transformer_blocks.11.norm1_context.linear torch.bfloat16
transformer_blocks.11.attn.norm_q torch.bfloat16
transformer_blocks.11.attn.norm_k torch.bfloat16
transformer_blocks.11.attn.to_q torch.float8_e4m3fn
transformer_blocks.11.attn.to_q.base_layer torch.float8_e4m3fn
transformer_blocks.11.attn.to_q.lora_A.default torch.float32
transformer_blocks.11.attn.to_q.lora_B.default torch.float32
transformer_blocks.11.attn.to_k torch.float8_e4m3fn
transformer_blocks.11.attn.to_k.base_layer torch.float8_e4m3fn
transformer_blocks.11.attn.to_k.lora_A.default torch.float32
transformer_blocks.11.attn.to_k.lora_B.default torch.float32
transformer_blocks.11.attn.to_v torch.float8_e4m3fn
transformer_blocks.11.attn.to_v.base_layer torch.float8_e4m3fn
transformer_blocks.11.attn.to_v.lora_A.default torch.float32
transformer_blocks.11.attn.to_v.lora_B.default torch.float32
transformer_blocks.11.attn.add_k_proj torch.float8_e4m3fn
transformer_blocks.11.attn.add_k_proj.base_layer torch.float8_e4m3fn
transformer_blocks.11.attn.add_k_proj.lora_A.default torch.float32
transformer_blocks.11.attn.add_k_proj.lora_B.default torch.float32
transformer_blocks.11.attn.add_v_proj torch.float8_e4m3fn
transformer_blocks.11.attn.add_v_proj.base_layer torch.float8_e4m3fn
transformer_blocks.11.attn.add_v_proj.lora_A.default torch.float32
transformer_blocks.11.attn.add_v_proj.lora_B.default torch.float32
transformer_blocks.11.attn.add_q_proj torch.float8_e4m3fn
transformer_blocks.11.attn.add_q_proj.base_layer torch.float8_e4m3fn
transformer_blocks.11.attn.add_q_proj.lora_A.default torch.float32
transformer_blocks.11.attn.add_q_proj.lora_B.default torch.float32
transformer_blocks.11.attn.to_out.0 torch.float8_e4m3fn
transformer_blocks.11.attn.to_out.0.base_layer torch.float8_e4m3fn
transformer_blocks.11.attn.to_out.0.lora_A.default torch.float32
transformer_blocks.11.attn.to_out.0.lora_B.default torch.float32
transformer_blocks.11.attn.to_add_out torch.float8_e4m3fn
transformer_blocks.11.attn.to_add_out.base_layer torch.float8_e4m3fn
transformer_blocks.11.attn.to_add_out.lora_A.default torch.float32
transformer_blocks.11.attn.to_add_out.lora_B.default torch.float32
transformer_blocks.11.attn.norm_added_q torch.bfloat16
transformer_blocks.11.attn.norm_added_k torch.bfloat16
transformer_blocks.11.ff.net.0.proj torch.float8_e4m3fn
transformer_blocks.11.ff.net.2 torch.float8_e4m3fn
transformer_blocks.11.ff_context.net.0.proj torch.float8_e4m3fn
transformer_blocks.11.ff_context.net.2 torch.float8_e4m3fn
transformer_blocks.12.norm1.linear torch.bfloat16
transformer_blocks.12.norm1_context.linear torch.bfloat16
transformer_blocks.12.attn.norm_q torch.bfloat16
transformer_blocks.12.attn.norm_k torch.bfloat16
transformer_blocks.12.attn.to_q torch.float8_e4m3fn
transformer_blocks.12.attn.to_q.base_layer torch.float8_e4m3fn
transformer_blocks.12.attn.to_q.lora_A.default torch.float32
transformer_blocks.12.attn.to_q.lora_B.default torch.float32
transformer_blocks.12.attn.to_k torch.float8_e4m3fn
transformer_blocks.12.attn.to_k.base_layer torch.float8_e4m3fn
transformer_blocks.12.attn.to_k.lora_A.default torch.float32
transformer_blocks.12.attn.to_k.lora_B.default torch.float32
transformer_blocks.12.attn.to_v torch.float8_e4m3fn
transformer_blocks.12.attn.to_v.base_layer torch.float8_e4m3fn
transformer_blocks.12.attn.to_v.lora_A.default torch.float32
transformer_blocks.12.attn.to_v.lora_B.default torch.float32
transformer_blocks.12.attn.add_k_proj torch.float8_e4m3fn
transformer_blocks.12.attn.add_k_proj.base_layer torch.float8_e4m3fn
transformer_blocks.12.attn.add_k_proj.lora_A.default torch.float32
transformer_blocks.12.attn.add_k_proj.lora_B.default torch.float32
transformer_blocks.12.attn.add_v_proj torch.float8_e4m3fn
transformer_blocks.12.attn.add_v_proj.base_layer torch.float8_e4m3fn
transformer_blocks.12.attn.add_v_proj.lora_A.default torch.float32
transformer_blocks.12.attn.add_v_proj.lora_B.default torch.float32
transformer_blocks.12.attn.add_q_proj torch.float8_e4m3fn
transformer_blocks.12.attn.add_q_proj.base_layer torch.float8_e4m3fn
transformer_blocks.12.attn.add_q_proj.lora_A.default torch.float32
transformer_blocks.12.attn.add_q_proj.lora_B.default torch.float32
transformer_blocks.12.attn.to_out.0 torch.float8_e4m3fn
transformer_blocks.12.attn.to_out.0.base_layer torch.float8_e4m3fn
transformer_blocks.12.attn.to_out.0.lora_A.default torch.float32
transformer_blocks.12.attn.to_out.0.lora_B.default torch.float32
transformer_blocks.12.attn.to_add_out torch.float8_e4m3fn
transformer_blocks.12.attn.to_add_out.base_layer torch.float8_e4m3fn
transformer_blocks.12.attn.to_add_out.lora_A.default torch.float32
transformer_blocks.12.attn.to_add_out.lora_B.default torch.float32
transformer_blocks.12.attn.norm_added_q torch.bfloat16
transformer_blocks.12.attn.norm_added_k torch.bfloat16
transformer_blocks.12.ff.net.0.proj torch.float8_e4m3fn
transformer_blocks.12.ff.net.2 torch.float8_e4m3fn
transformer_blocks.12.ff_context.net.0.proj torch.float8_e4m3fn
transformer_blocks.12.ff_context.net.2 torch.float8_e4m3fn
transformer_blocks.13.norm1.linear torch.bfloat16
transformer_blocks.13.norm1_context.linear torch.bfloat16
transformer_blocks.13.attn.norm_q torch.bfloat16
transformer_blocks.13.attn.norm_k torch.bfloat16
transformer_blocks.13.attn.to_q torch.float8_e4m3fn
transformer_blocks.13.attn.to_q.base_layer torch.float8_e4m3fn
transformer_blocks.13.attn.to_q.lora_A.default torch.float32
transformer_blocks.13.attn.to_q.lora_B.default torch.float32
transformer_blocks.13.attn.to_k torch.float8_e4m3fn
transformer_blocks.13.attn.to_k.base_layer torch.float8_e4m3fn
transformer_blocks.13.attn.to_k.lora_A.default torch.float32
transformer_blocks.13.attn.to_k.lora_B.default torch.float32
transformer_blocks.13.attn.to_v torch.float8_e4m3fn
transformer_blocks.13.attn.to_v.base_layer torch.float8_e4m3fn
transformer_blocks.13.attn.to_v.lora_A.default torch.float32
transformer_blocks.13.attn.to_v.lora_B.default torch.float32
transformer_blocks.13.attn.add_k_proj torch.float8_e4m3fn
transformer_blocks.13.attn.add_k_proj.base_layer torch.float8_e4m3fn
transformer_blocks.13.attn.add_k_proj.lora_A.default torch.float32
transformer_blocks.13.attn.add_k_proj.lora_B.default torch.float32
transformer_blocks.13.attn.add_v_proj torch.float8_e4m3fn
transformer_blocks.13.attn.add_v_proj.base_layer torch.float8_e4m3fn
transformer_blocks.13.attn.add_v_proj.lora_A.default torch.float32
transformer_blocks.13.attn.add_v_proj.lora_B.default torch.float32
transformer_blocks.13.attn.add_q_proj torch.float8_e4m3fn
transformer_blocks.13.attn.add_q_proj.base_layer torch.float8_e4m3fn
transformer_blocks.13.attn.add_q_proj.lora_A.default torch.float32
transformer_blocks.13.attn.add_q_proj.lora_B.default torch.float32
transformer_blocks.13.attn.to_out.0 torch.float8_e4m3fn
transformer_blocks.13.attn.to_out.0.base_layer torch.float8_e4m3fn
transformer_blocks.13.attn.to_out.0.lora_A.default torch.float32
transformer_blocks.13.attn.to_out.0.lora_B.default torch.float32
transformer_blocks.13.attn.to_add_out torch.float8_e4m3fn
transformer_blocks.13.attn.to_add_out.base_layer torch.float8_e4m3fn
transformer_blocks.13.attn.to_add_out.lora_A.default torch.float32
transformer_blocks.13.attn.to_add_out.lora_B.default torch.float32
transformer_blocks.13.attn.norm_added_q torch.bfloat16
transformer_blocks.13.attn.norm_added_k torch.bfloat16
transformer_blocks.13.ff.net.0.proj torch.float8_e4m3fn
transformer_blocks.13.ff.net.2 torch.float8_e4m3fn
transformer_blocks.13.ff_context.net.0.proj torch.float8_e4m3fn
transformer_blocks.13.ff_context.net.2 torch.float8_e4m3fn
transformer_blocks.14.norm1.linear torch.bfloat16
transformer_blocks.14.norm1_context.linear torch.bfloat16
transformer_blocks.14.attn.norm_q torch.bfloat16
transformer_blocks.14.attn.norm_k torch.bfloat16
transformer_blocks.14.attn.to_q torch.float8_e4m3fn
transformer_blocks.14.attn.to_q.base_layer torch.float8_e4m3fn
transformer_blocks.14.attn.to_q.lora_A.default torch.float32
transformer_blocks.14.attn.to_q.lora_B.default torch.float32
transformer_blocks.14.attn.to_k torch.float8_e4m3fn
transformer_blocks.14.attn.to_k.base_layer torch.float8_e4m3fn
transformer_blocks.14.attn.to_k.lora_A.default torch.float32
transformer_blocks.14.attn.to_k.lora_B.default torch.float32
transformer_blocks.14.attn.to_v torch.float8_e4m3fn
transformer_blocks.14.attn.to_v.base_layer torch.float8_e4m3fn
transformer_blocks.14.attn.to_v.lora_A.default torch.float32
transformer_blocks.14.attn.to_v.lora_B.default torch.float32
transformer_blocks.14.attn.add_k_proj torch.float8_e4m3fn
transformer_blocks.14.attn.add_k_proj.base_layer torch.float8_e4m3fn
transformer_blocks.14.attn.add_k_proj.lora_A.default torch.float32
transformer_blocks.14.attn.add_k_proj.lora_B.default torch.float32
transformer_blocks.14.attn.add_v_proj torch.float8_e4m3fn
transformer_blocks.14.attn.add_v_proj.base_layer torch.float8_e4m3fn
transformer_blocks.14.attn.add_v_proj.lora_A.default torch.float32
transformer_blocks.14.attn.add_v_proj.lora_B.default torch.float32
transformer_blocks.14.attn.add_q_proj torch.float8_e4m3fn
transformer_blocks.14.attn.add_q_proj.base_layer torch.float8_e4m3fn
transformer_blocks.14.attn.add_q_proj.lora_A.default torch.float32
transformer_blocks.14.attn.add_q_proj.lora_B.default torch.float32
transformer_blocks.14.attn.to_out.0 torch.float8_e4m3fn
transformer_blocks.14.attn.to_out.0.base_layer torch.float8_e4m3fn
transformer_blocks.14.attn.to_out.0.lora_A.default torch.float32
transformer_blocks.14.attn.to_out.0.lora_B.default torch.float32
transformer_blocks.14.attn.to_add_out torch.float8_e4m3fn
transformer_blocks.14.attn.to_add_out.base_layer torch.float8_e4m3fn
transformer_blocks.14.attn.to_add_out.lora_A.default torch.float32
transformer_blocks.14.attn.to_add_out.lora_B.default torch.float32
transformer_blocks.14.attn.norm_added_q torch.bfloat16
transformer_blocks.14.attn.norm_added_k torch.bfloat16
transformer_blocks.14.ff.net.0.proj torch.float8_e4m3fn
transformer_blocks.14.ff.net.2 torch.float8_e4m3fn
transformer_blocks.14.ff_context.net.0.proj torch.float8_e4m3fn
transformer_blocks.14.ff_context.net.2 torch.float8_e4m3fn
transformer_blocks.15.norm1.linear torch.bfloat16
transformer_blocks.15.norm1_context.linear torch.bfloat16
transformer_blocks.15.attn.norm_q torch.bfloat16
transformer_blocks.15.attn.norm_k torch.bfloat16
transformer_blocks.15.attn.to_q torch.float8_e4m3fn
transformer_blocks.15.attn.to_q.base_layer torch.float8_e4m3fn
transformer_blocks.15.attn.to_q.lora_A.default torch.float32
transformer_blocks.15.attn.to_q.lora_B.default torch.float32
transformer_blocks.15.attn.to_k torch.float8_e4m3fn
transformer_blocks.15.attn.to_k.base_layer torch.float8_e4m3fn
transformer_blocks.15.attn.to_k.lora_A.default torch.float32
transformer_blocks.15.attn.to_k.lora_B.default torch.float32
transformer_blocks.15.attn.to_v torch.float8_e4m3fn
transformer_blocks.15.attn.to_v.base_layer torch.float8_e4m3fn
transformer_blocks.15.attn.to_v.lora_A.default torch.float32
transformer_blocks.15.attn.to_v.lora_B.default torch.float32
transformer_blocks.15.attn.add_k_proj torch.float8_e4m3fn
transformer_blocks.15.attn.add_k_proj.base_layer torch.float8_e4m3fn
transformer_blocks.15.attn.add_k_proj.lora_A.default torch.float32
transformer_blocks.15.attn.add_k_proj.lora_B.default torch.float32
transformer_blocks.15.attn.add_v_proj torch.float8_e4m3fn
transformer_blocks.15.attn.add_v_proj.base_layer torch.float8_e4m3fn
transformer_blocks.15.attn.add_v_proj.lora_A.default torch.float32
transformer_blocks.15.attn.add_v_proj.lora_B.default torch.float32
transformer_blocks.15.attn.add_q_proj torch.float8_e4m3fn
transformer_blocks.15.attn.add_q_proj.base_layer torch.float8_e4m3fn
transformer_blocks.15.attn.add_q_proj.lora_A.default torch.float32
transformer_blocks.15.attn.add_q_proj.lora_B.default torch.float32
transformer_blocks.15.attn.to_out.0 torch.float8_e4m3fn
transformer_blocks.15.attn.to_out.0.base_layer torch.float8_e4m3fn
transformer_blocks.15.attn.to_out.0.lora_A.default torch.float32
transformer_blocks.15.attn.to_out.0.lora_B.default torch.float32
transformer_blocks.15.attn.to_add_out torch.float8_e4m3fn
transformer_blocks.15.attn.to_add_out.base_layer torch.float8_e4m3fn
transformer_blocks.15.attn.to_add_out.lora_A.default torch.float32
transformer_blocks.15.attn.to_add_out.lora_B.default torch.float32
transformer_blocks.15.attn.norm_added_q torch.bfloat16
transformer_blocks.15.attn.norm_added_k torch.bfloat16
transformer_blocks.15.ff.net.0.proj torch.float8_e4m3fn
transformer_blocks.15.ff.net.2 torch.float8_e4m3fn
transformer_blocks.15.ff_context.net.0.proj torch.float8_e4m3fn
transformer_blocks.15.ff_context.net.2 torch.float8_e4m3fn
transformer_blocks.16.norm1.linear torch.bfloat16
transformer_blocks.16.norm1_context.linear torch.bfloat16
transformer_blocks.16.attn.norm_q torch.bfloat16
transformer_blocks.16.attn.norm_k torch.bfloat16
transformer_blocks.16.attn.to_q torch.float8_e4m3fn
transformer_blocks.16.attn.to_q.base_layer torch.float8_e4m3fn
transformer_blocks.16.attn.to_q.lora_A.default torch.float32
transformer_blocks.16.attn.to_q.lora_B.default torch.float32
transformer_blocks.16.attn.to_k torch.float8_e4m3fn
transformer_blocks.16.attn.to_k.base_layer torch.float8_e4m3fn
transformer_blocks.16.attn.to_k.lora_A.default torch.float32
transformer_blocks.16.attn.to_k.lora_B.default torch.float32
transformer_blocks.16.attn.to_v torch.float8_e4m3fn
transformer_blocks.16.attn.to_v.base_layer torch.float8_e4m3fn
transformer_blocks.16.attn.to_v.lora_A.default torch.float32
transformer_blocks.16.attn.to_v.lora_B.default torch.float32
transformer_blocks.16.attn.add_k_proj torch.float8_e4m3fn
transformer_blocks.16.attn.add_k_proj.base_layer torch.float8_e4m3fn
transformer_blocks.16.attn.add_k_proj.lora_A.default torch.float32
transformer_blocks.16.attn.add_k_proj.lora_B.default torch.float32
transformer_blocks.16.attn.add_v_proj torch.float8_e4m3fn
transformer_blocks.16.attn.add_v_proj.base_layer torch.float8_e4m3fn
transformer_blocks.16.attn.add_v_proj.lora_A.default torch.float32
transformer_blocks.16.attn.add_v_proj.lora_B.default torch.float32
transformer_blocks.16.attn.add_q_proj torch.float8_e4m3fn
transformer_blocks.16.attn.add_q_proj.base_layer torch.float8_e4m3fn
transformer_blocks.16.attn.add_q_proj.lora_A.default torch.float32
transformer_blocks.16.attn.add_q_proj.lora_B.default torch.float32
transformer_blocks.16.attn.to_out.0 torch.float8_e4m3fn
transformer_blocks.16.attn.to_out.0.base_layer torch.float8_e4m3fn
transformer_blocks.16.attn.to_out.0.lora_A.default torch.float32
transformer_blocks.16.attn.to_out.0.lora_B.default torch.float32
transformer_blocks.16.attn.to_add_out torch.float8_e4m3fn
transformer_blocks.16.attn.to_add_out.base_layer torch.float8_e4m3fn
transformer_blocks.16.attn.to_add_out.lora_A.default torch.float32
transformer_blocks.16.attn.to_add_out.lora_B.default torch.float32
transformer_blocks.16.attn.norm_added_q torch.bfloat16
transformer_blocks.16.attn.norm_added_k torch.bfloat16
transformer_blocks.16.ff.net.0.proj torch.float8_e4m3fn
transformer_blocks.16.ff.net.2 torch.float8_e4m3fn
transformer_blocks.16.ff_context.net.0.proj torch.float8_e4m3fn
transformer_blocks.16.ff_context.net.2 torch.float8_e4m3fn
transformer_blocks.17.norm1.linear torch.bfloat16
transformer_blocks.17.norm1_context.linear torch.bfloat16
transformer_blocks.17.attn.norm_q torch.bfloat16
transformer_blocks.17.attn.norm_k torch.bfloat16
transformer_blocks.17.attn.to_q torch.float8_e4m3fn
transformer_blocks.17.attn.to_q.base_layer torch.float8_e4m3fn
transformer_blocks.17.attn.to_q.lora_A.default torch.float32
transformer_blocks.17.attn.to_q.lora_B.default torch.float32
transformer_blocks.17.attn.to_k torch.float8_e4m3fn
transformer_blocks.17.attn.to_k.base_layer torch.float8_e4m3fn
transformer_blocks.17.attn.to_k.lora_A.default torch.float32
transformer_blocks.17.attn.to_k.lora_B.default torch.float32
transformer_blocks.17.attn.to_v torch.float8_e4m3fn
transformer_blocks.17.attn.to_v.base_layer torch.float8_e4m3fn
transformer_blocks.17.attn.to_v.lora_A.default torch.float32
transformer_blocks.17.attn.to_v.lora_B.default torch.float32
transformer_blocks.17.attn.add_k_proj torch.float8_e4m3fn
transformer_blocks.17.attn.add_k_proj.base_layer torch.float8_e4m3fn
transformer_blocks.17.attn.add_k_proj.lora_A.default torch.float32
transformer_blocks.17.attn.add_k_proj.lora_B.default torch.float32
transformer_blocks.17.attn.add_v_proj torch.float8_e4m3fn
transformer_blocks.17.attn.add_v_proj.base_layer torch.float8_e4m3fn
transformer_blocks.17.attn.add_v_proj.lora_A.default torch.float32
transformer_blocks.17.attn.add_v_proj.lora_B.default torch.float32
transformer_blocks.17.attn.add_q_proj torch.float8_e4m3fn
transformer_blocks.17.attn.add_q_proj.base_layer torch.float8_e4m3fn
transformer_blocks.17.attn.add_q_proj.lora_A.default torch.float32
transformer_blocks.17.attn.add_q_proj.lora_B.default torch.float32
transformer_blocks.17.attn.to_out.0 torch.float8_e4m3fn
transformer_blocks.17.attn.to_out.0.base_layer torch.float8_e4m3fn
transformer_blocks.17.attn.to_out.0.lora_A.default torch.float32
transformer_blocks.17.attn.to_out.0.lora_B.default torch.float32
transformer_blocks.17.attn.to_add_out torch.float8_e4m3fn
transformer_blocks.17.attn.to_add_out.base_layer torch.float8_e4m3fn
transformer_blocks.17.attn.to_add_out.lora_A.default torch.float32
transformer_blocks.17.attn.to_add_out.lora_B.default torch.float32
transformer_blocks.17.attn.norm_added_q torch.bfloat16
transformer_blocks.17.attn.norm_added_k torch.bfloat16
transformer_blocks.17.ff.net.0.proj torch.float8_e4m3fn
transformer_blocks.17.ff.net.2 torch.float8_e4m3fn
transformer_blocks.17.ff_context.net.0.proj torch.float8_e4m3fn
transformer_blocks.17.ff_context.net.2 torch.float8_e4m3fn
transformer_blocks.18.norm1.linear torch.bfloat16
transformer_blocks.18.norm1_context.linear torch.bfloat16
transformer_blocks.18.attn.norm_q torch.bfloat16
transformer_blocks.18.attn.norm_k torch.bfloat16
transformer_blocks.18.attn.to_q torch.float8_e4m3fn
transformer_blocks.18.attn.to_q.base_layer torch.float8_e4m3fn
transformer_blocks.18.attn.to_q.lora_A.default torch.float32
transformer_blocks.18.attn.to_q.lora_B.default torch.float32
transformer_blocks.18.attn.to_k torch.float8_e4m3fn
transformer_blocks.18.attn.to_k.base_layer torch.float8_e4m3fn
transformer_blocks.18.attn.to_k.lora_A.default torch.float32
transformer_blocks.18.attn.to_k.lora_B.default torch.float32
transformer_blocks.18.attn.to_v torch.float8_e4m3fn
transformer_blocks.18.attn.to_v.base_layer torch.float8_e4m3fn
transformer_blocks.18.attn.to_v.lora_A.default torch.float32
transformer_blocks.18.attn.to_v.lora_B.default torch.float32
transformer_blocks.18.attn.add_k_proj torch.float8_e4m3fn
transformer_blocks.18.attn.add_k_proj.base_layer torch.float8_e4m3fn
transformer_blocks.18.attn.add_k_proj.lora_A.default torch.float32
transformer_blocks.18.attn.add_k_proj.lora_B.default torch.float32
transformer_blocks.18.attn.add_v_proj torch.float8_e4m3fn
transformer_blocks.18.attn.add_v_proj.base_layer torch.float8_e4m3fn
transformer_blocks.18.attn.add_v_proj.lora_A.default torch.float32
transformer_blocks.18.attn.add_v_proj.lora_B.default torch.float32
transformer_blocks.18.attn.add_q_proj torch.float8_e4m3fn
transformer_blocks.18.attn.add_q_proj.base_layer torch.float8_e4m3fn
transformer_blocks.18.attn.add_q_proj.lora_A.default torch.float32
transformer_blocks.18.attn.add_q_proj.lora_B.default torch.float32
transformer_blocks.18.attn.to_out.0 torch.float8_e4m3fn
transformer_blocks.18.attn.to_out.0.base_layer torch.float8_e4m3fn
transformer_blocks.18.attn.to_out.0.lora_A.default torch.float32
transformer_blocks.18.attn.to_out.0.lora_B.default torch.float32
transformer_blocks.18.attn.to_add_out torch.float8_e4m3fn
transformer_blocks.18.attn.to_add_out.base_layer torch.float8_e4m3fn
transformer_blocks.18.attn.to_add_out.lora_A.default torch.float32
transformer_blocks.18.attn.to_add_out.lora_B.default torch.float32
transformer_blocks.18.attn.norm_added_q torch.bfloat16
transformer_blocks.18.attn.norm_added_k torch.bfloat16
transformer_blocks.18.ff.net.0.proj torch.float8_e4m3fn
transformer_blocks.18.ff.net.2 torch.float8_e4m3fn
transformer_blocks.18.ff_context.net.0.proj torch.float8_e4m3fn
transformer_blocks.18.ff_context.net.2 torch.float8_e4m3fn
transformer_blocks.19.norm1.linear torch.bfloat16
transformer_blocks.19.norm1_context.linear torch.bfloat16
transformer_blocks.19.attn.norm_q torch.bfloat16
transformer_blocks.19.attn.norm_k torch.bfloat16
transformer_blocks.19.attn.to_q torch.float8_e4m3fn
transformer_blocks.19.attn.to_q.base_layer torch.float8_e4m3fn
transformer_blocks.19.attn.to_q.lora_A.default torch.float32
transformer_blocks.19.attn.to_q.lora_B.default torch.float32
transformer_blocks.19.attn.to_k torch.float8_e4m3fn
transformer_blocks.19.attn.to_k.base_layer torch.float8_e4m3fn
transformer_blocks.19.attn.to_k.lora_A.default torch.float32
transformer_blocks.19.attn.to_k.lora_B.default torch.float32
transformer_blocks.19.attn.to_v torch.float8_e4m3fn
transformer_blocks.19.attn.to_v.base_layer torch.float8_e4m3fn
transformer_blocks.19.attn.to_v.lora_A.default torch.float32
transformer_blocks.19.attn.to_v.lora_B.default torch.float32
transformer_blocks.19.attn.add_k_proj torch.float8_e4m3fn
transformer_blocks.19.attn.add_k_proj.base_layer torch.float8_e4m3fn
transformer_blocks.19.attn.add_k_proj.lora_A.default torch.float32
transformer_blocks.19.attn.add_k_proj.lora_B.default torch.float32
transformer_blocks.19.attn.add_v_proj torch.float8_e4m3fn
transformer_blocks.19.attn.add_v_proj.base_layer torch.float8_e4m3fn
transformer_blocks.19.attn.add_v_proj.lora_A.default torch.float32
transformer_blocks.19.attn.add_v_proj.lora_B.default torch.float32
transformer_blocks.19.attn.add_q_proj torch.float8_e4m3fn
transformer_blocks.19.attn.add_q_proj.base_layer torch.float8_e4m3fn
transformer_blocks.19.attn.add_q_proj.lora_A.default torch.float32
transformer_blocks.19.attn.add_q_proj.lora_B.default torch.float32
transformer_blocks.19.attn.to_out.0 torch.float8_e4m3fn
transformer_blocks.19.attn.to_out.0.base_layer torch.float8_e4m3fn
transformer_blocks.19.attn.to_out.0.lora_A.default torch.float32
transformer_blocks.19.attn.to_out.0.lora_B.default torch.float32
transformer_blocks.19.attn.to_add_out torch.float8_e4m3fn
transformer_blocks.19.attn.to_add_out.base_layer torch.float8_e4m3fn
transformer_blocks.19.attn.to_add_out.lora_A.default torch.float32
transformer_blocks.19.attn.to_add_out.lora_B.default torch.float32
transformer_blocks.19.attn.norm_added_q torch.bfloat16
transformer_blocks.19.attn.norm_added_k torch.bfloat16
transformer_blocks.19.ff.net.0.proj torch.float8_e4m3fn
transformer_blocks.19.ff.net.2 torch.float8_e4m3fn
transformer_blocks.19.ff_context.net.0.proj torch.float8_e4m3fn
transformer_blocks.19.ff_context.net.2 torch.float8_e4m3fn
single_transformer_blocks.0.attn.norm_q torch.bfloat16
single_transformer_blocks.0.attn.norm_k torch.bfloat16
single_transformer_blocks.0.attn.to_q torch.float8_e4m3fn
single_transformer_blocks.0.attn.to_q.base_layer torch.float8_e4m3fn
single_transformer_blocks.0.attn.to_q.lora_A.default torch.float32
single_transformer_blocks.0.attn.to_q.lora_B.default torch.float32
single_transformer_blocks.0.attn.to_k torch.float8_e4m3fn
single_transformer_blocks.0.attn.to_k.base_layer torch.float8_e4m3fn
single_transformer_blocks.0.attn.to_k.lora_A.default torch.float32
single_transformer_blocks.0.attn.to_k.lora_B.default torch.float32
single_transformer_blocks.0.attn.to_v torch.float8_e4m3fn
single_transformer_blocks.0.attn.to_v.base_layer torch.float8_e4m3fn
single_transformer_blocks.0.attn.to_v.lora_A.default torch.float32
single_transformer_blocks.0.attn.to_v.lora_B.default torch.float32
single_transformer_blocks.0.norm.linear torch.bfloat16
single_transformer_blocks.0.proj_mlp torch.float8_e4m3fn
single_transformer_blocks.0.proj_out torch.float8_e4m3fn
single_transformer_blocks.1.attn.norm_q torch.bfloat16
single_transformer_blocks.1.attn.norm_k torch.bfloat16
single_transformer_blocks.1.attn.to_q torch.float8_e4m3fn
single_transformer_blocks.1.attn.to_q.base_layer torch.float8_e4m3fn
single_transformer_blocks.1.attn.to_q.lora_A.default torch.float32
single_transformer_blocks.1.attn.to_q.lora_B.default torch.float32
single_transformer_blocks.1.attn.to_k torch.float8_e4m3fn
single_transformer_blocks.1.attn.to_k.base_layer torch.float8_e4m3fn
single_transformer_blocks.1.attn.to_k.lora_A.default torch.float32
single_transformer_blocks.1.attn.to_k.lora_B.default torch.float32
single_transformer_blocks.1.attn.to_v torch.float8_e4m3fn
single_transformer_blocks.1.attn.to_v.base_layer torch.float8_e4m3fn
single_transformer_blocks.1.attn.to_v.lora_A.default torch.float32
single_transformer_blocks.1.attn.to_v.lora_B.default torch.float32
single_transformer_blocks.1.norm.linear torch.bfloat16
single_transformer_blocks.1.proj_mlp torch.float8_e4m3fn
single_transformer_blocks.1.proj_out torch.float8_e4m3fn
single_transformer_blocks.2.attn.norm_q torch.bfloat16
single_transformer_blocks.2.attn.norm_k torch.bfloat16
single_transformer_blocks.2.attn.to_q torch.float8_e4m3fn
single_transformer_blocks.2.attn.to_q.base_layer torch.float8_e4m3fn
single_transformer_blocks.2.attn.to_q.lora_A.default torch.float32
single_transformer_blocks.2.attn.to_q.lora_B.default torch.float32
single_transformer_blocks.2.attn.to_k torch.float8_e4m3fn
single_transformer_blocks.2.attn.to_k.base_layer torch.float8_e4m3fn
single_transformer_blocks.2.attn.to_k.lora_A.default torch.float32
single_transformer_blocks.2.attn.to_k.lora_B.default torch.float32
single_transformer_blocks.2.attn.to_v torch.float8_e4m3fn
single_transformer_blocks.2.attn.to_v.base_layer torch.float8_e4m3fn
single_transformer_blocks.2.attn.to_v.lora_A.default torch.float32
single_transformer_blocks.2.attn.to_v.lora_B.default torch.float32
single_transformer_blocks.2.norm.linear torch.bfloat16
single_transformer_blocks.2.proj_mlp torch.float8_e4m3fn
single_transformer_blocks.2.proj_out torch.float8_e4m3fn
single_transformer_blocks.3.attn.norm_q torch.bfloat16
single_transformer_blocks.3.attn.norm_k torch.bfloat16
single_transformer_blocks.3.attn.to_q torch.float8_e4m3fn
single_transformer_blocks.3.attn.to_q.base_layer torch.float8_e4m3fn
single_transformer_blocks.3.attn.to_q.lora_A.default torch.float32
single_transformer_blocks.3.attn.to_q.lora_B.default torch.float32
single_transformer_blocks.3.attn.to_k torch.float8_e4m3fn
single_transformer_blocks.3.attn.to_k.base_layer torch.float8_e4m3fn
single_transformer_blocks.3.attn.to_k.lora_A.default torch.float32
single_transformer_blocks.3.attn.to_k.lora_B.default torch.float32
single_transformer_blocks.3.attn.to_v torch.float8_e4m3fn
single_transformer_blocks.3.attn.to_v.base_layer torch.float8_e4m3fn
single_transformer_blocks.3.attn.to_v.lora_A.default torch.float32
single_transformer_blocks.3.attn.to_v.lora_B.default torch.float32
single_transformer_blocks.3.norm.linear torch.bfloat16
single_transformer_blocks.3.proj_mlp torch.float8_e4m3fn
single_transformer_blocks.3.proj_out torch.float8_e4m3fn
single_transformer_blocks.4.attn.norm_q torch.bfloat16
single_transformer_blocks.4.attn.norm_k torch.bfloat16
single_transformer_blocks.4.attn.to_q torch.float8_e4m3fn
single_transformer_blocks.4.attn.to_q.base_layer torch.float8_e4m3fn
single_transformer_blocks.4.attn.to_q.lora_A.default torch.float32
single_transformer_blocks.4.attn.to_q.lora_B.default torch.float32
single_transformer_blocks.4.attn.to_k torch.float8_e4m3fn
single_transformer_blocks.4.attn.to_k.base_layer torch.float8_e4m3fn
single_transformer_blocks.4.attn.to_k.lora_A.default torch.float32
single_transformer_blocks.4.attn.to_k.lora_B.default torch.float32
single_transformer_blocks.4.attn.to_v torch.float8_e4m3fn
single_transformer_blocks.4.attn.to_v.base_layer torch.float8_e4m3fn
single_transformer_blocks.4.attn.to_v.lora_A.default torch.float32
single_transformer_blocks.4.attn.to_v.lora_B.default torch.float32
single_transformer_blocks.4.norm.linear torch.bfloat16
single_transformer_blocks.4.proj_mlp torch.float8_e4m3fn
single_transformer_blocks.4.proj_out torch.float8_e4m3fn
single_transformer_blocks.5.attn.norm_q torch.bfloat16
single_transformer_blocks.5.attn.norm_k torch.bfloat16
single_transformer_blocks.5.attn.to_q torch.float8_e4m3fn
single_transformer_blocks.5.attn.to_q.base_layer torch.float8_e4m3fn
single_transformer_blocks.5.attn.to_q.lora_A.default torch.float32
single_transformer_blocks.5.attn.to_q.lora_B.default torch.float32
single_transformer_blocks.5.attn.to_k torch.float8_e4m3fn
single_transformer_blocks.5.attn.to_k.base_layer torch.float8_e4m3fn
single_transformer_blocks.5.attn.to_k.lora_A.default torch.float32
single_transformer_blocks.5.attn.to_k.lora_B.default torch.float32
single_transformer_blocks.5.attn.to_v torch.float8_e4m3fn
single_transformer_blocks.5.attn.to_v.base_layer torch.float8_e4m3fn
single_transformer_blocks.5.attn.to_v.lora_A.default torch.float32
single_transformer_blocks.5.attn.to_v.lora_B.default torch.float32
single_transformer_blocks.5.norm.linear torch.bfloat16
single_transformer_blocks.5.proj_mlp torch.float8_e4m3fn
single_transformer_blocks.5.proj_out torch.float8_e4m3fn
single_transformer_blocks.6.attn.norm_q torch.bfloat16
single_transformer_blocks.6.attn.norm_k torch.bfloat16
single_transformer_blocks.6.attn.to_q torch.float8_e4m3fn
single_transformer_blocks.6.attn.to_q.base_layer torch.float8_e4m3fn
single_transformer_blocks.6.attn.to_q.lora_A.default torch.float32
single_transformer_blocks.6.attn.to_q.lora_B.default torch.float32
single_transformer_blocks.6.attn.to_k torch.float8_e4m3fn
single_transformer_blocks.6.attn.to_k.base_layer torch.float8_e4m3fn
single_transformer_blocks.6.attn.to_k.lora_A.default torch.float32
single_transformer_blocks.6.attn.to_k.lora_B.default torch.float32
single_transformer_blocks.6.attn.to_v torch.float8_e4m3fn
single_transformer_blocks.6.attn.to_v.base_layer torch.float8_e4m3fn
single_transformer_blocks.6.attn.to_v.lora_A.default torch.float32
single_transformer_blocks.6.attn.to_v.lora_B.default torch.float32
single_transformer_blocks.6.norm.linear torch.bfloat16
single_transformer_blocks.6.proj_mlp torch.float8_e4m3fn
single_transformer_blocks.6.proj_out torch.float8_e4m3fn
single_transformer_blocks.7.attn.norm_q torch.bfloat16
single_transformer_blocks.7.attn.norm_k torch.bfloat16
single_transformer_blocks.7.attn.to_q torch.float8_e4m3fn
single_transformer_blocks.7.attn.to_q.base_layer torch.float8_e4m3fn
single_transformer_blocks.7.attn.to_q.lora_A.default torch.float32
single_transformer_blocks.7.attn.to_q.lora_B.default torch.float32
single_transformer_blocks.7.attn.to_k torch.float8_e4m3fn
single_transformer_blocks.7.attn.to_k.base_layer torch.float8_e4m3fn
single_transformer_blocks.7.attn.to_k.lora_A.default torch.float32
single_transformer_blocks.7.attn.to_k.lora_B.default torch.float32
single_transformer_blocks.7.attn.to_v torch.float8_e4m3fn
single_transformer_blocks.7.attn.to_v.base_layer torch.float8_e4m3fn
single_transformer_blocks.7.attn.to_v.lora_A.default torch.float32
single_transformer_blocks.7.attn.to_v.lora_B.default torch.float32
single_transformer_blocks.7.norm.linear torch.bfloat16
single_transformer_blocks.7.proj_mlp torch.float8_e4m3fn
single_transformer_blocks.7.proj_out torch.float8_e4m3fn
single_transformer_blocks.8.attn.norm_q torch.bfloat16
single_transformer_blocks.8.attn.norm_k torch.bfloat16
single_transformer_blocks.8.attn.to_q torch.float8_e4m3fn
single_transformer_blocks.8.attn.to_q.base_layer torch.float8_e4m3fn
single_transformer_blocks.8.attn.to_q.lora_A.default torch.float32
single_transformer_blocks.8.attn.to_q.lora_B.default torch.float32
single_transformer_blocks.8.attn.to_k torch.float8_e4m3fn
single_transformer_blocks.8.attn.to_k.base_layer torch.float8_e4m3fn
single_transformer_blocks.8.attn.to_k.lora_A.default torch.float32
single_transformer_blocks.8.attn.to_k.lora_B.default torch.float32
single_transformer_blocks.8.attn.to_v torch.float8_e4m3fn
single_transformer_blocks.8.attn.to_v.base_layer torch.float8_e4m3fn
single_transformer_blocks.8.attn.to_v.lora_A.default torch.float32
single_transformer_blocks.8.attn.to_v.lora_B.default torch.float32
single_transformer_blocks.8.norm.linear torch.bfloat16
single_transformer_blocks.8.proj_mlp torch.float8_e4m3fn
single_transformer_blocks.8.proj_out torch.float8_e4m3fn
single_transformer_blocks.9.attn.norm_q torch.bfloat16
single_transformer_blocks.9.attn.norm_k torch.bfloat16
single_transformer_blocks.9.attn.to_q torch.float8_e4m3fn
single_transformer_blocks.9.attn.to_q.base_layer torch.float8_e4m3fn
single_transformer_blocks.9.attn.to_q.lora_A.default torch.float32
single_transformer_blocks.9.attn.to_q.lora_B.default torch.float32
single_transformer_blocks.9.attn.to_k torch.float8_e4m3fn
single_transformer_blocks.9.attn.to_k.base_layer torch.float8_e4m3fn
single_transformer_blocks.9.attn.to_k.lora_A.default torch.float32
single_transformer_blocks.9.attn.to_k.lora_B.default torch.float32
single_transformer_blocks.9.attn.to_v torch.float8_e4m3fn
single_transformer_blocks.9.attn.to_v.base_layer torch.float8_e4m3fn
single_transformer_blocks.9.attn.to_v.lora_A.default torch.float32
single_transformer_blocks.9.attn.to_v.lora_B.default torch.float32
single_transformer_blocks.9.norm.linear torch.bfloat16
single_transformer_blocks.9.proj_mlp torch.float8_e4m3fn
single_transformer_blocks.9.proj_out torch.float8_e4m3fn
single_transformer_blocks.10.attn.norm_q torch.bfloat16
single_transformer_blocks.10.attn.norm_k torch.bfloat16
single_transformer_blocks.10.attn.to_q torch.float8_e4m3fn
single_transformer_blocks.10.attn.to_q.base_layer torch.float8_e4m3fn
single_transformer_blocks.10.attn.to_q.lora_A.default torch.float32
single_transformer_blocks.10.attn.to_q.lora_B.default torch.float32
single_transformer_blocks.10.attn.to_k torch.float8_e4m3fn
single_transformer_blocks.10.attn.to_k.base_layer torch.float8_e4m3fn
single_transformer_blocks.10.attn.to_k.lora_A.default torch.float32
single_transformer_blocks.10.attn.to_k.lora_B.default torch.float32
single_transformer_blocks.10.attn.to_v torch.float8_e4m3fn
single_transformer_blocks.10.attn.to_v.base_layer torch.float8_e4m3fn
single_transformer_blocks.10.attn.to_v.lora_A.default torch.float32
single_transformer_blocks.10.attn.to_v.lora_B.default torch.float32
single_transformer_blocks.10.norm.linear torch.bfloat16
single_transformer_blocks.10.proj_mlp torch.float8_e4m3fn
single_transformer_blocks.10.proj_out torch.float8_e4m3fn
single_transformer_blocks.11.attn.norm_q torch.bfloat16
single_transformer_blocks.11.attn.norm_k torch.bfloat16
single_transformer_blocks.11.attn.to_q torch.float8_e4m3fn
single_transformer_blocks.11.attn.to_q.base_layer torch.float8_e4m3fn
single_transformer_blocks.11.attn.to_q.lora_A.default torch.float32
single_transformer_blocks.11.attn.to_q.lora_B.default torch.float32
single_transformer_blocks.11.attn.to_k torch.float8_e4m3fn
single_transformer_blocks.11.attn.to_k.base_layer torch.float8_e4m3fn
single_transformer_blocks.11.attn.to_k.lora_A.default torch.float32
single_transformer_blocks.11.attn.to_k.lora_B.default torch.float32
single_transformer_blocks.11.attn.to_v torch.float8_e4m3fn
single_transformer_blocks.11.attn.to_v.base_layer torch.float8_e4m3fn
single_transformer_blocks.11.attn.to_v.lora_A.default torch.float32
single_transformer_blocks.11.attn.to_v.lora_B.default torch.float32
single_transformer_blocks.11.norm.linear torch.bfloat16
single_transformer_blocks.11.proj_mlp torch.float8_e4m3fn
single_transformer_blocks.11.proj_out torch.float8_e4m3fn
single_transformer_blocks.12.attn.norm_q torch.bfloat16
single_transformer_blocks.12.attn.norm_k torch.bfloat16
single_transformer_blocks.12.attn.to_q torch.float8_e4m3fn
single_transformer_blocks.12.attn.to_q.base_layer torch.float8_e4m3fn
single_transformer_blocks.12.attn.to_q.lora_A.default torch.float32
single_transformer_blocks.12.attn.to_q.lora_B.default torch.float32
single_transformer_blocks.12.attn.to_k torch.float8_e4m3fn
single_transformer_blocks.12.attn.to_k.base_layer torch.float8_e4m3fn
single_transformer_blocks.12.attn.to_k.lora_A.default torch.float32
single_transformer_blocks.12.attn.to_k.lora_B.default torch.float32
single_transformer_blocks.12.attn.to_v torch.float8_e4m3fn
single_transformer_blocks.12.attn.to_v.base_layer torch.float8_e4m3fn
single_transformer_blocks.12.attn.to_v.lora_A.default torch.float32
single_transformer_blocks.12.attn.to_v.lora_B.default torch.float32
single_transformer_blocks.12.norm.linear torch.bfloat16
single_transformer_blocks.12.proj_mlp torch.float8_e4m3fn
single_transformer_blocks.12.proj_out torch.float8_e4m3fn
single_transformer_blocks.13.attn.norm_q torch.bfloat16
single_transformer_blocks.13.attn.norm_k torch.bfloat16
single_transformer_blocks.13.attn.to_q torch.float8_e4m3fn
single_transformer_blocks.13.attn.to_q.base_layer torch.float8_e4m3fn
single_transformer_blocks.13.attn.to_q.lora_A.default torch.float32
single_transformer_blocks.13.attn.to_q.lora_B.default torch.float32
single_transformer_blocks.13.attn.to_k torch.float8_e4m3fn
single_transformer_blocks.13.attn.to_k.base_layer torch.float8_e4m3fn
single_transformer_blocks.13.attn.to_k.lora_A.default torch.float32
single_transformer_blocks.13.attn.to_k.lora_B.default torch.float32
single_transformer_blocks.13.attn.to_v torch.float8_e4m3fn
single_transformer_blocks.13.attn.to_v.base_layer torch.float8_e4m3fn
single_transformer_blocks.13.attn.to_v.lora_A.default torch.float32
single_transformer_blocks.13.attn.to_v.lora_B.default torch.float32
single_transformer_blocks.13.norm.linear torch.bfloat16
single_transformer_blocks.13.proj_mlp torch.float8_e4m3fn
single_transformer_blocks.13.proj_out torch.float8_e4m3fn
single_transformer_blocks.14.attn.norm_q torch.bfloat16
single_transformer_blocks.14.attn.norm_k torch.bfloat16
single_transformer_blocks.14.attn.to_q torch.float8_e4m3fn
single_transformer_blocks.14.attn.to_q.base_layer torch.float8_e4m3fn
single_transformer_blocks.14.attn.to_q.lora_A.default torch.float32
single_transformer_blocks.14.attn.to_q.lora_B.default torch.float32
single_transformer_blocks.14.attn.to_k torch.float8_e4m3fn
single_transformer_blocks.14.attn.to_k.base_layer torch.float8_e4m3fn
single_transformer_blocks.14.attn.to_k.lora_A.default torch.float32
single_transformer_blocks.14.attn.to_k.lora_B.default torch.float32
single_transformer_blocks.14.attn.to_v torch.float8_e4m3fn
single_transformer_blocks.14.attn.to_v.base_layer torch.float8_e4m3fn
single_transformer_blocks.14.attn.to_v.lora_A.default torch.float32
single_transformer_blocks.14.attn.to_v.lora_B.default torch.float32
single_transformer_blocks.14.norm.linear torch.bfloat16
single_transformer_blocks.14.proj_mlp torch.float8_e4m3fn
single_transformer_blocks.14.proj_out torch.float8_e4m3fn
single_transformer_blocks.15.attn.norm_q torch.bfloat16
single_transformer_blocks.15.attn.norm_k torch.bfloat16
single_transformer_blocks.15.attn.to_q torch.float8_e4m3fn
single_transformer_blocks.15.attn.to_q.base_layer torch.float8_e4m3fn
single_transformer_blocks.15.attn.to_q.lora_A.default torch.float32
single_transformer_blocks.15.attn.to_q.lora_B.default torch.float32
single_transformer_blocks.15.attn.to_k torch.float8_e4m3fn
single_transformer_blocks.15.attn.to_k.base_layer torch.float8_e4m3fn
single_transformer_blocks.15.attn.to_k.lora_A.default torch.float32
single_transformer_blocks.15.attn.to_k.lora_B.default torch.float32
single_transformer_blocks.15.attn.to_v torch.float8_e4m3fn
single_transformer_blocks.15.attn.to_v.base_layer torch.float8_e4m3fn
single_transformer_blocks.15.attn.to_v.lora_A.default torch.float32
single_transformer_blocks.15.attn.to_v.lora_B.default torch.float32
single_transformer_blocks.15.norm.linear torch.bfloat16
single_transformer_blocks.15.proj_mlp torch.float8_e4m3fn
single_transformer_blocks.15.proj_out torch.float8_e4m3fn
single_transformer_blocks.16.attn.norm_q torch.bfloat16
single_transformer_blocks.16.attn.norm_k torch.bfloat16
single_transformer_blocks.16.attn.to_q torch.float8_e4m3fn
single_transformer_blocks.16.attn.to_q.base_layer torch.float8_e4m3fn
single_transformer_blocks.16.attn.to_q.lora_A.default torch.float32
single_transformer_blocks.16.attn.to_q.lora_B.default torch.float32
single_transformer_blocks.16.attn.to_k torch.float8_e4m3fn
single_transformer_blocks.16.attn.to_k.base_layer torch.float8_e4m3fn
single_transformer_blocks.16.attn.to_k.lora_A.default torch.float32
single_transformer_blocks.16.attn.to_k.lora_B.default torch.float32
single_transformer_blocks.16.attn.to_v torch.float8_e4m3fn
single_transformer_blocks.16.attn.to_v.base_layer torch.float8_e4m3fn
single_transformer_blocks.16.attn.to_v.lora_A.default torch.float32
single_transformer_blocks.16.attn.to_v.lora_B.default torch.float32
single_transformer_blocks.16.norm.linear torch.bfloat16
single_transformer_blocks.16.proj_mlp torch.float8_e4m3fn
single_transformer_blocks.16.proj_out torch.float8_e4m3fn
single_transformer_blocks.17.attn.norm_q torch.bfloat16
single_transformer_blocks.17.attn.norm_k torch.bfloat16
single_transformer_blocks.17.attn.to_q torch.float8_e4m3fn
single_transformer_blocks.17.attn.to_q.base_layer torch.float8_e4m3fn
single_transformer_blocks.17.attn.to_q.lora_A.default torch.float32
single_transformer_blocks.17.attn.to_q.lora_B.default torch.float32
single_transformer_blocks.17.attn.to_k torch.float8_e4m3fn
single_transformer_blocks.17.attn.to_k.base_layer torch.float8_e4m3fn
single_transformer_blocks.17.attn.to_k.lora_A.default torch.float32
single_transformer_blocks.17.attn.to_k.lora_B.default torch.float32
single_transformer_blocks.17.attn.to_v torch.float8_e4m3fn
single_transformer_blocks.17.attn.to_v.base_layer torch.float8_e4m3fn
single_transformer_blocks.17.attn.to_v.lora_A.default torch.float32
single_transformer_blocks.17.attn.to_v.lora_B.default torch.float32
single_transformer_blocks.17.norm.linear torch.bfloat16
single_transformer_blocks.17.proj_mlp torch.float8_e4m3fn
single_transformer_blocks.17.proj_out torch.float8_e4m3fn
single_transformer_blocks.18.attn.norm_q torch.bfloat16
single_transformer_blocks.18.attn.norm_k torch.bfloat16
single_transformer_blocks.18.attn.to_q torch.float8_e4m3fn
single_transformer_blocks.18.attn.to_q.base_layer torch.float8_e4m3fn
single_transformer_blocks.18.attn.to_q.lora_A.default torch.float32
single_transformer_blocks.18.attn.to_q.lora_B.default torch.float32
single_transformer_blocks.18.attn.to_k torch.float8_e4m3fn
single_transformer_blocks.18.attn.to_k.base_layer torch.float8_e4m3fn
single_transformer_blocks.18.attn.to_k.lora_A.default torch.float32
single_transformer_blocks.18.attn.to_k.lora_B.default torch.float32
single_transformer_blocks.18.attn.to_v torch.float8_e4m3fn
single_transformer_blocks.18.attn.to_v.base_layer torch.float8_e4m3fn
single_transformer_blocks.18.attn.to_v.lora_A.default torch.float32
single_transformer_blocks.18.attn.to_v.lora_B.default torch.float32
single_transformer_blocks.18.norm.linear torch.bfloat16
single_transformer_blocks.18.proj_mlp torch.float8_e4m3fn
single_transformer_blocks.18.proj_out torch.float8_e4m3fn
single_transformer_blocks.19.attn.norm_q torch.bfloat16
single_transformer_blocks.19.attn.norm_k torch.bfloat16
single_transformer_blocks.19.attn.to_q torch.float8_e4m3fn
single_transformer_blocks.19.attn.to_q.base_layer torch.float8_e4m3fn
single_transformer_blocks.19.attn.to_q.lora_A.default torch.float32
single_transformer_blocks.19.attn.to_q.lora_B.default torch.float32
single_transformer_blocks.19.attn.to_k torch.float8_e4m3fn
single_transformer_blocks.19.attn.to_k.base_layer torch.float8_e4m3fn
single_transformer_blocks.19.attn.to_k.lora_A.default torch.float32
single_transformer_blocks.19.attn.to_k.lora_B.default torch.float32
single_transformer_blocks.19.attn.to_v torch.float8_e4m3fn
single_transformer_blocks.19.attn.to_v.base_layer torch.float8_e4m3fn
single_transformer_blocks.19.attn.to_v.lora_A.default torch.float32
single_transformer_blocks.19.attn.to_v.lora_B.default torch.float32
single_transformer_blocks.19.norm.linear torch.bfloat16
single_transformer_blocks.19.proj_mlp torch.float8_e4m3fn
single_transformer_blocks.19.proj_out torch.float8_e4m3fn
single_transformer_blocks.20.attn.norm_q torch.bfloat16
single_transformer_blocks.20.attn.norm_k torch.bfloat16
single_transformer_blocks.20.attn.to_q torch.float8_e4m3fn
single_transformer_blocks.20.attn.to_q.base_layer torch.float8_e4m3fn
single_transformer_blocks.20.attn.to_q.lora_A.default torch.float32
single_transformer_blocks.20.attn.to_q.lora_B.default torch.float32
single_transformer_blocks.20.attn.to_k torch.float8_e4m3fn
single_transformer_blocks.20.attn.to_k.base_layer torch.float8_e4m3fn
single_transformer_blocks.20.attn.to_k.lora_A.default torch.float32
single_transformer_blocks.20.attn.to_k.lora_B.default torch.float32
single_transformer_blocks.20.attn.to_v torch.float8_e4m3fn
single_transformer_blocks.20.attn.to_v.base_layer torch.float8_e4m3fn
single_transformer_blocks.20.attn.to_v.lora_A.default torch.float32
single_transformer_blocks.20.attn.to_v.lora_B.default torch.float32
single_transformer_blocks.20.norm.linear torch.bfloat16
single_transformer_blocks.20.proj_mlp torch.float8_e4m3fn
single_transformer_blocks.20.proj_out torch.float8_e4m3fn
single_transformer_blocks.21.attn.norm_q torch.bfloat16
single_transformer_blocks.21.attn.norm_k torch.bfloat16
single_transformer_blocks.21.attn.to_q torch.float8_e4m3fn
single_transformer_blocks.21.attn.to_q.base_layer torch.float8_e4m3fn
single_transformer_blocks.21.attn.to_q.lora_A.default torch.float32
single_transformer_blocks.21.attn.to_q.lora_B.default torch.float32
single_transformer_blocks.21.attn.to_k torch.float8_e4m3fn
single_transformer_blocks.21.attn.to_k.base_layer torch.float8_e4m3fn
single_transformer_blocks.21.attn.to_k.lora_A.default torch.float32
single_transformer_blocks.21.attn.to_k.lora_B.default torch.float32
single_transformer_blocks.21.attn.to_v torch.float8_e4m3fn
single_transformer_blocks.21.attn.to_v.base_layer torch.float8_e4m3fn
single_transformer_blocks.21.attn.to_v.lora_A.default torch.float32
single_transformer_blocks.21.attn.to_v.lora_B.default torch.float32
single_transformer_blocks.21.norm.linear torch.bfloat16
single_transformer_blocks.21.proj_mlp torch.float8_e4m3fn
single_transformer_blocks.21.proj_out torch.float8_e4m3fn
single_transformer_blocks.22.attn.norm_q torch.bfloat16
single_transformer_blocks.22.attn.norm_k torch.bfloat16
single_transformer_blocks.22.attn.to_q torch.float8_e4m3fn
single_transformer_blocks.22.attn.to_q.base_layer torch.float8_e4m3fn
single_transformer_blocks.22.attn.to_q.lora_A.default torch.float32
single_transformer_blocks.22.attn.to_q.lora_B.default torch.float32
single_transformer_blocks.22.attn.to_k torch.float8_e4m3fn
single_transformer_blocks.22.attn.to_k.base_layer torch.float8_e4m3fn
single_transformer_blocks.22.attn.to_k.lora_A.default torch.float32
single_transformer_blocks.22.attn.to_k.lora_B.default torch.float32
single_transformer_blocks.22.attn.to_v torch.float8_e4m3fn
single_transformer_blocks.22.attn.to_v.base_layer torch.float8_e4m3fn
single_transformer_blocks.22.attn.to_v.lora_A.default torch.float32
single_transformer_blocks.22.attn.to_v.lora_B.default torch.float32
single_transformer_blocks.22.norm.linear torch.bfloat16
single_transformer_blocks.22.proj_mlp torch.float8_e4m3fn
single_transformer_blocks.22.proj_out torch.float8_e4m3fn
single_transformer_blocks.23.attn.norm_q torch.bfloat16
single_transformer_blocks.23.attn.norm_k torch.bfloat16
single_transformer_blocks.23.attn.to_q torch.float8_e4m3fn
single_transformer_blocks.23.attn.to_q.base_layer torch.float8_e4m3fn
single_transformer_blocks.23.attn.to_q.lora_A.default torch.float32
single_transformer_blocks.23.attn.to_q.lora_B.default torch.float32
single_transformer_blocks.23.attn.to_k torch.float8_e4m3fn
single_transformer_blocks.23.attn.to_k.base_layer torch.float8_e4m3fn
single_transformer_blocks.23.attn.to_k.lora_A.default torch.float32
single_transformer_blocks.23.attn.to_k.lora_B.default torch.float32
single_transformer_blocks.23.attn.to_v torch.float8_e4m3fn
single_transformer_blocks.23.attn.to_v.base_layer torch.float8_e4m3fn
single_transformer_blocks.23.attn.to_v.lora_A.default torch.float32
single_transformer_blocks.23.attn.to_v.lora_B.default torch.float32
single_transformer_blocks.23.norm.linear torch.bfloat16
single_transformer_blocks.23.proj_mlp torch.float8_e4m3fn
single_transformer_blocks.23.proj_out torch.float8_e4m3fn
single_transformer_blocks.24.attn.norm_q torch.bfloat16
single_transformer_blocks.24.attn.norm_k torch.bfloat16
single_transformer_blocks.24.attn.to_q torch.float8_e4m3fn
single_transformer_blocks.24.attn.to_q.base_layer torch.float8_e4m3fn
single_transformer_blocks.24.attn.to_q.lora_A.default torch.float32
single_transformer_blocks.24.attn.to_q.lora_B.default torch.float32
single_transformer_blocks.24.attn.to_k torch.float8_e4m3fn
single_transformer_blocks.24.attn.to_k.base_layer torch.float8_e4m3fn
single_transformer_blocks.24.attn.to_k.lora_A.default torch.float32
single_transformer_blocks.24.attn.to_k.lora_B.default torch.float32
single_transformer_blocks.24.attn.to_v torch.float8_e4m3fn
single_transformer_blocks.24.attn.to_v.base_layer torch.float8_e4m3fn
single_transformer_blocks.24.attn.to_v.lora_A.default torch.float32
single_transformer_blocks.24.attn.to_v.lora_B.default torch.float32
single_transformer_blocks.24.norm.linear torch.bfloat16
single_transformer_blocks.24.proj_mlp torch.float8_e4m3fn
single_transformer_blocks.24.proj_out torch.float8_e4m3fn
single_transformer_blocks.25.attn.norm_q torch.bfloat16
single_transformer_blocks.25.attn.norm_k torch.bfloat16
single_transformer_blocks.25.attn.to_q torch.float8_e4m3fn
single_transformer_blocks.25.attn.to_q.base_layer torch.float8_e4m3fn
single_transformer_blocks.25.attn.to_q.lora_A.default torch.float32
single_transformer_blocks.25.attn.to_q.lora_B.default torch.float32
single_transformer_blocks.25.attn.to_k torch.float8_e4m3fn
single_transformer_blocks.25.attn.to_k.base_layer torch.float8_e4m3fn
single_transformer_blocks.25.attn.to_k.lora_A.default torch.float32
single_transformer_blocks.25.attn.to_k.lora_B.default torch.float32
single_transformer_blocks.25.attn.to_v torch.float8_e4m3fn
single_transformer_blocks.25.attn.to_v.base_layer torch.float8_e4m3fn
single_transformer_blocks.25.attn.to_v.lora_A.default torch.float32
single_transformer_blocks.25.attn.to_v.lora_B.default torch.float32
single_transformer_blocks.25.norm.linear torch.bfloat16
single_transformer_blocks.25.proj_mlp torch.float8_e4m3fn
single_transformer_blocks.25.proj_out torch.float8_e4m3fn
single_transformer_blocks.26.attn.norm_q torch.bfloat16
single_transformer_blocks.26.attn.norm_k torch.bfloat16
single_transformer_blocks.26.attn.to_q torch.float8_e4m3fn
single_transformer_blocks.26.attn.to_q.base_layer torch.float8_e4m3fn
single_transformer_blocks.26.attn.to_q.lora_A.default torch.float32
single_transformer_blocks.26.attn.to_q.lora_B.default torch.float32
single_transformer_blocks.26.attn.to_k torch.float8_e4m3fn
single_transformer_blocks.26.attn.to_k.base_layer torch.float8_e4m3fn
single_transformer_blocks.26.attn.to_k.lora_A.default torch.float32
single_transformer_blocks.26.attn.to_k.lora_B.default torch.float32
single_transformer_blocks.26.attn.to_v torch.float8_e4m3fn
single_transformer_blocks.26.attn.to_v.base_layer torch.float8_e4m3fn
single_transformer_blocks.26.attn.to_v.lora_A.default torch.float32
single_transformer_blocks.26.attn.to_v.lora_B.default torch.float32
single_transformer_blocks.26.norm.linear torch.bfloat16
single_transformer_blocks.26.proj_mlp torch.float8_e4m3fn
single_transformer_blocks.26.proj_out torch.float8_e4m3fn
single_transformer_blocks.27.attn.norm_q torch.bfloat16
single_transformer_blocks.27.attn.norm_k torch.bfloat16
single_transformer_blocks.27.attn.to_q torch.float8_e4m3fn
single_transformer_blocks.27.attn.to_q.base_layer torch.float8_e4m3fn
single_transformer_blocks.27.attn.to_q.lora_A.default torch.float32
single_transformer_blocks.27.attn.to_q.lora_B.default torch.float32
single_transformer_blocks.27.attn.to_k torch.float8_e4m3fn
single_transformer_blocks.27.attn.to_k.base_layer torch.float8_e4m3fn
single_transformer_blocks.27.attn.to_k.lora_A.default torch.float32
single_transformer_blocks.27.attn.to_k.lora_B.default torch.float32
single_transformer_blocks.27.attn.to_v torch.float8_e4m3fn
single_transformer_blocks.27.attn.to_v.base_layer torch.float8_e4m3fn
single_transformer_blocks.27.attn.to_v.lora_A.default torch.float32
single_transformer_blocks.27.attn.to_v.lora_B.default torch.float32
single_transformer_blocks.27.norm.linear torch.bfloat16
single_transformer_blocks.27.proj_mlp torch.float8_e4m3fn
single_transformer_blocks.27.proj_out torch.float8_e4m3fn
single_transformer_blocks.28.attn.norm_q torch.bfloat16
single_transformer_blocks.28.attn.norm_k torch.bfloat16
single_transformer_blocks.28.attn.to_q torch.float8_e4m3fn
single_transformer_blocks.28.attn.to_q.base_layer torch.float8_e4m3fn
single_transformer_blocks.28.attn.to_q.lora_A.default torch.float32
single_transformer_blocks.28.attn.to_q.lora_B.default torch.float32
single_transformer_blocks.28.attn.to_k torch.float8_e4m3fn
single_transformer_blocks.28.attn.to_k.base_layer torch.float8_e4m3fn
single_transformer_blocks.28.attn.to_k.lora_A.default torch.float32
single_transformer_blocks.28.attn.to_k.lora_B.default torch.float32
single_transformer_blocks.28.attn.to_v torch.float8_e4m3fn
single_transformer_blocks.28.attn.to_v.base_layer torch.float8_e4m3fn
single_transformer_blocks.28.attn.to_v.lora_A.default torch.float32
single_transformer_blocks.28.attn.to_v.lora_B.default torch.float32
single_transformer_blocks.28.norm.linear torch.bfloat16
single_transformer_blocks.28.proj_mlp torch.float8_e4m3fn
single_transformer_blocks.28.proj_out torch.float8_e4m3fn
single_transformer_blocks.29.attn.norm_q torch.bfloat16
single_transformer_blocks.29.attn.norm_k torch.bfloat16
single_transformer_blocks.29.attn.to_q torch.float8_e4m3fn
single_transformer_blocks.29.attn.to_q.base_layer torch.float8_e4m3fn
single_transformer_blocks.29.attn.to_q.lora_A.default torch.float32
single_transformer_blocks.29.attn.to_q.lora_B.default torch.float32
single_transformer_blocks.29.attn.to_k torch.float8_e4m3fn
single_transformer_blocks.29.attn.to_k.base_layer torch.float8_e4m3fn
single_transformer_blocks.29.attn.to_k.lora_A.default torch.float32
single_transformer_blocks.29.attn.to_k.lora_B.default torch.float32
single_transformer_blocks.29.attn.to_v torch.float8_e4m3fn
single_transformer_blocks.29.attn.to_v.base_layer torch.float8_e4m3fn
single_transformer_blocks.29.attn.to_v.lora_A.default torch.float32
single_transformer_blocks.29.attn.to_v.lora_B.default torch.float32
single_transformer_blocks.29.norm.linear torch.bfloat16
single_transformer_blocks.29.proj_mlp torch.float8_e4m3fn
single_transformer_blocks.29.proj_out torch.float8_e4m3fn
single_transformer_blocks.30.attn.norm_q torch.bfloat16
single_transformer_blocks.30.attn.norm_k torch.bfloat16
single_transformer_blocks.30.attn.to_q torch.float8_e4m3fn
single_transformer_blocks.30.attn.to_q.base_layer torch.float8_e4m3fn
single_transformer_blocks.30.attn.to_q.lora_A.default torch.float32
single_transformer_blocks.30.attn.to_q.lora_B.default torch.float32
single_transformer_blocks.30.attn.to_k torch.float8_e4m3fn
single_transformer_blocks.30.attn.to_k.base_layer torch.float8_e4m3fn
single_transformer_blocks.30.attn.to_k.lora_A.default torch.float32
single_transformer_blocks.30.attn.to_k.lora_B.default torch.float32
single_transformer_blocks.30.attn.to_v torch.float8_e4m3fn
single_transformer_blocks.30.attn.to_v.base_layer torch.float8_e4m3fn
single_transformer_blocks.30.attn.to_v.lora_A.default torch.float32
single_transformer_blocks.30.attn.to_v.lora_B.default torch.float32
single_transformer_blocks.30.norm.linear torch.bfloat16
single_transformer_blocks.30.proj_mlp torch.float8_e4m3fn
single_transformer_blocks.30.proj_out torch.float8_e4m3fn
single_transformer_blocks.31.attn.norm_q torch.bfloat16
single_transformer_blocks.31.attn.norm_k torch.bfloat16
single_transformer_blocks.31.attn.to_q torch.float8_e4m3fn
single_transformer_blocks.31.attn.to_q.base_layer torch.float8_e4m3fn
single_transformer_blocks.31.attn.to_q.lora_A.default torch.float32
single_transformer_blocks.31.attn.to_q.lora_B.default torch.float32
single_transformer_blocks.31.attn.to_k torch.float8_e4m3fn
single_transformer_blocks.31.attn.to_k.base_layer torch.float8_e4m3fn
single_transformer_blocks.31.attn.to_k.lora_A.default torch.float32
single_transformer_blocks.31.attn.to_k.lora_B.default torch.float32
single_transformer_blocks.31.attn.to_v torch.float8_e4m3fn
single_transformer_blocks.31.attn.to_v.base_layer torch.float8_e4m3fn
single_transformer_blocks.31.attn.to_v.lora_A.default torch.float32
single_transformer_blocks.31.attn.to_v.lora_B.default torch.float32
single_transformer_blocks.31.norm.linear torch.bfloat16
single_transformer_blocks.31.proj_mlp torch.float8_e4m3fn
single_transformer_blocks.31.proj_out torch.float8_e4m3fn
single_transformer_blocks.32.attn.norm_q torch.bfloat16
single_transformer_blocks.32.attn.norm_k torch.bfloat16
single_transformer_blocks.32.attn.to_q torch.float8_e4m3fn
single_transformer_blocks.32.attn.to_q.base_layer torch.float8_e4m3fn
single_transformer_blocks.32.attn.to_q.lora_A.default torch.float32
single_transformer_blocks.32.attn.to_q.lora_B.default torch.float32
single_transformer_blocks.32.attn.to_k torch.float8_e4m3fn
single_transformer_blocks.32.attn.to_k.base_layer torch.float8_e4m3fn
single_transformer_blocks.32.attn.to_k.lora_A.default torch.float32
single_transformer_blocks.32.attn.to_k.lora_B.default torch.float32
single_transformer_blocks.32.attn.to_v torch.float8_e4m3fn
single_transformer_blocks.32.attn.to_v.base_layer torch.float8_e4m3fn
single_transformer_blocks.32.attn.to_v.lora_A.default torch.float32
single_transformer_blocks.32.attn.to_v.lora_B.default torch.float32
single_transformer_blocks.32.norm.linear torch.bfloat16
single_transformer_blocks.32.proj_mlp torch.float8_e4m3fn
single_transformer_blocks.32.proj_out torch.float8_e4m3fn
single_transformer_blocks.33.attn.norm_q torch.bfloat16
single_transformer_blocks.33.attn.norm_k torch.bfloat16
single_transformer_blocks.33.attn.to_q torch.float8_e4m3fn
single_transformer_blocks.33.attn.to_q.base_layer torch.float8_e4m3fn
single_transformer_blocks.33.attn.to_q.lora_A.default torch.float32
single_transformer_blocks.33.attn.to_q.lora_B.default torch.float32
single_transformer_blocks.33.attn.to_k torch.float8_e4m3fn
single_transformer_blocks.33.attn.to_k.base_layer torch.float8_e4m3fn
single_transformer_blocks.33.attn.to_k.lora_A.default torch.float32
single_transformer_blocks.33.attn.to_k.lora_B.default torch.float32
single_transformer_blocks.33.attn.to_v torch.float8_e4m3fn
single_transformer_blocks.33.attn.to_v.base_layer torch.float8_e4m3fn
single_transformer_blocks.33.attn.to_v.lora_A.default torch.float32
single_transformer_blocks.33.attn.to_v.lora_B.default torch.float32
single_transformer_blocks.33.norm.linear torch.bfloat16
single_transformer_blocks.33.proj_mlp torch.float8_e4m3fn
single_transformer_blocks.33.proj_out torch.float8_e4m3fn
single_transformer_blocks.34.attn.norm_q torch.bfloat16
single_transformer_blocks.34.attn.norm_k torch.bfloat16
single_transformer_blocks.34.attn.to_q torch.float8_e4m3fn
single_transformer_blocks.34.attn.to_q.base_layer torch.float8_e4m3fn
single_transformer_blocks.34.attn.to_q.lora_A.default torch.float32
single_transformer_blocks.34.attn.to_q.lora_B.default torch.float32
single_transformer_blocks.34.attn.to_k torch.float8_e4m3fn
single_transformer_blocks.34.attn.to_k.base_layer torch.float8_e4m3fn
single_transformer_blocks.34.attn.to_k.lora_A.default torch.float32
single_transformer_blocks.34.attn.to_k.lora_B.default torch.float32
single_transformer_blocks.34.attn.to_v torch.float8_e4m3fn
single_transformer_blocks.34.attn.to_v.base_layer torch.float8_e4m3fn
single_transformer_blocks.34.attn.to_v.lora_A.default torch.float32
single_transformer_blocks.34.attn.to_v.lora_B.default torch.float32
single_transformer_blocks.34.norm.linear torch.bfloat16
single_transformer_blocks.34.proj_mlp torch.float8_e4m3fn
single_transformer_blocks.34.proj_out torch.float8_e4m3fn
single_transformer_blocks.35.attn.norm_q torch.bfloat16
single_transformer_blocks.35.attn.norm_k torch.bfloat16
single_transformer_blocks.35.attn.to_q torch.float8_e4m3fn
single_transformer_blocks.35.attn.to_q.base_layer torch.float8_e4m3fn
single_transformer_blocks.35.attn.to_q.lora_A.default torch.float32
single_transformer_blocks.35.attn.to_q.lora_B.default torch.float32
single_transformer_blocks.35.attn.to_k torch.float8_e4m3fn
single_transformer_blocks.35.attn.to_k.base_layer torch.float8_e4m3fn
single_transformer_blocks.35.attn.to_k.lora_A.default torch.float32
single_transformer_blocks.35.attn.to_k.lora_B.default torch.float32
single_transformer_blocks.35.attn.to_v torch.float8_e4m3fn
single_transformer_blocks.35.attn.to_v.base_layer torch.float8_e4m3fn
single_transformer_blocks.35.attn.to_v.lora_A.default torch.float32
single_transformer_blocks.35.attn.to_v.lora_B.default torch.float32
single_transformer_blocks.35.norm.linear torch.bfloat16
single_transformer_blocks.35.proj_mlp torch.float8_e4m3fn
single_transformer_blocks.35.proj_out torch.float8_e4m3fn
single_transformer_blocks.36.attn.norm_q torch.bfloat16
single_transformer_blocks.36.attn.norm_k torch.bfloat16
single_transformer_blocks.36.attn.to_q torch.float8_e4m3fn
single_transformer_blocks.36.attn.to_q.base_layer torch.float8_e4m3fn
single_transformer_blocks.36.attn.to_q.lora_A.default torch.float32
single_transformer_blocks.36.attn.to_q.lora_B.default torch.float32
single_transformer_blocks.36.attn.to_k torch.float8_e4m3fn
single_transformer_blocks.36.attn.to_k.base_layer torch.float8_e4m3fn
single_transformer_blocks.36.attn.to_k.base_layer torch.float8_e4m3fn
single_transformer_blocks.36.attn.to_k.lora_A.default torch.float32
single_transformer_blocks.36.attn.to_k.lora_B.default torch.float32
single_transformer_blocks.36.attn.to_v torch.float8_e4m3fn
single_transformer_blocks.36.attn.to_v.base_layer torch.float8_e4m3fn
single_transformer_blocks.36.attn.to_v.lora_A.default torch.float32
single_transformer_blocks.36.attn.to_v.lora_B.default torch.float32
single_transformer_blocks.36.norm.linear torch.bfloat16
single_transformer_blocks.36.proj_mlp torch.float8_e4m3fn
single_transformer_blocks.36.proj_out torch.float8_e4m3fn
single_transformer_blocks.37.attn.norm_q torch.bfloat16
single_transformer_blocks.37.attn.norm_k torch.bfloat16
single_transformer_blocks.37.attn.to_q torch.float8_e4m3fn
single_transformer_blocks.37.attn.to_q.base_layer torch.float8_e4m3fn
single_transformer_blocks.37.attn.to_q.lora_A.default torch.float32
single_transformer_blocks.37.attn.to_q.lora_B.default torch.float32
single_transformer_blocks.37.attn.to_k torch.float8_e4m3fn
single_transformer_blocks.37.attn.to_k.base_layer torch.float8_e4m3fn
single_transformer_blocks.37.attn.to_k.lora_A.default torch.float32
single_transformer_blocks.37.attn.to_k.lora_B.default torch.float32
single_transformer_blocks.37.attn.to_v torch.float8_e4m3fn
single_transformer_blocks.37.attn.to_v.base_layer torch.float8_e4m3fn
single_transformer_blocks.37.attn.to_v.lora_A.default torch.float32
single_transformer_blocks.37.attn.to_v.lora_B.default torch.float32
single_transformer_blocks.37.norm.linear torch.bfloat16
single_transformer_blocks.37.proj_mlp torch.float8_e4m3fn
single_transformer_blocks.37.proj_out torch.float8_e4m3fn
single_transformer_blocks.38.attn.norm_q torch.bfloat16
single_transformer_blocks.38.attn.norm_k torch.bfloat16
single_transformer_blocks.38.attn.to_q torch.float8_e4m3fn
single_transformer_blocks.38.attn.to_q.base_layer torch.float8_e4m3fn
single_transformer_blocks.38.attn.to_q.lora_A.default torch.float32
single_transformer_blocks.38.attn.to_q.lora_B.default torch.float32
single_transformer_blocks.38.attn.to_k torch.float8_e4m3fn
single_transformer_blocks.38.attn.to_k.base_layer torch.float8_e4m3fn
single_transformer_blocks.38.attn.to_k.lora_A.default torch.float32
single_transformer_blocks.38.attn.to_k.lora_B.default torch.float32
single_transformer_blocks.38.attn.to_v torch.float8_e4m3fn
single_transformer_blocks.38.attn.to_v.base_layer torch.float8_e4m3fn
single_transformer_blocks.38.attn.to_v.lora_A.default torch.float32
single_transformer_blocks.38.attn.to_v.lora_B.default torch.float32
single_transformer_blocks.38.norm.linear torch.bfloat16
single_transformer_blocks.38.proj_mlp torch.float8_e4m3fn
single_transformer_blocks.38.proj_out torch.float8_e4m3fn
single_transformer_blocks.39.attn.norm_q torch.bfloat16
single_transformer_blocks.39.attn.norm_k torch.bfloat16
single_transformer_blocks.39.attn.to_q torch.float8_e4m3fn
single_transformer_blocks.39.attn.to_q.base_layer torch.float8_e4m3fn
single_transformer_blocks.39.attn.to_q.lora_A.default torch.float32
single_transformer_blocks.39.attn.to_q.lora_B.default torch.float32
single_transformer_blocks.39.attn.to_k torch.float8_e4m3fn
single_transformer_blocks.39.attn.to_k.base_layer torch.float8_e4m3fn
single_transformer_blocks.39.attn.to_k.lora_A.default torch.float32
single_transformer_blocks.39.attn.to_k.lora_B.default torch.float32
single_transformer_blocks.39.attn.to_v torch.float8_e4m3fn
single_transformer_blocks.39.attn.to_v.base_layer torch.float8_e4m3fn
single_transformer_blocks.39.attn.to_v.lora_A.default torch.float32
single_transformer_blocks.39.attn.to_v.lora_B.default torch.float32
single_transformer_blocks.39.norm.linear torch.bfloat16
single_transformer_blocks.39.proj_mlp torch.float8_e4m3fn
single_transformer_blocks.39.proj_out torch.float8_e4m3fn
norm_out.linear torch.bfloat16
proj_out torch.float8_e4m3fn TODO: Allow LoRA to be in bf16 and only apply |
Hmm..., so while the loss curves match in the beginning with bf16 training (they diverge later on), the results are overly smoothed and it looks like lora didn't learn anything. I am not sure if this is because of a validation + hooks related bug, or a training bug. Will look into it a little later. As we can see in the diffusers PR, inference is almost unaffected at fp8, so I highly doubt that we can't train this way. One way to train would be to just use fp8 hunyuan checkpoint directly and use their custom code, but that is less ideal and this is more generally applicable for all supported models. |
@a-r-r-o-w allow me to help :) Sometimes, a fresh pair of eyes could be of help. So, how about I look into it a bit, ask you questions, run your experiments, while you keep polishing the WDYT? |
It's an open PR so anyone is open to try/help lol. I don't yet have an idea on what's causing the bad results:
So any help is appreciated |
FWIW, I've done several experiments with the exact same code for inference purposes with lora (comments in the diffusers PR) and did not find any problems other than the ones fixed by stupid workarounds, so it could just be something in training, and the graphs co-inciding at the start attributable to 0-weights of LoRA, which eventually do start do diverge (after > ~1000 steps) and become garbage |
Agreed. Even if this takes time, I think it's worth it because it's general.
Could you please point me to some? Will give me a better idea of what's tried already. For my testing, can I use the command you provided in #184 (comment) or is there a different one I should use? You mentioned One test I would try to run on the |
Fwiw, I ran LTX-V for 400 steps in fp8, and the lora definitely did learn. But since you say they 'diverge later on', maybe it's an accumulating error. |
@neph1 Thanks! I think it probably might be something to do with validation then most likely. I could try running a long LTX training run and see if I get the expected behaviour. For Hunyuan, I was using images for the FP8 training and it seems to collapse the model within 1000 steps of the total 10000 steps of training. Will try and power through the debugging today and hopefully something comes out. @sayakpaul Yes, here you go: https://wandb.ai/aryanvs/finetrainers-hunyuan-video/reports/FP8-vs-BF16-Hunyuan-Video--VmlldzoxMDg0OTMwMw?accessToken=vqwlt7y899u0qb25fyhma612khuk81o5ggi2pmds2y13xv78dfh1es97q0ksaprg. This contains the two runs with the exact same starting conditions. The FP8 run did not have any validation performed because it errors out due to an unimplemented fp8 multiplication kernel. This is partly due to how we infer the dtype in ModelMixin in diffusers and how the pipelines are written. Will address these concerns in the diffusers PR.
Yes, you should modify the command provided in the description to do a longer training run. The only relevant parameter different from previous scripts is The peft patches are enabled by default when you launch fp8 upcasting training. Without it, we will:
The diffusers PR linked in the description contains a wide many number of experiments with all the code/ideas for workarounds needed to make this work with transformers and peft. Flux works as expected when this is enabled for lora inference with little over half the memory limit needed by transformer in bf16 |
Latest commit fixes the following error of resuming fp8 training after a validation is performed. stacktrace01/08/2025 01:53:23 - ERROR - finetrainers - Traceback (most recent call last):
File "/raid/aryan/cogvideox-distillation/train.py", line 35, in main
trainer.train()
File "/raid/aryan/cogvideox-distillation/finetrainers/trainer.py", line 813, in train
self.optimizer.step()
File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/accelerate/optimizer.py", line 171, in step
self.optimizer.step(closure)
File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/optim/lr_scheduler.py", line 137, in wrapper
return func.__get__(opt, opt.__class__)(*args, **kwargs)
File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/optim/optimizer.py", line 487, in wrapper
out = func(*args, **kwargs)
File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/optim/optimizer.py", line 91, in _use_grad
ret = func(self, *args, **kwargs)
File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/optim/adamw.py", line 220, in step
adamw(
File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/optim/optimizer.py", line 154, in maybe_fallback
return func(*args, **kwargs)
File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/optim/adamw.py", line 782, in adamw
func(
File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/optim/adamw.py", line 531, in _multi_tensor_adamw
torch._foreach_lerp_(device_exp_avgs, device_grads, 1 - beta1)
RuntimeError: expected dtype float for `end` but got dtype c10::BFloat16 The device_grads should be in fp32, but due to a combination of fp8 hooks, and accelerate's convert_to_fp32 hooks, the gradients remain in bf16. Will investigate deeper later, but the current fix seems like an okay thing to do |
List of supported FP8 ops: pytorch/pytorch#107256 (comment) (as of PT 2.1) |
@a-r-r-o-w this could be. Very well could be. When I was doing FP8 training with torchao here, the validation results were extremely depressing so, I just let it train and when I performed inference with the trained LoRA things were fine. I believe we could do something similar here? |
Sure, that works. But:
If what you say is true for this PR, then:
So, would really like to get to the bottom of this and find the right solution even if it takes a little more time in moving forward with this PR. Will test what you mentioned too and get back soon |
Started a small fp8 20000 step run for LTX on single GPU (5000 training steps with 4 gradient accumulation): https://wandb.ai/aryanvs/finetrainers-ltxv/runs/uy0evi7m If this works as expected, then there's something fishy going on in Hunyuan, which I anticipate will be awful to debug :/ |
@a-r-r-o-w sounds good! Yes, my comment was to confirm if it's the validation that we need to debug if what I said and observed turns out to be True.
Oh I thought you mentioned it starts to diverge after certain steps? How are we doing for Cog with FP8? Or is it not robust enough currently to be trained with FP8? |
You can check this report for better understanding of what I mean. It's not the exact same to begin with, but the values are almost the same so I would expect that they converge to roughly similar weights. Edit: I just saw the report that was created. I'm not sure what wandb is doing here but it does not show loss at every step and is skipping some data points in the report. This is what the true graph looks like when overlapped (brown is bf16 and green is fp8):
I am yet to test Cog. Will do it soon. Inference works really well with FP8 Cog though |
Will run some tests myself today. Thanks a lot for bearing with my questions and providing detailed answers. |
Okay LTX training run looks extremely promising in just 1500 steps so far (actually 6000 steps because of 4 gradient accumulation): https://wandb.ai/aryanvs/finetrainers-ltxv/runs/uy0evi7m?nw=nwuseraryanvs. Definitely a Hunyuan pipeline problem or a dataset problem - I'm using 400 images which I did for the bf16 run. Memory required:
The LTX team has some amazing chefs, I must say. With framewise encoding/decoding support in diffusers coming soon (huggingface/diffusers#10488) and group offloading, this will be like negligible memory required lol |
Some more reports: https://wandb.ai/aryanvs/finetrainers-ltxv/reports/fp8-uniform-vs-fp8-logit_normal-vs-bf16-logit_normal--VmlldzoxMDg2NDU3Ng
|
WIP.
Mostly copied code from here, which adds support for FP8 inference. It works for training as well if we make some simple peft patches. For now, the copied code will remain here but once Diffusers PR is merged, we can start using that directly.
Script