Skip to content

Comments

fp8 full quant#105

Open
GokaNik wants to merge 1 commit intokandinskylab:mainfrom
GokaNik:fp8-full
Open

fp8 full quant#105
GokaNik wants to merge 1 commit intokandinskylab:mainfrom
GokaNik:fp8-full

Conversation

@GokaNik
Copy link

@GokaNik GokaNik commented Jan 31, 2026

Summary

This PR adds support for full FP8 quantization (FP8 activations + FP8 weights) of the DiT model using TorchAO, enabling reduced memory usage and improved execution efficiency while preserving model quality.

Quantization details

  • Quantization method:
    FP8 dynamic activations + FP8 weights (Float8DynamicActivationFloat8WeightConfig)

  • All DiT modules are quantized by default

  • The following layers are explicitly excluded for numerical stability:

    • Time, text, and visual embedding input/output layers

    • Final output and modulation layers

    • FFN output projections:

      • text_transformer_blocks[0–3].feed_forward.out_layer
      • visual_transformer_blocks[0–59].feed_forward.out_layer

Performance comparison (FP8 vs Base)

Stage Base (s) FP8 (s) Δ vs Base
Pipeline initialization 96.227 98.510 +2.4%
First generation 607.649 552.182 −9%
Next generations ~574.5 ~522.5 −9%

Percentages are computed relative to the base model. Negative values indicate faster execution.

New dependency

FP8 quantization relies on TorchAO:

pip install torchao

How to generate FP8 weights

python create_fp8_full.py

The script produces, for example:

/data/kandinsky-5/weights/K5_pro_5s_ao.pt

How to run the FP8-quantized model

pipe = get_video_pipeline(
    device_map={"dit": "cuda:0", "vae": "cuda:0", "text_embedder": "cuda:0"},
    conf_path="configs/k5_pro_t2v_5s_sft_sd.yaml",
    model_type="fp8",
    quantized_model_path="/data/kandinsky-5/weights/K5_pro_5s_ao.pt",
    mode="t2v",
)

Important

FP8-quantized DiT currently does not work with offload=True.

Running the FP8-quantized model with offload=True results in a runtime failure with the following error:

    raise RuntimeError("Cannot swap t1 because it has weakref associated with it")
RuntimeError: Cannot swap t1 because it has weakref associated with it

    raise RuntimeError(
RuntimeError: _apply(): Couldn't swap Linear.weight

The issue is known but not yet resolved. Due to project time constraints, a full investigation and fix of the interaction between FP8-quantized TorchAO modules and offloading mechanics were not completed.

At the moment, FP8 quantization is supported only with offload=False.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant