From 1466fbb0b6362b954e03b393d26b1556c61aa6bd Mon Sep 17 00:00:00 2001 From: MrEdwards007 <116316872+MrEdwards007@users.noreply.github.com> Date: Thu, 1 Jan 2026 07:38:20 -0500 Subject: [PATCH 1/6] Optimize wan2.2_i2v_infer.py (Image-to-Video) This is an optimized inference script for generating videos from a single input image and a text prompt. It features a Tiered Failover System to prevent "Out of Memory" (OOM) errors on consumer GPUs by automatically switching between high and low-noise models and utilizing tiled encoding/decoding. It is designed to maximize performance on hardware like the RTX 5090 while maintaining stability for high-resolution (720p) outputs. --- turbodiffusion/inference/wan2.2_i2v_infer.py | 658 ++++++++++++++----- 1 file changed, 505 insertions(+), 153 deletions(-) diff --git a/turbodiffusion/inference/wan2.2_i2v_infer.py b/turbodiffusion/inference/wan2.2_i2v_infer.py index e57e509..f6fe32d 100644 --- a/turbodiffusion/inference/wan2.2_i2v_infer.py +++ b/turbodiffusion/inference/wan2.2_i2v_infer.py @@ -1,28 +1,55 @@ +# Blackwell Bridge # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at +# ----------------------------------------------------------------------------------------- +# TURBODIFFUSION OPTIMIZED INFERENCE SCRIPT (I2V) # -# http://www.apache.org/licenses/LICENSE-2.0 +# Co-developed by: Waverly Edwards & Google Gemini (2025) # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Modifications: +# - Implemented "Tiered Failover System" for robust OOM protection. +# - Added Intelligent Model Switching (High/Low Noise) with memory optimization. +# - Integrated Tiled Encoding & Decoding for high-resolution processing. +# - Added Support for Pre-cached Text Embeddings to skip T5 loading. +# - Optimized memory management (VAE unload/reload, aggressive GC). +# +# Acknowledgments: +# - Made possible by the work (cache_t5.py) and creativity of: John D. Pope +# +# Description: +# cache_t5.py pre-computes text embeddings to allow running inference on GPUs with limited VRAM +# by removing the need to keep the 11GB T5 encoder loaded in memory. +# +# CREDIT REQUEST: +# If you utilize, share, or build upon this specific optimized script, please +# acknowledge Waverly Edwards and Google Gemini in your documentation or credits. +# ----------------------------------------------------------------------------------------- import argparse import math +import os +import gc +import time +import sys + +# --- 1. Memory Tuning (Must be before torch imports) --- +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" import torch +import torch.nn.functional as F from einops import rearrange, repeat from tqdm import tqdm from PIL import Image import torchvision.transforms.v2 as T import numpy as np +# Safe import for system memory checks +try: + import psutil +except ImportError: + psutil = None + from imaginaire.utils.io import save_image_or_video from imaginaire.utils import log @@ -34,189 +61,514 @@ torch._dynamo.config.suppress_errors = True - def parse_arguments() -> argparse.Namespace: - parser = argparse.ArgumentParser(description="TurboDiffusion inference script for Wan2.2 I2V with High/Low Noise models") - parser.add_argument("--image_path", type=str, default=None, help="Path to the input image (required unless --serve)") - parser.add_argument("--high_noise_model_path", type=str, required=True, help="Path to the high-noise model") - parser.add_argument("--low_noise_model_path", type=str, required=True, help="Path to the low-noise model") - parser.add_argument("--boundary", type=float, default=0.9, help="Timestep boundary for switching from high to low noise model") - parser.add_argument("--model", choices=["Wan2.2-A14B"], default="Wan2.2-A14B", help="Model to use") - parser.add_argument("--num_samples", type=int, default=1, help="Number of samples to generate") - parser.add_argument("--num_steps", type=int, choices=[1, 2, 3, 4], default=4, help="1~4 for timestep-distilled inference") - parser.add_argument("--sigma_max", type=float, default=200, help="Initial sigma for rCM") - parser.add_argument("--vae_path", type=str, default="checkpoints/Wan2.1_VAE.pth", help="Path to the Wan2.1 VAE") - parser.add_argument("--text_encoder_path", type=str, default="checkpoints/models_t5_umt5-xxl-enc-bf16.pth", help="Path to the umT5 text encoder") - parser.add_argument("--num_frames", type=int, default=81, help="Number of frames to generate") - parser.add_argument("--prompt", type=str, default=None, help="Text prompt for video generation (required unless --serve)") - parser.add_argument("--resolution", default="720p", type=str, help="Resolution of the generated output") - parser.add_argument("--aspect_ratio", default="16:9", type=str, help="Aspect ratio of the generated output (width:height)") - parser.add_argument("--adaptive_resolution", action="store_true", help="If set, adapts the output resolution to the input image's aspect ratio, using the area defined by --resolution and --aspect_ratio as a target.") - parser.add_argument("--ode", action="store_true", help="Use ODE for sampling (sharper but less robust than SDE)") - parser.add_argument("--seed", type=int, default=0, help="Random seed for reproducibility") - parser.add_argument("--save_path", type=str, default="output/generated_video.mp4", help="Path to save the generated video (include file extension)") - parser.add_argument("--attention_type", choices=["sla", "sagesla", "original"], default="sagesla", help="Type of attention mechanism to use") - parser.add_argument("--sla_topk", type=float, default=0.1, help="Top-k ratio for SLA/SageSLA attention") - parser.add_argument("--quant_linear", action="store_true", help="Whether to replace Linear layers with quantized versions") - parser.add_argument("--default_norm", action="store_true", help="Whether to replace LayerNorm/RMSNorm layers with faster versions") - parser.add_argument("--serve", action="store_true", help="Launch interactive TUI server mode (keeps model loaded)") + parser = argparse.ArgumentParser(description="TurboDiffusion inference script for Wan2.2 I2V") + parser.add_argument("--image_path", type=str, default=None, help="Path to input image") + parser.add_argument("--high_noise_model_path", type=str, required=True, help="Path to high-noise model") + parser.add_argument("--low_noise_model_path", type=str, required=True, help="Path to low-noise model") + parser.add_argument("--boundary", type=float, default=0.9, help="Switch boundary") + parser.add_argument("--model", choices=["Wan2.2-A14B"], default="Wan2.2-A14B") + parser.add_argument("--num_samples", type=int, default=1) + parser.add_argument("--num_steps", type=int, default=4) + parser.add_argument("--sigma_max", type=float, default=200) + parser.add_argument("--vae_path", type=str, default="checkpoints/Wan2.1_VAE.pth") + parser.add_argument("--text_encoder_path", type=str, default="checkpoints/models_t5_umt5-xxl-enc-bf16.pth") + parser.add_argument("--cached_embedding", type=str, default=None) + parser.add_argument("--skip_t5", action="store_true") + parser.add_argument("--num_frames", type=int, default=81) + parser.add_argument("--prompt", type=str, default=None) + parser.add_argument("--resolution", default="720p", type=str) + parser.add_argument("--aspect_ratio", default="16:9", type=str) + parser.add_argument("--adaptive_resolution", action="store_true") + parser.add_argument("--ode", action="store_true") + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--save_path", type=str, default="output/generated_video.mp4") + parser.add_argument("--attention_type", choices=["sla", "sagesla", "original"], default="sagesla") + parser.add_argument("--sla_topk", type=float, default=0.1) + parser.add_argument("--quant_linear", action="store_true") + parser.add_argument("--default_norm", action="store_true") + parser.add_argument("--serve", action="store_true") + parser.add_argument("--offload_dit", action="store_true") return parser.parse_args() +def print_memory_status(step_name=""): + """ + Prints a detailed breakdown of GPU memory usage. + """ + if not torch.cuda.is_available(): + return + + torch.cuda.synchronize() + allocated_gb = torch.cuda.memory_allocated() / (1024**3) + reserved_gb = torch.cuda.memory_reserved() / (1024**3) + free_mem, total_mem = torch.cuda.mem_get_info() + free_gb = free_mem / (1024**3) + + print(f"\n๐ [MEMORY REPORT] {step_name}") + print(f" โโโ ๐พ VRAM In Use: {allocated_gb:.2f} GB") + print(f" โโโ ๐ฆ VRAM Reserved: {reserved_gb:.2f} GB") + print(f" โโโ ๐ VRAM Free: {free_gb:.2f} GB") + print("-" * 60) + +def cleanup_memory(step_info=""): + """Aggressively clears VRAM.""" + if step_info: + print(f"๐งน Cleaning memory: {step_info}...") + gc.collect() + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + if step_info: + print_memory_status(f"After Cleanup ({step_info})") + +def get_tensor_size_mb(tensor): + return tensor.element_size() * tensor.nelement() / (1024 * 1024) + +def force_cpu_float32(target_obj): + """ + Recursively forces a model or wrapper object to CPU Float32. + """ + def recursive_cast_to_cpu(obj): + if isinstance(obj, torch.Tensor): + return obj.cpu().float() + elif isinstance(obj, (list, tuple)): + return type(obj)(recursive_cast_to_cpu(x) for x in obj) + elif isinstance(obj, dict): + return {k: recursive_cast_to_cpu(v) for k, v in obj.items()} + return obj + + targets = [target_obj] + if hasattr(target_obj, "model"): + targets.append(target_obj.model) + + for obj in targets: + if isinstance(obj, torch.nn.Module): + try: obj.cpu().float() + except: pass + + for attr_name in dir(obj): + if attr_name.startswith("__"): continue + try: + val = getattr(obj, attr_name) + if isinstance(val, torch.nn.Module): + val.cpu().float() + elif isinstance(val, (torch.Tensor, list, tuple)): + setattr(obj, attr_name, recursive_cast_to_cpu(val)) + except Exception: pass + + if isinstance(obj, torch.nn.Module): + try: + for module in obj.modules(): + for param in module.parameters(recurse=False): + if param is not None: + param.data = param.data.cpu().float() + if param.grad is not None: param.grad = None + for buf in module.buffers(recurse=False): + if buf is not None: + buf.data = buf.data.cpu().float() + except: pass + else: + for attr_name in dir(obj): + if attr_name.startswith("__"): continue + try: + val = getattr(obj, attr_name) + if isinstance(val, torch.nn.Module): + for module in val.modules(): + for param in module.parameters(recurse=False): + if param is not None: + param.data = param.data.cpu().float() + for buf in module.buffers(recurse=False): + if buf is not None: + buf.data = buf.data.cpu().float() + except: pass + +def tiled_encode_4x(tokenizer, frames, target_dtype): + B, C, T, H, W = frames.shape + h_mid = H // 2 + w_mid = W // 2 + print(f"\n๐งฉ Starting 4-Chunk Tiled Encoding (Input: {W}x{H})") + latents_list = [[None, None], [None, None]] + + try: + print(" ๐ Encoding Chunk 1/4 (Top-Left)...") + with torch.amp.autocast("cuda", dtype=target_dtype): + l_tl = tokenizer.encode(frames[:, :, :, :h_mid, :w_mid]) + latents_list[0][0] = l_tl.cpu() + del l_tl; cleanup_memory("After Chunk 1") + except Exception as e: + print(f"โ Chunk 1 Failed: {e}") + raise e + + try: + print(" ๐ Encoding Chunk 2/4 (Top-Right)...") + with torch.amp.autocast("cuda", dtype=target_dtype): + l_tr = tokenizer.encode(frames[:, :, :, :h_mid, w_mid:]) + latents_list[0][1] = l_tr.cpu() + del l_tr; cleanup_memory("After Chunk 2") + except Exception as e: + print(f"โ Chunk 2 Failed: {e}") + raise e + + try: + print(" ๐ Encoding Chunk 3/4 (Bottom-Left)...") + with torch.amp.autocast("cuda", dtype=target_dtype): + l_bl = tokenizer.encode(frames[:, :, :, h_mid:, :w_mid]) + latents_list[1][0] = l_bl.cpu() + del l_bl; cleanup_memory("After Chunk 3") + except Exception as e: + print(f"โ Chunk 3 Failed: {e}") + raise e + + try: + print(" ๐ Encoding Chunk 4/4 (Bottom-Right)...") + with torch.amp.autocast("cuda", dtype=target_dtype): + l_br = tokenizer.encode(frames[:, :, :, h_mid:, w_mid:]) + latents_list[1][1] = l_br.cpu() + del l_br; cleanup_memory("After Chunk 4") + except Exception as e: + print(f"โ Chunk 4 Failed: {e}") + raise e + + print(" ๐งต Stitching Latents...") + row1 = torch.cat([latents_list[0][0], latents_list[0][1]], dim=4) + row2 = torch.cat([latents_list[1][0], latents_list[1][1]], dim=4) + full_latents = torch.cat([row1, row2], dim=3) + return full_latents.to(device=tensor_kwargs["device"], dtype=target_dtype) + +def safe_cpu_fallback_encode(tokenizer, frames, target_dtype): + log.warning("๐ Switching to CPU for VAE Encode (Slow but reliable)...") + cleanup_memory("Pre-CPU Encode") + frames_cpu = frames.cpu().to(dtype=torch.float32) + force_cpu_float32(tokenizer) + t0 = time.time() + with torch.autocast("cpu", enabled=False): + with torch.autocast("cuda", enabled=False): + latents = tokenizer.encode(frames_cpu) + print(f" โฑ๏ธ CPU Encode took: {time.time() - t0:.2f}s") + return latents.to(device=tensor_kwargs["device"], dtype=target_dtype) + +def tiled_decode_gpu(tokenizer, latents, overlap=12): + """ + Decodes latents in 4 spatial quadrants with OVERLAP and SIGMOID BLENDING. + Overlap=12 latents (96 pixels). Safe for 720p. + Removing Global Color Matching to prevent exposure shifts. + """ + print(f"\n๐งฑ Starting Tiled GPU Decode (4 Quadrants, Overlap={overlap}, Blended)...") + B, C, T, H, W = latents.shape + scale = tokenizer.spatial_compression_factor + h_mid = H // 2 + w_mid = W // 2 + + def decode_tile(tile_latents, name): + cleanup_memory(f"Tile {name}") + with torch.no_grad(): + return tokenizer.decode(tile_latents).cpu() + + try: + # 1. Decode Top-Left and Top-Right + l_tl = latents[..., :h_mid+overlap, :w_mid+overlap] + l_tr = latents[..., :h_mid+overlap, w_mid-overlap:] + v_tl = decode_tile(l_tl, "1/4 (TL)") + v_tr = decode_tile(l_tr, "2/4 (TR)") + B_dec, C_dec, T_dec, H_tile, W_tile = v_tl.shape + + print(f" ๐งต Blending Top Row (Decoded Frames: {T_dec})...") + mid_pix = w_mid * scale + overlap_pix = overlap * scale + + # Slices for overlap + tl_blend_slice = v_tl[..., mid_pix-overlap_pix:] + tr_blend_slice = v_tr[..., :2*overlap_pix] + + row_top = torch.zeros(B_dec, 3, T_dec, H_tile, W*scale, dtype=v_tl.dtype, device='cpu') + + # Place non-overlapping parts (Clamped indices) + end_left = max(0, mid_pix - overlap_pix) + start_right = mid_pix + overlap_pix + + row_top[..., :end_left] = v_tl[..., :end_left] + row_top[..., start_right:] = v_tr[..., 2*overlap_pix:] + + x = torch.linspace(-6, 6, 2*overlap_pix, device='cpu') + alpha = torch.sigmoid(x).view(1, 1, 1, 1, -1) + blended_h = tl_blend_slice * (1 - alpha) + tr_blend_slice * alpha + + row_top[..., end_left:start_right] = blended_h + del v_tl, v_tr, l_tl, l_tr + + # 3. Decode Bottom-Left and Bottom-Right + l_bl = latents[..., h_mid-overlap:, :w_mid+overlap] + l_br = latents[..., h_mid-overlap:, w_mid-overlap:] + v_bl = decode_tile(l_bl, "3/4 (BL)") + v_br = decode_tile(l_br, "4/4 (BR)") + + print(" ๐งต Blending Bottom Row...") + bl_blend_slice = v_bl[..., mid_pix-overlap_pix:] + br_blend_slice = v_br[..., :2*overlap_pix] + + row_bot = torch.zeros(B_dec, 3, T_dec, H_tile, W*scale, dtype=v_bl.dtype, device='cpu') + row_bot[..., :end_left] = v_bl[..., :end_left] + row_bot[..., start_right:] = v_br[..., 2*overlap_pix:] + row_bot[..., end_left:start_right] = bl_blend_slice * (1 - alpha) + br_blend_slice * alpha + del v_bl, v_br, l_bl, l_br + + # 5. Blend Top and Bottom Vertically + print(" ๐งต Blending Rows Vertically...") + h_mid_pix = h_mid * scale + + # Slices + top_blend_slice = row_top[..., h_mid_pix-overlap_pix:, :] + bot_blend_slice = row_bot[..., :2*overlap_pix, :] + + video = torch.zeros(B_dec, 3, T_dec, H*scale, W*scale, dtype=row_top.dtype, device='cpu') + + end_top = max(0, h_mid_pix - overlap_pix) + start_bot = h_mid_pix + overlap_pix + + video[..., :end_top, :] = row_top[..., :end_top, :] + video[..., start_bot:, :] = row_bot[..., 2*overlap_pix:, :] + + alpha_v = torch.sigmoid(x).view(1, 1, 1, -1, 1) + blended_v = top_blend_slice * (1 - alpha_v) + bot_blend_slice * alpha_v + + video[..., end_top:start_bot, :] = blended_v + + except Exception as e: + print(f"โ Tiled GPU Decode Failed: {e}") + raise e + return video.to(latents.device) + +def load_dit_model(args, is_high_noise=True, force_offload=False): + """Helper to load the model, respecting overrides.""" + original_offload = args.offload_dit + if force_offload: + args.offload_dit = True + + path = args.high_noise_model_path if is_high_noise else args.low_noise_model_path + log.info(f"Loading {'High' if is_high_noise else 'Low'} Noise DiT (Offload={args.offload_dit})...") + + try: + model = create_model(dit_path=path, args=args).cpu() + finally: + args.offload_dit = original_offload + + return model if __name__ == "__main__": + print_memory_status("Script Start") args = parse_arguments() - # Handle serve mode if args.serve: - # Set mode to i2v for the TUI server args.mode = "i2v" from serve.tui import main as serve_main serve_main(args) exit(0) - - # Validate required args for one-shot mode - if args.prompt is None: - log.error("--prompt is required (unless using --serve mode)") - exit(1) - if args.image_path is None: - log.error("--image_path is required (unless using --serve mode)") - exit(1) - - log.info(f"Computing embedding for prompt: {args.prompt}") - with torch.no_grad(): + + # --- AUTO-ADJUST FRAME COUNT --- + if (args.num_frames - 1) % 4 != 0: + old_f = args.num_frames + new_f = ((old_f - 1) // 4 + 1) * 4 + 1 + print(f"โ ๏ธ Adjusting --num_frames from {old_f} to {new_f} to satisfy VAE temporal stride (4n+1).") + args.num_frames = new_f + + # --- AUTO-ENABLE OFFLOAD FOR HIGH FRAMES --- + if args.num_frames > 90 and not args.offload_dit: + print(f"โ ๏ธ High frame count ({args.num_frames}) detected. Enabling --offload_dit to prevent OOM.") + args.offload_dit = True + + # 1. Text Embeddings + if args.cached_embedding and os.path.exists(args.cached_embedding): + log.info(f"Loading cached embedding from: {args.cached_embedding}") + cache_data = torch.load(args.cached_embedding, map_location='cpu') + text_emb = cache_data['embeddings'][0]['embedding'].to(**tensor_kwargs) + else: + log.info(f"Computing embedding...") text_emb = get_umt5_embedding(checkpoint_path=args.text_encoder_path, prompts=args.prompt).to(**tensor_kwargs) - clear_umt5_memory() - - log.info(f"Loading DiT models.") - high_noise_model = create_model(dit_path=args.high_noise_model_path, args=args).cpu() - torch.cuda.empty_cache() - low_noise_model = create_model(dit_path=args.low_noise_model_path, args=args).cpu() - torch.cuda.empty_cache() - log.success(f"Successfully loaded DiT model.") + clear_umt5_memory() + # 2. VAE Encoding + print("-" * 20 + " VAE SETUP " + "-" * 20) tokenizer = Wan2pt1VAEInterface(vae_pth=args.vae_path) - - log.info(f"Loading and preprocessing image from: {args.image_path}") + target_dtype = tensor_kwargs.get("dtype", torch.bfloat16) input_image = Image.open(args.image_path).convert("RGB") + if args.adaptive_resolution: - log.info("Adaptive resolution mode enabled.") base_w, base_h = VIDEO_RES_SIZE_INFO[args.resolution][args.aspect_ratio] max_resolution_area = base_w * base_h - log.info(f"Target area is based on {args.resolution} {args.aspect_ratio} (~{max_resolution_area} pixels).") - orig_w, orig_h = input_image.size - image_aspect_ratio = orig_h / orig_w - - ideal_w = np.sqrt(max_resolution_area / image_aspect_ratio) - ideal_h = np.sqrt(max_resolution_area * image_aspect_ratio) - + aspect = orig_h / orig_w + ideal_w = np.sqrt(max_resolution_area / aspect) + ideal_h = np.sqrt(max_resolution_area * aspect) stride = tokenizer.spatial_compression_factor * 2 - lat_h = round(ideal_h / stride) - lat_w = round(ideal_w / stride) - h = lat_h * stride - w = lat_w * stride - - log.info(f"Input image aspect ratio: {image_aspect_ratio:.4f}. Adaptive resolution set to: {w}x{h}") + h = round(ideal_h / stride) * stride + w = round(ideal_w / stride) * stride + log.info(f"Adaptive Res: {w}x{h}") else: - log.info("Fixed resolution mode.") w, h = VIDEO_RES_SIZE_INFO[args.resolution][args.aspect_ratio] - log.info(f"Resolution set to: {w}x{h}") + F = args.num_frames lat_h = h // tokenizer.spatial_compression_factor lat_w = w // tokenizer.spatial_compression_factor lat_t = tokenizer.get_latent_num_frames(F) - log.info(f"Preprocessing image to {w}x{h}...") - image_transforms = T.Compose( - [ - T.ToImage(), - T.Resize(size=(h, w), antialias=True), - T.ToDtype(torch.float32, scale=True), - T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), - ] - ) - image_tensor = image_transforms(input_image).unsqueeze(0).to(device=tensor_kwargs["device"], dtype=torch.float32) - - with torch.no_grad(): - frames_to_encode = torch.cat( - [image_tensor.unsqueeze(2), torch.zeros(1, 3, F - 1, h, w, device=image_tensor.device)], dim=2 - ) # -> B, C, T, H, W - encoded_latents = tokenizer.encode(frames_to_encode) # -> B, C_lat, T_lat, H_lat, W_lat - - del frames_to_encode - torch.cuda.empty_cache() - + image_transforms = T.Compose([ + T.ToImage(), + T.Resize(size=(h, w), antialias=True), + T.ToDtype(torch.float32, scale=True), + T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ]) + + image_tensor = image_transforms(input_image).unsqueeze(0).to(device=tensor_kwargs["device"], dtype=target_dtype) + frames_to_encode = torch.cat([image_tensor.unsqueeze(2), torch.zeros(1, 3, F - 1, h, w, device=image_tensor.device, dtype=target_dtype)], dim=2) + + log.info(f"Encoding {F} frames...") + + try: + free_mem, _ = torch.cuda.mem_get_info() + if free_mem < 24 * (1024**3): + raise torch.OutOfMemoryError("Pre-emptive tiling") + with torch.amp.autocast("cuda", dtype=target_dtype): + encoded_latents = tokenizer.encode(frames_to_encode) + except torch.OutOfMemoryError: + try: + cleanup_memory("Switching to Tiled Encode") + encoded_latents = tiled_encode_4x(tokenizer, frames_to_encode, target_dtype) + except Exception as e: + log.warning(f"Tiling failed ({e}). Fallback to CPU.") + encoded_latents = safe_cpu_fallback_encode(tokenizer, frames_to_encode, target_dtype) + + print(f"โ VAE Encode Complete.") + del frames_to_encode + cleanup_memory("After VAE Encode") + + # Prepare for Diffusion msk = torch.zeros(1, 4, lat_t, lat_h, lat_w, device=tensor_kwargs["device"], dtype=tensor_kwargs["dtype"]) msk[:, :, 0, :, :] = 1.0 - y = torch.cat([msk, encoded_latents.to(**tensor_kwargs)], dim=1) y = y.repeat(args.num_samples, 1, 1, 1, 1) + saved_latent_ch = tokenizer.latent_ch + + del tokenizer + cleanup_memory("Unloaded VAE Model") + + # 3. Diffusion Sampling + print("-" * 20 + " DIT LOADING " + "-" * 20) + + current_model = load_dit_model(args, is_high_noise=True) + is_high_noise_active = True + fallback_triggered = args.offload_dit - log.info(f"Generating with prompt: {args.prompt}") condition = {"crossattn_emb": repeat(text_emb.to(**tensor_kwargs), "b l d -> (k b) l d", k=args.num_samples), "y_B_C_T_H_W": y} - - to_show = [] - - state_shape = [tokenizer.latent_ch, lat_t, lat_h, lat_w] - - generator = torch.Generator(device=tensor_kwargs["device"]) - generator.manual_seed(args.seed) - - init_noise = torch.randn( - args.num_samples, - *state_shape, - dtype=torch.float32, - device=tensor_kwargs["device"], - generator=generator, - ) - + + generator = torch.Generator(device=tensor_kwargs["device"]).manual_seed(args.seed) + init_noise = torch.randn(args.num_samples, saved_latent_ch, lat_t, lat_h, lat_w, dtype=torch.float32, device=tensor_kwargs["device"], generator=generator) + mid_t = [1.5, 1.4, 1.0][: args.num_steps - 1] - - t_steps = torch.tensor( - [math.atan(args.sigma_max), *mid_t, 0], - dtype=torch.float64, - device=init_noise.device, - ) - - # Convert TrigFlow timesteps to RectifiedFlow + t_steps = torch.tensor([math.atan(args.sigma_max), *mid_t, 0], dtype=torch.float64, device=init_noise.device) t_steps = torch.sin(t_steps) / (torch.cos(t_steps) + torch.sin(t_steps)) - x = init_noise.to(torch.float64) * t_steps[0] ones = torch.ones(x.size(0), 1, device=x.device, dtype=x.dtype) - total_steps = t_steps.shape[0] - 1 - high_noise_model.cuda() - net = high_noise_model - switched = False - for i, (t_cur, t_next) in enumerate(tqdm(list(zip(t_steps[:-1], t_steps[1:])), desc="Sampling", total=total_steps)): - if t_cur.item() < args.boundary and not switched: - high_noise_model.cpu() - torch.cuda.empty_cache() - low_noise_model.cuda() - net = low_noise_model - switched = True - log.info("Switched to low noise model.") - with torch.no_grad(): - v_pred = net(x_B_C_T_H_W=x.to(**tensor_kwargs), timesteps_B_T=(t_cur.float() * ones * 1000).to(**tensor_kwargs), **condition).to( - torch.float64 - ) - if args.ode: - x = x - (t_cur - t_next) * v_pred - else: - x = (1 - t_next) * (x - t_cur * v_pred) + t_next * torch.randn( - *x.shape, - dtype=torch.float32, - device=tensor_kwargs["device"], - generator=generator, - ) + + # Always ensure CUDA initially + current_model.cuda() + + print("-" * 20 + " SAMPLING START " + "-" * 20) + print_memory_status("High Noise Model to GPU") + + # Sampling Loop + for i, (t_cur, t_next) in enumerate(tqdm(list(zip(t_steps[:-1], t_steps[1:])), total=len(t_steps)-1)): + if t_cur.item() < args.boundary and is_high_noise_active: + print(f"\n๐ Switching DiT Models (Step {i})...") + current_model.cpu() + del current_model + cleanup_memory("Unloaded High Noise") + + current_model = load_dit_model(args, is_high_noise=False, force_offload=fallback_triggered) + current_model.cuda() # Force CUDA + is_high_noise_active = False + print_memory_status("Loaded Low Noise to GPU") + + step_success = False + while not step_success: + try: + gc.collect() + torch.cuda.empty_cache() + with torch.no_grad(): + v_pred = current_model( + x_B_C_T_H_W=x.to(**tensor_kwargs), + timesteps_B_T=(t_cur.float() * ones * 1000).to(**tensor_kwargs), + **condition + ).to(torch.float64) + step_success = True + except torch.OutOfMemoryError: + if fallback_triggered: + log.error("โ OOM occurred even after reload. Physical Memory Limit Reached.") + sys.exit(1) + + print(f"\nโ ๏ธ OOM in DiT Sampling Step {i}. Reloading model to clear fragmentation...") + cleanup_memory("Pre-Reload") + + # Unload and Reload to Defrag + was_high = is_high_noise_active + current_model.cpu() + del current_model + cleanup_memory("Unload for Reload") + + fallback_triggered = True + current_model = load_dit_model(args, is_high_noise=was_high, force_offload=True) + current_model.cuda() # Move back to GPU + + print("โป๏ธ Model Reloaded. Retrying step...") + + if args.ode: + x = x - (t_cur - t_next) * v_pred + else: + x = (1 - t_next) * (x - t_cur * v_pred) + t_next * torch.randn(*x.shape, dtype=torch.float32, device=tensor_kwargs["device"], generator=generator) + samples = x.float() - low_noise_model.cpu() - torch.cuda.empty_cache() - + + print("-" * 20 + " DECODE SETUP (DEFRAG) " + "-" * 20) + samples_cpu_backup = samples.cpu() + del samples + del x + current_model.cpu() + del current_model + cleanup_memory("FULL WIPE before VAE Load") + + log.info("Reloading VAE for decoding...") + tokenizer = Wan2pt1VAEInterface(vae_pth=args.vae_path) + print_memory_status("Reloaded VAE (Clean Slate)") + + samples = samples_cpu_backup.to(device=tensor_kwargs["device"]) + with torch.no_grad(): - video = tokenizer.decode(samples) - - to_show.append(video.float().cpu()) - + success = False + video = None + + try: + log.info("Attempting Standard GPU Decode...") + video = tokenizer.decode(samples) + success = True + except torch.OutOfMemoryError: + log.warning("โ ๏ธ GPU OOM (Standard). Switching to Tiled GPU Decode...") + cleanup_memory("Pre-Tile Fallback") + + try: + # 12 Latents overlap = 96 Image pixels + video = tiled_decode_gpu(tokenizer, samples, overlap=12) + success = True + except (torch.OutOfMemoryError, RuntimeError) as e: + log.warning(f"โ ๏ธ GPU Tiled Decode Failed ({e}). Switching to CPU Decode (Slow)...") + cleanup_memory("Pre-CPU Fallback") + + if not success: + log.info("Performing Hard Cast of VAE to CPU Float32...") + samples_cpu = samples.cpu().float() + force_cpu_float32(tokenizer) + with torch.autocast("cpu", enabled=False): + with torch.autocast("cuda", enabled=False): + video = tokenizer.decode(samples_cpu) + + to_show = [video.float().cpu()] to_show = (1.0 + torch.stack(to_show, dim=0).clamp(-1, 1)) / 2.0 - save_image_or_video(rearrange(to_show, "n b c t h w -> c t (n h) (b w)"), args.save_path, fps=16) + log.success("Done.") From 57a3e1a1822fe465ed85d2a97098cb40c6d1cab1 Mon Sep 17 00:00:00 2001 From: MrEdwards007 <116316872+MrEdwards007@users.noreply.github.com> Date: Thu, 1 Jan 2026 07:42:26 -0500 Subject: [PATCH 2/6] wan2.1_t2v_infer.py (Text-to-Video) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This script handles video generation directly from text prompts. It includes Intelligent Hardware Detection to automatically select the best precision (BF16/FP16) for your GPU and a three-tier memory recovery system (GPU โ Checkpointing โ CPU Offloading). It also enforces safety checks, such as disabling torch.compile for quantized models to prevent system crashes during the diffusion process. --- turbodiffusion/inference/wan2.1_t2v_infer.py | 386 ++++++++++++++----- 1 file changed, 300 insertions(+), 86 deletions(-) diff --git a/turbodiffusion/inference/wan2.1_t2v_infer.py b/turbodiffusion/inference/wan2.1_t2v_infer.py index c5581e4..8ecabf9 100644 --- a/turbodiffusion/inference/wan2.1_t2v_infer.py +++ b/turbodiffusion/inference/wan2.1_t2v_infer.py @@ -1,24 +1,58 @@ +# Blackwell Bridge # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at +# ----------------------------------------------------------------------------------------- +# TURBODIFFUSION OPTIMIZED INFERENCE SCRIPT (T2V) # -# http://www.apache.org/licenses/LICENSE-2.0 +# Co-developed by: Waverly Edwards & Google Gemini (2025) # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Modifications: +# - Implemented "Tiered Failover System" for robust OOM protection (GPU -> Checkpoint -> CPU). +# - Added Intelligent Hardware Detection (TF32/BF16/FP16 auto-switching). +# - Integrated Tiled Decoding for high-resolution VAE processing. +# - Added Support for Pre-cached Text Embeddings to skip T5 loading. +# - Optimized compilation logic for Quantized models (preventing graph breaks). +# +# Acknowledgments: +# - Made possible by the work (cache_t5.py) and creativity of: John D. Pope +# +# Description: +# cache_t5.py pre-computes text embeddings to allow running inference on GPUs with limited VRAM +# by removing the need to keep the 11GB T5 encoder loaded in memory. +# +# CREDIT REQUEST: +# If you utilize, share, or build upon this specific optimized script, please +# acknowledge Waverly Edwards and Google Gemini in your documentation or credits. +# ----------------------------------------------------------------------------------------- import argparse import math +import os +import gc +import time +import sys + +# --- 1. Memory Tuning --- +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" import torch +import torch.nn.functional as F from einops import rearrange, repeat from tqdm import tqdm +import numpy as np + +# --- 2. Hardware Optimization --- +if torch.cuda.is_available(): + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + # 'high' allows TF32 but maintains reasonable precision + torch.set_float32_matmul_precision('high') + +try: + import psutil +except ImportError: + psutil = None from imaginaire.utils.io import save_image_or_video from imaginaire.utils import log @@ -29,123 +63,303 @@ from modify_model import tensor_kwargs, create_model +# Suppress graph break warnings for cleaner output torch._dynamo.config.suppress_errors = True - +torch._dynamo.config.verbose = False def parse_arguments() -> argparse.Namespace: parser = argparse.ArgumentParser(description="TurboDiffusion inference script for Wan2.1 T2V") - parser.add_argument("--dit_path", type=str, required=True, help="Custom path to the DiT model checkpoint for distilled models") + parser.add_argument("--dit_path", type=str, required=True, help="Custom path to the DiT model checkpoint") parser.add_argument("--model", choices=["Wan2.1-1.3B", "Wan2.1-14B"], default="Wan2.1-1.3B", help="Model to use") parser.add_argument("--num_samples", type=int, default=1, help="Number of samples to generate") parser.add_argument("--num_steps", type=int, choices=[1, 2, 3, 4], default=4, help="1~4 for timestep-distilled inference") parser.add_argument("--sigma_max", type=float, default=80, help="Initial sigma for rCM") parser.add_argument("--vae_path", type=str, default="checkpoints/Wan2.1_VAE.pth", help="Path to the Wan2.1 VAE") parser.add_argument("--text_encoder_path", type=str, default="checkpoints/models_t5_umt5-xxl-enc-bf16.pth", help="Path to the umT5 text encoder") + parser.add_argument("--cached_embedding", type=str, default=None, help="Path to cached text embeddings (pt file)") + parser.add_argument("--skip_t5", action="store_true", help="Skip T5 loading (implied if cached_embedding is used)") parser.add_argument("--num_frames", type=int, default=81, help="Number of frames to generate") - parser.add_argument("--prompt", type=str, default=None, help="Text prompt for video generation (required unless --serve)") + parser.add_argument("--prompt", type=str, required=True, help="Text prompt for video generation") parser.add_argument("--resolution", default="480p", type=str, help="Resolution of the generated output") parser.add_argument("--aspect_ratio", default="16:9", type=str, help="Aspect ratio of the generated output (width:height)") parser.add_argument("--seed", type=int, default=0, help="Random seed for reproducibility") - parser.add_argument("--save_path", type=str, default="output/generated_video.mp4", help="Path to save the generated video (include file extension)") - parser.add_argument("--attention_type", choices=["sla", "sagesla", "original"], default="sagesla", help="Type of attention mechanism to use") + parser.add_argument("--save_path", type=str, default="output/generated_video.mp4", help="Path to save the generated video") + parser.add_argument("--attention_type", choices=["sla", "sagesla", "original"], default="sagesla", help="Type of attention mechanism") parser.add_argument("--sla_topk", type=float, default=0.1, help="Top-k ratio for SLA/SageSLA attention") parser.add_argument("--quant_linear", action="store_true", help="Whether to replace Linear layers with quantized versions") - parser.add_argument("--default_norm", action="store_true", help="Whether to replace LayerNorm/RMSNorm layers with faster versions") - parser.add_argument("--serve", action="store_true", help="Launch interactive TUI server mode (keeps model loaded)") + parser.add_argument("--default_norm", action="store_true", help="Whether to replace LayerNorm/RMSNorm with faster versions") + parser.add_argument("--offload_dit", action="store_true", help="Offload DiT to CPU when not in use to save VRAM") + parser.add_argument("--compile", action="store_true", help="Use torch.compile (Inductor) for faster inference") return parser.parse_args() +def check_hardware_compatibility(): + if not torch.cuda.is_available(): return + gpu_name = torch.cuda.get_device_name(0) + log.info(f"Hardware: {gpu_name}") + + current_dtype = tensor_kwargs.get("dtype") + if current_dtype == torch.bfloat16 and not torch.cuda.is_bf16_supported(): + log.warning(f"โ ๏ธ Device does not support BFloat16. Switching to Float16.") + tensor_kwargs["dtype"] = torch.float16 -if __name__ == "__main__": - args = parse_arguments() +def print_memory_status(step_name=""): + if not torch.cuda.is_available(): return + torch.cuda.synchronize() + allocated = torch.cuda.memory_allocated() / (1024**3) + free, total = torch.cuda.mem_get_info() + free_gb = free / (1024**3) + print(f"๐ [MEM] {step_name}: InUse={allocated:.2f}GB, Free={free_gb:.2f}GB") - # Handle serve mode - if args.serve: - # Set mode to t2v for the TUI server - args.mode = "t2v" - from serve.tui import main as serve_main - serve_main(args) - exit(0) +def cleanup_memory(step_info=""): + gc.collect() + torch.cuda.empty_cache() - # Validate prompt is provided for one-shot mode - if args.prompt is None: - log.error("--prompt is required (unless using --serve mode)") - exit(1) +def load_dit_model(args, force_offload=False): + orig_offload = args.offload_dit + if force_offload: args.offload_dit = True + log.info(f"Loading DiT (Offload={args.offload_dit})...") + model = create_model(dit_path=args.dit_path, args=args).cpu() + args.offload_dit = orig_offload + return model - log.info(f"Computing embedding for prompt: {args.prompt}") - with torch.no_grad(): - text_emb = get_umt5_embedding(checkpoint_path=args.text_encoder_path, prompts=args.prompt).to(**tensor_kwargs) - clear_umt5_memory() +def tiled_decode_gpu(tokenizer, latents, overlap=12): + print(f"\n๐งฑ Starting Tiled GPU Decode (Overlap={overlap})...") + B, C, T, H, W = latents.shape + scale = tokenizer.spatial_compression_factor + h_mid = H // 2 + w_mid = W // 2 + + def decode_tile(tile_latents): + cleanup_memory() + with torch.no_grad(): return tokenizer.decode(tile_latents).cpu() - log.info(f"Loading DiT model from {args.dit_path}") - net = create_model(dit_path=args.dit_path, args=args).cpu() - torch.cuda.empty_cache() - log.success("Successfully loaded DiT model.") + # 1. Top Tiles + l_tl = latents[..., :h_mid+overlap, :w_mid+overlap] + l_tr = latents[..., :h_mid+overlap, w_mid-overlap:] + v_tl = decode_tile(l_tl) + v_tr = decode_tile(l_tr) - tokenizer = Wan2pt1VAEInterface(vae_pth=args.vae_path) + B_dec, C_dec, T_dec, H_tile, W_tile = v_tl.shape + mid_pix = w_mid * scale + overlap_pix = overlap * scale + + row_top = torch.zeros(B_dec, 3, T_dec, H_tile, W*scale, dtype=v_tl.dtype, device='cpu') + end_left = max(0, mid_pix - overlap_pix) + start_right = mid_pix + overlap_pix + + row_top[..., :end_left] = v_tl[..., :end_left] + row_top[..., start_right:] = v_tr[..., 2*overlap_pix:] + + x_linspace = torch.linspace(-6, 6, 2*overlap_pix, device='cpu') + alpha = torch.sigmoid(x_linspace).view(1, 1, 1, 1, -1) + row_top[..., end_left:start_right] = v_tl[..., mid_pix-overlap_pix:] * (1 - alpha) + v_tr[..., :2*overlap_pix] * alpha + del v_tl, v_tr - w, h = VIDEO_RES_SIZE_INFO[args.resolution][args.aspect_ratio] + # 2. Bottom Tiles + l_bl = latents[..., h_mid-overlap:, :w_mid+overlap] + l_br = latents[..., h_mid-overlap:, w_mid-overlap:] + v_bl = decode_tile(l_bl) + v_br = decode_tile(l_br) + + row_bot = torch.zeros(B_dec, 3, T_dec, H_tile, W*scale, dtype=v_bl.dtype, device='cpu') + row_bot[..., :end_left] = v_bl[..., :end_left] + row_bot[..., start_right:] = v_br[..., 2*overlap_pix:] + row_bot[..., end_left:start_right] = v_bl[..., mid_pix-overlap_pix:] * (1 - alpha) + v_br[..., :2*overlap_pix] * alpha + del v_bl, v_br - log.info(f"Generating with prompt: {args.prompt}") - condition = {"crossattn_emb": repeat(text_emb.to(**tensor_kwargs), "b l d -> (k b) l d", k=args.num_samples)} + # 3. Blend Vertically + h_mid_pix = h_mid * scale + video = torch.zeros(B_dec, 3, T_dec, H*scale, W*scale, dtype=row_top.dtype, device='cpu') + end_top = max(0, h_mid_pix - overlap_pix) + start_bot = h_mid_pix + overlap_pix + + video[..., :end_top, :] = row_top[..., :end_top, :] + video[..., start_bot:, :] = row_bot[..., 2*overlap_pix:, :] + + alpha_v = torch.sigmoid(x_linspace).view(1, 1, 1, -1, 1) + video[..., end_top:start_bot, :] = row_top[..., h_mid_pix-overlap_pix:, :] * (1 - alpha_v) + row_bot[..., :2*overlap_pix, :] * alpha_v + + return video.to(latents.device) - to_show = [] +def force_cpu_float32(target_obj): + for module in target_obj.modules(): + module.cpu().float() - state_shape = [ - tokenizer.latent_ch, - tokenizer.get_latent_num_frames(args.num_frames), - h // tokenizer.spatial_compression_factor, - w // tokenizer.spatial_compression_factor, - ] +def apply_manual_offload(model, device="cuda"): + log.info("Applying Tier 3 Offload...") + block_list_name = None + max_len = 0 + for name, child in model.named_children(): + if isinstance(child, torch.nn.ModuleList): + if len(child) > max_len: + max_len = len(child) + block_list_name = name + + if not block_list_name: + log.warning("Could not identify Block List! Offloading entire model to CPU.") + model.to("cpu") + return - generator = torch.Generator(device=tensor_kwargs["device"]) - generator.manual_seed(args.seed) + print(f" ๐ Identified Transformer Blocks: '{block_list_name}' ({max_len} layers)") + try: model.to(device) + except RuntimeError: model.to("cpu") + + blocks = getattr(model, block_list_name) + blocks.to("cpu") + + def pre_hook(module, args): + module.to(device) + return args + def post_hook(module, args, output): + module.to("cpu") + return output + + for i, block in enumerate(blocks): + block.register_forward_pre_hook(pre_hook) + block.register_forward_hook(post_hook) - init_noise = torch.randn( - args.num_samples, - *state_shape, - dtype=torch.float32, - device=tensor_kwargs["device"], - generator=generator, - ) +if __name__ == "__main__": + print_memory_status("Script Start") + + # --- CREDIT PRINT --- + log.info("----------------------------------------------------------------") + log.info("๐ TurboDiffusion Optimized Inference") + log.info(" Co-developed by [Your Name/Handle] & Google Gemini") + log.info("----------------------------------------------------------------") - # mid_t = [1.3, 1.0, 0.6][: args.num_steps - 1] - # For better visual quality - mid_t = [1.5, 1.4, 1.0][: args.num_steps - 1] + check_hardware_compatibility() + args = parse_arguments() - t_steps = torch.tensor( - [math.atan(args.sigma_max), *mid_t, 0], - dtype=torch.float64, - device=init_noise.device, - ) + if (args.num_frames - 1) % 4 != 0: + new_f = ((args.num_frames - 1) // 4 + 1) * 4 + 1 + print(f"โ ๏ธ Adjusting --num_frames to {new_f}") + args.num_frames = new_f - # Convert TrigFlow timesteps to RectifiedFlow - t_steps = torch.sin(t_steps) / (torch.cos(t_steps) + torch.sin(t_steps)) + if args.num_frames > 90 and not args.offload_dit: + args.offload_dit = True + + # --- CRITICAL FIX: Strictly Disable Compile for Quantized Models --- + if args.compile and args.quant_linear: + log.warning("๐ซ Quantized Model Detected: FORCE DISABLING `torch.compile` to avoid OOM.") + log.warning(" (Custom quantized kernels are not compatible with CUDA Graphs)") + args.compile = False + + # 1. Text Embeddings + if args.cached_embedding and os.path.exists(args.cached_embedding): + log.info(f"Loading cache: {args.cached_embedding}") + c = torch.load(args.cached_embedding, map_location='cpu') + text_emb = c['embeddings'][0]['embedding'].to(**tensor_kwargs) if isinstance(c, dict) else c.to(**tensor_kwargs) + else: + log.info(f"Computing embedding...") + with torch.no_grad(): + text_emb = get_umt5_embedding(args.text_encoder_path, args.prompt).to(**tensor_kwargs) + clear_umt5_memory() + cleanup_memory() + + # 2. VAE Shape Calc & UNLOAD + log.info("VAE Setup (Temp)...") + tokenizer = Wan2pt1VAEInterface(vae_pth=args.vae_path) + w, h = VIDEO_RES_SIZE_INFO[args.resolution][args.aspect_ratio] + state_shape = [tokenizer.latent_ch, tokenizer.get_latent_num_frames(args.num_frames), h // tokenizer.spatial_compression_factor, w // tokenizer.spatial_compression_factor] + del tokenizer + cleanup_memory("VAE Unloaded") - # Sampling steps + # 3. Load DiT + net = load_dit_model(args) + + # 4. Noise & Schedule + gen = torch.Generator(device=tensor_kwargs["device"]).manual_seed(args.seed) + cond = {"crossattn_emb": repeat(text_emb.to(**tensor_kwargs), "b l d -> (k b) l d", k=args.num_samples)} + init_noise = torch.randn(args.num_samples, *state_shape, dtype=torch.float32, device=tensor_kwargs["device"], generator=gen) + + mid_t = [1.5, 1.4, 1.0][: args.num_steps - 1] + t_steps = torch.tensor([math.atan(args.sigma_max), *mid_t, 0], dtype=torch.float64, device=init_noise.device) + t_steps = torch.sin(t_steps) / (torch.cos(t_steps) + torch.sin(t_steps)) + x = init_noise.to(torch.float64) * t_steps[0] ones = torch.ones(x.size(0), 1, device=x.device, dtype=x.dtype) - total_steps = t_steps.shape[0] - 1 + + # 5. Fast Sampling Loop + log.info("๐ฅ STARTING SAMPLING (INFERENCE MODE) ๐ฅ") + torch.cuda.empty_cache() net.cuda() - for i, (t_cur, t_next) in enumerate(tqdm(list(zip(t_steps[:-1], t_steps[1:])), desc="Sampling", total=total_steps)): - with torch.no_grad(): - v_pred = net(x_B_C_T_H_W=x.to(**tensor_kwargs), timesteps_B_T=(t_cur.float() * ones * 1000).to(**tensor_kwargs), **condition).to( - torch.float64 - ) - x = (1 - t_next) * (x - t_cur * v_pred) + t_next * torch.randn( - *x.shape, - dtype=torch.float32, - device=tensor_kwargs["device"], - generator=generator, - ) + print_memory_status("Tier 1: GPU Ready") + + # Compile? (Only if NOT disabled above) + if args.compile: + log.info("๐ Compiling model...") + try: + net = torch.compile(net, mode="reduce-overhead") + except Exception as e: + log.warning(f"Compile failed: {e}. Running eager.") + + failover = 0 + + with torch.inference_mode(): + for i, (t_cur, t_next) in enumerate(tqdm(zip(t_steps[:-1], t_steps[1:]), total=len(t_steps)-1)): + retry = True + while retry: + try: + t_cur_scalar = t_cur.item() + t_next_scalar = t_next.item() + + v_pred = net( + x_B_C_T_H_W=x.to(**tensor_kwargs), + timesteps_B_T=(t_cur * ones * 1000).to(**tensor_kwargs), + **cond + ).to(torch.float64) + + if args.offload_dit and i == len(t_steps)-2 and failover == 0: + net.cpu() + + noise = torch.randn(*x.shape, dtype=torch.float32, device=x.device, generator=gen).to(torch.float64) + term1 = x - (v_pred * t_cur_scalar) + x = term1 * (1.0 - t_next_scalar) + (noise * t_next_scalar) + + retry = False + + except torch.OutOfMemoryError: + log.warning(f"โ ๏ธ OOM at Step {i}. Recovering...") + try: net.cpu() + except: pass + del net + cleanup_memory() + failover += 1 + + if failover == 1: + print("โป๏ธ Tier 2: Checkpointing") + net = load_dit_model(args, force_offload=True) + net.cuda() + # Retry compile with safer mode if first attempt was aggressive + if args.compile: + try: net = torch.compile(net, mode="default") + except: pass + elif failover == 2: + print("โป๏ธ Tier 3: Manual Offload") + net = load_dit_model(args, force_offload=True) + apply_manual_offload(net) + else: + sys.exit("โ Critical OOM.") + samples = x.float() - net.cpu() - torch.cuda.empty_cache() + # 6. Decode + if 'net' in locals(): + try: net.cpu() + except: pass + del net + cleanup_memory("Pre-VAE") + + log.info("Decoding...") + tokenizer = Wan2pt1VAEInterface(vae_pth=args.vae_path) with torch.no_grad(): - video = tokenizer.decode(samples) - - to_show.append(video.float().cpu()) + try: + video = tokenizer.decode(samples) + except torch.OutOfMemoryError: + log.warning("Falling back to Tiled Decode...") + video = tiled_decode_gpu(tokenizer, samples) + to_show = [video.float().cpu()] to_show = (1.0 + torch.stack(to_show, dim=0).clamp(-1, 1)) / 2.0 - save_image_or_video(rearrange(to_show, "n b c t h w -> c t (n h) (b w)"), args.save_path, fps=16) + log.success(f"Saved: {args.save_path}") From 5b5f5dd658d8a6901f468640c444f0023afa586d Mon Sep 17 00:00:00 2001 From: MrEdwards007 <116316872+MrEdwards007@users.noreply.github.com> Date: Thu, 1 Jan 2026 07:45:10 -0500 Subject: [PATCH 3/6] cache_t5.py (Memory Utility) A utility script designed by John D. Pope to pre-compute and save text embeddings to a file. By "caching" these embeddings, you can run the main inference scripts without loading the heavy 11GB T5 Text Encoder into VRAM. This is essential for running the large 14B models on GPUs with limited memory, effectively "skipping" the most memory-intensive part of the initialization. --- turbodiffusion/inference/cache_t5.py | 117 +++++++++++++++++++++++++++ 1 file changed, 117 insertions(+) create mode 100644 turbodiffusion/inference/cache_t5.py diff --git a/turbodiffusion/inference/cache_t5.py b/turbodiffusion/inference/cache_t5.py new file mode 100644 index 0000000..a25e3cf --- /dev/null +++ b/turbodiffusion/inference/cache_t5.py @@ -0,0 +1,117 @@ +#!/usr/bin/env python +# ----------------------------------------------------------------------------------------- +# T5 EMBEDDING CACHE UTILITY +# +# Acknowledgments: +# - Work and creativity of: John D. Pope +# +# Description: +# Pre-computes text embeddings to allow running inference on GPUs with limited VRAM +# by removing the need to keep the 11GB T5 encoder loaded in memory. +# ----------------------------------------------------------------------------------------- +""" +Pre-cache T5 text embeddings to avoid loading the 11GB model during inference. + +Usage: + # Cache a single prompt + python scripts/cache_t5.py --prompt "slow head turn, cinematic" --output cached_embeddings.pt + + # Cache multiple prompts from file + python scripts/cache_t5.py --prompts_file prompts.txt --output cached_embeddings.pt + +Then use with inference: + python turbodiffusion/inference/wan2.2_i2v_infer.py \ + --cached_embedding cached_embeddings.pt \ + --skip_t5 \ + ... +""" +import os +import sys +import argparse +import torch + +# Add repo root to path for imports +SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) +REPO_ROOT = os.path.dirname(SCRIPT_DIR) +sys.path.insert(0, REPO_ROOT) + +def main(): + parser = argparse.ArgumentParser(description="Pre-cache T5 text embeddings") + parser.add_argument("--prompt", type=str, default=None, help="Single prompt to cache") + parser.add_argument("--prompts_file", type=str, default=None, help="File with prompts (one per line)") + parser.add_argument("--text_encoder_path", type=str, + default="/media/2TB/ComfyUI/models/text_encoders/models_t5_umt5-xxl-enc-bf16.pth", + help="Path to the umT5 text encoder") + parser.add_argument("--output", type=str, default="cached_t5_embeddings.pt", + help="Output path for cached embeddings") + parser.add_argument("--device", type=str, default="cuda", + help="Device to use for encoding (cuda is faster, memory freed after)") + args = parser.parse_args() + + # Collect prompts + prompts = [] + if args.prompt: + prompts.append(args.prompt) + if args.prompts_file and os.path.exists(args.prompts_file): + with open(args.prompts_file, 'r') as f: + prompts.extend([line.strip() for line in f if line.strip()]) + + if not prompts: + print("Error: Provide --prompt or --prompts_file") + sys.exit(1) + + print(f"Caching embeddings for {len(prompts)} prompt(s)") + print(f"Text encoder: {args.text_encoder_path}") + print(f"Device: {args.device}") + print() + + # Import after path setup + from rcm.utils.umt5 import get_umt5_embedding, clear_umt5_memory + + cache_data = { + 'prompts': prompts, + 'embeddings': [], + 'text_encoder_path': args.text_encoder_path, + } + + with torch.no_grad(): + for i, prompt in enumerate(prompts): + print(f"[{i+1}/{len(prompts)}] Encoding: '{prompt[:60]}...' " if len(prompt) > 60 else f"[{i+1}/{len(prompts)}] Encoding: '{prompt}'") + + # Get embedding (loads T5 if not already loaded) + embedding = get_umt5_embedding( + checkpoint_path=args.text_encoder_path, + prompts=prompt + ) + + # Move to CPU for storage + cache_data['embeddings'].append({ + 'prompt': prompt, + 'embedding': embedding.cpu(), + 'shape': list(embedding.shape), + }) + + print(f" Shape: {embedding.shape}, dtype: {embedding.dtype}") + + # Clear T5 from memory + print("\nClearing T5 from memory...") + clear_umt5_memory() + torch.cuda.empty_cache() + + # Save cache + print(f"\nSaving to: {args.output}") + torch.save(cache_data, args.output) + + # Summary + file_size = os.path.getsize(args.output) / (1024 * 1024) + print(f"Done! Cache file size: {file_size:.2f} MB") + print() + print("Usage:") + print(f" python turbodiffusion/inference/wan2.2_i2v_infer.py \\") + print(f" --cached_embedding {args.output} \\") + print(f" --skip_t5 \\") + print(f" ... (other args)") + + +if __name__ == "__main__": + main() From 95aef856f42f46c5607ee19ab88647bc2140e677 Mon Sep 17 00:00:00 2001 From: MrEdwards007 <116316872+MrEdwards007@users.noreply.github.com> Date: Thu, 1 Jan 2026 07:53:48 -0500 Subject: [PATCH 4/6] Gradio-based Interface (Unified Web UI) A Gradio-based dashboard that centralizes Text-to-Video and Image-to-Video workflows. It automates environment setup, provides real-time VRAM monitoring, and enforces the 4n+1 frame rule for VAE stability. It also supports automatic T5 caching and saves reproduction metadata for every video generated. --- .../turbo_diffusion_t5_cache_optimize_v6.py | 343 ++++++++++++++++++ 1 file changed, 343 insertions(+) create mode 100644 turbodiffusion/inference/turbo_diffusion_t5_cache_optimize_v6.py diff --git a/turbodiffusion/inference/turbo_diffusion_t5_cache_optimize_v6.py b/turbodiffusion/inference/turbo_diffusion_t5_cache_optimize_v6.py new file mode 100644 index 0000000..0525f4b --- /dev/null +++ b/turbodiffusion/inference/turbo_diffusion_t5_cache_optimize_v6.py @@ -0,0 +1,343 @@ +import os +import sys +import subprocess +import gradio as gr +import glob +import random +import time +import select +import torch +from datetime import datetime + +# --- 1. System Setup --- +PROJECT_ROOT = "/home/wedwards/Documents/Programs/TurboDiffusion" +os.chdir(PROJECT_ROOT) +os.system('clear' if os.name == 'posix' else 'cls') + +CHECKPOINT_DIR = os.path.join(PROJECT_ROOT, "checkpoints") +OUTPUT_DIR = os.path.join(PROJECT_ROOT, "output") +os.makedirs(OUTPUT_DIR, exist_ok=True) + +T2V_SCRIPT = "turbodiffusion/inference/wan2.1_t2v_infer.py" +I2V_SCRIPT = "turbodiffusion/inference/wan2.2_i2v_infer.py" +CACHE_SCRIPT = "turbodiffusion/inference/cache_t5.py" + +def get_gpu_status_original(): + """System-level GPU check.""" + try: + res = subprocess.check_output( + ["nvidia-smi", "--query-gpu=name,memory.used,memory.total", "--format=csv,nounits,noheader"], + encoding='utf-8' + ).strip().split(',') + return f"๐ฅ๏ธ {res[0]} | โก VRAM: {res[1]}MB / {res[2]}MB" + except: + return "๐ฅ๏ธ GPU Monitor Active" + + +def get_gpu_status(): + """ + Check GPU status using PyTorch. + Returns system-wide VRAM usage without relying on nvidia-smi CLI. + """ + try: + # 1. Check for CUDA (NVIDIA) or ROCm (AMD) + if torch.cuda.is_available(): + # mem_get_info returns (free_bytes, total_bytes) + free_mem, total_mem = torch.cuda.mem_get_info() + + used_mem = total_mem - free_mem + + # Convert to MB for display + total_mb = int(total_mem / 1024**2) + used_mb = int(used_mem / 1024**2) + name = torch.cuda.get_device_name(0) + + return f"๐ฅ๏ธ {name} | โก VRAM: {used_mb}MB / {total_mb}MB" + + # 2. Check for Apple Silicon (MPS) + # Note: Apple uses Unified Memory, so 'VRAM' is shared with System RAM. + elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): + return "๐ฅ๏ธ Apple Silicon (MPS) | โก Unified Memory Active" + + # 3. Fallback to CPU + else: + return "๐ฅ๏ธ Running on CPU" + + except ImportError: + return "๐ฅ๏ธ GPU Monitor: PyTorch not installed" + except Exception as e: + return f"๐ฅ๏ธ GPU Monitor Error: {str(e)}" + +def save_debug_metadata(video_path, script_rel, cmd_list, cache_cmd_list=None): + """ + Saves a fully executable reproduction script with env vars. + """ + meta_path = video_path.replace(".mp4", "_metadata.txt") + with open(meta_path, "w") as f: + f.write(f"# Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n") + f.write("# Copy and paste the lines below to reproduce this video exactly:\n\n") + + # Environment Variables + f.write("export PYTHONPATH=turbodiffusion\n") + f.write("export PYTORCH_ALLOC_CONF=expandable_segments:True\n") + f.write("export TOKENIZERS_PARALLELISM=false\n\n") + + # Optional Cache Step + if cache_cmd_list: + f.write("# --- Step 1: Pre-Cache Embeddings ---\n") + f.write(f"python {CACHE_SCRIPT} \\\n") + c_args = cache_cmd_list[2:] + for i, arg in enumerate(c_args): + if arg.startswith("--"): + val = f'"{c_args[i+1]}"' if i+1 < len(c_args) and not c_args[i+1].startswith("--") else "" + f.write(f" {arg} {val} \\\n") + f.write("\n# --- Step 2: Run Inference ---\n") + + # Main Inference Command + f.write(f"python {script_rel} \\\n") + args_only = cmd_list[2:] + for i, arg in enumerate(args_only): + if arg.startswith("--"): + val = f'"{args_only[i+1]}"' if i+1 < len(args_only) and not args_only[i+1].startswith("--") else "" + f.write(f" {arg} {val} \\\n") + +def sync_path(scale): + fname = "TurboWan2.1-T2V-1.3B-480P-quant.pth" if "1.3B" in scale else "TurboWan2.1-T2V-14B-720P-quant.pth" + return os.path.join(CHECKPOINT_DIR, fname) + +# --- 2. Unified Generation Logic (With Safety Checks) --- + +def run_gen(mode, prompt, model, dit_path, i2v_high, i2v_low, image, res, ratio, steps, seed, quant, attn, top_k, frames, sigma, norm, adapt, ode, use_cache, cache_path, pr=gr.Progress()): + # --- PRE-FLIGHT SAFETY CHECK --- + error_msg = "" + if mode == "T2V": + if "quant" in dit_path.lower() and not quant: + error_msg = "โ CONFIG ERROR: Quantized model selected but '8-bit' disabled." + if attn == "original" and ("turbo" in dit_path.lower() or "quant" in dit_path.lower()): + error_msg = "โ COMPATIBILITY ERROR: 'Original' attention with Turbo/Quantized checkpoint." + else: + if ("quant" in i2v_high.lower() or "quant" in i2v_low.lower()) and not quant: + error_msg = "โ CONFIG ERROR: Quantized I2V model selected but '8-bit' disabled." + if attn == "original" and (("turbo" in i2v_high.lower() or "quant" in i2v_high.lower()) or ("turbo" in i2v_low.lower() or "quant" in i2v_low.lower())): + error_msg = "โ COMPATIBILITY ERROR: 'Original' attention with Turbo/Quantized checkpoints." + + if error_msg: + yield None, None, "โ Config Error", "๐ Aborted", error_msg + return + # ------------------------------- + + actual_seed = random.randint(1, 1000000) if seed <= 0 else int(seed) + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + start_time = time.time() + + full_log = f"๐ Starting Job: {timestamp}\n" + pr(0, desc="๐ Starting...") + + # --- FRAME SANITIZATION (4n+1 RULE) --- + # Wan2.1 VAE requires frames to be (4n + 1). If not, we sanitize. + target_frames = int(frames) + valid_frames = ((target_frames - 1) // 4) * 4 + 1 + + # If the user input (e.g., 32) became smaller (29) or changed, we log it. + # Note: We enforce a minimum of 1 frame just in case. + valid_frames = max(1, valid_frames) + + if valid_frames != target_frames: + warning_msg = f"โ ๏ธ AUTO-ADJUST: Frame count {target_frames} is incompatible with VAE (requires 4n+1).\n" + warning_msg += f" Adjusted {target_frames} -> {valid_frames} frames to prevent kernel crash.\n" + full_log += warning_msg + print(warning_msg) # Print to console as well + + # Use valid_frames for the rest of the logic + frames = valid_frames + # -------------------------------------- + + # --- AUTO-CACHE STEP --- + cache_cmd_list = None + if use_cache: + pr(0, desc="๐พ Auto-Caching T5 Embeddings...") + cache_script_full = os.path.join(PROJECT_ROOT, CACHE_SCRIPT) + encoder_path = os.path.join(CHECKPOINT_DIR, "models_t5_umt5-xxl-enc-bf16.pth") + + cache_cmd = [ + sys.executable, + cache_script_full, + "--prompt", prompt, + "--output", cache_path, + "--text_encoder_path", encoder_path + ] + cache_cmd_list = cache_cmd + + full_log += f"\n[System] Running Cache Script: {' '.join(cache_cmd)}\n" + yield None, None, f"Seed: {actual_seed}", "๐พ Caching...", full_log + + cache_process = subprocess.Popen(cache_cmd, cwd=PROJECT_ROOT, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, bufsize=1) + + while True: + if cache_process.poll() is not None: + rest = cache_process.stdout.read() + if rest: full_log += rest + break + line = cache_process.stdout.readline() + if line: + full_log += line + yield None, None, f"Seed: {actual_seed}", "๐พ Caching...", full_log + time.sleep(0.02) + + if cache_process.returncode != 0: + full_log += "\nโ CACHE FAILED. Aborting generation." + yield None, None, "โ Cache Failed", "๐ Aborted", full_log + return + + full_log += "\nโ Cache Complete. Starting Inference...\n" + # ----------------------------------------- + + # --- SETUP VIDEO GENERATION --- + if mode == "T2V": + save_path = os.path.join(OUTPUT_DIR, f"t2v_{timestamp}.mp4") + script_rel = T2V_SCRIPT + cmd = [sys.executable, os.path.join(PROJECT_ROOT, T2V_SCRIPT), "--model", model, "--dit_path", dit_path, "--prompt", prompt, "--resolution", res, "--aspect_ratio", ratio, "--num_steps", str(steps), "--seed", str(actual_seed), "--attention_type", attn, "--sla_topk", str(top_k), "--num_samples", "1", "--num_frames", str(frames), "--sigma_max", str(sigma)] + else: + save_path = os.path.join(OUTPUT_DIR, f"i2v_{timestamp}.mp4") + script_rel = I2V_SCRIPT + # Note: Added frames to I2V command in previous step, maintained here. + cmd = [sys.executable, os.path.join(PROJECT_ROOT, I2V_SCRIPT), "--prompt", prompt, "--image_path", image, "--high_noise_model_path", i2v_high, "--low_noise_model_path", i2v_low, "--resolution", res, "--aspect_ratio", ratio, "--num_steps", str(steps), "--seed", str(actual_seed), "--attention_type", attn, "--sla_topk", str(top_k), "--num_frames", str(frames)] + if adapt: cmd.append("--adaptive_resolution") + if ode: cmd.append("--ode") + + if quant: cmd.append("--quant_linear") + if norm: cmd.append("--default_norm") + + if use_cache: + cmd.extend(["--cached_embedding", cache_path, "--skip_t5"]) + + cmd.extend(["--save_path", save_path]) + + # Call the restored metadata saver + save_debug_metadata(save_path, script_rel, cmd, cache_cmd_list) + + env = os.environ.copy() + env["PYTHONPATH"] = os.path.join(PROJECT_ROOT, "turbodiffusion") + env["PYTORCH_ALLOC_CONF"] = "expandable_segments:True" + env["TOKENIZERS_PARALLELISM"] = "false" + env["PYTHONUNBUFFERED"] = "1" + + full_log += f"\n[System] Running Inference: {' '.join(cmd)}\n" + process = subprocess.Popen(cmd, cwd=PROJECT_ROOT, env=env, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, bufsize=1) + + last_ui_update = 0 + + while True: + if process.poll() is not None: + rest = process.stdout.read() + if rest: full_log += rest + break + + reads = [process.stdout.fileno()] + ret = select.select(reads, [], [], 0.1) + + if ret[0]: + line = process.stdout.readline() + full_log += line + + if "Loading DiT" in line: pr(0.1, desc="โก Loading weights...") + if "Encoding" in line: pr(0.05, desc="๐ผ๏ธ VAE Encoding...") + if "Switching to CPU" in line: pr(0.1, desc="โ ๏ธ CPU Fallback...") + if "Sampling:" in line: + try: + pct = int(line.split('%')[0].split('|')[-1].strip()) + pr(0.2 + (pct/100 * 0.7), desc=f"๐ฌ Sampling: {pct}%") + except: pass + if "decoding" in line.lower(): pr(0.95, desc="๐ฅ Decoding VAE...") + + current_time = time.time() + if current_time - last_ui_update > 0.25: + last_ui_update = current_time + elapsed = f"{int(current_time - start_time)}s" + yield None, None, f"Seed: {actual_seed}", f"โฑ๏ธ Time: {elapsed}", full_log + + history = sorted(glob.glob(os.path.join(OUTPUT_DIR, "*.mp4")), key=os.path.getmtime, reverse=True) + total_time = f"{int(time.time() - start_time)}s" + + yield save_path, history, f"โ Done | Seed: {actual_seed}", f"๐ Finished in {total_time}", full_log + +# --- 3. UI Layout --- +with gr.Blocks(title="TurboDiffusion Studio") as demo: + with gr.Row(): + gr.HTML("