-
Notifications
You must be signed in to change notification settings - Fork 221
Open
Description
Hi authors,
Thanks for your great work! I reproduce this on RTX 5090, this is my script:
python turbodiffusion/inference/wan2.1_t2v_infer.py --model Wan2.1-1.3B --dit_path checkpoints/TurboWan2.1-T2V-1.3B-480P-quant.pth --resolution 480p --prompt "A stylish woman walks down a Tokyo street filled with warm glowing neon and animated city signage. She wears a black leather jacket, a long red dress, and black boots, and carries a black purse. She wears sunglasses and red lipstick. She walks confidently and casually. The street is damp and reflective, creating a mirror effect of the colorful lights. Many pedestrians walk about." --num_samples 1 --num_steps 3 --quant_linear --attention_type sagesla --sla_topk 0.1and the output:
[01-19 08:01:52|INFO|turbodiffusion/inference/wan2.1_t2v_infer.py:88:<module>] Generating with prompt: A stylish woman walks down a Tokyo street filled with warm glowing neon and animated city signage. She wears a black leather jacket, a long red dress, and black boots, and carries a black purse. She wears sunglasses and red lipstick. She walks confidently and casually. The street is damp and reflective, creating a mirror effect of the colorful lights. Many pedestrians walk about.
Sampling: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:02<00:00, 1.25it/s]
[01-19 08:01:54|INFO|turbodiffusion/inference/wan2.1_t2v_infer.py:142:<module>] Sampling completed in 2.40 seconds.And here is how I implemented the timing measurement:
t1 = time.time()
for i, (t_cur, t_next) in enumerate(tqdm(list(zip(t_steps[:-1], t_steps[1:])), desc="Sampling", total=total_steps)):
with torch.no_grad():
v_pred = net(x_B_C_T_H_W=x.to(**tensor_kwargs), timesteps_B_T=(t_cur.float() * ones * 1000).to(**tensor_kwargs), **condition).to(
torch.float64
)
x = (1 - t_next) * (x - t_cur * v_pred) + t_next * torch.randn(
*x.shape,
dtype=torch.float32,
device=tensor_kwargs["device"],
generator=generator,
)
t2 = time.time()
log.info(f"Sampling completed in {t2 - t1:.2f} seconds.")Did I perform the test correctly? I only measured the time for the DiT part. My PyTorch version is 2.8.0+cu129. The 2.4s is not the same as 1.9s as reported.
Metadata
Metadata
Assignees
Labels
No labels