Skip to content

[Benchmark] Inference Performance on RTX 5090 with Wan2.1-1.3B (PyTorch 2.8.0) #102

@haoyuhe04

Description

@haoyuhe04

Hi authors,

Thanks for your great work! I reproduce this on RTX 5090, this is my script:

python turbodiffusion/inference/wan2.1_t2v_infer.py     --model Wan2.1-1.3B     --dit_path checkpoints/TurboWan2.1-T2V-1.3B-480P-quant.pth     --resolution 480p     --prompt "A stylish woman walks down a Tokyo street filled with warm glowing neon and animated city signage. She wears a black leather jacket, a long red dress, and black boots, and carries a black purse. She wears sunglasses and red lipstick. She walks confidently and casually. The street is damp and reflective, creating a mirror effect of the colorful lights. Many pedestrians walk about."     --num_samples 1     --num_steps 3     --quant_linear     --attention_type sagesla     --sla_topk 0.1

and the output:

[01-19 08:01:52|INFO|turbodiffusion/inference/wan2.1_t2v_infer.py:88:<module>] Generating with prompt: A stylish woman walks down a Tokyo street filled with warm glowing neon and animated city signage. She wears a black leather jacket, a long red dress, and black boots, and carries a black purse. She wears sunglasses and red lipstick. She walks confidently and casually. The street is damp and reflective, creating a mirror effect of the colorful lights. Many pedestrians walk about.
Sampling: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:02<00:00,  1.25it/s]
[01-19 08:01:54|INFO|turbodiffusion/inference/wan2.1_t2v_infer.py:142:<module>] Sampling completed in 2.40 seconds.

And here is how I implemented the timing measurement:

    t1 = time.time()
    for i, (t_cur, t_next) in enumerate(tqdm(list(zip(t_steps[:-1], t_steps[1:])), desc="Sampling", total=total_steps)):
        with torch.no_grad():
            v_pred = net(x_B_C_T_H_W=x.to(**tensor_kwargs), timesteps_B_T=(t_cur.float() * ones * 1000).to(**tensor_kwargs), **condition).to(
                torch.float64
            )
            x = (1 - t_next) * (x - t_cur * v_pred) + t_next * torch.randn(
                *x.shape,
                dtype=torch.float32,
                device=tensor_kwargs["device"],
                generator=generator,
            )
    t2 = time.time()
    log.info(f"Sampling completed in {t2 - t1:.2f} seconds.")

Did I perform the test correctly? I only measured the time for the DiT part. My PyTorch version is 2.8.0+cu129. The 2.4s is not the same as 1.9s as reported.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions