Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FP8 training support #184

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open

FP8 training support #184

wants to merge 10 commits into from

Conversation

a-r-r-o-w
Copy link
Owner

@a-r-r-o-w a-r-r-o-w commented Jan 5, 2025

WIP.

Mostly copied code from here, which adds support for FP8 inference. It works for training as well if we make some simple peft patches. For now, the copied code will remain here but once Diffusers PR is merged, we can start using that directly.

Script
#!/bin/bash

# export TORCH_LOGS="+dynamo,recompiles,graph_breaks"
# export TORCHDYNAMO_VERBOSE=1
export WANDB_MODE="offline"
export NCCL_P2P_DISABLE=1
export TORCH_NCCL_ENABLE_MONITORING=0

GPU_IDS="3"

DATA_ROOT="/raid/aryan/video-dataset-disney"
CAPTION_COLUMN="prompts2.txt"
VIDEO_COLUMN="videos2.txt"

# Model arguments
model_cmd="--model_name hunyuan_video \
  --pretrained_model_name_or_path hunyuanvideo-community/HunyuanVideo \
  --layerwise_upcasting_modules transformer"

# Dataset arguments
dataset_cmd="--data_root $DATA_ROOT \
  --video_column $VIDEO_COLUMN \
  --caption_column $CAPTION_COLUMN \
  --id_token afkx \
  --video_resolution_buckets 49x480x720 \
  --caption_dropout_p 0.00 \
  --precompute_conditions"

# Dataloader arguments
dataloader_cmd="--dataloader_num_workers 0"

# Diffusion arguments
diffusion_cmd="--flow_weighting_scheme logit_normal"

# Training arguments
training_cmd="--training_type lora \
  --seed 42 \
  --batch_size 1 \
  --train_steps 2 \
  --rank 128 \
  --lora_alpha 128 \
  --target_modules to_q to_k to_v to_out.0 \
  --gradient_accumulation_steps 1 \
  --gradient_checkpointing \
  --checkpointing_steps 500 \
  --checkpointing_limit 2 \
  --enable_slicing \
  --enable_tiling"

# Optimizer arguments
optimizer_cmd="--optimizer adamw \
  --lr 1e-4 \
  --lr_scheduler cosine_with_restarts \
  --lr_warmup_steps 100 \
  --lr_num_cycles 1 \
  --beta1 0.9 \
  --beta2 0.95 \
  --weight_decay 1e-4 \
  --epsilon 1e-8 \
  --max_grad_norm 1.0"

# Validation arguments
validation_cmd="--validation_prompts \"afkx A black and white animated scene unfolds with an anthropomorphic goat surrounded by musical notes and symbols, suggesting a playful environment. Mickey Mouse appears, leaning forward in curiosity as the goat remains still. The goat then engages with Mickey, who bends down to converse or react. The dynamics shift as Mickey grabs the goat, potentially in surprise or playfulness, amidst a minimalistic background. The scene captures the evolving relationship between the two characters in a whimsical, animated setting, emphasizing their interactions and emotions.@@@49x480x720\" \
  --num_validation_videos 1 \
  --validation_steps 10"

# Miscellaneous arguments
miscellaneous_cmd="--tracker_name finetrainers-hunyuan-video \
  --output_dir /raid/aryan/hunyuan-video \
  --nccl_timeout 1800 \
  --report_to wandb"

cmd="accelerate launch --config_file accelerate_configs/uncompiled_1.yaml --gpu_ids $GPU_IDS train.py \
  $model_cmd \
  $dataset_cmd \
  $dataloader_cmd \
  $diffusion_cmd \
  $training_cmd \
  $optimizer_cmd \
  $validation_cmd \
  $miscellaneous_cmd"

echo "Running command: $cmd"
eval $cmd
echo -ne "-------------------- Finished executing script --------------------\n\n"

@a-r-r-o-w
Copy link
Owner Author

Seem to be getting an error when running multi-GPU training:

01/06/2025 00:56:10 - ERROR - finetrainers - An error occurred during training: "cat_cuda" not implemented for 'Float8_e4m3fn'
01/06/2025 00:56:10 - ERROR - finetrainers - Traceback (most recent call last):
  File "/raid/aryan/cogvideox-distillation/train.py", line 33, in main
    trainer.prepare_for_training()
  File "/raid/aryan/cogvideox-distillation/finetrainers/trainer.py", line 563, in prepare_for_training
    self.transformer, self.optimizer, self.dataloader, self.lr_scheduler = self.state.accelerator.prepare(
  File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/accelerate/accelerator.py", line 1328, in prepare
    result = tuple(
  File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/accelerate/accelerator.py", line 1329, in <genexpr>
    self._prepare_one(obj, first_pass=True, device_placement=d) for obj, d in zip(args, device_placement)
  File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/accelerate/accelerator.py", line 1204, in _prepare_one
    return self.prepare_model(obj, device_placement=device_placement)
  File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/accelerate/accelerator.py", line 1458, in prepare_model
    model = torch.nn.parallel.DistributedDataParallel(
  File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/nn/parallel/distributed.py", line 827, in __init__
    _sync_module_states(
  File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/distributed/utils.py", line 317, in _sync_module_states
    _sync_params_and_buffers(process_group, module_states, broadcast_bucket_size, src)
  File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/distributed/utils.py", line 328, in _sync_params_and_buffers
    dist._broadcast_coalesced(
RuntimeError: "cat_cuda" not implemented for 'Float8_e4m3fn'

Possibly helpful references:

Will try to investigate more when I find time. Accelerate seems to be able to handle fp8 mixed precision just fine (atleast on Hopper), so will try to poke around

@a-r-r-o-w
Copy link
Owner Author

a-r-r-o-w commented Jan 6, 2025

Working fine for all (Hunyuan, LTX, Cog) on a single GPU. Tested with 49x480x720 resolution.

BF16 Trace with precomputation (about 45 GB required without validation):

image

FP8 Trace with precomputation (about 32 GB required without validation):

image

Notes:

  • Validation takes about the same amount memory for both with/without layerwise upcasting
  • The peaks come from F.scaled_dot_product_attention. This makes sense because we have a very large sequence length, and the default backend is not very memory-optimized. We should explore adding support for flash attention and sage attention. The amount of extra memory added is ~3 GB.

image

  • If we ignore the peaks from SDPA, the next candidate for optimization is the FeedForward intermediate activations. We can probably get rid of those by splitting the forward pass across non-channel dimensions. This will maybe save us 0.5-1 GB

@a-r-r-o-w
Copy link
Owner Author

a-r-r-o-w commented Jan 6, 2025

The above numbers were with precomputation. This is what we have without precomputation at 49x480x720 on Hunyuan (gradient_checkpointing, adamw, rank=128):

- no precomputation yes precomputation
bf16 54.5 45.8
fp8 45.7 31.4

For all cases, the numbers are without performing validation at all. If we do validation, memory blows up further

@a-r-r-o-w
Copy link
Owner Author

Training with single images works completely under 20 GB unless validation is performed (in which case it peaks around 57 GB. HunyuanVideo can be finetuned for styles/characters and a lot more with just images, so this is a great win!

logs
(nightly-venv) aryan@hf-dgx-01:/raid/aryan/cogvideox-distillation$ ./dump_training/hunyuan_video/test_fp8_short.sh
Running command: accelerate launch --config_file accelerate_configs/uncompiled_1.yaml --gpu_ids 3 train.py   --model_name hunyuan_video   --pretrained_model_name_or_path hunyuanvideo-community/HunyuanVideo   --layerwise_upcasting_modules transformer   --data_root /raid/aryan/video-dataset-disney   --video_column images.txt   --caption_column prompts.txt   --id_token afkx   --video_resolution_buckets 1x480x720   --caption_dropout_p 0.00   --precompute_conditions   --dataloader_num_workers 0   --flow_weighting_scheme logit_normal   --training_type lora   --seed 42   --batch_size 1   --train_steps 2   --rank 128   --lora_alpha 128   --target_modules to_q to_k to_v to_out.0   --gradient_accumulation_steps 1   --gradient_checkpointing   --checkpointing_steps 500   --checkpointing_limit 2   --enable_slicing   --enable_tiling   --optimizer adamw   --lr 1e-4   --lr_scheduler cosine_with_restarts   --lr_warmup_steps 100   --lr_num_cycles 1   --beta1 0.9   --beta2 0.95   --weight_decay 1e-4   --epsilon 1e-8   --max_grad_norm 1.0   --validation_prompts "afkx A black and white animated scene unfolds with an anthropomorphic goat surrounded by musical notes and symbols, suggesting a playful environment. Mickey Mouse appears, leaning forward in curiosity as the goat remains still. The goat then engages with Mickey, who bends down to converse or react. The dynamics shift as Mickey grabs the goat, potentially in surprise or playfulness, amidst a minimalistic background. The scene captures the evolving relationship between the two characters in a whimsical, animated setting, emphasizing their interactions and emotions.@@@1x480x720"   --num_validation_videos 1   --validation_steps 10   --tracker_name finetrainers-hunyuan-video   --output_dir /raid/aryan/hunyuan-video   --nccl_timeout 1800   --report_to wandb
Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
01/06/2025 02:11:50 - INFO - finetrainers - Initialized FineTrainers
01/06/2025 02:11:50 - INFO - finetrainers - Distributed environment: NO
Num processes: 1
Process index: 0
Local process index: 0
Device: cuda

Mixed precision type: bf16

01/06/2025 02:11:50 - INFO - finetrainers - Initializing dataset and dataloader
01/06/2025 02:11:50 - INFO - finetrainers - Initializing models
01/06/2025 02:11:50 - INFO - finetrainers - Initializing precomputations
01/06/2025 02:11:50 - INFO - finetrainers - Precomputed data not found. Running precomputation.
01/06/2025 02:11:50 - INFO - finetrainers - Precomputed conditions and latents not found. Running precomputation.
Downloading shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 14513.16it/s]
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:14<00:00,  3.75s/it]
Precomputing conditions: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:03<00:00,  3.93s/it]
01/06/2025 02:12:36 - INFO - finetrainers - Memory after precomputing conditions: {
    "memory_allocated": 14.232,
    "memory_reserved": 14.336,
    "max_memory_allocated": 14.395,
    "max_memory_reserved": 14.461
}
Precomputing conditions: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:06<00:00,  6.83s/it]
Precomputing latents:   0%|                                                                                                                                                                                                                                        | 0/1 [00:00<?, ?it/s01/06/2025 02:12:41 - INFO - finetrainers - Precomputation complete█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:02<00:00,  2.24s/it]
01/06/2025 02:12:41 - INFO - finetrainers - Memory after precomputing latents: {
    "memory_allocated": 14.713,
    "memory_reserved": 14.762,
    "max_memory_allocated": 14.89,
    "max_memory_reserved": 14.998
}
Precomputing latents: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:03<00:00,  3.38s/it]
01/06/2025 02:12:41 - INFO - finetrainers - Initializing trainable parameters
01/06/2025 02:13:03 - INFO - finetrainers - Initializing optimizer and lr scheduler
01/06/2025 02:13:03 - INFO - finetrainers - Initializing trackers
wandb: Tracking run with wandb version 0.17.7
wandb: W&B syncing is set to `offline` in this directory.
wandb: Run `wandb online` or set WANDB_MODE=online to enable cloud syncing.
01/06/2025 02:13:05 - INFO - finetrainers - Starting training
01/06/2025 02:13:05 - INFO - finetrainers - Memory before training start: {
    "memory_allocated": 15.934,
    "memory_reserved": 16.656,
    "max_memory_allocated": 15.934,
    "max_memory_reserved": 16.656
}
01/06/2025 02:13:05 - INFO - finetrainers - Training configuration: {
    "trainable parameters": 163577856,
    "total samples": 1,
    "train epochs": 2,
    "train steps": 2,
    "batches per device": 1,
    "total batches observed per epoch": 1,
    "train batch size": 1,
    "gradient accumulation steps": 1
}
Training steps:  50%|████████████████████████████████████████████████████████████████████████████████████████████████▌                                                                                                | 1/2 [00:12<00:12, 12.22s/it, grad_norm=0.132, loss=1.23, lr=1e-6]
01/06/2025 02:13:17 - INFO - finetrainers - Memory after epoch 1: {
    "memory_allocated": 17.177,
    "memory_reserved": 18.889,
    "max_memory_allocated": 18.45,
    "max_memory_reserved": 18.889
}
Training steps: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:15<00:00,  6.80s/it, grad_norm=0.142, loss=1.33, lr=2e-6]
01/06/2025 02:13:20 - INFO - finetrainers - Memory after epoch 2: {
    "memory_allocated": 17.178,
    "memory_reserved": 19.486,
    "max_memory_allocated": 18.944,
    "max_memory_reserved": 19.486
}
[2025-01-06 02:13:20,938] [INFO] [real_accelerator.py:203:get_accelerator] Setting ds_accelerator to cuda (auto detect)
Model weights saved in /raid/aryan/hunyuan-video/pytorch_lora_weights.safetensors
01/06/2025 02:13:23 - INFO - finetrainers - Starting validation
01/06/2025 02:13:23 - INFO - finetrainers - Memory before validation start: {
    "memory_allocated": 17.178,
    "memory_reserved": 19.486,
    "max_memory_allocated": 18.944,
    "max_memory_reserved": 19.486
}
                                                                                                                                                                                                                                                                                        Loaded tokenizer_2 as CLIPTokenizer from `tokenizer_2` subfolder of hunyuanvideo-community/HunyuanVideo.                                                                                                                                                            | 0/7 [00:00<?, ?it/s]
Loaded tokenizer as LlamaTokenizerFast from `tokenizer` subfolder of hunyuanvideo-community/HunyuanVideo.
                                                                                                                                                                                                                                                                                        Loaded transformer as HunyuanVideoTransformer3DModel from `transformer` subfolder of hunyuanvideo-community/HunyuanVideo.                                                                                                                                   | 2/7 [00:00<00:01,  3.48it/s]
                                                                                                                                                                                                                                                                                        Loaded text_encoder_2 as CLIPTextModel from `text_encoder_2` subfolder of hunyuanvideo-community/HunyuanVideo.███████████████████▋                                                                                                                          | 3/7 [00:18<00:30,  7.69s/it]
                                                                                                                                                                                                                                                                                        Loaded scheduler as FlowMatchEulerDiscreteScheduler from `scheduler` subfolder of hunyuanvideo-community/HunyuanVideo.██████████████████████████████████████████▎                                                                                           | 4/7 [00:19<00:14,  4.99s/it]
Loaded vae as AutoencoderKLHunyuanVideo from `vae` subfolder of hunyuanvideo-community/HunyuanVideo.
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:09<00:00,  2.39s/it]
Loaded text_encoder as LlamaModel from `text_encoder` subfolder of hunyuanvideo-community/HunyuanVideo.
Loading pipeline components...: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:28<00:00,  4.11s/it]
Loading transformer.ponents...: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:28<00:00,  3.91s/it]
Token indices sequence length is longer than the specified maximum sequence length for this model (108 > 77). Running this sequence through the model will result in indexing errors
The following part of your input was truncated because CLIP can only handle sequences up to 77 tokens: [', amidst a minimalistic background . the scene captures the evolving relationship between the two characters in a whimsical , animated setting , emphasizing their interactions and emotions .']
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:01<00:00,  1.44it/s]
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...██████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:01<00:00,  1.63it/s]
To disable this warning, you can either:
        - Avoid using `tokenizers` before the fork if possible
        - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
01/06/2025 02:15:37 - INFO - finetrainers - Memory after validation end: {
    "memory_allocated": 17.178,
    "memory_reserved": 17.932,
    "max_memory_allocated": 56.735,
    "max_memory_reserved": 57.203
}
01/06/2025 02:15:37 - INFO - finetrainers - Memory after training end: {
    "memory_allocated": 17.178,
    "memory_reserved": 17.932,
    "max_memory_allocated": 17.178,
    "max_memory_reserved": 17.932
}

@a-r-r-o-w
Copy link
Owner Author

Multi-GPU layerwise upcasting training seems to work without errors for Hopper (tested on 8xH100) and Ada (tested on 2x RTX 4090).

Seems like Ampere doesn't have the relevant bits implemented for fp8 DDP as mentioned in the linked issue, which is where the above error stack trace comes from.

@a-r-r-o-w
Copy link
Owner Author

a-r-r-o-w commented Jan 6, 2025

Oh no, it failed on Hopper too (8x DDP)... Training works, but it failed when validation started. This is because I did not account for some cases in the pipeline like guidance preparation.

ERROR:finetrainers:Traceback (most recent call last):
  File "/fsx/aryan/finetrainers/train.py", line 35, in main
    trainer.train()
  File "/fsx/aryan/finetrainers/finetrainers/trainer.py", line 838, in train
    self.validate(global_step)
  File "/fsx/aryan/finetrainers/finetrainers/trainer.py", line 952, in validate
    validation_artifacts = self.model_config["validation"](
  File "/fsx/aryan/finetrainers/finetrainers/hunyuan_video/hunyuan_video_lora.py", line 266, in validation
    output = pipeline(**generation_kwargs).frames[0]
  File "/fsx/aryan/nightly-venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/fsx/aryan/diffusers/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py", line 632, in __call__
    guidance = torch.tensor([guidance_scale] * latents.shape[0], dtype=transformer_dtype, device=device) * 1000.0
RuntimeError: "mul_cuda" not implemented for 'Float8_e4m3fn'

@neph1
Copy link

neph1 commented Jan 6, 2025

I've done a brief test with LTX-Video with this branch, and I'm seeing a ~13% VRAM reduction when using float8_e4m3fn. Speed wise it seems to be about the same.
Single-GPU, 3090.

Edit: Oh, and keep up the good work! ;)

Edit 2: Actually, during the short time I ran it, it seems to be about 5-10% faster per step as well.

@sayakpaul
Copy link
Collaborator

Speed wise it seems to be about the same.

@neph1 I think that is expected because from what I understand we're not launching native FP8 kernels

@a-r-r-o-w
Copy link
Owner Author

a-r-r-o-w commented Jan 7, 2025

Just for future self-reference for debugging, dtypes of different layers:

dtypes
x_embedder.proj torch.float8_e4m3fn
context_embedder.time_text_embed.timestep_embedder.linear_1 torch.float8_e4m3fn
context_embedder.time_text_embed.timestep_embedder.linear_2 torch.float8_e4m3fn
context_embedder.time_text_embed.text_embedder.linear_1 torch.float8_e4m3fn
context_embedder.time_text_embed.text_embedder.linear_2 torch.float8_e4m3fn
context_embedder.proj_in torch.float8_e4m3fn
context_embedder.token_refiner.refiner_blocks.0.norm1 torch.bfloat16
context_embedder.token_refiner.refiner_blocks.0.attn.to_q torch.float8_e4m3fn
context_embedder.token_refiner.refiner_blocks.0.attn.to_q.base_layer torch.float8_e4m3fn
context_embedder.token_refiner.refiner_blocks.0.attn.to_q.lora_A.default torch.float32
context_embedder.token_refiner.refiner_blocks.0.attn.to_q.lora_B.default torch.float32
context_embedder.token_refiner.refiner_blocks.0.attn.to_k torch.float8_e4m3fn
context_embedder.token_refiner.refiner_blocks.0.attn.to_k.base_layer torch.float8_e4m3fn
context_embedder.token_refiner.refiner_blocks.0.attn.to_k.lora_A.default torch.float32
context_embedder.token_refiner.refiner_blocks.0.attn.to_k.lora_B.default torch.float32
context_embedder.token_refiner.refiner_blocks.0.attn.to_v torch.float8_e4m3fn
context_embedder.token_refiner.refiner_blocks.0.attn.to_v.base_layer torch.float8_e4m3fn
context_embedder.token_refiner.refiner_blocks.0.attn.to_v.lora_A.default torch.float32
context_embedder.token_refiner.refiner_blocks.0.attn.to_v.lora_B.default torch.float32
context_embedder.token_refiner.refiner_blocks.0.attn.to_out.0 torch.float8_e4m3fn
context_embedder.token_refiner.refiner_blocks.0.attn.to_out.0.base_layer torch.float8_e4m3fn
context_embedder.token_refiner.refiner_blocks.0.attn.to_out.0.lora_A.default torch.float32
context_embedder.token_refiner.refiner_blocks.0.attn.to_out.0.lora_B.default torch.float32
context_embedder.token_refiner.refiner_blocks.0.norm2 torch.bfloat16
context_embedder.token_refiner.refiner_blocks.0.ff.net.0.proj torch.float8_e4m3fn
context_embedder.token_refiner.refiner_blocks.0.ff.net.2 torch.float8_e4m3fn
context_embedder.token_refiner.refiner_blocks.0.norm_out.linear torch.bfloat16
context_embedder.token_refiner.refiner_blocks.1.norm1 torch.bfloat16
context_embedder.token_refiner.refiner_blocks.1.attn.to_q torch.float8_e4m3fn
context_embedder.token_refiner.refiner_blocks.1.attn.to_q.base_layer torch.float8_e4m3fn
context_embedder.token_refiner.refiner_blocks.1.attn.to_q.lora_A.default torch.float32
context_embedder.token_refiner.refiner_blocks.1.attn.to_q.lora_B.default torch.float32
context_embedder.token_refiner.refiner_blocks.1.attn.to_k torch.float8_e4m3fn
context_embedder.token_refiner.refiner_blocks.1.attn.to_k.base_layer torch.float8_e4m3fn
context_embedder.token_refiner.refiner_blocks.1.attn.to_k.lora_A.default torch.float32
context_embedder.token_refiner.refiner_blocks.1.attn.to_k.lora_B.default torch.float32
context_embedder.token_refiner.refiner_blocks.1.attn.to_v torch.float8_e4m3fn
context_embedder.token_refiner.refiner_blocks.1.attn.to_v.base_layer torch.float8_e4m3fn
context_embedder.token_refiner.refiner_blocks.1.attn.to_v.lora_A.default torch.float32
context_embedder.token_refiner.refiner_blocks.1.attn.to_v.lora_B.default torch.float32
context_embedder.token_refiner.refiner_blocks.1.attn.to_out.0 torch.float8_e4m3fn
context_embedder.token_refiner.refiner_blocks.1.attn.to_out.0.base_layer torch.float8_e4m3fn
context_embedder.token_refiner.refiner_blocks.1.attn.to_out.0.lora_A.default torch.float32
context_embedder.token_refiner.refiner_blocks.1.attn.to_out.0.lora_B.default torch.float32
context_embedder.token_refiner.refiner_blocks.1.norm2 torch.bfloat16
context_embedder.token_refiner.refiner_blocks.1.ff.net.0.proj torch.float8_e4m3fn
context_embedder.token_refiner.refiner_blocks.1.ff.net.2 torch.float8_e4m3fn
context_embedder.token_refiner.refiner_blocks.1.norm_out.linear torch.bfloat16
time_text_embed.timestep_embedder.linear_1 torch.float8_e4m3fn
time_text_embed.timestep_embedder.linear_2 torch.float8_e4m3fn
time_text_embed.guidance_embedder.linear_1 torch.float8_e4m3fn
time_text_embed.guidance_embedder.linear_2 torch.float8_e4m3fn
time_text_embed.text_embedder.linear_1 torch.float8_e4m3fn
time_text_embed.text_embedder.linear_2 torch.float8_e4m3fn
transformer_blocks.0.norm1.linear torch.bfloat16
transformer_blocks.0.norm1_context.linear torch.bfloat16
transformer_blocks.0.attn.norm_q torch.bfloat16
transformer_blocks.0.attn.norm_k torch.bfloat16
transformer_blocks.0.attn.to_q torch.float8_e4m3fn
transformer_blocks.0.attn.to_q.base_layer torch.float8_e4m3fn
transformer_blocks.0.attn.to_q.lora_A.default torch.float32
transformer_blocks.0.attn.to_q.lora_B.default torch.float32
transformer_blocks.0.attn.to_k torch.float8_e4m3fn
transformer_blocks.0.attn.to_k.base_layer torch.float8_e4m3fn
transformer_blocks.0.attn.to_k.lora_A.default torch.float32
transformer_blocks.0.attn.to_k.lora_B.default torch.float32
transformer_blocks.0.attn.to_v torch.float8_e4m3fn
transformer_blocks.0.attn.to_v.base_layer torch.float8_e4m3fn
transformer_blocks.0.attn.to_v.lora_A.default torch.float32
transformer_blocks.0.attn.to_v.lora_B.default torch.float32
transformer_blocks.0.attn.add_k_proj torch.float8_e4m3fn
transformer_blocks.0.attn.add_k_proj.base_layer torch.float8_e4m3fn
transformer_blocks.0.attn.add_k_proj.lora_A.default torch.float32
transformer_blocks.0.attn.add_k_proj.lora_B.default torch.float32
transformer_blocks.0.attn.add_v_proj torch.float8_e4m3fn
transformer_blocks.0.attn.add_v_proj.base_layer torch.float8_e4m3fn
transformer_blocks.0.attn.add_v_proj.lora_A.default torch.float32
transformer_blocks.0.attn.add_v_proj.lora_B.default torch.float32
transformer_blocks.0.attn.add_q_proj torch.float8_e4m3fn
transformer_blocks.0.attn.add_q_proj.base_layer torch.float8_e4m3fn
transformer_blocks.0.attn.add_q_proj.lora_A.default torch.float32
transformer_blocks.0.attn.add_q_proj.lora_B.default torch.float32
transformer_blocks.0.attn.to_out.0 torch.float8_e4m3fn
transformer_blocks.0.attn.to_out.0.base_layer torch.float8_e4m3fn
transformer_blocks.0.attn.to_out.0.lora_A.default torch.float32
transformer_blocks.0.attn.to_out.0.lora_B.default torch.float32
transformer_blocks.0.attn.to_add_out torch.float8_e4m3fn
transformer_blocks.0.attn.to_add_out.base_layer torch.float8_e4m3fn
transformer_blocks.0.attn.to_add_out.lora_A.default torch.float32
transformer_blocks.0.attn.to_add_out.lora_B.default torch.float32
transformer_blocks.0.attn.norm_added_q torch.bfloat16
transformer_blocks.0.attn.norm_added_k torch.bfloat16
transformer_blocks.0.ff.net.0.proj torch.float8_e4m3fn
transformer_blocks.0.ff.net.2 torch.float8_e4m3fn
transformer_blocks.0.ff_context.net.0.proj torch.float8_e4m3fn
transformer_blocks.0.ff_context.net.2 torch.float8_e4m3fn
transformer_blocks.1.norm1.linear torch.bfloat16
transformer_blocks.1.norm1_context.linear torch.bfloat16
transformer_blocks.1.attn.norm_q torch.bfloat16
transformer_blocks.1.attn.norm_k torch.bfloat16
transformer_blocks.1.attn.to_q torch.float8_e4m3fn
transformer_blocks.1.attn.to_q.base_layer torch.float8_e4m3fn
transformer_blocks.1.attn.to_q.lora_A.default torch.float32
transformer_blocks.1.attn.to_q.lora_B.default torch.float32
transformer_blocks.1.attn.to_k torch.float8_e4m3fn
transformer_blocks.1.attn.to_k.base_layer torch.float8_e4m3fn
transformer_blocks.1.attn.to_k.lora_A.default torch.float32
transformer_blocks.1.attn.to_k.lora_B.default torch.float32
transformer_blocks.1.attn.to_v torch.float8_e4m3fn
transformer_blocks.1.attn.to_v.base_layer torch.float8_e4m3fn
transformer_blocks.1.attn.to_v.lora_A.default torch.float32
transformer_blocks.1.attn.to_v.lora_B.default torch.float32
transformer_blocks.1.attn.add_k_proj torch.float8_e4m3fn
transformer_blocks.1.attn.add_k_proj.base_layer torch.float8_e4m3fn
transformer_blocks.1.attn.add_k_proj.lora_A.default torch.float32
transformer_blocks.1.attn.add_k_proj.lora_B.default torch.float32
transformer_blocks.1.attn.add_v_proj torch.float8_e4m3fn
transformer_blocks.1.attn.add_v_proj.base_layer torch.float8_e4m3fn
transformer_blocks.1.attn.add_v_proj.lora_A.default torch.float32
transformer_blocks.1.attn.add_v_proj.lora_B.default torch.float32
transformer_blocks.1.attn.add_q_proj torch.float8_e4m3fn
transformer_blocks.1.attn.add_q_proj.base_layer torch.float8_e4m3fn
transformer_blocks.1.attn.add_q_proj.lora_A.default torch.float32
transformer_blocks.1.attn.add_q_proj.lora_B.default torch.float32
transformer_blocks.1.attn.to_out.0 torch.float8_e4m3fn
transformer_blocks.1.attn.to_out.0.base_layer torch.float8_e4m3fn
transformer_blocks.1.attn.to_out.0.lora_A.default torch.float32
transformer_blocks.1.attn.to_out.0.lora_B.default torch.float32
transformer_blocks.1.attn.to_add_out torch.float8_e4m3fn
transformer_blocks.1.attn.to_add_out.base_layer torch.float8_e4m3fn
transformer_blocks.1.attn.to_add_out.lora_A.default torch.float32
transformer_blocks.1.attn.to_add_out.lora_B.default torch.float32
transformer_blocks.1.attn.norm_added_q torch.bfloat16
transformer_blocks.1.attn.norm_added_k torch.bfloat16
transformer_blocks.1.ff.net.0.proj torch.float8_e4m3fn
transformer_blocks.1.ff.net.2 torch.float8_e4m3fn
transformer_blocks.1.ff_context.net.0.proj torch.float8_e4m3fn
transformer_blocks.1.ff_context.net.2 torch.float8_e4m3fn
transformer_blocks.2.norm1.linear torch.bfloat16
transformer_blocks.2.norm1_context.linear torch.bfloat16
transformer_blocks.2.attn.norm_q torch.bfloat16
transformer_blocks.2.attn.norm_k torch.bfloat16
transformer_blocks.2.attn.to_q torch.float8_e4m3fn
transformer_blocks.2.attn.to_q.base_layer torch.float8_e4m3fn
transformer_blocks.2.attn.to_q.lora_A.default torch.float32
transformer_blocks.2.attn.to_q.lora_B.default torch.float32
transformer_blocks.2.attn.to_k torch.float8_e4m3fn
transformer_blocks.2.attn.to_k.base_layer torch.float8_e4m3fn
transformer_blocks.2.attn.to_k.lora_A.default torch.float32
transformer_blocks.2.attn.to_k.lora_B.default torch.float32
transformer_blocks.2.attn.to_v torch.float8_e4m3fn
transformer_blocks.2.attn.to_v.base_layer torch.float8_e4m3fn
transformer_blocks.2.attn.to_v.lora_A.default torch.float32
transformer_blocks.2.attn.to_v.lora_B.default torch.float32
transformer_blocks.2.attn.add_k_proj torch.float8_e4m3fn
transformer_blocks.2.attn.add_k_proj.base_layer torch.float8_e4m3fn
transformer_blocks.2.attn.add_k_proj.lora_A.default torch.float32
transformer_blocks.2.attn.add_k_proj.lora_B.default torch.float32
transformer_blocks.2.attn.add_v_proj torch.float8_e4m3fn
transformer_blocks.2.attn.add_v_proj.base_layer torch.float8_e4m3fn
transformer_blocks.2.attn.add_v_proj.lora_A.default torch.float32
transformer_blocks.2.attn.add_v_proj.lora_B.default torch.float32
transformer_blocks.2.attn.add_q_proj torch.float8_e4m3fn
transformer_blocks.2.attn.add_q_proj.base_layer torch.float8_e4m3fn
transformer_blocks.2.attn.add_q_proj.lora_A.default torch.float32
transformer_blocks.2.attn.add_q_proj.lora_B.default torch.float32
transformer_blocks.2.attn.to_out.0 torch.float8_e4m3fn
transformer_blocks.2.attn.to_out.0.base_layer torch.float8_e4m3fn
transformer_blocks.2.attn.to_out.0.lora_A.default torch.float32
transformer_blocks.2.attn.to_out.0.lora_B.default torch.float32
transformer_blocks.2.attn.to_add_out torch.float8_e4m3fn
transformer_blocks.2.attn.to_add_out.base_layer torch.float8_e4m3fn
transformer_blocks.2.attn.to_add_out.lora_A.default torch.float32
transformer_blocks.2.attn.to_add_out.lora_B.default torch.float32
transformer_blocks.2.attn.norm_added_q torch.bfloat16
transformer_blocks.2.attn.norm_added_k torch.bfloat16
transformer_blocks.2.ff.net.0.proj torch.float8_e4m3fn
transformer_blocks.2.ff.net.2 torch.float8_e4m3fn
transformer_blocks.2.ff_context.net.0.proj torch.float8_e4m3fn
transformer_blocks.2.ff_context.net.2 torch.float8_e4m3fn
transformer_blocks.3.norm1.linear torch.bfloat16
transformer_blocks.3.norm1_context.linear torch.bfloat16
transformer_blocks.3.attn.norm_q torch.bfloat16
transformer_blocks.3.attn.norm_k torch.bfloat16
transformer_blocks.3.attn.to_q torch.float8_e4m3fn
transformer_blocks.3.attn.to_q.base_layer torch.float8_e4m3fn
transformer_blocks.3.attn.to_q.lora_A.default torch.float32
transformer_blocks.3.attn.to_q.lora_B.default torch.float32
transformer_blocks.3.attn.to_k torch.float8_e4m3fn
transformer_blocks.3.attn.to_k.base_layer torch.float8_e4m3fn
transformer_blocks.3.attn.to_k.lora_A.default torch.float32
transformer_blocks.3.attn.to_k.lora_B.default torch.float32
transformer_blocks.3.attn.to_v torch.float8_e4m3fn
transformer_blocks.3.attn.to_v.base_layer torch.float8_e4m3fn
transformer_blocks.3.attn.to_v.lora_A.default torch.float32
transformer_blocks.3.attn.to_v.lora_B.default torch.float32
transformer_blocks.3.attn.add_k_proj torch.float8_e4m3fn
transformer_blocks.3.attn.add_k_proj.base_layer torch.float8_e4m3fn
transformer_blocks.3.attn.add_k_proj.lora_A.default torch.float32
transformer_blocks.3.attn.add_k_proj.lora_B.default torch.float32
transformer_blocks.3.attn.add_v_proj torch.float8_e4m3fn
transformer_blocks.3.attn.add_v_proj.base_layer torch.float8_e4m3fn
transformer_blocks.3.attn.add_v_proj.lora_A.default torch.float32
transformer_blocks.3.attn.add_v_proj.lora_B.default torch.float32
transformer_blocks.3.attn.add_q_proj torch.float8_e4m3fn
transformer_blocks.3.attn.add_q_proj.base_layer torch.float8_e4m3fn
transformer_blocks.3.attn.add_q_proj.lora_A.default torch.float32
transformer_blocks.3.attn.add_q_proj.lora_B.default torch.float32
transformer_blocks.3.attn.to_out.0 torch.float8_e4m3fn
transformer_blocks.3.attn.to_out.0.base_layer torch.float8_e4m3fn
transformer_blocks.3.attn.to_out.0.lora_A.default torch.float32
transformer_blocks.3.attn.to_out.0.lora_B.default torch.float32
transformer_blocks.3.attn.to_add_out torch.float8_e4m3fn
transformer_blocks.3.attn.to_add_out.base_layer torch.float8_e4m3fn
transformer_blocks.3.attn.to_add_out.lora_A.default torch.float32
transformer_blocks.3.attn.to_add_out.lora_B.default torch.float32
transformer_blocks.3.attn.norm_added_q torch.bfloat16
transformer_blocks.3.attn.norm_added_k torch.bfloat16
transformer_blocks.3.ff.net.0.proj torch.float8_e4m3fn
transformer_blocks.3.ff.net.2 torch.float8_e4m3fn
transformer_blocks.3.ff_context.net.0.proj torch.float8_e4m3fn
transformer_blocks.3.ff_context.net.2 torch.float8_e4m3fn
transformer_blocks.4.norm1.linear torch.bfloat16
transformer_blocks.4.norm1_context.linear torch.bfloat16
transformer_blocks.4.attn.norm_q torch.bfloat16
transformer_blocks.4.attn.norm_k torch.bfloat16
transformer_blocks.4.attn.to_q torch.float8_e4m3fn
transformer_blocks.4.attn.to_q.base_layer torch.float8_e4m3fn
transformer_blocks.4.attn.to_q.lora_A.default torch.float32
transformer_blocks.4.attn.to_q.lora_B.default torch.float32
transformer_blocks.4.attn.to_k torch.float8_e4m3fn
transformer_blocks.4.attn.to_k.base_layer torch.float8_e4m3fn
transformer_blocks.4.attn.to_k.lora_A.default torch.float32
transformer_blocks.4.attn.to_k.lora_B.default torch.float32
transformer_blocks.4.attn.to_v torch.float8_e4m3fn
transformer_blocks.4.attn.to_v.base_layer torch.float8_e4m3fn
transformer_blocks.4.attn.to_v.lora_A.default torch.float32
transformer_blocks.4.attn.to_v.lora_B.default torch.float32
transformer_blocks.4.attn.add_k_proj torch.float8_e4m3fn
transformer_blocks.4.attn.add_k_proj.base_layer torch.float8_e4m3fn
transformer_blocks.4.attn.add_k_proj.lora_A.default torch.float32
transformer_blocks.4.attn.add_k_proj.lora_B.default torch.float32
transformer_blocks.4.attn.add_v_proj torch.float8_e4m3fn
transformer_blocks.4.attn.add_v_proj.base_layer torch.float8_e4m3fn
transformer_blocks.4.attn.add_v_proj.lora_A.default torch.float32
transformer_blocks.4.attn.add_v_proj.lora_B.default torch.float32
transformer_blocks.4.attn.add_q_proj torch.float8_e4m3fn
transformer_blocks.4.attn.add_q_proj.base_layer torch.float8_e4m3fn
transformer_blocks.4.attn.add_q_proj.lora_A.default torch.float32
transformer_blocks.4.attn.add_q_proj.lora_B.default torch.float32
transformer_blocks.4.attn.to_out.0 torch.float8_e4m3fn
transformer_blocks.4.attn.to_out.0.base_layer torch.float8_e4m3fn
transformer_blocks.4.attn.to_out.0.lora_A.default torch.float32
transformer_blocks.4.attn.to_out.0.lora_B.default torch.float32
transformer_blocks.4.attn.to_add_out torch.float8_e4m3fn
transformer_blocks.4.attn.to_add_out.base_layer torch.float8_e4m3fn
transformer_blocks.4.attn.to_add_out.lora_A.default torch.float32
transformer_blocks.4.attn.to_add_out.lora_B.default torch.float32
transformer_blocks.4.attn.norm_added_q torch.bfloat16
transformer_blocks.4.attn.norm_added_k torch.bfloat16
transformer_blocks.4.ff.net.0.proj torch.float8_e4m3fn
transformer_blocks.4.ff.net.2 torch.float8_e4m3fn
transformer_blocks.4.ff_context.net.0.proj torch.float8_e4m3fn
transformer_blocks.4.ff_context.net.2 torch.float8_e4m3fn
transformer_blocks.5.norm1.linear torch.bfloat16
transformer_blocks.5.norm1_context.linear torch.bfloat16
transformer_blocks.5.attn.norm_q torch.bfloat16
transformer_blocks.5.attn.norm_k torch.bfloat16
transformer_blocks.5.attn.to_q torch.float8_e4m3fn
transformer_blocks.5.attn.to_q.base_layer torch.float8_e4m3fn
transformer_blocks.5.attn.to_q.lora_A.default torch.float32
transformer_blocks.5.attn.to_q.lora_B.default torch.float32
transformer_blocks.5.attn.to_k torch.float8_e4m3fn
transformer_blocks.5.attn.to_k.base_layer torch.float8_e4m3fn
transformer_blocks.5.attn.to_k.lora_A.default torch.float32
transformer_blocks.5.attn.to_k.lora_B.default torch.float32
transformer_blocks.5.attn.to_v torch.float8_e4m3fn
transformer_blocks.5.attn.to_v.base_layer torch.float8_e4m3fn
transformer_blocks.5.attn.to_v.lora_A.default torch.float32
transformer_blocks.5.attn.to_v.lora_B.default torch.float32
transformer_blocks.5.attn.add_k_proj torch.float8_e4m3fn
transformer_blocks.5.attn.add_k_proj.base_layer torch.float8_e4m3fn
transformer_blocks.5.attn.add_k_proj.lora_A.default torch.float32
transformer_blocks.5.attn.add_k_proj.lora_B.default torch.float32
transformer_blocks.5.attn.add_v_proj torch.float8_e4m3fn
transformer_blocks.5.attn.add_v_proj.base_layer torch.float8_e4m3fn
transformer_blocks.5.attn.add_v_proj.lora_A.default torch.float32
transformer_blocks.5.attn.add_v_proj.lora_B.default torch.float32
transformer_blocks.5.attn.add_q_proj torch.float8_e4m3fn
transformer_blocks.5.attn.add_q_proj.base_layer torch.float8_e4m3fn
transformer_blocks.5.attn.add_q_proj.lora_A.default torch.float32
transformer_blocks.5.attn.add_q_proj.lora_B.default torch.float32
transformer_blocks.5.attn.to_out.0 torch.float8_e4m3fn
transformer_blocks.5.attn.to_out.0.base_layer torch.float8_e4m3fn
transformer_blocks.5.attn.to_out.0.lora_A.default torch.float32
transformer_blocks.5.attn.to_out.0.lora_B.default torch.float32
transformer_blocks.5.attn.to_add_out torch.float8_e4m3fn
transformer_blocks.5.attn.to_add_out.base_layer torch.float8_e4m3fn
transformer_blocks.5.attn.to_add_out.lora_A.default torch.float32
transformer_blocks.5.attn.to_add_out.lora_B.default torch.float32
transformer_blocks.5.attn.norm_added_q torch.bfloat16
transformer_blocks.5.attn.norm_added_k torch.bfloat16
transformer_blocks.5.ff.net.0.proj torch.float8_e4m3fn
transformer_blocks.5.ff.net.2 torch.float8_e4m3fn
transformer_blocks.5.ff_context.net.0.proj torch.float8_e4m3fn
transformer_blocks.5.ff_context.net.2 torch.float8_e4m3fn
transformer_blocks.6.norm1.linear torch.bfloat16
transformer_blocks.6.norm1_context.linear torch.bfloat16
transformer_blocks.6.attn.norm_q torch.bfloat16
transformer_blocks.6.attn.norm_k torch.bfloat16
transformer_blocks.6.attn.to_q torch.float8_e4m3fn
transformer_blocks.6.attn.to_q.base_layer torch.float8_e4m3fn
transformer_blocks.6.attn.to_q.lora_A.default torch.float32
transformer_blocks.6.attn.to_q.lora_B.default torch.float32
transformer_blocks.6.attn.to_k torch.float8_e4m3fn
transformer_blocks.6.attn.to_k.base_layer torch.float8_e4m3fn
transformer_blocks.6.attn.to_k.lora_A.default torch.float32
transformer_blocks.6.attn.to_k.lora_B.default torch.float32
transformer_blocks.6.attn.to_v torch.float8_e4m3fn
transformer_blocks.6.attn.to_v.base_layer torch.float8_e4m3fn
transformer_blocks.6.attn.to_v.lora_A.default torch.float32
transformer_blocks.6.attn.to_v.lora_B.default torch.float32
transformer_blocks.6.attn.add_k_proj torch.float8_e4m3fn
transformer_blocks.6.attn.add_k_proj.base_layer torch.float8_e4m3fn
transformer_blocks.6.attn.add_k_proj.lora_A.default torch.float32
transformer_blocks.6.attn.add_k_proj.lora_B.default torch.float32
transformer_blocks.6.attn.add_v_proj torch.float8_e4m3fn
transformer_blocks.6.attn.add_v_proj.base_layer torch.float8_e4m3fn
transformer_blocks.6.attn.add_v_proj.lora_A.default torch.float32
transformer_blocks.6.attn.add_v_proj.lora_B.default torch.float32
transformer_blocks.6.attn.add_q_proj torch.float8_e4m3fn
transformer_blocks.6.attn.add_q_proj.base_layer torch.float8_e4m3fn
transformer_blocks.6.attn.add_q_proj.lora_A.default torch.float32
transformer_blocks.6.attn.add_q_proj.lora_B.default torch.float32
transformer_blocks.6.attn.to_out.0 torch.float8_e4m3fn
transformer_blocks.6.attn.to_out.0.base_layer torch.float8_e4m3fn
transformer_blocks.6.attn.to_out.0.lora_A.default torch.float32
transformer_blocks.6.attn.to_out.0.lora_B.default torch.float32
transformer_blocks.6.attn.to_add_out torch.float8_e4m3fn
transformer_blocks.6.attn.to_add_out.base_layer torch.float8_e4m3fn
transformer_blocks.6.attn.to_add_out.lora_A.default torch.float32
transformer_blocks.6.attn.to_add_out.lora_B.default torch.float32
transformer_blocks.6.attn.norm_added_q torch.bfloat16
transformer_blocks.6.attn.norm_added_k torch.bfloat16
transformer_blocks.6.ff.net.0.proj torch.float8_e4m3fn
transformer_blocks.6.ff.net.2 torch.float8_e4m3fn
transformer_blocks.6.ff_context.net.0.proj torch.float8_e4m3fn
transformer_blocks.6.ff_context.net.2 torch.float8_e4m3fn
transformer_blocks.7.norm1.linear torch.bfloat16
transformer_blocks.7.norm1_context.linear torch.bfloat16
transformer_blocks.7.attn.norm_q torch.bfloat16
transformer_blocks.7.attn.norm_k torch.bfloat16
transformer_blocks.7.attn.to_q torch.float8_e4m3fn
transformer_blocks.7.attn.to_q.base_layer torch.float8_e4m3fn
transformer_blocks.7.attn.to_q.lora_A.default torch.float32
transformer_blocks.7.attn.to_q.lora_B.default torch.float32
transformer_blocks.7.attn.to_k torch.float8_e4m3fn
transformer_blocks.7.attn.to_k.base_layer torch.float8_e4m3fn
transformer_blocks.7.attn.to_k.lora_A.default torch.float32
transformer_blocks.7.attn.to_k.lora_B.default torch.float32
transformer_blocks.7.attn.to_v torch.float8_e4m3fn
transformer_blocks.7.attn.to_v.base_layer torch.float8_e4m3fn
transformer_blocks.7.attn.to_v.lora_A.default torch.float32
transformer_blocks.7.attn.to_v.lora_B.default torch.float32
transformer_blocks.7.attn.add_k_proj torch.float8_e4m3fn
transformer_blocks.7.attn.add_k_proj.base_layer torch.float8_e4m3fn
transformer_blocks.7.attn.add_k_proj.lora_A.default torch.float32
transformer_blocks.7.attn.add_k_proj.lora_B.default torch.float32
transformer_blocks.7.attn.add_v_proj torch.float8_e4m3fn
transformer_blocks.7.attn.add_v_proj.base_layer torch.float8_e4m3fn
transformer_blocks.7.attn.add_v_proj.lora_A.default torch.float32
transformer_blocks.7.attn.add_v_proj.lora_B.default torch.float32
transformer_blocks.7.attn.add_q_proj torch.float8_e4m3fn
transformer_blocks.7.attn.add_q_proj.base_layer torch.float8_e4m3fn
transformer_blocks.7.attn.add_q_proj.lora_A.default torch.float32
transformer_blocks.7.attn.add_q_proj.lora_B.default torch.float32
transformer_blocks.7.attn.to_out.0 torch.float8_e4m3fn
transformer_blocks.7.attn.to_out.0.base_layer torch.float8_e4m3fn
transformer_blocks.7.attn.to_out.0.lora_A.default torch.float32
transformer_blocks.7.attn.to_out.0.lora_B.default torch.float32
transformer_blocks.7.attn.to_add_out torch.float8_e4m3fn
transformer_blocks.7.attn.to_add_out.base_layer torch.float8_e4m3fn
transformer_blocks.7.attn.to_add_out.lora_A.default torch.float32
transformer_blocks.7.attn.to_add_out.lora_B.default torch.float32
transformer_blocks.7.attn.norm_added_q torch.bfloat16
transformer_blocks.7.attn.norm_added_k torch.bfloat16
transformer_blocks.7.ff.net.0.proj torch.float8_e4m3fn
transformer_blocks.7.ff.net.2 torch.float8_e4m3fn
transformer_blocks.7.ff_context.net.0.proj torch.float8_e4m3fn
transformer_blocks.7.ff_context.net.2 torch.float8_e4m3fn
transformer_blocks.8.norm1.linear torch.bfloat16
transformer_blocks.8.norm1_context.linear torch.bfloat16
transformer_blocks.8.attn.norm_q torch.bfloat16
transformer_blocks.8.attn.norm_k torch.bfloat16
transformer_blocks.8.attn.to_q torch.float8_e4m3fn
transformer_blocks.8.attn.to_q.base_layer torch.float8_e4m3fn
transformer_blocks.8.attn.to_q.lora_A.default torch.float32
transformer_blocks.8.attn.to_q.lora_B.default torch.float32
transformer_blocks.8.attn.to_k torch.float8_e4m3fn
transformer_blocks.8.attn.to_k.base_layer torch.float8_e4m3fn
transformer_blocks.8.attn.to_k.lora_A.default torch.float32
transformer_blocks.8.attn.to_k.lora_B.default torch.float32
transformer_blocks.8.attn.to_v torch.float8_e4m3fn
transformer_blocks.8.attn.to_v.base_layer torch.float8_e4m3fn
transformer_blocks.8.attn.to_v.lora_A.default torch.float32
transformer_blocks.8.attn.to_v.lora_B.default torch.float32
transformer_blocks.8.attn.add_k_proj torch.float8_e4m3fn
transformer_blocks.8.attn.add_k_proj.base_layer torch.float8_e4m3fn
transformer_blocks.8.attn.add_k_proj.lora_A.default torch.float32
transformer_blocks.8.attn.add_k_proj.lora_B.default torch.float32
transformer_blocks.8.attn.add_v_proj torch.float8_e4m3fn
transformer_blocks.8.attn.add_v_proj.base_layer torch.float8_e4m3fn
transformer_blocks.8.attn.add_v_proj.lora_A.default torch.float32
transformer_blocks.8.attn.add_v_proj.lora_B.default torch.float32
transformer_blocks.8.attn.add_q_proj torch.float8_e4m3fn
transformer_blocks.8.attn.add_q_proj.base_layer torch.float8_e4m3fn
transformer_blocks.8.attn.add_q_proj.lora_A.default torch.float32
transformer_blocks.8.attn.add_q_proj.lora_B.default torch.float32
transformer_blocks.8.attn.to_out.0 torch.float8_e4m3fn
transformer_blocks.8.attn.to_out.0.base_layer torch.float8_e4m3fn
transformer_blocks.8.attn.to_out.0.lora_A.default torch.float32
transformer_blocks.8.attn.to_out.0.lora_B.default torch.float32
transformer_blocks.8.attn.to_add_out torch.float8_e4m3fn
transformer_blocks.8.attn.to_add_out.base_layer torch.float8_e4m3fn
transformer_blocks.8.attn.to_add_out.lora_A.default torch.float32
transformer_blocks.8.attn.to_add_out.lora_B.default torch.float32
transformer_blocks.8.attn.norm_added_q torch.bfloat16
transformer_blocks.8.attn.norm_added_k torch.bfloat16
transformer_blocks.8.ff.net.0.proj torch.float8_e4m3fn
transformer_blocks.8.ff.net.2 torch.float8_e4m3fn
transformer_blocks.8.ff_context.net.0.proj torch.float8_e4m3fn
transformer_blocks.8.ff_context.net.2 torch.float8_e4m3fn
transformer_blocks.9.norm1.linear torch.bfloat16
transformer_blocks.9.norm1_context.linear torch.bfloat16
transformer_blocks.9.attn.norm_q torch.bfloat16
transformer_blocks.9.attn.norm_k torch.bfloat16
transformer_blocks.9.attn.to_q torch.float8_e4m3fn
transformer_blocks.9.attn.to_q.base_layer torch.float8_e4m3fn
transformer_blocks.9.attn.to_q.lora_A.default torch.float32
transformer_blocks.9.attn.to_q.lora_B.default torch.float32
transformer_blocks.9.attn.to_k torch.float8_e4m3fn
transformer_blocks.9.attn.to_k.base_layer torch.float8_e4m3fn
transformer_blocks.9.attn.to_k.lora_A.default torch.float32
transformer_blocks.9.attn.to_k.lora_B.default torch.float32
transformer_blocks.9.attn.to_v torch.float8_e4m3fn
transformer_blocks.9.attn.to_v.base_layer torch.float8_e4m3fn
transformer_blocks.9.attn.to_v.lora_A.default torch.float32
transformer_blocks.9.attn.to_v.lora_B.default torch.float32
transformer_blocks.9.attn.add_k_proj torch.float8_e4m3fn
transformer_blocks.9.attn.add_k_proj.base_layer torch.float8_e4m3fn
transformer_blocks.9.attn.add_k_proj.lora_A.default torch.float32
transformer_blocks.9.attn.add_k_proj.lora_B.default torch.float32
transformer_blocks.9.attn.add_v_proj torch.float8_e4m3fn
transformer_blocks.9.attn.add_v_proj.base_layer torch.float8_e4m3fn
transformer_blocks.9.attn.add_v_proj.lora_A.default torch.float32
transformer_blocks.9.attn.add_v_proj.lora_B.default torch.float32
transformer_blocks.9.attn.add_q_proj torch.float8_e4m3fn
transformer_blocks.9.attn.add_q_proj.base_layer torch.float8_e4m3fn
transformer_blocks.9.attn.add_q_proj.lora_A.default torch.float32
transformer_blocks.9.attn.add_q_proj.lora_B.default torch.float32
transformer_blocks.9.attn.to_out.0 torch.float8_e4m3fn
transformer_blocks.9.attn.to_out.0.base_layer torch.float8_e4m3fn
transformer_blocks.9.attn.to_out.0.lora_A.default torch.float32
transformer_blocks.9.attn.to_out.0.lora_B.default torch.float32
transformer_blocks.9.attn.to_add_out torch.float8_e4m3fn
transformer_blocks.9.attn.to_add_out.base_layer torch.float8_e4m3fn
transformer_blocks.9.attn.to_add_out.lora_A.default torch.float32
transformer_blocks.9.attn.to_add_out.lora_B.default torch.float32
transformer_blocks.9.attn.norm_added_q torch.bfloat16
transformer_blocks.9.attn.norm_added_k torch.bfloat16
transformer_blocks.9.ff.net.0.proj torch.float8_e4m3fn
transformer_blocks.9.ff.net.2 torch.float8_e4m3fn
transformer_blocks.9.ff_context.net.0.proj torch.float8_e4m3fn
transformer_blocks.9.ff_context.net.2 torch.float8_e4m3fn
transformer_blocks.10.norm1.linear torch.bfloat16
transformer_blocks.10.norm1_context.linear torch.bfloat16
transformer_blocks.10.attn.norm_q torch.bfloat16
transformer_blocks.10.attn.norm_k torch.bfloat16
transformer_blocks.10.attn.to_q torch.float8_e4m3fn
transformer_blocks.10.attn.to_q.base_layer torch.float8_e4m3fn
transformer_blocks.10.attn.to_q.lora_A.default torch.float32
transformer_blocks.10.attn.to_q.lora_B.default torch.float32
transformer_blocks.10.attn.to_k torch.float8_e4m3fn
transformer_blocks.10.attn.to_k.base_layer torch.float8_e4m3fn
transformer_blocks.10.attn.to_k.lora_A.default torch.float32
transformer_blocks.10.attn.to_k.lora_B.default torch.float32
transformer_blocks.10.attn.to_v torch.float8_e4m3fn
transformer_blocks.10.attn.to_v.base_layer torch.float8_e4m3fn
transformer_blocks.10.attn.to_v.lora_A.default torch.float32
transformer_blocks.10.attn.to_v.lora_B.default torch.float32
transformer_blocks.10.attn.add_k_proj torch.float8_e4m3fn
transformer_blocks.10.attn.add_k_proj.base_layer torch.float8_e4m3fn
transformer_blocks.10.attn.add_k_proj.lora_A.default torch.float32
transformer_blocks.10.attn.add_k_proj.lora_B.default torch.float32
transformer_blocks.10.attn.add_v_proj torch.float8_e4m3fn
transformer_blocks.10.attn.add_v_proj.base_layer torch.float8_e4m3fn
transformer_blocks.10.attn.add_v_proj.lora_A.default torch.float32
transformer_blocks.10.attn.add_v_proj.lora_B.default torch.float32
transformer_blocks.10.attn.add_q_proj torch.float8_e4m3fn
transformer_blocks.10.attn.add_q_proj.base_layer torch.float8_e4m3fn
transformer_blocks.10.attn.add_q_proj.lora_A.default torch.float32
transformer_blocks.10.attn.add_q_proj.lora_B.default torch.float32
transformer_blocks.10.attn.to_out.0 torch.float8_e4m3fn
transformer_blocks.10.attn.to_out.0.base_layer torch.float8_e4m3fn
transformer_blocks.10.attn.to_out.0.lora_A.default torch.float32
transformer_blocks.10.attn.to_out.0.lora_B.default torch.float32
transformer_blocks.10.attn.to_add_out torch.float8_e4m3fn
transformer_blocks.10.attn.to_add_out.base_layer torch.float8_e4m3fn
transformer_blocks.10.attn.to_add_out.lora_A.default torch.float32
transformer_blocks.10.attn.to_add_out.lora_B.default torch.float32
transformer_blocks.10.attn.norm_added_q torch.bfloat16
transformer_blocks.10.attn.norm_added_k torch.bfloat16
transformer_blocks.10.ff.net.0.proj torch.float8_e4m3fn
transformer_blocks.10.ff.net.2 torch.float8_e4m3fn
transformer_blocks.10.ff_context.net.0.proj torch.float8_e4m3fn
transformer_blocks.10.ff_context.net.2 torch.float8_e4m3fn
transformer_blocks.11.norm1.linear torch.bfloat16
transformer_blocks.11.norm1_context.linear torch.bfloat16
transformer_blocks.11.attn.norm_q torch.bfloat16
transformer_blocks.11.attn.norm_k torch.bfloat16
transformer_blocks.11.attn.to_q torch.float8_e4m3fn
transformer_blocks.11.attn.to_q.base_layer torch.float8_e4m3fn
transformer_blocks.11.attn.to_q.lora_A.default torch.float32
transformer_blocks.11.attn.to_q.lora_B.default torch.float32
transformer_blocks.11.attn.to_k torch.float8_e4m3fn
transformer_blocks.11.attn.to_k.base_layer torch.float8_e4m3fn
transformer_blocks.11.attn.to_k.lora_A.default torch.float32
transformer_blocks.11.attn.to_k.lora_B.default torch.float32
transformer_blocks.11.attn.to_v torch.float8_e4m3fn
transformer_blocks.11.attn.to_v.base_layer torch.float8_e4m3fn
transformer_blocks.11.attn.to_v.lora_A.default torch.float32
transformer_blocks.11.attn.to_v.lora_B.default torch.float32
transformer_blocks.11.attn.add_k_proj torch.float8_e4m3fn
transformer_blocks.11.attn.add_k_proj.base_layer torch.float8_e4m3fn
transformer_blocks.11.attn.add_k_proj.lora_A.default torch.float32
transformer_blocks.11.attn.add_k_proj.lora_B.default torch.float32
transformer_blocks.11.attn.add_v_proj torch.float8_e4m3fn
transformer_blocks.11.attn.add_v_proj.base_layer torch.float8_e4m3fn
transformer_blocks.11.attn.add_v_proj.lora_A.default torch.float32
transformer_blocks.11.attn.add_v_proj.lora_B.default torch.float32
transformer_blocks.11.attn.add_q_proj torch.float8_e4m3fn
transformer_blocks.11.attn.add_q_proj.base_layer torch.float8_e4m3fn
transformer_blocks.11.attn.add_q_proj.lora_A.default torch.float32
transformer_blocks.11.attn.add_q_proj.lora_B.default torch.float32
transformer_blocks.11.attn.to_out.0 torch.float8_e4m3fn
transformer_blocks.11.attn.to_out.0.base_layer torch.float8_e4m3fn
transformer_blocks.11.attn.to_out.0.lora_A.default torch.float32
transformer_blocks.11.attn.to_out.0.lora_B.default torch.float32
transformer_blocks.11.attn.to_add_out torch.float8_e4m3fn
transformer_blocks.11.attn.to_add_out.base_layer torch.float8_e4m3fn
transformer_blocks.11.attn.to_add_out.lora_A.default torch.float32
transformer_blocks.11.attn.to_add_out.lora_B.default torch.float32
transformer_blocks.11.attn.norm_added_q torch.bfloat16
transformer_blocks.11.attn.norm_added_k torch.bfloat16
transformer_blocks.11.ff.net.0.proj torch.float8_e4m3fn
transformer_blocks.11.ff.net.2 torch.float8_e4m3fn
transformer_blocks.11.ff_context.net.0.proj torch.float8_e4m3fn
transformer_blocks.11.ff_context.net.2 torch.float8_e4m3fn
transformer_blocks.12.norm1.linear torch.bfloat16
transformer_blocks.12.norm1_context.linear torch.bfloat16
transformer_blocks.12.attn.norm_q torch.bfloat16
transformer_blocks.12.attn.norm_k torch.bfloat16
transformer_blocks.12.attn.to_q torch.float8_e4m3fn
transformer_blocks.12.attn.to_q.base_layer torch.float8_e4m3fn
transformer_blocks.12.attn.to_q.lora_A.default torch.float32
transformer_blocks.12.attn.to_q.lora_B.default torch.float32
transformer_blocks.12.attn.to_k torch.float8_e4m3fn
transformer_blocks.12.attn.to_k.base_layer torch.float8_e4m3fn
transformer_blocks.12.attn.to_k.lora_A.default torch.float32
transformer_blocks.12.attn.to_k.lora_B.default torch.float32
transformer_blocks.12.attn.to_v torch.float8_e4m3fn
transformer_blocks.12.attn.to_v.base_layer torch.float8_e4m3fn
transformer_blocks.12.attn.to_v.lora_A.default torch.float32
transformer_blocks.12.attn.to_v.lora_B.default torch.float32
transformer_blocks.12.attn.add_k_proj torch.float8_e4m3fn
transformer_blocks.12.attn.add_k_proj.base_layer torch.float8_e4m3fn
transformer_blocks.12.attn.add_k_proj.lora_A.default torch.float32
transformer_blocks.12.attn.add_k_proj.lora_B.default torch.float32
transformer_blocks.12.attn.add_v_proj torch.float8_e4m3fn
transformer_blocks.12.attn.add_v_proj.base_layer torch.float8_e4m3fn
transformer_blocks.12.attn.add_v_proj.lora_A.default torch.float32
transformer_blocks.12.attn.add_v_proj.lora_B.default torch.float32
transformer_blocks.12.attn.add_q_proj torch.float8_e4m3fn
transformer_blocks.12.attn.add_q_proj.base_layer torch.float8_e4m3fn
transformer_blocks.12.attn.add_q_proj.lora_A.default torch.float32
transformer_blocks.12.attn.add_q_proj.lora_B.default torch.float32
transformer_blocks.12.attn.to_out.0 torch.float8_e4m3fn
transformer_blocks.12.attn.to_out.0.base_layer torch.float8_e4m3fn
transformer_blocks.12.attn.to_out.0.lora_A.default torch.float32
transformer_blocks.12.attn.to_out.0.lora_B.default torch.float32
transformer_blocks.12.attn.to_add_out torch.float8_e4m3fn
transformer_blocks.12.attn.to_add_out.base_layer torch.float8_e4m3fn
transformer_blocks.12.attn.to_add_out.lora_A.default torch.float32
transformer_blocks.12.attn.to_add_out.lora_B.default torch.float32
transformer_blocks.12.attn.norm_added_q torch.bfloat16
transformer_blocks.12.attn.norm_added_k torch.bfloat16
transformer_blocks.12.ff.net.0.proj torch.float8_e4m3fn
transformer_blocks.12.ff.net.2 torch.float8_e4m3fn
transformer_blocks.12.ff_context.net.0.proj torch.float8_e4m3fn
transformer_blocks.12.ff_context.net.2 torch.float8_e4m3fn
transformer_blocks.13.norm1.linear torch.bfloat16
transformer_blocks.13.norm1_context.linear torch.bfloat16
transformer_blocks.13.attn.norm_q torch.bfloat16
transformer_blocks.13.attn.norm_k torch.bfloat16
transformer_blocks.13.attn.to_q torch.float8_e4m3fn
transformer_blocks.13.attn.to_q.base_layer torch.float8_e4m3fn
transformer_blocks.13.attn.to_q.lora_A.default torch.float32
transformer_blocks.13.attn.to_q.lora_B.default torch.float32
transformer_blocks.13.attn.to_k torch.float8_e4m3fn
transformer_blocks.13.attn.to_k.base_layer torch.float8_e4m3fn
transformer_blocks.13.attn.to_k.lora_A.default torch.float32
transformer_blocks.13.attn.to_k.lora_B.default torch.float32
transformer_blocks.13.attn.to_v torch.float8_e4m3fn
transformer_blocks.13.attn.to_v.base_layer torch.float8_e4m3fn
transformer_blocks.13.attn.to_v.lora_A.default torch.float32
transformer_blocks.13.attn.to_v.lora_B.default torch.float32
transformer_blocks.13.attn.add_k_proj torch.float8_e4m3fn
transformer_blocks.13.attn.add_k_proj.base_layer torch.float8_e4m3fn
transformer_blocks.13.attn.add_k_proj.lora_A.default torch.float32
transformer_blocks.13.attn.add_k_proj.lora_B.default torch.float32
transformer_blocks.13.attn.add_v_proj torch.float8_e4m3fn
transformer_blocks.13.attn.add_v_proj.base_layer torch.float8_e4m3fn
transformer_blocks.13.attn.add_v_proj.lora_A.default torch.float32
transformer_blocks.13.attn.add_v_proj.lora_B.default torch.float32
transformer_blocks.13.attn.add_q_proj torch.float8_e4m3fn
transformer_blocks.13.attn.add_q_proj.base_layer torch.float8_e4m3fn
transformer_blocks.13.attn.add_q_proj.lora_A.default torch.float32
transformer_blocks.13.attn.add_q_proj.lora_B.default torch.float32
transformer_blocks.13.attn.to_out.0 torch.float8_e4m3fn
transformer_blocks.13.attn.to_out.0.base_layer torch.float8_e4m3fn
transformer_blocks.13.attn.to_out.0.lora_A.default torch.float32
transformer_blocks.13.attn.to_out.0.lora_B.default torch.float32
transformer_blocks.13.attn.to_add_out torch.float8_e4m3fn
transformer_blocks.13.attn.to_add_out.base_layer torch.float8_e4m3fn
transformer_blocks.13.attn.to_add_out.lora_A.default torch.float32
transformer_blocks.13.attn.to_add_out.lora_B.default torch.float32
transformer_blocks.13.attn.norm_added_q torch.bfloat16
transformer_blocks.13.attn.norm_added_k torch.bfloat16
transformer_blocks.13.ff.net.0.proj torch.float8_e4m3fn
transformer_blocks.13.ff.net.2 torch.float8_e4m3fn
transformer_blocks.13.ff_context.net.0.proj torch.float8_e4m3fn
transformer_blocks.13.ff_context.net.2 torch.float8_e4m3fn
transformer_blocks.14.norm1.linear torch.bfloat16
transformer_blocks.14.norm1_context.linear torch.bfloat16
transformer_blocks.14.attn.norm_q torch.bfloat16
transformer_blocks.14.attn.norm_k torch.bfloat16
transformer_blocks.14.attn.to_q torch.float8_e4m3fn
transformer_blocks.14.attn.to_q.base_layer torch.float8_e4m3fn
transformer_blocks.14.attn.to_q.lora_A.default torch.float32
transformer_blocks.14.attn.to_q.lora_B.default torch.float32
transformer_blocks.14.attn.to_k torch.float8_e4m3fn
transformer_blocks.14.attn.to_k.base_layer torch.float8_e4m3fn
transformer_blocks.14.attn.to_k.lora_A.default torch.float32
transformer_blocks.14.attn.to_k.lora_B.default torch.float32
transformer_blocks.14.attn.to_v torch.float8_e4m3fn
transformer_blocks.14.attn.to_v.base_layer torch.float8_e4m3fn
transformer_blocks.14.attn.to_v.lora_A.default torch.float32
transformer_blocks.14.attn.to_v.lora_B.default torch.float32
transformer_blocks.14.attn.add_k_proj torch.float8_e4m3fn
transformer_blocks.14.attn.add_k_proj.base_layer torch.float8_e4m3fn
transformer_blocks.14.attn.add_k_proj.lora_A.default torch.float32
transformer_blocks.14.attn.add_k_proj.lora_B.default torch.float32
transformer_blocks.14.attn.add_v_proj torch.float8_e4m3fn
transformer_blocks.14.attn.add_v_proj.base_layer torch.float8_e4m3fn
transformer_blocks.14.attn.add_v_proj.lora_A.default torch.float32
transformer_blocks.14.attn.add_v_proj.lora_B.default torch.float32
transformer_blocks.14.attn.add_q_proj torch.float8_e4m3fn
transformer_blocks.14.attn.add_q_proj.base_layer torch.float8_e4m3fn
transformer_blocks.14.attn.add_q_proj.lora_A.default torch.float32
transformer_blocks.14.attn.add_q_proj.lora_B.default torch.float32
transformer_blocks.14.attn.to_out.0 torch.float8_e4m3fn
transformer_blocks.14.attn.to_out.0.base_layer torch.float8_e4m3fn
transformer_blocks.14.attn.to_out.0.lora_A.default torch.float32
transformer_blocks.14.attn.to_out.0.lora_B.default torch.float32
transformer_blocks.14.attn.to_add_out torch.float8_e4m3fn
transformer_blocks.14.attn.to_add_out.base_layer torch.float8_e4m3fn
transformer_blocks.14.attn.to_add_out.lora_A.default torch.float32
transformer_blocks.14.attn.to_add_out.lora_B.default torch.float32
transformer_blocks.14.attn.norm_added_q torch.bfloat16
transformer_blocks.14.attn.norm_added_k torch.bfloat16
transformer_blocks.14.ff.net.0.proj torch.float8_e4m3fn
transformer_blocks.14.ff.net.2 torch.float8_e4m3fn
transformer_blocks.14.ff_context.net.0.proj torch.float8_e4m3fn
transformer_blocks.14.ff_context.net.2 torch.float8_e4m3fn
transformer_blocks.15.norm1.linear torch.bfloat16
transformer_blocks.15.norm1_context.linear torch.bfloat16
transformer_blocks.15.attn.norm_q torch.bfloat16
transformer_blocks.15.attn.norm_k torch.bfloat16
transformer_blocks.15.attn.to_q torch.float8_e4m3fn
transformer_blocks.15.attn.to_q.base_layer torch.float8_e4m3fn
transformer_blocks.15.attn.to_q.lora_A.default torch.float32
transformer_blocks.15.attn.to_q.lora_B.default torch.float32
transformer_blocks.15.attn.to_k torch.float8_e4m3fn
transformer_blocks.15.attn.to_k.base_layer torch.float8_e4m3fn
transformer_blocks.15.attn.to_k.lora_A.default torch.float32
transformer_blocks.15.attn.to_k.lora_B.default torch.float32
transformer_blocks.15.attn.to_v torch.float8_e4m3fn
transformer_blocks.15.attn.to_v.base_layer torch.float8_e4m3fn
transformer_blocks.15.attn.to_v.lora_A.default torch.float32
transformer_blocks.15.attn.to_v.lora_B.default torch.float32
transformer_blocks.15.attn.add_k_proj torch.float8_e4m3fn
transformer_blocks.15.attn.add_k_proj.base_layer torch.float8_e4m3fn
transformer_blocks.15.attn.add_k_proj.lora_A.default torch.float32
transformer_blocks.15.attn.add_k_proj.lora_B.default torch.float32
transformer_blocks.15.attn.add_v_proj torch.float8_e4m3fn
transformer_blocks.15.attn.add_v_proj.base_layer torch.float8_e4m3fn
transformer_blocks.15.attn.add_v_proj.lora_A.default torch.float32
transformer_blocks.15.attn.add_v_proj.lora_B.default torch.float32
transformer_blocks.15.attn.add_q_proj torch.float8_e4m3fn
transformer_blocks.15.attn.add_q_proj.base_layer torch.float8_e4m3fn
transformer_blocks.15.attn.add_q_proj.lora_A.default torch.float32
transformer_blocks.15.attn.add_q_proj.lora_B.default torch.float32
transformer_blocks.15.attn.to_out.0 torch.float8_e4m3fn
transformer_blocks.15.attn.to_out.0.base_layer torch.float8_e4m3fn
transformer_blocks.15.attn.to_out.0.lora_A.default torch.float32
transformer_blocks.15.attn.to_out.0.lora_B.default torch.float32
transformer_blocks.15.attn.to_add_out torch.float8_e4m3fn
transformer_blocks.15.attn.to_add_out.base_layer torch.float8_e4m3fn
transformer_blocks.15.attn.to_add_out.lora_A.default torch.float32
transformer_blocks.15.attn.to_add_out.lora_B.default torch.float32
transformer_blocks.15.attn.norm_added_q torch.bfloat16
transformer_blocks.15.attn.norm_added_k torch.bfloat16
transformer_blocks.15.ff.net.0.proj torch.float8_e4m3fn
transformer_blocks.15.ff.net.2 torch.float8_e4m3fn
transformer_blocks.15.ff_context.net.0.proj torch.float8_e4m3fn
transformer_blocks.15.ff_context.net.2 torch.float8_e4m3fn
transformer_blocks.16.norm1.linear torch.bfloat16
transformer_blocks.16.norm1_context.linear torch.bfloat16
transformer_blocks.16.attn.norm_q torch.bfloat16
transformer_blocks.16.attn.norm_k torch.bfloat16
transformer_blocks.16.attn.to_q torch.float8_e4m3fn
transformer_blocks.16.attn.to_q.base_layer torch.float8_e4m3fn
transformer_blocks.16.attn.to_q.lora_A.default torch.float32
transformer_blocks.16.attn.to_q.lora_B.default torch.float32
transformer_blocks.16.attn.to_k torch.float8_e4m3fn
transformer_blocks.16.attn.to_k.base_layer torch.float8_e4m3fn
transformer_blocks.16.attn.to_k.lora_A.default torch.float32
transformer_blocks.16.attn.to_k.lora_B.default torch.float32
transformer_blocks.16.attn.to_v torch.float8_e4m3fn
transformer_blocks.16.attn.to_v.base_layer torch.float8_e4m3fn
transformer_blocks.16.attn.to_v.lora_A.default torch.float32
transformer_blocks.16.attn.to_v.lora_B.default torch.float32
transformer_blocks.16.attn.add_k_proj torch.float8_e4m3fn
transformer_blocks.16.attn.add_k_proj.base_layer torch.float8_e4m3fn
transformer_blocks.16.attn.add_k_proj.lora_A.default torch.float32
transformer_blocks.16.attn.add_k_proj.lora_B.default torch.float32
transformer_blocks.16.attn.add_v_proj torch.float8_e4m3fn
transformer_blocks.16.attn.add_v_proj.base_layer torch.float8_e4m3fn
transformer_blocks.16.attn.add_v_proj.lora_A.default torch.float32
transformer_blocks.16.attn.add_v_proj.lora_B.default torch.float32
transformer_blocks.16.attn.add_q_proj torch.float8_e4m3fn
transformer_blocks.16.attn.add_q_proj.base_layer torch.float8_e4m3fn
transformer_blocks.16.attn.add_q_proj.lora_A.default torch.float32
transformer_blocks.16.attn.add_q_proj.lora_B.default torch.float32
transformer_blocks.16.attn.to_out.0 torch.float8_e4m3fn
transformer_blocks.16.attn.to_out.0.base_layer torch.float8_e4m3fn
transformer_blocks.16.attn.to_out.0.lora_A.default torch.float32
transformer_blocks.16.attn.to_out.0.lora_B.default torch.float32
transformer_blocks.16.attn.to_add_out torch.float8_e4m3fn
transformer_blocks.16.attn.to_add_out.base_layer torch.float8_e4m3fn
transformer_blocks.16.attn.to_add_out.lora_A.default torch.float32
transformer_blocks.16.attn.to_add_out.lora_B.default torch.float32
transformer_blocks.16.attn.norm_added_q torch.bfloat16
transformer_blocks.16.attn.norm_added_k torch.bfloat16
transformer_blocks.16.ff.net.0.proj torch.float8_e4m3fn
transformer_blocks.16.ff.net.2 torch.float8_e4m3fn
transformer_blocks.16.ff_context.net.0.proj torch.float8_e4m3fn
transformer_blocks.16.ff_context.net.2 torch.float8_e4m3fn
transformer_blocks.17.norm1.linear torch.bfloat16
transformer_blocks.17.norm1_context.linear torch.bfloat16
transformer_blocks.17.attn.norm_q torch.bfloat16
transformer_blocks.17.attn.norm_k torch.bfloat16
transformer_blocks.17.attn.to_q torch.float8_e4m3fn
transformer_blocks.17.attn.to_q.base_layer torch.float8_e4m3fn
transformer_blocks.17.attn.to_q.lora_A.default torch.float32
transformer_blocks.17.attn.to_q.lora_B.default torch.float32
transformer_blocks.17.attn.to_k torch.float8_e4m3fn
transformer_blocks.17.attn.to_k.base_layer torch.float8_e4m3fn
transformer_blocks.17.attn.to_k.lora_A.default torch.float32
transformer_blocks.17.attn.to_k.lora_B.default torch.float32
transformer_blocks.17.attn.to_v torch.float8_e4m3fn
transformer_blocks.17.attn.to_v.base_layer torch.float8_e4m3fn
transformer_blocks.17.attn.to_v.lora_A.default torch.float32
transformer_blocks.17.attn.to_v.lora_B.default torch.float32
transformer_blocks.17.attn.add_k_proj torch.float8_e4m3fn
transformer_blocks.17.attn.add_k_proj.base_layer torch.float8_e4m3fn
transformer_blocks.17.attn.add_k_proj.lora_A.default torch.float32
transformer_blocks.17.attn.add_k_proj.lora_B.default torch.float32
transformer_blocks.17.attn.add_v_proj torch.float8_e4m3fn
transformer_blocks.17.attn.add_v_proj.base_layer torch.float8_e4m3fn
transformer_blocks.17.attn.add_v_proj.lora_A.default torch.float32
transformer_blocks.17.attn.add_v_proj.lora_B.default torch.float32
transformer_blocks.17.attn.add_q_proj torch.float8_e4m3fn
transformer_blocks.17.attn.add_q_proj.base_layer torch.float8_e4m3fn
transformer_blocks.17.attn.add_q_proj.lora_A.default torch.float32
transformer_blocks.17.attn.add_q_proj.lora_B.default torch.float32
transformer_blocks.17.attn.to_out.0 torch.float8_e4m3fn
transformer_blocks.17.attn.to_out.0.base_layer torch.float8_e4m3fn
transformer_blocks.17.attn.to_out.0.lora_A.default torch.float32
transformer_blocks.17.attn.to_out.0.lora_B.default torch.float32
transformer_blocks.17.attn.to_add_out torch.float8_e4m3fn
transformer_blocks.17.attn.to_add_out.base_layer torch.float8_e4m3fn
transformer_blocks.17.attn.to_add_out.lora_A.default torch.float32
transformer_blocks.17.attn.to_add_out.lora_B.default torch.float32
transformer_blocks.17.attn.norm_added_q torch.bfloat16
transformer_blocks.17.attn.norm_added_k torch.bfloat16
transformer_blocks.17.ff.net.0.proj torch.float8_e4m3fn
transformer_blocks.17.ff.net.2 torch.float8_e4m3fn
transformer_blocks.17.ff_context.net.0.proj torch.float8_e4m3fn
transformer_blocks.17.ff_context.net.2 torch.float8_e4m3fn
transformer_blocks.18.norm1.linear torch.bfloat16
transformer_blocks.18.norm1_context.linear torch.bfloat16
transformer_blocks.18.attn.norm_q torch.bfloat16
transformer_blocks.18.attn.norm_k torch.bfloat16
transformer_blocks.18.attn.to_q torch.float8_e4m3fn
transformer_blocks.18.attn.to_q.base_layer torch.float8_e4m3fn
transformer_blocks.18.attn.to_q.lora_A.default torch.float32
transformer_blocks.18.attn.to_q.lora_B.default torch.float32
transformer_blocks.18.attn.to_k torch.float8_e4m3fn
transformer_blocks.18.attn.to_k.base_layer torch.float8_e4m3fn
transformer_blocks.18.attn.to_k.lora_A.default torch.float32
transformer_blocks.18.attn.to_k.lora_B.default torch.float32
transformer_blocks.18.attn.to_v torch.float8_e4m3fn
transformer_blocks.18.attn.to_v.base_layer torch.float8_e4m3fn
transformer_blocks.18.attn.to_v.lora_A.default torch.float32
transformer_blocks.18.attn.to_v.lora_B.default torch.float32
transformer_blocks.18.attn.add_k_proj torch.float8_e4m3fn
transformer_blocks.18.attn.add_k_proj.base_layer torch.float8_e4m3fn
transformer_blocks.18.attn.add_k_proj.lora_A.default torch.float32
transformer_blocks.18.attn.add_k_proj.lora_B.default torch.float32
transformer_blocks.18.attn.add_v_proj torch.float8_e4m3fn
transformer_blocks.18.attn.add_v_proj.base_layer torch.float8_e4m3fn
transformer_blocks.18.attn.add_v_proj.lora_A.default torch.float32
transformer_blocks.18.attn.add_v_proj.lora_B.default torch.float32
transformer_blocks.18.attn.add_q_proj torch.float8_e4m3fn
transformer_blocks.18.attn.add_q_proj.base_layer torch.float8_e4m3fn
transformer_blocks.18.attn.add_q_proj.lora_A.default torch.float32
transformer_blocks.18.attn.add_q_proj.lora_B.default torch.float32
transformer_blocks.18.attn.to_out.0 torch.float8_e4m3fn
transformer_blocks.18.attn.to_out.0.base_layer torch.float8_e4m3fn
transformer_blocks.18.attn.to_out.0.lora_A.default torch.float32
transformer_blocks.18.attn.to_out.0.lora_B.default torch.float32
transformer_blocks.18.attn.to_add_out torch.float8_e4m3fn
transformer_blocks.18.attn.to_add_out.base_layer torch.float8_e4m3fn
transformer_blocks.18.attn.to_add_out.lora_A.default torch.float32
transformer_blocks.18.attn.to_add_out.lora_B.default torch.float32
transformer_blocks.18.attn.norm_added_q torch.bfloat16
transformer_blocks.18.attn.norm_added_k torch.bfloat16
transformer_blocks.18.ff.net.0.proj torch.float8_e4m3fn
transformer_blocks.18.ff.net.2 torch.float8_e4m3fn
transformer_blocks.18.ff_context.net.0.proj torch.float8_e4m3fn
transformer_blocks.18.ff_context.net.2 torch.float8_e4m3fn
transformer_blocks.19.norm1.linear torch.bfloat16
transformer_blocks.19.norm1_context.linear torch.bfloat16
transformer_blocks.19.attn.norm_q torch.bfloat16
transformer_blocks.19.attn.norm_k torch.bfloat16
transformer_blocks.19.attn.to_q torch.float8_e4m3fn
transformer_blocks.19.attn.to_q.base_layer torch.float8_e4m3fn
transformer_blocks.19.attn.to_q.lora_A.default torch.float32
transformer_blocks.19.attn.to_q.lora_B.default torch.float32
transformer_blocks.19.attn.to_k torch.float8_e4m3fn
transformer_blocks.19.attn.to_k.base_layer torch.float8_e4m3fn
transformer_blocks.19.attn.to_k.lora_A.default torch.float32
transformer_blocks.19.attn.to_k.lora_B.default torch.float32
transformer_blocks.19.attn.to_v torch.float8_e4m3fn
transformer_blocks.19.attn.to_v.base_layer torch.float8_e4m3fn
transformer_blocks.19.attn.to_v.lora_A.default torch.float32
transformer_blocks.19.attn.to_v.lora_B.default torch.float32
transformer_blocks.19.attn.add_k_proj torch.float8_e4m3fn
transformer_blocks.19.attn.add_k_proj.base_layer torch.float8_e4m3fn
transformer_blocks.19.attn.add_k_proj.lora_A.default torch.float32
transformer_blocks.19.attn.add_k_proj.lora_B.default torch.float32
transformer_blocks.19.attn.add_v_proj torch.float8_e4m3fn
transformer_blocks.19.attn.add_v_proj.base_layer torch.float8_e4m3fn
transformer_blocks.19.attn.add_v_proj.lora_A.default torch.float32
transformer_blocks.19.attn.add_v_proj.lora_B.default torch.float32
transformer_blocks.19.attn.add_q_proj torch.float8_e4m3fn
transformer_blocks.19.attn.add_q_proj.base_layer torch.float8_e4m3fn
transformer_blocks.19.attn.add_q_proj.lora_A.default torch.float32
transformer_blocks.19.attn.add_q_proj.lora_B.default torch.float32
transformer_blocks.19.attn.to_out.0 torch.float8_e4m3fn
transformer_blocks.19.attn.to_out.0.base_layer torch.float8_e4m3fn
transformer_blocks.19.attn.to_out.0.lora_A.default torch.float32
transformer_blocks.19.attn.to_out.0.lora_B.default torch.float32
transformer_blocks.19.attn.to_add_out torch.float8_e4m3fn
transformer_blocks.19.attn.to_add_out.base_layer torch.float8_e4m3fn
transformer_blocks.19.attn.to_add_out.lora_A.default torch.float32
transformer_blocks.19.attn.to_add_out.lora_B.default torch.float32
transformer_blocks.19.attn.norm_added_q torch.bfloat16
transformer_blocks.19.attn.norm_added_k torch.bfloat16
transformer_blocks.19.ff.net.0.proj torch.float8_e4m3fn
transformer_blocks.19.ff.net.2 torch.float8_e4m3fn
transformer_blocks.19.ff_context.net.0.proj torch.float8_e4m3fn
transformer_blocks.19.ff_context.net.2 torch.float8_e4m3fn
single_transformer_blocks.0.attn.norm_q torch.bfloat16
single_transformer_blocks.0.attn.norm_k torch.bfloat16
single_transformer_blocks.0.attn.to_q torch.float8_e4m3fn
single_transformer_blocks.0.attn.to_q.base_layer torch.float8_e4m3fn
single_transformer_blocks.0.attn.to_q.lora_A.default torch.float32
single_transformer_blocks.0.attn.to_q.lora_B.default torch.float32
single_transformer_blocks.0.attn.to_k torch.float8_e4m3fn
single_transformer_blocks.0.attn.to_k.base_layer torch.float8_e4m3fn
single_transformer_blocks.0.attn.to_k.lora_A.default torch.float32
single_transformer_blocks.0.attn.to_k.lora_B.default torch.float32
single_transformer_blocks.0.attn.to_v torch.float8_e4m3fn
single_transformer_blocks.0.attn.to_v.base_layer torch.float8_e4m3fn
single_transformer_blocks.0.attn.to_v.lora_A.default torch.float32
single_transformer_blocks.0.attn.to_v.lora_B.default torch.float32
single_transformer_blocks.0.norm.linear torch.bfloat16
single_transformer_blocks.0.proj_mlp torch.float8_e4m3fn
single_transformer_blocks.0.proj_out torch.float8_e4m3fn
single_transformer_blocks.1.attn.norm_q torch.bfloat16
single_transformer_blocks.1.attn.norm_k torch.bfloat16
single_transformer_blocks.1.attn.to_q torch.float8_e4m3fn
single_transformer_blocks.1.attn.to_q.base_layer torch.float8_e4m3fn
single_transformer_blocks.1.attn.to_q.lora_A.default torch.float32
single_transformer_blocks.1.attn.to_q.lora_B.default torch.float32
single_transformer_blocks.1.attn.to_k torch.float8_e4m3fn
single_transformer_blocks.1.attn.to_k.base_layer torch.float8_e4m3fn
single_transformer_blocks.1.attn.to_k.lora_A.default torch.float32
single_transformer_blocks.1.attn.to_k.lora_B.default torch.float32
single_transformer_blocks.1.attn.to_v torch.float8_e4m3fn
single_transformer_blocks.1.attn.to_v.base_layer torch.float8_e4m3fn
single_transformer_blocks.1.attn.to_v.lora_A.default torch.float32
single_transformer_blocks.1.attn.to_v.lora_B.default torch.float32
single_transformer_blocks.1.norm.linear torch.bfloat16
single_transformer_blocks.1.proj_mlp torch.float8_e4m3fn
single_transformer_blocks.1.proj_out torch.float8_e4m3fn
single_transformer_blocks.2.attn.norm_q torch.bfloat16
single_transformer_blocks.2.attn.norm_k torch.bfloat16
single_transformer_blocks.2.attn.to_q torch.float8_e4m3fn
single_transformer_blocks.2.attn.to_q.base_layer torch.float8_e4m3fn
single_transformer_blocks.2.attn.to_q.lora_A.default torch.float32
single_transformer_blocks.2.attn.to_q.lora_B.default torch.float32
single_transformer_blocks.2.attn.to_k torch.float8_e4m3fn
single_transformer_blocks.2.attn.to_k.base_layer torch.float8_e4m3fn
single_transformer_blocks.2.attn.to_k.lora_A.default torch.float32
single_transformer_blocks.2.attn.to_k.lora_B.default torch.float32
single_transformer_blocks.2.attn.to_v torch.float8_e4m3fn
single_transformer_blocks.2.attn.to_v.base_layer torch.float8_e4m3fn
single_transformer_blocks.2.attn.to_v.lora_A.default torch.float32
single_transformer_blocks.2.attn.to_v.lora_B.default torch.float32
single_transformer_blocks.2.norm.linear torch.bfloat16
single_transformer_blocks.2.proj_mlp torch.float8_e4m3fn
single_transformer_blocks.2.proj_out torch.float8_e4m3fn
single_transformer_blocks.3.attn.norm_q torch.bfloat16
single_transformer_blocks.3.attn.norm_k torch.bfloat16
single_transformer_blocks.3.attn.to_q torch.float8_e4m3fn
single_transformer_blocks.3.attn.to_q.base_layer torch.float8_e4m3fn
single_transformer_blocks.3.attn.to_q.lora_A.default torch.float32
single_transformer_blocks.3.attn.to_q.lora_B.default torch.float32
single_transformer_blocks.3.attn.to_k torch.float8_e4m3fn
single_transformer_blocks.3.attn.to_k.base_layer torch.float8_e4m3fn
single_transformer_blocks.3.attn.to_k.lora_A.default torch.float32
single_transformer_blocks.3.attn.to_k.lora_B.default torch.float32
single_transformer_blocks.3.attn.to_v torch.float8_e4m3fn
single_transformer_blocks.3.attn.to_v.base_layer torch.float8_e4m3fn
single_transformer_blocks.3.attn.to_v.lora_A.default torch.float32
single_transformer_blocks.3.attn.to_v.lora_B.default torch.float32
single_transformer_blocks.3.norm.linear torch.bfloat16
single_transformer_blocks.3.proj_mlp torch.float8_e4m3fn
single_transformer_blocks.3.proj_out torch.float8_e4m3fn
single_transformer_blocks.4.attn.norm_q torch.bfloat16
single_transformer_blocks.4.attn.norm_k torch.bfloat16
single_transformer_blocks.4.attn.to_q torch.float8_e4m3fn
single_transformer_blocks.4.attn.to_q.base_layer torch.float8_e4m3fn
single_transformer_blocks.4.attn.to_q.lora_A.default torch.float32
single_transformer_blocks.4.attn.to_q.lora_B.default torch.float32
single_transformer_blocks.4.attn.to_k torch.float8_e4m3fn
single_transformer_blocks.4.attn.to_k.base_layer torch.float8_e4m3fn
single_transformer_blocks.4.attn.to_k.lora_A.default torch.float32
single_transformer_blocks.4.attn.to_k.lora_B.default torch.float32
single_transformer_blocks.4.attn.to_v torch.float8_e4m3fn
single_transformer_blocks.4.attn.to_v.base_layer torch.float8_e4m3fn
single_transformer_blocks.4.attn.to_v.lora_A.default torch.float32
single_transformer_blocks.4.attn.to_v.lora_B.default torch.float32
single_transformer_blocks.4.norm.linear torch.bfloat16
single_transformer_blocks.4.proj_mlp torch.float8_e4m3fn
single_transformer_blocks.4.proj_out torch.float8_e4m3fn
single_transformer_blocks.5.attn.norm_q torch.bfloat16
single_transformer_blocks.5.attn.norm_k torch.bfloat16
single_transformer_blocks.5.attn.to_q torch.float8_e4m3fn
single_transformer_blocks.5.attn.to_q.base_layer torch.float8_e4m3fn
single_transformer_blocks.5.attn.to_q.lora_A.default torch.float32
single_transformer_blocks.5.attn.to_q.lora_B.default torch.float32
single_transformer_blocks.5.attn.to_k torch.float8_e4m3fn
single_transformer_blocks.5.attn.to_k.base_layer torch.float8_e4m3fn
single_transformer_blocks.5.attn.to_k.lora_A.default torch.float32
single_transformer_blocks.5.attn.to_k.lora_B.default torch.float32
single_transformer_blocks.5.attn.to_v torch.float8_e4m3fn
single_transformer_blocks.5.attn.to_v.base_layer torch.float8_e4m3fn
single_transformer_blocks.5.attn.to_v.lora_A.default torch.float32
single_transformer_blocks.5.attn.to_v.lora_B.default torch.float32
single_transformer_blocks.5.norm.linear torch.bfloat16
single_transformer_blocks.5.proj_mlp torch.float8_e4m3fn
single_transformer_blocks.5.proj_out torch.float8_e4m3fn
single_transformer_blocks.6.attn.norm_q torch.bfloat16
single_transformer_blocks.6.attn.norm_k torch.bfloat16
single_transformer_blocks.6.attn.to_q torch.float8_e4m3fn
single_transformer_blocks.6.attn.to_q.base_layer torch.float8_e4m3fn
single_transformer_blocks.6.attn.to_q.lora_A.default torch.float32
single_transformer_blocks.6.attn.to_q.lora_B.default torch.float32
single_transformer_blocks.6.attn.to_k torch.float8_e4m3fn
single_transformer_blocks.6.attn.to_k.base_layer torch.float8_e4m3fn
single_transformer_blocks.6.attn.to_k.lora_A.default torch.float32
single_transformer_blocks.6.attn.to_k.lora_B.default torch.float32
single_transformer_blocks.6.attn.to_v torch.float8_e4m3fn
single_transformer_blocks.6.attn.to_v.base_layer torch.float8_e4m3fn
single_transformer_blocks.6.attn.to_v.lora_A.default torch.float32
single_transformer_blocks.6.attn.to_v.lora_B.default torch.float32
single_transformer_blocks.6.norm.linear torch.bfloat16
single_transformer_blocks.6.proj_mlp torch.float8_e4m3fn
single_transformer_blocks.6.proj_out torch.float8_e4m3fn
single_transformer_blocks.7.attn.norm_q torch.bfloat16
single_transformer_blocks.7.attn.norm_k torch.bfloat16
single_transformer_blocks.7.attn.to_q torch.float8_e4m3fn
single_transformer_blocks.7.attn.to_q.base_layer torch.float8_e4m3fn
single_transformer_blocks.7.attn.to_q.lora_A.default torch.float32
single_transformer_blocks.7.attn.to_q.lora_B.default torch.float32
single_transformer_blocks.7.attn.to_k torch.float8_e4m3fn
single_transformer_blocks.7.attn.to_k.base_layer torch.float8_e4m3fn
single_transformer_blocks.7.attn.to_k.lora_A.default torch.float32
single_transformer_blocks.7.attn.to_k.lora_B.default torch.float32
single_transformer_blocks.7.attn.to_v torch.float8_e4m3fn
single_transformer_blocks.7.attn.to_v.base_layer torch.float8_e4m3fn
single_transformer_blocks.7.attn.to_v.lora_A.default torch.float32
single_transformer_blocks.7.attn.to_v.lora_B.default torch.float32
single_transformer_blocks.7.norm.linear torch.bfloat16
single_transformer_blocks.7.proj_mlp torch.float8_e4m3fn
single_transformer_blocks.7.proj_out torch.float8_e4m3fn
single_transformer_blocks.8.attn.norm_q torch.bfloat16
single_transformer_blocks.8.attn.norm_k torch.bfloat16
single_transformer_blocks.8.attn.to_q torch.float8_e4m3fn
single_transformer_blocks.8.attn.to_q.base_layer torch.float8_e4m3fn
single_transformer_blocks.8.attn.to_q.lora_A.default torch.float32
single_transformer_blocks.8.attn.to_q.lora_B.default torch.float32
single_transformer_blocks.8.attn.to_k torch.float8_e4m3fn
single_transformer_blocks.8.attn.to_k.base_layer torch.float8_e4m3fn
single_transformer_blocks.8.attn.to_k.lora_A.default torch.float32
single_transformer_blocks.8.attn.to_k.lora_B.default torch.float32
single_transformer_blocks.8.attn.to_v torch.float8_e4m3fn
single_transformer_blocks.8.attn.to_v.base_layer torch.float8_e4m3fn
single_transformer_blocks.8.attn.to_v.lora_A.default torch.float32
single_transformer_blocks.8.attn.to_v.lora_B.default torch.float32
single_transformer_blocks.8.norm.linear torch.bfloat16
single_transformer_blocks.8.proj_mlp torch.float8_e4m3fn
single_transformer_blocks.8.proj_out torch.float8_e4m3fn
single_transformer_blocks.9.attn.norm_q torch.bfloat16
single_transformer_blocks.9.attn.norm_k torch.bfloat16
single_transformer_blocks.9.attn.to_q torch.float8_e4m3fn
single_transformer_blocks.9.attn.to_q.base_layer torch.float8_e4m3fn
single_transformer_blocks.9.attn.to_q.lora_A.default torch.float32
single_transformer_blocks.9.attn.to_q.lora_B.default torch.float32
single_transformer_blocks.9.attn.to_k torch.float8_e4m3fn
single_transformer_blocks.9.attn.to_k.base_layer torch.float8_e4m3fn
single_transformer_blocks.9.attn.to_k.lora_A.default torch.float32
single_transformer_blocks.9.attn.to_k.lora_B.default torch.float32
single_transformer_blocks.9.attn.to_v torch.float8_e4m3fn
single_transformer_blocks.9.attn.to_v.base_layer torch.float8_e4m3fn
single_transformer_blocks.9.attn.to_v.lora_A.default torch.float32
single_transformer_blocks.9.attn.to_v.lora_B.default torch.float32
single_transformer_blocks.9.norm.linear torch.bfloat16
single_transformer_blocks.9.proj_mlp torch.float8_e4m3fn
single_transformer_blocks.9.proj_out torch.float8_e4m3fn
single_transformer_blocks.10.attn.norm_q torch.bfloat16
single_transformer_blocks.10.attn.norm_k torch.bfloat16
single_transformer_blocks.10.attn.to_q torch.float8_e4m3fn
single_transformer_blocks.10.attn.to_q.base_layer torch.float8_e4m3fn
single_transformer_blocks.10.attn.to_q.lora_A.default torch.float32
single_transformer_blocks.10.attn.to_q.lora_B.default torch.float32
single_transformer_blocks.10.attn.to_k torch.float8_e4m3fn
single_transformer_blocks.10.attn.to_k.base_layer torch.float8_e4m3fn
single_transformer_blocks.10.attn.to_k.lora_A.default torch.float32
single_transformer_blocks.10.attn.to_k.lora_B.default torch.float32
single_transformer_blocks.10.attn.to_v torch.float8_e4m3fn
single_transformer_blocks.10.attn.to_v.base_layer torch.float8_e4m3fn
single_transformer_blocks.10.attn.to_v.lora_A.default torch.float32
single_transformer_blocks.10.attn.to_v.lora_B.default torch.float32
single_transformer_blocks.10.norm.linear torch.bfloat16
single_transformer_blocks.10.proj_mlp torch.float8_e4m3fn
single_transformer_blocks.10.proj_out torch.float8_e4m3fn
single_transformer_blocks.11.attn.norm_q torch.bfloat16
single_transformer_blocks.11.attn.norm_k torch.bfloat16
single_transformer_blocks.11.attn.to_q torch.float8_e4m3fn
single_transformer_blocks.11.attn.to_q.base_layer torch.float8_e4m3fn
single_transformer_blocks.11.attn.to_q.lora_A.default torch.float32
single_transformer_blocks.11.attn.to_q.lora_B.default torch.float32
single_transformer_blocks.11.attn.to_k torch.float8_e4m3fn
single_transformer_blocks.11.attn.to_k.base_layer torch.float8_e4m3fn
single_transformer_blocks.11.attn.to_k.lora_A.default torch.float32
single_transformer_blocks.11.attn.to_k.lora_B.default torch.float32
single_transformer_blocks.11.attn.to_v torch.float8_e4m3fn
single_transformer_blocks.11.attn.to_v.base_layer torch.float8_e4m3fn
single_transformer_blocks.11.attn.to_v.lora_A.default torch.float32
single_transformer_blocks.11.attn.to_v.lora_B.default torch.float32
single_transformer_blocks.11.norm.linear torch.bfloat16
single_transformer_blocks.11.proj_mlp torch.float8_e4m3fn
single_transformer_blocks.11.proj_out torch.float8_e4m3fn
single_transformer_blocks.12.attn.norm_q torch.bfloat16
single_transformer_blocks.12.attn.norm_k torch.bfloat16
single_transformer_blocks.12.attn.to_q torch.float8_e4m3fn
single_transformer_blocks.12.attn.to_q.base_layer torch.float8_e4m3fn
single_transformer_blocks.12.attn.to_q.lora_A.default torch.float32
single_transformer_blocks.12.attn.to_q.lora_B.default torch.float32
single_transformer_blocks.12.attn.to_k torch.float8_e4m3fn
single_transformer_blocks.12.attn.to_k.base_layer torch.float8_e4m3fn
single_transformer_blocks.12.attn.to_k.lora_A.default torch.float32
single_transformer_blocks.12.attn.to_k.lora_B.default torch.float32
single_transformer_blocks.12.attn.to_v torch.float8_e4m3fn
single_transformer_blocks.12.attn.to_v.base_layer torch.float8_e4m3fn
single_transformer_blocks.12.attn.to_v.lora_A.default torch.float32
single_transformer_blocks.12.attn.to_v.lora_B.default torch.float32
single_transformer_blocks.12.norm.linear torch.bfloat16
single_transformer_blocks.12.proj_mlp torch.float8_e4m3fn
single_transformer_blocks.12.proj_out torch.float8_e4m3fn
single_transformer_blocks.13.attn.norm_q torch.bfloat16
single_transformer_blocks.13.attn.norm_k torch.bfloat16
single_transformer_blocks.13.attn.to_q torch.float8_e4m3fn
single_transformer_blocks.13.attn.to_q.base_layer torch.float8_e4m3fn
single_transformer_blocks.13.attn.to_q.lora_A.default torch.float32
single_transformer_blocks.13.attn.to_q.lora_B.default torch.float32
single_transformer_blocks.13.attn.to_k torch.float8_e4m3fn
single_transformer_blocks.13.attn.to_k.base_layer torch.float8_e4m3fn
single_transformer_blocks.13.attn.to_k.lora_A.default torch.float32
single_transformer_blocks.13.attn.to_k.lora_B.default torch.float32
single_transformer_blocks.13.attn.to_v torch.float8_e4m3fn
single_transformer_blocks.13.attn.to_v.base_layer torch.float8_e4m3fn
single_transformer_blocks.13.attn.to_v.lora_A.default torch.float32
single_transformer_blocks.13.attn.to_v.lora_B.default torch.float32
single_transformer_blocks.13.norm.linear torch.bfloat16
single_transformer_blocks.13.proj_mlp torch.float8_e4m3fn
single_transformer_blocks.13.proj_out torch.float8_e4m3fn
single_transformer_blocks.14.attn.norm_q torch.bfloat16
single_transformer_blocks.14.attn.norm_k torch.bfloat16
single_transformer_blocks.14.attn.to_q torch.float8_e4m3fn
single_transformer_blocks.14.attn.to_q.base_layer torch.float8_e4m3fn
single_transformer_blocks.14.attn.to_q.lora_A.default torch.float32
single_transformer_blocks.14.attn.to_q.lora_B.default torch.float32
single_transformer_blocks.14.attn.to_k torch.float8_e4m3fn
single_transformer_blocks.14.attn.to_k.base_layer torch.float8_e4m3fn
single_transformer_blocks.14.attn.to_k.lora_A.default torch.float32
single_transformer_blocks.14.attn.to_k.lora_B.default torch.float32
single_transformer_blocks.14.attn.to_v torch.float8_e4m3fn
single_transformer_blocks.14.attn.to_v.base_layer torch.float8_e4m3fn
single_transformer_blocks.14.attn.to_v.lora_A.default torch.float32
single_transformer_blocks.14.attn.to_v.lora_B.default torch.float32
single_transformer_blocks.14.norm.linear torch.bfloat16
single_transformer_blocks.14.proj_mlp torch.float8_e4m3fn
single_transformer_blocks.14.proj_out torch.float8_e4m3fn
single_transformer_blocks.15.attn.norm_q torch.bfloat16
single_transformer_blocks.15.attn.norm_k torch.bfloat16
single_transformer_blocks.15.attn.to_q torch.float8_e4m3fn
single_transformer_blocks.15.attn.to_q.base_layer torch.float8_e4m3fn
single_transformer_blocks.15.attn.to_q.lora_A.default torch.float32
single_transformer_blocks.15.attn.to_q.lora_B.default torch.float32
single_transformer_blocks.15.attn.to_k torch.float8_e4m3fn
single_transformer_blocks.15.attn.to_k.base_layer torch.float8_e4m3fn
single_transformer_blocks.15.attn.to_k.lora_A.default torch.float32
single_transformer_blocks.15.attn.to_k.lora_B.default torch.float32
single_transformer_blocks.15.attn.to_v torch.float8_e4m3fn
single_transformer_blocks.15.attn.to_v.base_layer torch.float8_e4m3fn
single_transformer_blocks.15.attn.to_v.lora_A.default torch.float32
single_transformer_blocks.15.attn.to_v.lora_B.default torch.float32
single_transformer_blocks.15.norm.linear torch.bfloat16
single_transformer_blocks.15.proj_mlp torch.float8_e4m3fn
single_transformer_blocks.15.proj_out torch.float8_e4m3fn
single_transformer_blocks.16.attn.norm_q torch.bfloat16
single_transformer_blocks.16.attn.norm_k torch.bfloat16
single_transformer_blocks.16.attn.to_q torch.float8_e4m3fn
single_transformer_blocks.16.attn.to_q.base_layer torch.float8_e4m3fn
single_transformer_blocks.16.attn.to_q.lora_A.default torch.float32
single_transformer_blocks.16.attn.to_q.lora_B.default torch.float32
single_transformer_blocks.16.attn.to_k torch.float8_e4m3fn
single_transformer_blocks.16.attn.to_k.base_layer torch.float8_e4m3fn
single_transformer_blocks.16.attn.to_k.lora_A.default torch.float32
single_transformer_blocks.16.attn.to_k.lora_B.default torch.float32
single_transformer_blocks.16.attn.to_v torch.float8_e4m3fn
single_transformer_blocks.16.attn.to_v.base_layer torch.float8_e4m3fn
single_transformer_blocks.16.attn.to_v.lora_A.default torch.float32
single_transformer_blocks.16.attn.to_v.lora_B.default torch.float32
single_transformer_blocks.16.norm.linear torch.bfloat16
single_transformer_blocks.16.proj_mlp torch.float8_e4m3fn
single_transformer_blocks.16.proj_out torch.float8_e4m3fn
single_transformer_blocks.17.attn.norm_q torch.bfloat16
single_transformer_blocks.17.attn.norm_k torch.bfloat16
single_transformer_blocks.17.attn.to_q torch.float8_e4m3fn
single_transformer_blocks.17.attn.to_q.base_layer torch.float8_e4m3fn
single_transformer_blocks.17.attn.to_q.lora_A.default torch.float32
single_transformer_blocks.17.attn.to_q.lora_B.default torch.float32
single_transformer_blocks.17.attn.to_k torch.float8_e4m3fn
single_transformer_blocks.17.attn.to_k.base_layer torch.float8_e4m3fn
single_transformer_blocks.17.attn.to_k.lora_A.default torch.float32
single_transformer_blocks.17.attn.to_k.lora_B.default torch.float32
single_transformer_blocks.17.attn.to_v torch.float8_e4m3fn
single_transformer_blocks.17.attn.to_v.base_layer torch.float8_e4m3fn
single_transformer_blocks.17.attn.to_v.lora_A.default torch.float32
single_transformer_blocks.17.attn.to_v.lora_B.default torch.float32
single_transformer_blocks.17.norm.linear torch.bfloat16
single_transformer_blocks.17.proj_mlp torch.float8_e4m3fn
single_transformer_blocks.17.proj_out torch.float8_e4m3fn
single_transformer_blocks.18.attn.norm_q torch.bfloat16
single_transformer_blocks.18.attn.norm_k torch.bfloat16
single_transformer_blocks.18.attn.to_q torch.float8_e4m3fn
single_transformer_blocks.18.attn.to_q.base_layer torch.float8_e4m3fn
single_transformer_blocks.18.attn.to_q.lora_A.default torch.float32
single_transformer_blocks.18.attn.to_q.lora_B.default torch.float32
single_transformer_blocks.18.attn.to_k torch.float8_e4m3fn
single_transformer_blocks.18.attn.to_k.base_layer torch.float8_e4m3fn
single_transformer_blocks.18.attn.to_k.lora_A.default torch.float32
single_transformer_blocks.18.attn.to_k.lora_B.default torch.float32
single_transformer_blocks.18.attn.to_v torch.float8_e4m3fn
single_transformer_blocks.18.attn.to_v.base_layer torch.float8_e4m3fn
single_transformer_blocks.18.attn.to_v.lora_A.default torch.float32
single_transformer_blocks.18.attn.to_v.lora_B.default torch.float32
single_transformer_blocks.18.norm.linear torch.bfloat16
single_transformer_blocks.18.proj_mlp torch.float8_e4m3fn
single_transformer_blocks.18.proj_out torch.float8_e4m3fn
single_transformer_blocks.19.attn.norm_q torch.bfloat16
single_transformer_blocks.19.attn.norm_k torch.bfloat16
single_transformer_blocks.19.attn.to_q torch.float8_e4m3fn
single_transformer_blocks.19.attn.to_q.base_layer torch.float8_e4m3fn
single_transformer_blocks.19.attn.to_q.lora_A.default torch.float32
single_transformer_blocks.19.attn.to_q.lora_B.default torch.float32
single_transformer_blocks.19.attn.to_k torch.float8_e4m3fn
single_transformer_blocks.19.attn.to_k.base_layer torch.float8_e4m3fn
single_transformer_blocks.19.attn.to_k.lora_A.default torch.float32
single_transformer_blocks.19.attn.to_k.lora_B.default torch.float32
single_transformer_blocks.19.attn.to_v torch.float8_e4m3fn
single_transformer_blocks.19.attn.to_v.base_layer torch.float8_e4m3fn
single_transformer_blocks.19.attn.to_v.lora_A.default torch.float32
single_transformer_blocks.19.attn.to_v.lora_B.default torch.float32
single_transformer_blocks.19.norm.linear torch.bfloat16
single_transformer_blocks.19.proj_mlp torch.float8_e4m3fn
single_transformer_blocks.19.proj_out torch.float8_e4m3fn
single_transformer_blocks.20.attn.norm_q torch.bfloat16
single_transformer_blocks.20.attn.norm_k torch.bfloat16
single_transformer_blocks.20.attn.to_q torch.float8_e4m3fn
single_transformer_blocks.20.attn.to_q.base_layer torch.float8_e4m3fn
single_transformer_blocks.20.attn.to_q.lora_A.default torch.float32
single_transformer_blocks.20.attn.to_q.lora_B.default torch.float32
single_transformer_blocks.20.attn.to_k torch.float8_e4m3fn
single_transformer_blocks.20.attn.to_k.base_layer torch.float8_e4m3fn
single_transformer_blocks.20.attn.to_k.lora_A.default torch.float32
single_transformer_blocks.20.attn.to_k.lora_B.default torch.float32
single_transformer_blocks.20.attn.to_v torch.float8_e4m3fn
single_transformer_blocks.20.attn.to_v.base_layer torch.float8_e4m3fn
single_transformer_blocks.20.attn.to_v.lora_A.default torch.float32
single_transformer_blocks.20.attn.to_v.lora_B.default torch.float32
single_transformer_blocks.20.norm.linear torch.bfloat16
single_transformer_blocks.20.proj_mlp torch.float8_e4m3fn
single_transformer_blocks.20.proj_out torch.float8_e4m3fn
single_transformer_blocks.21.attn.norm_q torch.bfloat16
single_transformer_blocks.21.attn.norm_k torch.bfloat16
single_transformer_blocks.21.attn.to_q torch.float8_e4m3fn
single_transformer_blocks.21.attn.to_q.base_layer torch.float8_e4m3fn
single_transformer_blocks.21.attn.to_q.lora_A.default torch.float32
single_transformer_blocks.21.attn.to_q.lora_B.default torch.float32
single_transformer_blocks.21.attn.to_k torch.float8_e4m3fn
single_transformer_blocks.21.attn.to_k.base_layer torch.float8_e4m3fn
single_transformer_blocks.21.attn.to_k.lora_A.default torch.float32
single_transformer_blocks.21.attn.to_k.lora_B.default torch.float32
single_transformer_blocks.21.attn.to_v torch.float8_e4m3fn
single_transformer_blocks.21.attn.to_v.base_layer torch.float8_e4m3fn
single_transformer_blocks.21.attn.to_v.lora_A.default torch.float32
single_transformer_blocks.21.attn.to_v.lora_B.default torch.float32
single_transformer_blocks.21.norm.linear torch.bfloat16
single_transformer_blocks.21.proj_mlp torch.float8_e4m3fn
single_transformer_blocks.21.proj_out torch.float8_e4m3fn
single_transformer_blocks.22.attn.norm_q torch.bfloat16
single_transformer_blocks.22.attn.norm_k torch.bfloat16
single_transformer_blocks.22.attn.to_q torch.float8_e4m3fn
single_transformer_blocks.22.attn.to_q.base_layer torch.float8_e4m3fn
single_transformer_blocks.22.attn.to_q.lora_A.default torch.float32
single_transformer_blocks.22.attn.to_q.lora_B.default torch.float32
single_transformer_blocks.22.attn.to_k torch.float8_e4m3fn
single_transformer_blocks.22.attn.to_k.base_layer torch.float8_e4m3fn
single_transformer_blocks.22.attn.to_k.lora_A.default torch.float32
single_transformer_blocks.22.attn.to_k.lora_B.default torch.float32
single_transformer_blocks.22.attn.to_v torch.float8_e4m3fn
single_transformer_blocks.22.attn.to_v.base_layer torch.float8_e4m3fn
single_transformer_blocks.22.attn.to_v.lora_A.default torch.float32
single_transformer_blocks.22.attn.to_v.lora_B.default torch.float32
single_transformer_blocks.22.norm.linear torch.bfloat16
single_transformer_blocks.22.proj_mlp torch.float8_e4m3fn
single_transformer_blocks.22.proj_out torch.float8_e4m3fn
single_transformer_blocks.23.attn.norm_q torch.bfloat16
single_transformer_blocks.23.attn.norm_k torch.bfloat16
single_transformer_blocks.23.attn.to_q torch.float8_e4m3fn
single_transformer_blocks.23.attn.to_q.base_layer torch.float8_e4m3fn
single_transformer_blocks.23.attn.to_q.lora_A.default torch.float32
single_transformer_blocks.23.attn.to_q.lora_B.default torch.float32
single_transformer_blocks.23.attn.to_k torch.float8_e4m3fn
single_transformer_blocks.23.attn.to_k.base_layer torch.float8_e4m3fn
single_transformer_blocks.23.attn.to_k.lora_A.default torch.float32
single_transformer_blocks.23.attn.to_k.lora_B.default torch.float32
single_transformer_blocks.23.attn.to_v torch.float8_e4m3fn
single_transformer_blocks.23.attn.to_v.base_layer torch.float8_e4m3fn
single_transformer_blocks.23.attn.to_v.lora_A.default torch.float32
single_transformer_blocks.23.attn.to_v.lora_B.default torch.float32
single_transformer_blocks.23.norm.linear torch.bfloat16
single_transformer_blocks.23.proj_mlp torch.float8_e4m3fn
single_transformer_blocks.23.proj_out torch.float8_e4m3fn
single_transformer_blocks.24.attn.norm_q torch.bfloat16
single_transformer_blocks.24.attn.norm_k torch.bfloat16
single_transformer_blocks.24.attn.to_q torch.float8_e4m3fn
single_transformer_blocks.24.attn.to_q.base_layer torch.float8_e4m3fn
single_transformer_blocks.24.attn.to_q.lora_A.default torch.float32
single_transformer_blocks.24.attn.to_q.lora_B.default torch.float32
single_transformer_blocks.24.attn.to_k torch.float8_e4m3fn
single_transformer_blocks.24.attn.to_k.base_layer torch.float8_e4m3fn
single_transformer_blocks.24.attn.to_k.lora_A.default torch.float32
single_transformer_blocks.24.attn.to_k.lora_B.default torch.float32
single_transformer_blocks.24.attn.to_v torch.float8_e4m3fn
single_transformer_blocks.24.attn.to_v.base_layer torch.float8_e4m3fn
single_transformer_blocks.24.attn.to_v.lora_A.default torch.float32
single_transformer_blocks.24.attn.to_v.lora_B.default torch.float32
single_transformer_blocks.24.norm.linear torch.bfloat16
single_transformer_blocks.24.proj_mlp torch.float8_e4m3fn
single_transformer_blocks.24.proj_out torch.float8_e4m3fn
single_transformer_blocks.25.attn.norm_q torch.bfloat16
single_transformer_blocks.25.attn.norm_k torch.bfloat16
single_transformer_blocks.25.attn.to_q torch.float8_e4m3fn
single_transformer_blocks.25.attn.to_q.base_layer torch.float8_e4m3fn
single_transformer_blocks.25.attn.to_q.lora_A.default torch.float32
single_transformer_blocks.25.attn.to_q.lora_B.default torch.float32
single_transformer_blocks.25.attn.to_k torch.float8_e4m3fn
single_transformer_blocks.25.attn.to_k.base_layer torch.float8_e4m3fn
single_transformer_blocks.25.attn.to_k.lora_A.default torch.float32
single_transformer_blocks.25.attn.to_k.lora_B.default torch.float32
single_transformer_blocks.25.attn.to_v torch.float8_e4m3fn
single_transformer_blocks.25.attn.to_v.base_layer torch.float8_e4m3fn
single_transformer_blocks.25.attn.to_v.lora_A.default torch.float32
single_transformer_blocks.25.attn.to_v.lora_B.default torch.float32
single_transformer_blocks.25.norm.linear torch.bfloat16
single_transformer_blocks.25.proj_mlp torch.float8_e4m3fn
single_transformer_blocks.25.proj_out torch.float8_e4m3fn
single_transformer_blocks.26.attn.norm_q torch.bfloat16
single_transformer_blocks.26.attn.norm_k torch.bfloat16
single_transformer_blocks.26.attn.to_q torch.float8_e4m3fn
single_transformer_blocks.26.attn.to_q.base_layer torch.float8_e4m3fn
single_transformer_blocks.26.attn.to_q.lora_A.default torch.float32
single_transformer_blocks.26.attn.to_q.lora_B.default torch.float32
single_transformer_blocks.26.attn.to_k torch.float8_e4m3fn
single_transformer_blocks.26.attn.to_k.base_layer torch.float8_e4m3fn
single_transformer_blocks.26.attn.to_k.lora_A.default torch.float32
single_transformer_blocks.26.attn.to_k.lora_B.default torch.float32
single_transformer_blocks.26.attn.to_v torch.float8_e4m3fn
single_transformer_blocks.26.attn.to_v.base_layer torch.float8_e4m3fn
single_transformer_blocks.26.attn.to_v.lora_A.default torch.float32
single_transformer_blocks.26.attn.to_v.lora_B.default torch.float32
single_transformer_blocks.26.norm.linear torch.bfloat16
single_transformer_blocks.26.proj_mlp torch.float8_e4m3fn
single_transformer_blocks.26.proj_out torch.float8_e4m3fn
single_transformer_blocks.27.attn.norm_q torch.bfloat16
single_transformer_blocks.27.attn.norm_k torch.bfloat16
single_transformer_blocks.27.attn.to_q torch.float8_e4m3fn
single_transformer_blocks.27.attn.to_q.base_layer torch.float8_e4m3fn
single_transformer_blocks.27.attn.to_q.lora_A.default torch.float32
single_transformer_blocks.27.attn.to_q.lora_B.default torch.float32
single_transformer_blocks.27.attn.to_k torch.float8_e4m3fn
single_transformer_blocks.27.attn.to_k.base_layer torch.float8_e4m3fn
single_transformer_blocks.27.attn.to_k.lora_A.default torch.float32
single_transformer_blocks.27.attn.to_k.lora_B.default torch.float32
single_transformer_blocks.27.attn.to_v torch.float8_e4m3fn
single_transformer_blocks.27.attn.to_v.base_layer torch.float8_e4m3fn
single_transformer_blocks.27.attn.to_v.lora_A.default torch.float32
single_transformer_blocks.27.attn.to_v.lora_B.default torch.float32
single_transformer_blocks.27.norm.linear torch.bfloat16
single_transformer_blocks.27.proj_mlp torch.float8_e4m3fn
single_transformer_blocks.27.proj_out torch.float8_e4m3fn
single_transformer_blocks.28.attn.norm_q torch.bfloat16
single_transformer_blocks.28.attn.norm_k torch.bfloat16
single_transformer_blocks.28.attn.to_q torch.float8_e4m3fn
single_transformer_blocks.28.attn.to_q.base_layer torch.float8_e4m3fn
single_transformer_blocks.28.attn.to_q.lora_A.default torch.float32
single_transformer_blocks.28.attn.to_q.lora_B.default torch.float32
single_transformer_blocks.28.attn.to_k torch.float8_e4m3fn
single_transformer_blocks.28.attn.to_k.base_layer torch.float8_e4m3fn
single_transformer_blocks.28.attn.to_k.lora_A.default torch.float32
single_transformer_blocks.28.attn.to_k.lora_B.default torch.float32
single_transformer_blocks.28.attn.to_v torch.float8_e4m3fn
single_transformer_blocks.28.attn.to_v.base_layer torch.float8_e4m3fn
single_transformer_blocks.28.attn.to_v.lora_A.default torch.float32
single_transformer_blocks.28.attn.to_v.lora_B.default torch.float32
single_transformer_blocks.28.norm.linear torch.bfloat16
single_transformer_blocks.28.proj_mlp torch.float8_e4m3fn
single_transformer_blocks.28.proj_out torch.float8_e4m3fn
single_transformer_blocks.29.attn.norm_q torch.bfloat16
single_transformer_blocks.29.attn.norm_k torch.bfloat16
single_transformer_blocks.29.attn.to_q torch.float8_e4m3fn
single_transformer_blocks.29.attn.to_q.base_layer torch.float8_e4m3fn
single_transformer_blocks.29.attn.to_q.lora_A.default torch.float32
single_transformer_blocks.29.attn.to_q.lora_B.default torch.float32
single_transformer_blocks.29.attn.to_k torch.float8_e4m3fn
single_transformer_blocks.29.attn.to_k.base_layer torch.float8_e4m3fn
single_transformer_blocks.29.attn.to_k.lora_A.default torch.float32
single_transformer_blocks.29.attn.to_k.lora_B.default torch.float32
single_transformer_blocks.29.attn.to_v torch.float8_e4m3fn
single_transformer_blocks.29.attn.to_v.base_layer torch.float8_e4m3fn
single_transformer_blocks.29.attn.to_v.lora_A.default torch.float32
single_transformer_blocks.29.attn.to_v.lora_B.default torch.float32
single_transformer_blocks.29.norm.linear torch.bfloat16
single_transformer_blocks.29.proj_mlp torch.float8_e4m3fn
single_transformer_blocks.29.proj_out torch.float8_e4m3fn
single_transformer_blocks.30.attn.norm_q torch.bfloat16
single_transformer_blocks.30.attn.norm_k torch.bfloat16
single_transformer_blocks.30.attn.to_q torch.float8_e4m3fn
single_transformer_blocks.30.attn.to_q.base_layer torch.float8_e4m3fn
single_transformer_blocks.30.attn.to_q.lora_A.default torch.float32
single_transformer_blocks.30.attn.to_q.lora_B.default torch.float32
single_transformer_blocks.30.attn.to_k torch.float8_e4m3fn
single_transformer_blocks.30.attn.to_k.base_layer torch.float8_e4m3fn
single_transformer_blocks.30.attn.to_k.lora_A.default torch.float32
single_transformer_blocks.30.attn.to_k.lora_B.default torch.float32
single_transformer_blocks.30.attn.to_v torch.float8_e4m3fn
single_transformer_blocks.30.attn.to_v.base_layer torch.float8_e4m3fn
single_transformer_blocks.30.attn.to_v.lora_A.default torch.float32
single_transformer_blocks.30.attn.to_v.lora_B.default torch.float32
single_transformer_blocks.30.norm.linear torch.bfloat16
single_transformer_blocks.30.proj_mlp torch.float8_e4m3fn
single_transformer_blocks.30.proj_out torch.float8_e4m3fn
single_transformer_blocks.31.attn.norm_q torch.bfloat16
single_transformer_blocks.31.attn.norm_k torch.bfloat16
single_transformer_blocks.31.attn.to_q torch.float8_e4m3fn
single_transformer_blocks.31.attn.to_q.base_layer torch.float8_e4m3fn
single_transformer_blocks.31.attn.to_q.lora_A.default torch.float32
single_transformer_blocks.31.attn.to_q.lora_B.default torch.float32
single_transformer_blocks.31.attn.to_k torch.float8_e4m3fn
single_transformer_blocks.31.attn.to_k.base_layer torch.float8_e4m3fn
single_transformer_blocks.31.attn.to_k.lora_A.default torch.float32
single_transformer_blocks.31.attn.to_k.lora_B.default torch.float32
single_transformer_blocks.31.attn.to_v torch.float8_e4m3fn
single_transformer_blocks.31.attn.to_v.base_layer torch.float8_e4m3fn
single_transformer_blocks.31.attn.to_v.lora_A.default torch.float32
single_transformer_blocks.31.attn.to_v.lora_B.default torch.float32
single_transformer_blocks.31.norm.linear torch.bfloat16
single_transformer_blocks.31.proj_mlp torch.float8_e4m3fn
single_transformer_blocks.31.proj_out torch.float8_e4m3fn
single_transformer_blocks.32.attn.norm_q torch.bfloat16
single_transformer_blocks.32.attn.norm_k torch.bfloat16
single_transformer_blocks.32.attn.to_q torch.float8_e4m3fn
single_transformer_blocks.32.attn.to_q.base_layer torch.float8_e4m3fn
single_transformer_blocks.32.attn.to_q.lora_A.default torch.float32
single_transformer_blocks.32.attn.to_q.lora_B.default torch.float32
single_transformer_blocks.32.attn.to_k torch.float8_e4m3fn
single_transformer_blocks.32.attn.to_k.base_layer torch.float8_e4m3fn
single_transformer_blocks.32.attn.to_k.lora_A.default torch.float32
single_transformer_blocks.32.attn.to_k.lora_B.default torch.float32
single_transformer_blocks.32.attn.to_v torch.float8_e4m3fn
single_transformer_blocks.32.attn.to_v.base_layer torch.float8_e4m3fn
single_transformer_blocks.32.attn.to_v.lora_A.default torch.float32
single_transformer_blocks.32.attn.to_v.lora_B.default torch.float32
single_transformer_blocks.32.norm.linear torch.bfloat16
single_transformer_blocks.32.proj_mlp torch.float8_e4m3fn
single_transformer_blocks.32.proj_out torch.float8_e4m3fn
single_transformer_blocks.33.attn.norm_q torch.bfloat16
single_transformer_blocks.33.attn.norm_k torch.bfloat16
single_transformer_blocks.33.attn.to_q torch.float8_e4m3fn
single_transformer_blocks.33.attn.to_q.base_layer torch.float8_e4m3fn
single_transformer_blocks.33.attn.to_q.lora_A.default torch.float32
single_transformer_blocks.33.attn.to_q.lora_B.default torch.float32
single_transformer_blocks.33.attn.to_k torch.float8_e4m3fn
single_transformer_blocks.33.attn.to_k.base_layer torch.float8_e4m3fn
single_transformer_blocks.33.attn.to_k.lora_A.default torch.float32
single_transformer_blocks.33.attn.to_k.lora_B.default torch.float32
single_transformer_blocks.33.attn.to_v torch.float8_e4m3fn
single_transformer_blocks.33.attn.to_v.base_layer torch.float8_e4m3fn
single_transformer_blocks.33.attn.to_v.lora_A.default torch.float32
single_transformer_blocks.33.attn.to_v.lora_B.default torch.float32
single_transformer_blocks.33.norm.linear torch.bfloat16
single_transformer_blocks.33.proj_mlp torch.float8_e4m3fn
single_transformer_blocks.33.proj_out torch.float8_e4m3fn
single_transformer_blocks.34.attn.norm_q torch.bfloat16
single_transformer_blocks.34.attn.norm_k torch.bfloat16
single_transformer_blocks.34.attn.to_q torch.float8_e4m3fn
single_transformer_blocks.34.attn.to_q.base_layer torch.float8_e4m3fn
single_transformer_blocks.34.attn.to_q.lora_A.default torch.float32
single_transformer_blocks.34.attn.to_q.lora_B.default torch.float32
single_transformer_blocks.34.attn.to_k torch.float8_e4m3fn
single_transformer_blocks.34.attn.to_k.base_layer torch.float8_e4m3fn
single_transformer_blocks.34.attn.to_k.lora_A.default torch.float32
single_transformer_blocks.34.attn.to_k.lora_B.default torch.float32
single_transformer_blocks.34.attn.to_v torch.float8_e4m3fn
single_transformer_blocks.34.attn.to_v.base_layer torch.float8_e4m3fn
single_transformer_blocks.34.attn.to_v.lora_A.default torch.float32
single_transformer_blocks.34.attn.to_v.lora_B.default torch.float32
single_transformer_blocks.34.norm.linear torch.bfloat16
single_transformer_blocks.34.proj_mlp torch.float8_e4m3fn
single_transformer_blocks.34.proj_out torch.float8_e4m3fn
single_transformer_blocks.35.attn.norm_q torch.bfloat16
single_transformer_blocks.35.attn.norm_k torch.bfloat16
single_transformer_blocks.35.attn.to_q torch.float8_e4m3fn
single_transformer_blocks.35.attn.to_q.base_layer torch.float8_e4m3fn
single_transformer_blocks.35.attn.to_q.lora_A.default torch.float32
single_transformer_blocks.35.attn.to_q.lora_B.default torch.float32
single_transformer_blocks.35.attn.to_k torch.float8_e4m3fn
single_transformer_blocks.35.attn.to_k.base_layer torch.float8_e4m3fn
single_transformer_blocks.35.attn.to_k.lora_A.default torch.float32
single_transformer_blocks.35.attn.to_k.lora_B.default torch.float32
single_transformer_blocks.35.attn.to_v torch.float8_e4m3fn
single_transformer_blocks.35.attn.to_v.base_layer torch.float8_e4m3fn
single_transformer_blocks.35.attn.to_v.lora_A.default torch.float32
single_transformer_blocks.35.attn.to_v.lora_B.default torch.float32
single_transformer_blocks.35.norm.linear torch.bfloat16
single_transformer_blocks.35.proj_mlp torch.float8_e4m3fn
single_transformer_blocks.35.proj_out torch.float8_e4m3fn
single_transformer_blocks.36.attn.norm_q torch.bfloat16
single_transformer_blocks.36.attn.norm_k torch.bfloat16
single_transformer_blocks.36.attn.to_q torch.float8_e4m3fn
single_transformer_blocks.36.attn.to_q.base_layer torch.float8_e4m3fn
single_transformer_blocks.36.attn.to_q.lora_A.default torch.float32
single_transformer_blocks.36.attn.to_q.lora_B.default torch.float32
single_transformer_blocks.36.attn.to_k torch.float8_e4m3fn
single_transformer_blocks.36.attn.to_k.base_layer torch.float8_e4m3fn
single_transformer_blocks.36.attn.to_k.base_layer torch.float8_e4m3fn
single_transformer_blocks.36.attn.to_k.lora_A.default torch.float32
single_transformer_blocks.36.attn.to_k.lora_B.default torch.float32
single_transformer_blocks.36.attn.to_v torch.float8_e4m3fn
single_transformer_blocks.36.attn.to_v.base_layer torch.float8_e4m3fn
single_transformer_blocks.36.attn.to_v.lora_A.default torch.float32
single_transformer_blocks.36.attn.to_v.lora_B.default torch.float32
single_transformer_blocks.36.norm.linear torch.bfloat16
single_transformer_blocks.36.proj_mlp torch.float8_e4m3fn
single_transformer_blocks.36.proj_out torch.float8_e4m3fn
single_transformer_blocks.37.attn.norm_q torch.bfloat16
single_transformer_blocks.37.attn.norm_k torch.bfloat16
single_transformer_blocks.37.attn.to_q torch.float8_e4m3fn
single_transformer_blocks.37.attn.to_q.base_layer torch.float8_e4m3fn
single_transformer_blocks.37.attn.to_q.lora_A.default torch.float32
single_transformer_blocks.37.attn.to_q.lora_B.default torch.float32
single_transformer_blocks.37.attn.to_k torch.float8_e4m3fn
single_transformer_blocks.37.attn.to_k.base_layer torch.float8_e4m3fn
single_transformer_blocks.37.attn.to_k.lora_A.default torch.float32
single_transformer_blocks.37.attn.to_k.lora_B.default torch.float32
single_transformer_blocks.37.attn.to_v torch.float8_e4m3fn
single_transformer_blocks.37.attn.to_v.base_layer torch.float8_e4m3fn
single_transformer_blocks.37.attn.to_v.lora_A.default torch.float32
single_transformer_blocks.37.attn.to_v.lora_B.default torch.float32
single_transformer_blocks.37.norm.linear torch.bfloat16
single_transformer_blocks.37.proj_mlp torch.float8_e4m3fn
single_transformer_blocks.37.proj_out torch.float8_e4m3fn
single_transformer_blocks.38.attn.norm_q torch.bfloat16
single_transformer_blocks.38.attn.norm_k torch.bfloat16
single_transformer_blocks.38.attn.to_q torch.float8_e4m3fn
single_transformer_blocks.38.attn.to_q.base_layer torch.float8_e4m3fn
single_transformer_blocks.38.attn.to_q.lora_A.default torch.float32
single_transformer_blocks.38.attn.to_q.lora_B.default torch.float32
single_transformer_blocks.38.attn.to_k torch.float8_e4m3fn
single_transformer_blocks.38.attn.to_k.base_layer torch.float8_e4m3fn
single_transformer_blocks.38.attn.to_k.lora_A.default torch.float32
single_transformer_blocks.38.attn.to_k.lora_B.default torch.float32
single_transformer_blocks.38.attn.to_v torch.float8_e4m3fn
single_transformer_blocks.38.attn.to_v.base_layer torch.float8_e4m3fn
single_transformer_blocks.38.attn.to_v.lora_A.default torch.float32
single_transformer_blocks.38.attn.to_v.lora_B.default torch.float32
single_transformer_blocks.38.norm.linear torch.bfloat16
single_transformer_blocks.38.proj_mlp torch.float8_e4m3fn
single_transformer_blocks.38.proj_out torch.float8_e4m3fn
single_transformer_blocks.39.attn.norm_q torch.bfloat16
single_transformer_blocks.39.attn.norm_k torch.bfloat16
single_transformer_blocks.39.attn.to_q torch.float8_e4m3fn
single_transformer_blocks.39.attn.to_q.base_layer torch.float8_e4m3fn
single_transformer_blocks.39.attn.to_q.lora_A.default torch.float32
single_transformer_blocks.39.attn.to_q.lora_B.default torch.float32
single_transformer_blocks.39.attn.to_k torch.float8_e4m3fn
single_transformer_blocks.39.attn.to_k.base_layer torch.float8_e4m3fn
single_transformer_blocks.39.attn.to_k.lora_A.default torch.float32
single_transformer_blocks.39.attn.to_k.lora_B.default torch.float32
single_transformer_blocks.39.attn.to_v torch.float8_e4m3fn
single_transformer_blocks.39.attn.to_v.base_layer torch.float8_e4m3fn
single_transformer_blocks.39.attn.to_v.lora_A.default torch.float32
single_transformer_blocks.39.attn.to_v.lora_B.default torch.float32
single_transformer_blocks.39.norm.linear torch.bfloat16
single_transformer_blocks.39.proj_mlp torch.float8_e4m3fn
single_transformer_blocks.39.proj_out torch.float8_e4m3fn
norm_out.linear torch.bfloat16
proj_out torch.float8_e4m3fn

TODO: Allow LoRA to be in bf16 and only apply .to() patch if training in fp16/fp8

@a-r-r-o-w
Copy link
Owner Author

a-r-r-o-w commented Jan 7, 2025

Hmm..., so while the loss curves match in the beginning with bf16 training (they diverge later on), the results are overly smoothed and it looks like lora didn't learn anything. I am not sure if this is because of a validation + hooks related bug, or a training bug. Will look into it a little later. As we can see in the diffusers PR, inference is almost unaffected at fp8, so I highly doubt that we can't train this way. One way to train would be to just use fp8 hunyuan checkpoint directly and use their custom code, but that is less ideal and this is more generally applicable for all supported models.

image

@sayakpaul
Copy link
Collaborator

@a-r-r-o-w allow me to help :)

Sometimes, a fresh pair of eyes could be of help. So, how about I look into it a bit, ask you questions, run your experiments, while you keep polishing the diffusers PR, thinking about any test that might be helpful to have/do.

WDYT?

@a-r-r-o-w
Copy link
Owner Author

It's an open PR so anyone is open to try/help lol. I don't yet have an idea on what's causing the bad results:

  • Is it training which is not working? That would explain bad results.
  • Or, if training is working as expected (loss curves and gradient norms are highly similar compared to bf16), am I not handling something in validation/diffusers-pipeline, which is causing lossy downcast/upcast of latents.

So any help is appreciated

@a-r-r-o-w
Copy link
Owner Author

FWIW, I've done several experiments with the exact same code for inference purposes with lora (comments in the diffusers PR) and did not find any problems other than the ones fixed by stupid workarounds, so it could just be something in training, and the graphs co-inciding at the start attributable to 0-weights of LoRA, which eventually do start do diverge (after > ~1000 steps) and become garbage

@sayakpaul
Copy link
Collaborator

One way to train would be to just use fp8 hunyuan checkpoint directly and use their custom code, but that is less ideal and this is more generally applicable for all supported models.

Agreed. Even if this takes time, I think it's worth it because it's general.

I've done several experiments with the exact same code for inference purposes with lora (comments in the diffusers PR)

Could you please point me to some? Will give me a better idea of what's tried already.

For my testing, can I use the command you provided in #184 (comment) or is there a different one I should use? You mentioned peft patches. Are they needed for me to run my experiments?

One test I would try to run on the diffusers PR is the following. Do the layerwise upcasting on a pipeline (such as Flux) where there's a plenty of LoRAs already (https://huggingface.co/spaces/multimodalart/flux-lora-the-explorer/blob/main/loras.json), load a loRA into it, and see if it's working as expected. Sorry if you have conducted it (or something similar) already. This will give a nice idea about if something's broken.

@neph1
Copy link

neph1 commented Jan 7, 2025

so while the loss curves match in the beginning with bf16 training (they diverge later on)

Fwiw, I ran LTX-V for 400 steps in fp8, and the lora definitely did learn. But since you say they 'diverge later on', maybe it's an accumulating error.

@a-r-r-o-w
Copy link
Owner Author

@neph1 Thanks! I think it probably might be something to do with validation then most likely. I could try running a long LTX training run and see if I get the expected behaviour. For Hunyuan, I was using images for the FP8 training and it seems to collapse the model within 1000 steps of the total 10000 steps of training. Will try and power through the debugging today and hopefully something comes out.

@sayakpaul Yes, here you go: https://wandb.ai/aryanvs/finetrainers-hunyuan-video/reports/FP8-vs-BF16-Hunyuan-Video--VmlldzoxMDg0OTMwMw?accessToken=vqwlt7y899u0qb25fyhma612khuk81o5ggi2pmds2y13xv78dfh1es97q0ksaprg.

This contains the two runs with the exact same starting conditions. The FP8 run did not have any validation performed because it errors out due to an unimplemented fp8 multiplication kernel. This is partly due to how we infer the dtype in ModelMixin in diffusers and how the pipelines are written. Will address these concerns in the diffusers PR.

For my testing, can I use the command you provided in #184 (comment) or is there a different one I should use? You mentioned peft patches. Are they needed for me to run my experiments?

Yes, you should modify the command provided in the description to do a longer training run. The only relevant parameter different from previous scripts is --layerwise_upcasting_modules transformer. One could maybe play with --layerwise_upcasting_skip_modules_pattern too.

The peft patches are enabled by default when you launch fp8 upcasting training. Without it, we will:

  • either error out somewhere in peft, I don't recall exactly
  • cast lora weights to fp8 and then back to fp32, which will make the gaussian initialization of lora_A lossy. This isn't a problem in itself, but would prefer that the lora never gets downcasted even at initialization

One test I would try to run on the diffusers PR is the following. Do the layerwise upcasting on a pipeline (such as Flux) where there's a plenty of LoRAs already (https://huggingface.co/spaces/multimodalart/flux-lora-the-explorer/blob/main/loras.json), load a loRA into it, and see if it's working as expected. Sorry if you have conducted it (or something similar) already. This will give a nice idea about if something's broken.

The diffusers PR linked in the description contains a wide many number of experiments with all the code/ideas for workarounds needed to make this work with transformers and peft. Flux works as expected when this is enabled for lora inference with little over half the memory limit needed by transformer in bf16

@a-r-r-o-w
Copy link
Owner Author

a-r-r-o-w commented Jan 8, 2025

Latest commit fixes the following error of resuming fp8 training after a validation is performed.

stacktrace
01/08/2025 01:53:23 - ERROR - finetrainers - Traceback (most recent call last):
  File "/raid/aryan/cogvideox-distillation/train.py", line 35, in main
    trainer.train()
  File "/raid/aryan/cogvideox-distillation/finetrainers/trainer.py", line 813, in train
    self.optimizer.step()
  File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/accelerate/optimizer.py", line 171, in step
    self.optimizer.step(closure)
  File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/optim/lr_scheduler.py", line 137, in wrapper
    return func.__get__(opt, opt.__class__)(*args, **kwargs)
  File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/optim/optimizer.py", line 487, in wrapper
    out = func(*args, **kwargs)
  File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/optim/optimizer.py", line 91, in _use_grad
    ret = func(self, *args, **kwargs)
  File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/optim/adamw.py", line 220, in step
    adamw(
  File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/optim/optimizer.py", line 154, in maybe_fallback
    return func(*args, **kwargs)
  File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/optim/adamw.py", line 782, in adamw
    func(
  File "/raid/aryan/nightly-venv/lib/python3.10/site-packages/torch/optim/adamw.py", line 531, in _multi_tensor_adamw
    torch._foreach_lerp_(device_exp_avgs, device_grads, 1 - beta1)
RuntimeError: expected dtype float for `end` but got dtype c10::BFloat16

The device_grads should be in fp32, but due to a combination of fp8 hooks, and accelerate's convert_to_fp32 hooks, the gradients remain in bf16. Will investigate deeper later, but the current fix seems like an okay thing to do

@a-r-r-o-w
Copy link
Owner Author

a-r-r-o-w commented Jan 8, 2025

List of supported FP8 ops: pytorch/pytorch#107256 (comment) (as of PT 2.1)

@sayakpaul
Copy link
Collaborator

Thanks! I think it probably might be something to do with validation then most likely. I could try running a long LTX training run and see if I get the expected behaviour.

@a-r-r-o-w this could be. Very well could be.

When I was doing FP8 training with torchao here, the validation results were extremely depressing so, I just let it train and when I performed inference with the trained LoRA things were fine.

I believe we could do something similar here?

@a-r-r-o-w
Copy link
Owner Author

When I was doing FP8 training with torchao here, the validation results were extremely depressing so, I just let it train and when I performed inference with the trained LoRA things were fine.

I believe we could do something similar here?

Sure, that works. But:

  • If we run normal inference with fp8LU, it produces almost same quality videos as bf16.
  • If we run training and compare bf16 vs fp8, the loss is almost identical throughout training.

If what you say is true for this PR, then:

  • Why running lora inference works in fp8 after training, but same does not during the training? The weights are exactly in the same condition in both cases

So, would really like to get to the bottom of this and find the right solution even if it takes a little more time in moving forward with this PR. Will test what you mentioned too and get back soon

@a-r-r-o-w
Copy link
Owner Author

Started a small fp8 20000 step run for LTX on single GPU (5000 training steps with 4 gradient accumulation): https://wandb.ai/aryanvs/finetrainers-ltxv/runs/uy0evi7m

If this works as expected, then there's something fishy going on in Hunyuan, which I anticipate will be awful to debug :/

@sayakpaul
Copy link
Collaborator

@a-r-r-o-w sounds good! Yes, my comment was to confirm if it's the validation that we need to debug if what I said and observed turns out to be True.

If we run training and compare bf16 vs fp8, the loss is almost identical throughout training.

Oh I thought you mentioned it starts to diverge after certain steps?

How are we doing for Cog with FP8? Or is it not robust enough currently to be trained with FP8?

@a-r-r-o-w
Copy link
Owner Author

a-r-r-o-w commented Jan 8, 2025

Oh I thought you mentioned it starts to diverge after certain steps?

You can check this report for better understanding of what I mean. It's not the exact same to begin with, but the values are almost the same so I would expect that they converge to roughly similar weights.

Edit: I just saw the report that was created. I'm not sure what wandb is doing here but it does not show loss at every step and is skipping some data points in the report. This is what the true graph looks like when overlapped (brown is bf16 and green is fp8):

image

image

How are we doing for Cog with FP8? Or is it not robust enough currently to be trained with FP8?

I am yet to test Cog. Will do it soon. Inference works really well with FP8 Cog though

@sayakpaul
Copy link
Collaborator

Will run some tests myself today. Thanks a lot for bearing with my questions and providing detailed answers.

@a-r-r-o-w
Copy link
Owner Author

a-r-r-o-w commented Jan 8, 2025

Okay LTX training run looks extremely promising in just 1500 steps so far (actually 6000 steps because of 4 gradient accumulation): https://wandb.ai/aryanvs/finetrainers-ltxv/runs/uy0evi7m?nw=nwuseraryanvs.

Definitely a Hunyuan pipeline problem or a dataset problem - I'm using 400 images which I did for the bf16 run.

Memory required:

01/08/2025 03:25:29 - INFO - finetrainers - Memory after validation end:
    "memory_allocated": 3.13,
    "memory_reserved": 3.195,
    "max_memory_allocated": 14.822,
    "max_memory_reserved": 18.625
}
Training steps:  10%|█████████████████▋                                                                                                                                                             | 504/5000 [28:34<19:35:57, 15.69s/it, grad_norm=0.0125, step_loss=0.303, lr=5.04e-6]
01/08/2025 03:25:54 - INFO - finetrainers - Memory after epoch 28: {
    "memory_allocated": 3.13,
    "memory_reserved": 5.398,
    "max_memory_allocated": 4.839,
    "max_memory_reserved": 5.398
}

The LTX team has some amazing chefs, I must say. With framewise encoding/decoding support in diffusers coming soon (huggingface/diffusers#10488) and group offloading, this will be like negligible memory required lol

@a-r-r-o-w
Copy link
Owner Author

a-r-r-o-w commented Jan 8, 2025

Some more reports: https://wandb.ai/aryanvs/finetrainers-ltxv/reports/fp8-uniform-vs-fp8-logit_normal-vs-bf16-logit_normal--VmlldzoxMDg2NDU3Ng

  • fp8 with uniform sampling comes closest to the original dataset videos
  • fp8 vs bf16, with logit_normal weighting scheme as used in the paper, are very close throughout training, but do not converge to expected style

@a-r-r-o-w a-r-r-o-w mentioned this pull request Jan 9, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants