From 1466fbb0b6362b954e03b393d26b1556c61aa6bd Mon Sep 17 00:00:00 2001 From: MrEdwards007 <116316872+MrEdwards007@users.noreply.github.com> Date: Thu, 1 Jan 2026 07:38:20 -0500 Subject: [PATCH 1/6] Optimize wan2.2_i2v_infer.py (Image-to-Video) This is an optimized inference script for generating videos from a single input image and a text prompt. It features a Tiered Failover System to prevent "Out of Memory" (OOM) errors on consumer GPUs by automatically switching between high and low-noise models and utilizing tiled encoding/decoding. It is designed to maximize performance on hardware like the RTX 5090 while maintaining stability for high-resolution (720p) outputs. --- turbodiffusion/inference/wan2.2_i2v_infer.py | 658 ++++++++++++++----- 1 file changed, 505 insertions(+), 153 deletions(-) diff --git a/turbodiffusion/inference/wan2.2_i2v_infer.py b/turbodiffusion/inference/wan2.2_i2v_infer.py index e57e509..f6fe32d 100644 --- a/turbodiffusion/inference/wan2.2_i2v_infer.py +++ b/turbodiffusion/inference/wan2.2_i2v_infer.py @@ -1,28 +1,55 @@ +# Blackwell Bridge # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at +# ----------------------------------------------------------------------------------------- +# TURBODIFFUSION OPTIMIZED INFERENCE SCRIPT (I2V) # -# http://www.apache.org/licenses/LICENSE-2.0 +# Co-developed by: Waverly Edwards & Google Gemini (2025) # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Modifications: +# - Implemented "Tiered Failover System" for robust OOM protection. +# - Added Intelligent Model Switching (High/Low Noise) with memory optimization. +# - Integrated Tiled Encoding & Decoding for high-resolution processing. +# - Added Support for Pre-cached Text Embeddings to skip T5 loading. +# - Optimized memory management (VAE unload/reload, aggressive GC). +# +# Acknowledgments: +# - Made possible by the work (cache_t5.py) and creativity of: John D. Pope +# +# Description: +# cache_t5.py pre-computes text embeddings to allow running inference on GPUs with limited VRAM +# by removing the need to keep the 11GB T5 encoder loaded in memory. +# +# CREDIT REQUEST: +# If you utilize, share, or build upon this specific optimized script, please +# acknowledge Waverly Edwards and Google Gemini in your documentation or credits. +# ----------------------------------------------------------------------------------------- import argparse import math +import os +import gc +import time +import sys + +# --- 1. Memory Tuning (Must be before torch imports) --- +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" import torch +import torch.nn.functional as F from einops import rearrange, repeat from tqdm import tqdm from PIL import Image import torchvision.transforms.v2 as T import numpy as np +# Safe import for system memory checks +try: + import psutil +except ImportError: + psutil = None + from imaginaire.utils.io import save_image_or_video from imaginaire.utils import log @@ -34,189 +61,514 @@ torch._dynamo.config.suppress_errors = True - def parse_arguments() -> argparse.Namespace: - parser = argparse.ArgumentParser(description="TurboDiffusion inference script for Wan2.2 I2V with High/Low Noise models") - parser.add_argument("--image_path", type=str, default=None, help="Path to the input image (required unless --serve)") - parser.add_argument("--high_noise_model_path", type=str, required=True, help="Path to the high-noise model") - parser.add_argument("--low_noise_model_path", type=str, required=True, help="Path to the low-noise model") - parser.add_argument("--boundary", type=float, default=0.9, help="Timestep boundary for switching from high to low noise model") - parser.add_argument("--model", choices=["Wan2.2-A14B"], default="Wan2.2-A14B", help="Model to use") - parser.add_argument("--num_samples", type=int, default=1, help="Number of samples to generate") - parser.add_argument("--num_steps", type=int, choices=[1, 2, 3, 4], default=4, help="1~4 for timestep-distilled inference") - parser.add_argument("--sigma_max", type=float, default=200, help="Initial sigma for rCM") - parser.add_argument("--vae_path", type=str, default="checkpoints/Wan2.1_VAE.pth", help="Path to the Wan2.1 VAE") - parser.add_argument("--text_encoder_path", type=str, default="checkpoints/models_t5_umt5-xxl-enc-bf16.pth", help="Path to the umT5 text encoder") - parser.add_argument("--num_frames", type=int, default=81, help="Number of frames to generate") - parser.add_argument("--prompt", type=str, default=None, help="Text prompt for video generation (required unless --serve)") - parser.add_argument("--resolution", default="720p", type=str, help="Resolution of the generated output") - parser.add_argument("--aspect_ratio", default="16:9", type=str, help="Aspect ratio of the generated output (width:height)") - parser.add_argument("--adaptive_resolution", action="store_true", help="If set, adapts the output resolution to the input image's aspect ratio, using the area defined by --resolution and --aspect_ratio as a target.") - parser.add_argument("--ode", action="store_true", help="Use ODE for sampling (sharper but less robust than SDE)") - parser.add_argument("--seed", type=int, default=0, help="Random seed for reproducibility") - parser.add_argument("--save_path", type=str, default="output/generated_video.mp4", help="Path to save the generated video (include file extension)") - parser.add_argument("--attention_type", choices=["sla", "sagesla", "original"], default="sagesla", help="Type of attention mechanism to use") - parser.add_argument("--sla_topk", type=float, default=0.1, help="Top-k ratio for SLA/SageSLA attention") - parser.add_argument("--quant_linear", action="store_true", help="Whether to replace Linear layers with quantized versions") - parser.add_argument("--default_norm", action="store_true", help="Whether to replace LayerNorm/RMSNorm layers with faster versions") - parser.add_argument("--serve", action="store_true", help="Launch interactive TUI server mode (keeps model loaded)") + parser = argparse.ArgumentParser(description="TurboDiffusion inference script for Wan2.2 I2V") + parser.add_argument("--image_path", type=str, default=None, help="Path to input image") + parser.add_argument("--high_noise_model_path", type=str, required=True, help="Path to high-noise model") + parser.add_argument("--low_noise_model_path", type=str, required=True, help="Path to low-noise model") + parser.add_argument("--boundary", type=float, default=0.9, help="Switch boundary") + parser.add_argument("--model", choices=["Wan2.2-A14B"], default="Wan2.2-A14B") + parser.add_argument("--num_samples", type=int, default=1) + parser.add_argument("--num_steps", type=int, default=4) + parser.add_argument("--sigma_max", type=float, default=200) + parser.add_argument("--vae_path", type=str, default="checkpoints/Wan2.1_VAE.pth") + parser.add_argument("--text_encoder_path", type=str, default="checkpoints/models_t5_umt5-xxl-enc-bf16.pth") + parser.add_argument("--cached_embedding", type=str, default=None) + parser.add_argument("--skip_t5", action="store_true") + parser.add_argument("--num_frames", type=int, default=81) + parser.add_argument("--prompt", type=str, default=None) + parser.add_argument("--resolution", default="720p", type=str) + parser.add_argument("--aspect_ratio", default="16:9", type=str) + parser.add_argument("--adaptive_resolution", action="store_true") + parser.add_argument("--ode", action="store_true") + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--save_path", type=str, default="output/generated_video.mp4") + parser.add_argument("--attention_type", choices=["sla", "sagesla", "original"], default="sagesla") + parser.add_argument("--sla_topk", type=float, default=0.1) + parser.add_argument("--quant_linear", action="store_true") + parser.add_argument("--default_norm", action="store_true") + parser.add_argument("--serve", action="store_true") + parser.add_argument("--offload_dit", action="store_true") return parser.parse_args() +def print_memory_status(step_name=""): + """ + Prints a detailed breakdown of GPU memory usage. + """ + if not torch.cuda.is_available(): + return + + torch.cuda.synchronize() + allocated_gb = torch.cuda.memory_allocated() / (1024**3) + reserved_gb = torch.cuda.memory_reserved() / (1024**3) + free_mem, total_mem = torch.cuda.mem_get_info() + free_gb = free_mem / (1024**3) + + print(f"\n๐Ÿ“Š [MEMORY REPORT] {step_name}") + print(f" โ”œโ”€โ”€ ๐Ÿ’พ VRAM In Use: {allocated_gb:.2f} GB") + print(f" โ”œโ”€โ”€ ๐Ÿ“ฆ VRAM Reserved: {reserved_gb:.2f} GB") + print(f" โ”œโ”€โ”€ ๐Ÿ†“ VRAM Free: {free_gb:.2f} GB") + print("-" * 60) + +def cleanup_memory(step_info=""): + """Aggressively clears VRAM.""" + if step_info: + print(f"๐Ÿงน Cleaning memory: {step_info}...") + gc.collect() + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + if step_info: + print_memory_status(f"After Cleanup ({step_info})") + +def get_tensor_size_mb(tensor): + return tensor.element_size() * tensor.nelement() / (1024 * 1024) + +def force_cpu_float32(target_obj): + """ + Recursively forces a model or wrapper object to CPU Float32. + """ + def recursive_cast_to_cpu(obj): + if isinstance(obj, torch.Tensor): + return obj.cpu().float() + elif isinstance(obj, (list, tuple)): + return type(obj)(recursive_cast_to_cpu(x) for x in obj) + elif isinstance(obj, dict): + return {k: recursive_cast_to_cpu(v) for k, v in obj.items()} + return obj + + targets = [target_obj] + if hasattr(target_obj, "model"): + targets.append(target_obj.model) + + for obj in targets: + if isinstance(obj, torch.nn.Module): + try: obj.cpu().float() + except: pass + + for attr_name in dir(obj): + if attr_name.startswith("__"): continue + try: + val = getattr(obj, attr_name) + if isinstance(val, torch.nn.Module): + val.cpu().float() + elif isinstance(val, (torch.Tensor, list, tuple)): + setattr(obj, attr_name, recursive_cast_to_cpu(val)) + except Exception: pass + + if isinstance(obj, torch.nn.Module): + try: + for module in obj.modules(): + for param in module.parameters(recurse=False): + if param is not None: + param.data = param.data.cpu().float() + if param.grad is not None: param.grad = None + for buf in module.buffers(recurse=False): + if buf is not None: + buf.data = buf.data.cpu().float() + except: pass + else: + for attr_name in dir(obj): + if attr_name.startswith("__"): continue + try: + val = getattr(obj, attr_name) + if isinstance(val, torch.nn.Module): + for module in val.modules(): + for param in module.parameters(recurse=False): + if param is not None: + param.data = param.data.cpu().float() + for buf in module.buffers(recurse=False): + if buf is not None: + buf.data = buf.data.cpu().float() + except: pass + +def tiled_encode_4x(tokenizer, frames, target_dtype): + B, C, T, H, W = frames.shape + h_mid = H // 2 + w_mid = W // 2 + print(f"\n๐Ÿงฉ Starting 4-Chunk Tiled Encoding (Input: {W}x{H})") + latents_list = [[None, None], [None, None]] + + try: + print(" ๐Ÿ‘‰ Encoding Chunk 1/4 (Top-Left)...") + with torch.amp.autocast("cuda", dtype=target_dtype): + l_tl = tokenizer.encode(frames[:, :, :, :h_mid, :w_mid]) + latents_list[0][0] = l_tl.cpu() + del l_tl; cleanup_memory("After Chunk 1") + except Exception as e: + print(f"โŒ Chunk 1 Failed: {e}") + raise e + + try: + print(" ๐Ÿ‘‰ Encoding Chunk 2/4 (Top-Right)...") + with torch.amp.autocast("cuda", dtype=target_dtype): + l_tr = tokenizer.encode(frames[:, :, :, :h_mid, w_mid:]) + latents_list[0][1] = l_tr.cpu() + del l_tr; cleanup_memory("After Chunk 2") + except Exception as e: + print(f"โŒ Chunk 2 Failed: {e}") + raise e + + try: + print(" ๐Ÿ‘‰ Encoding Chunk 3/4 (Bottom-Left)...") + with torch.amp.autocast("cuda", dtype=target_dtype): + l_bl = tokenizer.encode(frames[:, :, :, h_mid:, :w_mid]) + latents_list[1][0] = l_bl.cpu() + del l_bl; cleanup_memory("After Chunk 3") + except Exception as e: + print(f"โŒ Chunk 3 Failed: {e}") + raise e + + try: + print(" ๐Ÿ‘‰ Encoding Chunk 4/4 (Bottom-Right)...") + with torch.amp.autocast("cuda", dtype=target_dtype): + l_br = tokenizer.encode(frames[:, :, :, h_mid:, w_mid:]) + latents_list[1][1] = l_br.cpu() + del l_br; cleanup_memory("After Chunk 4") + except Exception as e: + print(f"โŒ Chunk 4 Failed: {e}") + raise e + + print(" ๐Ÿงต Stitching Latents...") + row1 = torch.cat([latents_list[0][0], latents_list[0][1]], dim=4) + row2 = torch.cat([latents_list[1][0], latents_list[1][1]], dim=4) + full_latents = torch.cat([row1, row2], dim=3) + return full_latents.to(device=tensor_kwargs["device"], dtype=target_dtype) + +def safe_cpu_fallback_encode(tokenizer, frames, target_dtype): + log.warning("๐Ÿ”„ Switching to CPU for VAE Encode (Slow but reliable)...") + cleanup_memory("Pre-CPU Encode") + frames_cpu = frames.cpu().to(dtype=torch.float32) + force_cpu_float32(tokenizer) + t0 = time.time() + with torch.autocast("cpu", enabled=False): + with torch.autocast("cuda", enabled=False): + latents = tokenizer.encode(frames_cpu) + print(f" โฑ๏ธ CPU Encode took: {time.time() - t0:.2f}s") + return latents.to(device=tensor_kwargs["device"], dtype=target_dtype) + +def tiled_decode_gpu(tokenizer, latents, overlap=12): + """ + Decodes latents in 4 spatial quadrants with OVERLAP and SIGMOID BLENDING. + Overlap=12 latents (96 pixels). Safe for 720p. + Removing Global Color Matching to prevent exposure shifts. + """ + print(f"\n๐Ÿงฑ Starting Tiled GPU Decode (4 Quadrants, Overlap={overlap}, Blended)...") + B, C, T, H, W = latents.shape + scale = tokenizer.spatial_compression_factor + h_mid = H // 2 + w_mid = W // 2 + + def decode_tile(tile_latents, name): + cleanup_memory(f"Tile {name}") + with torch.no_grad(): + return tokenizer.decode(tile_latents).cpu() + + try: + # 1. Decode Top-Left and Top-Right + l_tl = latents[..., :h_mid+overlap, :w_mid+overlap] + l_tr = latents[..., :h_mid+overlap, w_mid-overlap:] + v_tl = decode_tile(l_tl, "1/4 (TL)") + v_tr = decode_tile(l_tr, "2/4 (TR)") + B_dec, C_dec, T_dec, H_tile, W_tile = v_tl.shape + + print(f" ๐Ÿงต Blending Top Row (Decoded Frames: {T_dec})...") + mid_pix = w_mid * scale + overlap_pix = overlap * scale + + # Slices for overlap + tl_blend_slice = v_tl[..., mid_pix-overlap_pix:] + tr_blend_slice = v_tr[..., :2*overlap_pix] + + row_top = torch.zeros(B_dec, 3, T_dec, H_tile, W*scale, dtype=v_tl.dtype, device='cpu') + + # Place non-overlapping parts (Clamped indices) + end_left = max(0, mid_pix - overlap_pix) + start_right = mid_pix + overlap_pix + + row_top[..., :end_left] = v_tl[..., :end_left] + row_top[..., start_right:] = v_tr[..., 2*overlap_pix:] + + x = torch.linspace(-6, 6, 2*overlap_pix, device='cpu') + alpha = torch.sigmoid(x).view(1, 1, 1, 1, -1) + blended_h = tl_blend_slice * (1 - alpha) + tr_blend_slice * alpha + + row_top[..., end_left:start_right] = blended_h + del v_tl, v_tr, l_tl, l_tr + + # 3. Decode Bottom-Left and Bottom-Right + l_bl = latents[..., h_mid-overlap:, :w_mid+overlap] + l_br = latents[..., h_mid-overlap:, w_mid-overlap:] + v_bl = decode_tile(l_bl, "3/4 (BL)") + v_br = decode_tile(l_br, "4/4 (BR)") + + print(" ๐Ÿงต Blending Bottom Row...") + bl_blend_slice = v_bl[..., mid_pix-overlap_pix:] + br_blend_slice = v_br[..., :2*overlap_pix] + + row_bot = torch.zeros(B_dec, 3, T_dec, H_tile, W*scale, dtype=v_bl.dtype, device='cpu') + row_bot[..., :end_left] = v_bl[..., :end_left] + row_bot[..., start_right:] = v_br[..., 2*overlap_pix:] + row_bot[..., end_left:start_right] = bl_blend_slice * (1 - alpha) + br_blend_slice * alpha + del v_bl, v_br, l_bl, l_br + + # 5. Blend Top and Bottom Vertically + print(" ๐Ÿงต Blending Rows Vertically...") + h_mid_pix = h_mid * scale + + # Slices + top_blend_slice = row_top[..., h_mid_pix-overlap_pix:, :] + bot_blend_slice = row_bot[..., :2*overlap_pix, :] + + video = torch.zeros(B_dec, 3, T_dec, H*scale, W*scale, dtype=row_top.dtype, device='cpu') + + end_top = max(0, h_mid_pix - overlap_pix) + start_bot = h_mid_pix + overlap_pix + + video[..., :end_top, :] = row_top[..., :end_top, :] + video[..., start_bot:, :] = row_bot[..., 2*overlap_pix:, :] + + alpha_v = torch.sigmoid(x).view(1, 1, 1, -1, 1) + blended_v = top_blend_slice * (1 - alpha_v) + bot_blend_slice * alpha_v + + video[..., end_top:start_bot, :] = blended_v + + except Exception as e: + print(f"โŒ Tiled GPU Decode Failed: {e}") + raise e + return video.to(latents.device) + +def load_dit_model(args, is_high_noise=True, force_offload=False): + """Helper to load the model, respecting overrides.""" + original_offload = args.offload_dit + if force_offload: + args.offload_dit = True + + path = args.high_noise_model_path if is_high_noise else args.low_noise_model_path + log.info(f"Loading {'High' if is_high_noise else 'Low'} Noise DiT (Offload={args.offload_dit})...") + + try: + model = create_model(dit_path=path, args=args).cpu() + finally: + args.offload_dit = original_offload + + return model if __name__ == "__main__": + print_memory_status("Script Start") args = parse_arguments() - # Handle serve mode if args.serve: - # Set mode to i2v for the TUI server args.mode = "i2v" from serve.tui import main as serve_main serve_main(args) exit(0) - - # Validate required args for one-shot mode - if args.prompt is None: - log.error("--prompt is required (unless using --serve mode)") - exit(1) - if args.image_path is None: - log.error("--image_path is required (unless using --serve mode)") - exit(1) - - log.info(f"Computing embedding for prompt: {args.prompt}") - with torch.no_grad(): + + # --- AUTO-ADJUST FRAME COUNT --- + if (args.num_frames - 1) % 4 != 0: + old_f = args.num_frames + new_f = ((old_f - 1) // 4 + 1) * 4 + 1 + print(f"โš ๏ธ Adjusting --num_frames from {old_f} to {new_f} to satisfy VAE temporal stride (4n+1).") + args.num_frames = new_f + + # --- AUTO-ENABLE OFFLOAD FOR HIGH FRAMES --- + if args.num_frames > 90 and not args.offload_dit: + print(f"โš ๏ธ High frame count ({args.num_frames}) detected. Enabling --offload_dit to prevent OOM.") + args.offload_dit = True + + # 1. Text Embeddings + if args.cached_embedding and os.path.exists(args.cached_embedding): + log.info(f"Loading cached embedding from: {args.cached_embedding}") + cache_data = torch.load(args.cached_embedding, map_location='cpu') + text_emb = cache_data['embeddings'][0]['embedding'].to(**tensor_kwargs) + else: + log.info(f"Computing embedding...") text_emb = get_umt5_embedding(checkpoint_path=args.text_encoder_path, prompts=args.prompt).to(**tensor_kwargs) - clear_umt5_memory() - - log.info(f"Loading DiT models.") - high_noise_model = create_model(dit_path=args.high_noise_model_path, args=args).cpu() - torch.cuda.empty_cache() - low_noise_model = create_model(dit_path=args.low_noise_model_path, args=args).cpu() - torch.cuda.empty_cache() - log.success(f"Successfully loaded DiT model.") + clear_umt5_memory() + # 2. VAE Encoding + print("-" * 20 + " VAE SETUP " + "-" * 20) tokenizer = Wan2pt1VAEInterface(vae_pth=args.vae_path) - - log.info(f"Loading and preprocessing image from: {args.image_path}") + target_dtype = tensor_kwargs.get("dtype", torch.bfloat16) input_image = Image.open(args.image_path).convert("RGB") + if args.adaptive_resolution: - log.info("Adaptive resolution mode enabled.") base_w, base_h = VIDEO_RES_SIZE_INFO[args.resolution][args.aspect_ratio] max_resolution_area = base_w * base_h - log.info(f"Target area is based on {args.resolution} {args.aspect_ratio} (~{max_resolution_area} pixels).") - orig_w, orig_h = input_image.size - image_aspect_ratio = orig_h / orig_w - - ideal_w = np.sqrt(max_resolution_area / image_aspect_ratio) - ideal_h = np.sqrt(max_resolution_area * image_aspect_ratio) - + aspect = orig_h / orig_w + ideal_w = np.sqrt(max_resolution_area / aspect) + ideal_h = np.sqrt(max_resolution_area * aspect) stride = tokenizer.spatial_compression_factor * 2 - lat_h = round(ideal_h / stride) - lat_w = round(ideal_w / stride) - h = lat_h * stride - w = lat_w * stride - - log.info(f"Input image aspect ratio: {image_aspect_ratio:.4f}. Adaptive resolution set to: {w}x{h}") + h = round(ideal_h / stride) * stride + w = round(ideal_w / stride) * stride + log.info(f"Adaptive Res: {w}x{h}") else: - log.info("Fixed resolution mode.") w, h = VIDEO_RES_SIZE_INFO[args.resolution][args.aspect_ratio] - log.info(f"Resolution set to: {w}x{h}") + F = args.num_frames lat_h = h // tokenizer.spatial_compression_factor lat_w = w // tokenizer.spatial_compression_factor lat_t = tokenizer.get_latent_num_frames(F) - log.info(f"Preprocessing image to {w}x{h}...") - image_transforms = T.Compose( - [ - T.ToImage(), - T.Resize(size=(h, w), antialias=True), - T.ToDtype(torch.float32, scale=True), - T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), - ] - ) - image_tensor = image_transforms(input_image).unsqueeze(0).to(device=tensor_kwargs["device"], dtype=torch.float32) - - with torch.no_grad(): - frames_to_encode = torch.cat( - [image_tensor.unsqueeze(2), torch.zeros(1, 3, F - 1, h, w, device=image_tensor.device)], dim=2 - ) # -> B, C, T, H, W - encoded_latents = tokenizer.encode(frames_to_encode) # -> B, C_lat, T_lat, H_lat, W_lat - - del frames_to_encode - torch.cuda.empty_cache() - + image_transforms = T.Compose([ + T.ToImage(), + T.Resize(size=(h, w), antialias=True), + T.ToDtype(torch.float32, scale=True), + T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ]) + + image_tensor = image_transforms(input_image).unsqueeze(0).to(device=tensor_kwargs["device"], dtype=target_dtype) + frames_to_encode = torch.cat([image_tensor.unsqueeze(2), torch.zeros(1, 3, F - 1, h, w, device=image_tensor.device, dtype=target_dtype)], dim=2) + + log.info(f"Encoding {F} frames...") + + try: + free_mem, _ = torch.cuda.mem_get_info() + if free_mem < 24 * (1024**3): + raise torch.OutOfMemoryError("Pre-emptive tiling") + with torch.amp.autocast("cuda", dtype=target_dtype): + encoded_latents = tokenizer.encode(frames_to_encode) + except torch.OutOfMemoryError: + try: + cleanup_memory("Switching to Tiled Encode") + encoded_latents = tiled_encode_4x(tokenizer, frames_to_encode, target_dtype) + except Exception as e: + log.warning(f"Tiling failed ({e}). Fallback to CPU.") + encoded_latents = safe_cpu_fallback_encode(tokenizer, frames_to_encode, target_dtype) + + print(f"โœ… VAE Encode Complete.") + del frames_to_encode + cleanup_memory("After VAE Encode") + + # Prepare for Diffusion msk = torch.zeros(1, 4, lat_t, lat_h, lat_w, device=tensor_kwargs["device"], dtype=tensor_kwargs["dtype"]) msk[:, :, 0, :, :] = 1.0 - y = torch.cat([msk, encoded_latents.to(**tensor_kwargs)], dim=1) y = y.repeat(args.num_samples, 1, 1, 1, 1) + saved_latent_ch = tokenizer.latent_ch + + del tokenizer + cleanup_memory("Unloaded VAE Model") + + # 3. Diffusion Sampling + print("-" * 20 + " DIT LOADING " + "-" * 20) + + current_model = load_dit_model(args, is_high_noise=True) + is_high_noise_active = True + fallback_triggered = args.offload_dit - log.info(f"Generating with prompt: {args.prompt}") condition = {"crossattn_emb": repeat(text_emb.to(**tensor_kwargs), "b l d -> (k b) l d", k=args.num_samples), "y_B_C_T_H_W": y} - - to_show = [] - - state_shape = [tokenizer.latent_ch, lat_t, lat_h, lat_w] - - generator = torch.Generator(device=tensor_kwargs["device"]) - generator.manual_seed(args.seed) - - init_noise = torch.randn( - args.num_samples, - *state_shape, - dtype=torch.float32, - device=tensor_kwargs["device"], - generator=generator, - ) - + + generator = torch.Generator(device=tensor_kwargs["device"]).manual_seed(args.seed) + init_noise = torch.randn(args.num_samples, saved_latent_ch, lat_t, lat_h, lat_w, dtype=torch.float32, device=tensor_kwargs["device"], generator=generator) + mid_t = [1.5, 1.4, 1.0][: args.num_steps - 1] - - t_steps = torch.tensor( - [math.atan(args.sigma_max), *mid_t, 0], - dtype=torch.float64, - device=init_noise.device, - ) - - # Convert TrigFlow timesteps to RectifiedFlow + t_steps = torch.tensor([math.atan(args.sigma_max), *mid_t, 0], dtype=torch.float64, device=init_noise.device) t_steps = torch.sin(t_steps) / (torch.cos(t_steps) + torch.sin(t_steps)) - x = init_noise.to(torch.float64) * t_steps[0] ones = torch.ones(x.size(0), 1, device=x.device, dtype=x.dtype) - total_steps = t_steps.shape[0] - 1 - high_noise_model.cuda() - net = high_noise_model - switched = False - for i, (t_cur, t_next) in enumerate(tqdm(list(zip(t_steps[:-1], t_steps[1:])), desc="Sampling", total=total_steps)): - if t_cur.item() < args.boundary and not switched: - high_noise_model.cpu() - torch.cuda.empty_cache() - low_noise_model.cuda() - net = low_noise_model - switched = True - log.info("Switched to low noise model.") - with torch.no_grad(): - v_pred = net(x_B_C_T_H_W=x.to(**tensor_kwargs), timesteps_B_T=(t_cur.float() * ones * 1000).to(**tensor_kwargs), **condition).to( - torch.float64 - ) - if args.ode: - x = x - (t_cur - t_next) * v_pred - else: - x = (1 - t_next) * (x - t_cur * v_pred) + t_next * torch.randn( - *x.shape, - dtype=torch.float32, - device=tensor_kwargs["device"], - generator=generator, - ) + + # Always ensure CUDA initially + current_model.cuda() + + print("-" * 20 + " SAMPLING START " + "-" * 20) + print_memory_status("High Noise Model to GPU") + + # Sampling Loop + for i, (t_cur, t_next) in enumerate(tqdm(list(zip(t_steps[:-1], t_steps[1:])), total=len(t_steps)-1)): + if t_cur.item() < args.boundary and is_high_noise_active: + print(f"\n๐Ÿ”„ Switching DiT Models (Step {i})...") + current_model.cpu() + del current_model + cleanup_memory("Unloaded High Noise") + + current_model = load_dit_model(args, is_high_noise=False, force_offload=fallback_triggered) + current_model.cuda() # Force CUDA + is_high_noise_active = False + print_memory_status("Loaded Low Noise to GPU") + + step_success = False + while not step_success: + try: + gc.collect() + torch.cuda.empty_cache() + with torch.no_grad(): + v_pred = current_model( + x_B_C_T_H_W=x.to(**tensor_kwargs), + timesteps_B_T=(t_cur.float() * ones * 1000).to(**tensor_kwargs), + **condition + ).to(torch.float64) + step_success = True + except torch.OutOfMemoryError: + if fallback_triggered: + log.error("โŒ OOM occurred even after reload. Physical Memory Limit Reached.") + sys.exit(1) + + print(f"\nโš ๏ธ OOM in DiT Sampling Step {i}. Reloading model to clear fragmentation...") + cleanup_memory("Pre-Reload") + + # Unload and Reload to Defrag + was_high = is_high_noise_active + current_model.cpu() + del current_model + cleanup_memory("Unload for Reload") + + fallback_triggered = True + current_model = load_dit_model(args, is_high_noise=was_high, force_offload=True) + current_model.cuda() # Move back to GPU + + print("โ™ป๏ธ Model Reloaded. Retrying step...") + + if args.ode: + x = x - (t_cur - t_next) * v_pred + else: + x = (1 - t_next) * (x - t_cur * v_pred) + t_next * torch.randn(*x.shape, dtype=torch.float32, device=tensor_kwargs["device"], generator=generator) + samples = x.float() - low_noise_model.cpu() - torch.cuda.empty_cache() - + + print("-" * 20 + " DECODE SETUP (DEFRAG) " + "-" * 20) + samples_cpu_backup = samples.cpu() + del samples + del x + current_model.cpu() + del current_model + cleanup_memory("FULL WIPE before VAE Load") + + log.info("Reloading VAE for decoding...") + tokenizer = Wan2pt1VAEInterface(vae_pth=args.vae_path) + print_memory_status("Reloaded VAE (Clean Slate)") + + samples = samples_cpu_backup.to(device=tensor_kwargs["device"]) + with torch.no_grad(): - video = tokenizer.decode(samples) - - to_show.append(video.float().cpu()) - + success = False + video = None + + try: + log.info("Attempting Standard GPU Decode...") + video = tokenizer.decode(samples) + success = True + except torch.OutOfMemoryError: + log.warning("โš ๏ธ GPU OOM (Standard). Switching to Tiled GPU Decode...") + cleanup_memory("Pre-Tile Fallback") + + try: + # 12 Latents overlap = 96 Image pixels + video = tiled_decode_gpu(tokenizer, samples, overlap=12) + success = True + except (torch.OutOfMemoryError, RuntimeError) as e: + log.warning(f"โš ๏ธ GPU Tiled Decode Failed ({e}). Switching to CPU Decode (Slow)...") + cleanup_memory("Pre-CPU Fallback") + + if not success: + log.info("Performing Hard Cast of VAE to CPU Float32...") + samples_cpu = samples.cpu().float() + force_cpu_float32(tokenizer) + with torch.autocast("cpu", enabled=False): + with torch.autocast("cuda", enabled=False): + video = tokenizer.decode(samples_cpu) + + to_show = [video.float().cpu()] to_show = (1.0 + torch.stack(to_show, dim=0).clamp(-1, 1)) / 2.0 - save_image_or_video(rearrange(to_show, "n b c t h w -> c t (n h) (b w)"), args.save_path, fps=16) + log.success("Done.") From 57a3e1a1822fe465ed85d2a97098cb40c6d1cab1 Mon Sep 17 00:00:00 2001 From: MrEdwards007 <116316872+MrEdwards007@users.noreply.github.com> Date: Thu, 1 Jan 2026 07:42:26 -0500 Subject: [PATCH 2/6] wan2.1_t2v_infer.py (Text-to-Video) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This script handles video generation directly from text prompts. It includes Intelligent Hardware Detection to automatically select the best precision (BF16/FP16) for your GPU and a three-tier memory recovery system (GPU โ†’ Checkpointing โ†’ CPU Offloading). It also enforces safety checks, such as disabling torch.compile for quantized models to prevent system crashes during the diffusion process. --- turbodiffusion/inference/wan2.1_t2v_infer.py | 386 ++++++++++++++----- 1 file changed, 300 insertions(+), 86 deletions(-) diff --git a/turbodiffusion/inference/wan2.1_t2v_infer.py b/turbodiffusion/inference/wan2.1_t2v_infer.py index c5581e4..8ecabf9 100644 --- a/turbodiffusion/inference/wan2.1_t2v_infer.py +++ b/turbodiffusion/inference/wan2.1_t2v_infer.py @@ -1,24 +1,58 @@ +# Blackwell Bridge # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at +# ----------------------------------------------------------------------------------------- +# TURBODIFFUSION OPTIMIZED INFERENCE SCRIPT (T2V) # -# http://www.apache.org/licenses/LICENSE-2.0 +# Co-developed by: Waverly Edwards & Google Gemini (2025) # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Modifications: +# - Implemented "Tiered Failover System" for robust OOM protection (GPU -> Checkpoint -> CPU). +# - Added Intelligent Hardware Detection (TF32/BF16/FP16 auto-switching). +# - Integrated Tiled Decoding for high-resolution VAE processing. +# - Added Support for Pre-cached Text Embeddings to skip T5 loading. +# - Optimized compilation logic for Quantized models (preventing graph breaks). +# +# Acknowledgments: +# - Made possible by the work (cache_t5.py) and creativity of: John D. Pope +# +# Description: +# cache_t5.py pre-computes text embeddings to allow running inference on GPUs with limited VRAM +# by removing the need to keep the 11GB T5 encoder loaded in memory. +# +# CREDIT REQUEST: +# If you utilize, share, or build upon this specific optimized script, please +# acknowledge Waverly Edwards and Google Gemini in your documentation or credits. +# ----------------------------------------------------------------------------------------- import argparse import math +import os +import gc +import time +import sys + +# --- 1. Memory Tuning --- +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" import torch +import torch.nn.functional as F from einops import rearrange, repeat from tqdm import tqdm +import numpy as np + +# --- 2. Hardware Optimization --- +if torch.cuda.is_available(): + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + # 'high' allows TF32 but maintains reasonable precision + torch.set_float32_matmul_precision('high') + +try: + import psutil +except ImportError: + psutil = None from imaginaire.utils.io import save_image_or_video from imaginaire.utils import log @@ -29,123 +63,303 @@ from modify_model import tensor_kwargs, create_model +# Suppress graph break warnings for cleaner output torch._dynamo.config.suppress_errors = True - +torch._dynamo.config.verbose = False def parse_arguments() -> argparse.Namespace: parser = argparse.ArgumentParser(description="TurboDiffusion inference script for Wan2.1 T2V") - parser.add_argument("--dit_path", type=str, required=True, help="Custom path to the DiT model checkpoint for distilled models") + parser.add_argument("--dit_path", type=str, required=True, help="Custom path to the DiT model checkpoint") parser.add_argument("--model", choices=["Wan2.1-1.3B", "Wan2.1-14B"], default="Wan2.1-1.3B", help="Model to use") parser.add_argument("--num_samples", type=int, default=1, help="Number of samples to generate") parser.add_argument("--num_steps", type=int, choices=[1, 2, 3, 4], default=4, help="1~4 for timestep-distilled inference") parser.add_argument("--sigma_max", type=float, default=80, help="Initial sigma for rCM") parser.add_argument("--vae_path", type=str, default="checkpoints/Wan2.1_VAE.pth", help="Path to the Wan2.1 VAE") parser.add_argument("--text_encoder_path", type=str, default="checkpoints/models_t5_umt5-xxl-enc-bf16.pth", help="Path to the umT5 text encoder") + parser.add_argument("--cached_embedding", type=str, default=None, help="Path to cached text embeddings (pt file)") + parser.add_argument("--skip_t5", action="store_true", help="Skip T5 loading (implied if cached_embedding is used)") parser.add_argument("--num_frames", type=int, default=81, help="Number of frames to generate") - parser.add_argument("--prompt", type=str, default=None, help="Text prompt for video generation (required unless --serve)") + parser.add_argument("--prompt", type=str, required=True, help="Text prompt for video generation") parser.add_argument("--resolution", default="480p", type=str, help="Resolution of the generated output") parser.add_argument("--aspect_ratio", default="16:9", type=str, help="Aspect ratio of the generated output (width:height)") parser.add_argument("--seed", type=int, default=0, help="Random seed for reproducibility") - parser.add_argument("--save_path", type=str, default="output/generated_video.mp4", help="Path to save the generated video (include file extension)") - parser.add_argument("--attention_type", choices=["sla", "sagesla", "original"], default="sagesla", help="Type of attention mechanism to use") + parser.add_argument("--save_path", type=str, default="output/generated_video.mp4", help="Path to save the generated video") + parser.add_argument("--attention_type", choices=["sla", "sagesla", "original"], default="sagesla", help="Type of attention mechanism") parser.add_argument("--sla_topk", type=float, default=0.1, help="Top-k ratio for SLA/SageSLA attention") parser.add_argument("--quant_linear", action="store_true", help="Whether to replace Linear layers with quantized versions") - parser.add_argument("--default_norm", action="store_true", help="Whether to replace LayerNorm/RMSNorm layers with faster versions") - parser.add_argument("--serve", action="store_true", help="Launch interactive TUI server mode (keeps model loaded)") + parser.add_argument("--default_norm", action="store_true", help="Whether to replace LayerNorm/RMSNorm with faster versions") + parser.add_argument("--offload_dit", action="store_true", help="Offload DiT to CPU when not in use to save VRAM") + parser.add_argument("--compile", action="store_true", help="Use torch.compile (Inductor) for faster inference") return parser.parse_args() +def check_hardware_compatibility(): + if not torch.cuda.is_available(): return + gpu_name = torch.cuda.get_device_name(0) + log.info(f"Hardware: {gpu_name}") + + current_dtype = tensor_kwargs.get("dtype") + if current_dtype == torch.bfloat16 and not torch.cuda.is_bf16_supported(): + log.warning(f"โš ๏ธ Device does not support BFloat16. Switching to Float16.") + tensor_kwargs["dtype"] = torch.float16 -if __name__ == "__main__": - args = parse_arguments() +def print_memory_status(step_name=""): + if not torch.cuda.is_available(): return + torch.cuda.synchronize() + allocated = torch.cuda.memory_allocated() / (1024**3) + free, total = torch.cuda.mem_get_info() + free_gb = free / (1024**3) + print(f"๐Ÿ“Š [MEM] {step_name}: InUse={allocated:.2f}GB, Free={free_gb:.2f}GB") - # Handle serve mode - if args.serve: - # Set mode to t2v for the TUI server - args.mode = "t2v" - from serve.tui import main as serve_main - serve_main(args) - exit(0) +def cleanup_memory(step_info=""): + gc.collect() + torch.cuda.empty_cache() - # Validate prompt is provided for one-shot mode - if args.prompt is None: - log.error("--prompt is required (unless using --serve mode)") - exit(1) +def load_dit_model(args, force_offload=False): + orig_offload = args.offload_dit + if force_offload: args.offload_dit = True + log.info(f"Loading DiT (Offload={args.offload_dit})...") + model = create_model(dit_path=args.dit_path, args=args).cpu() + args.offload_dit = orig_offload + return model - log.info(f"Computing embedding for prompt: {args.prompt}") - with torch.no_grad(): - text_emb = get_umt5_embedding(checkpoint_path=args.text_encoder_path, prompts=args.prompt).to(**tensor_kwargs) - clear_umt5_memory() +def tiled_decode_gpu(tokenizer, latents, overlap=12): + print(f"\n๐Ÿงฑ Starting Tiled GPU Decode (Overlap={overlap})...") + B, C, T, H, W = latents.shape + scale = tokenizer.spatial_compression_factor + h_mid = H // 2 + w_mid = W // 2 + + def decode_tile(tile_latents): + cleanup_memory() + with torch.no_grad(): return tokenizer.decode(tile_latents).cpu() - log.info(f"Loading DiT model from {args.dit_path}") - net = create_model(dit_path=args.dit_path, args=args).cpu() - torch.cuda.empty_cache() - log.success("Successfully loaded DiT model.") + # 1. Top Tiles + l_tl = latents[..., :h_mid+overlap, :w_mid+overlap] + l_tr = latents[..., :h_mid+overlap, w_mid-overlap:] + v_tl = decode_tile(l_tl) + v_tr = decode_tile(l_tr) - tokenizer = Wan2pt1VAEInterface(vae_pth=args.vae_path) + B_dec, C_dec, T_dec, H_tile, W_tile = v_tl.shape + mid_pix = w_mid * scale + overlap_pix = overlap * scale + + row_top = torch.zeros(B_dec, 3, T_dec, H_tile, W*scale, dtype=v_tl.dtype, device='cpu') + end_left = max(0, mid_pix - overlap_pix) + start_right = mid_pix + overlap_pix + + row_top[..., :end_left] = v_tl[..., :end_left] + row_top[..., start_right:] = v_tr[..., 2*overlap_pix:] + + x_linspace = torch.linspace(-6, 6, 2*overlap_pix, device='cpu') + alpha = torch.sigmoid(x_linspace).view(1, 1, 1, 1, -1) + row_top[..., end_left:start_right] = v_tl[..., mid_pix-overlap_pix:] * (1 - alpha) + v_tr[..., :2*overlap_pix] * alpha + del v_tl, v_tr - w, h = VIDEO_RES_SIZE_INFO[args.resolution][args.aspect_ratio] + # 2. Bottom Tiles + l_bl = latents[..., h_mid-overlap:, :w_mid+overlap] + l_br = latents[..., h_mid-overlap:, w_mid-overlap:] + v_bl = decode_tile(l_bl) + v_br = decode_tile(l_br) + + row_bot = torch.zeros(B_dec, 3, T_dec, H_tile, W*scale, dtype=v_bl.dtype, device='cpu') + row_bot[..., :end_left] = v_bl[..., :end_left] + row_bot[..., start_right:] = v_br[..., 2*overlap_pix:] + row_bot[..., end_left:start_right] = v_bl[..., mid_pix-overlap_pix:] * (1 - alpha) + v_br[..., :2*overlap_pix] * alpha + del v_bl, v_br - log.info(f"Generating with prompt: {args.prompt}") - condition = {"crossattn_emb": repeat(text_emb.to(**tensor_kwargs), "b l d -> (k b) l d", k=args.num_samples)} + # 3. Blend Vertically + h_mid_pix = h_mid * scale + video = torch.zeros(B_dec, 3, T_dec, H*scale, W*scale, dtype=row_top.dtype, device='cpu') + end_top = max(0, h_mid_pix - overlap_pix) + start_bot = h_mid_pix + overlap_pix + + video[..., :end_top, :] = row_top[..., :end_top, :] + video[..., start_bot:, :] = row_bot[..., 2*overlap_pix:, :] + + alpha_v = torch.sigmoid(x_linspace).view(1, 1, 1, -1, 1) + video[..., end_top:start_bot, :] = row_top[..., h_mid_pix-overlap_pix:, :] * (1 - alpha_v) + row_bot[..., :2*overlap_pix, :] * alpha_v + + return video.to(latents.device) - to_show = [] +def force_cpu_float32(target_obj): + for module in target_obj.modules(): + module.cpu().float() - state_shape = [ - tokenizer.latent_ch, - tokenizer.get_latent_num_frames(args.num_frames), - h // tokenizer.spatial_compression_factor, - w // tokenizer.spatial_compression_factor, - ] +def apply_manual_offload(model, device="cuda"): + log.info("Applying Tier 3 Offload...") + block_list_name = None + max_len = 0 + for name, child in model.named_children(): + if isinstance(child, torch.nn.ModuleList): + if len(child) > max_len: + max_len = len(child) + block_list_name = name + + if not block_list_name: + log.warning("Could not identify Block List! Offloading entire model to CPU.") + model.to("cpu") + return - generator = torch.Generator(device=tensor_kwargs["device"]) - generator.manual_seed(args.seed) + print(f" ๐Ÿ‘‰ Identified Transformer Blocks: '{block_list_name}' ({max_len} layers)") + try: model.to(device) + except RuntimeError: model.to("cpu") + + blocks = getattr(model, block_list_name) + blocks.to("cpu") + + def pre_hook(module, args): + module.to(device) + return args + def post_hook(module, args, output): + module.to("cpu") + return output + + for i, block in enumerate(blocks): + block.register_forward_pre_hook(pre_hook) + block.register_forward_hook(post_hook) - init_noise = torch.randn( - args.num_samples, - *state_shape, - dtype=torch.float32, - device=tensor_kwargs["device"], - generator=generator, - ) +if __name__ == "__main__": + print_memory_status("Script Start") + + # --- CREDIT PRINT --- + log.info("----------------------------------------------------------------") + log.info("๐Ÿš€ TurboDiffusion Optimized Inference") + log.info(" Co-developed by [Your Name/Handle] & Google Gemini") + log.info("----------------------------------------------------------------") - # mid_t = [1.3, 1.0, 0.6][: args.num_steps - 1] - # For better visual quality - mid_t = [1.5, 1.4, 1.0][: args.num_steps - 1] + check_hardware_compatibility() + args = parse_arguments() - t_steps = torch.tensor( - [math.atan(args.sigma_max), *mid_t, 0], - dtype=torch.float64, - device=init_noise.device, - ) + if (args.num_frames - 1) % 4 != 0: + new_f = ((args.num_frames - 1) // 4 + 1) * 4 + 1 + print(f"โš ๏ธ Adjusting --num_frames to {new_f}") + args.num_frames = new_f - # Convert TrigFlow timesteps to RectifiedFlow - t_steps = torch.sin(t_steps) / (torch.cos(t_steps) + torch.sin(t_steps)) + if args.num_frames > 90 and not args.offload_dit: + args.offload_dit = True + + # --- CRITICAL FIX: Strictly Disable Compile for Quantized Models --- + if args.compile and args.quant_linear: + log.warning("๐Ÿšซ Quantized Model Detected: FORCE DISABLING `torch.compile` to avoid OOM.") + log.warning(" (Custom quantized kernels are not compatible with CUDA Graphs)") + args.compile = False + + # 1. Text Embeddings + if args.cached_embedding and os.path.exists(args.cached_embedding): + log.info(f"Loading cache: {args.cached_embedding}") + c = torch.load(args.cached_embedding, map_location='cpu') + text_emb = c['embeddings'][0]['embedding'].to(**tensor_kwargs) if isinstance(c, dict) else c.to(**tensor_kwargs) + else: + log.info(f"Computing embedding...") + with torch.no_grad(): + text_emb = get_umt5_embedding(args.text_encoder_path, args.prompt).to(**tensor_kwargs) + clear_umt5_memory() + cleanup_memory() + + # 2. VAE Shape Calc & UNLOAD + log.info("VAE Setup (Temp)...") + tokenizer = Wan2pt1VAEInterface(vae_pth=args.vae_path) + w, h = VIDEO_RES_SIZE_INFO[args.resolution][args.aspect_ratio] + state_shape = [tokenizer.latent_ch, tokenizer.get_latent_num_frames(args.num_frames), h // tokenizer.spatial_compression_factor, w // tokenizer.spatial_compression_factor] + del tokenizer + cleanup_memory("VAE Unloaded") - # Sampling steps + # 3. Load DiT + net = load_dit_model(args) + + # 4. Noise & Schedule + gen = torch.Generator(device=tensor_kwargs["device"]).manual_seed(args.seed) + cond = {"crossattn_emb": repeat(text_emb.to(**tensor_kwargs), "b l d -> (k b) l d", k=args.num_samples)} + init_noise = torch.randn(args.num_samples, *state_shape, dtype=torch.float32, device=tensor_kwargs["device"], generator=gen) + + mid_t = [1.5, 1.4, 1.0][: args.num_steps - 1] + t_steps = torch.tensor([math.atan(args.sigma_max), *mid_t, 0], dtype=torch.float64, device=init_noise.device) + t_steps = torch.sin(t_steps) / (torch.cos(t_steps) + torch.sin(t_steps)) + x = init_noise.to(torch.float64) * t_steps[0] ones = torch.ones(x.size(0), 1, device=x.device, dtype=x.dtype) - total_steps = t_steps.shape[0] - 1 + + # 5. Fast Sampling Loop + log.info("๐Ÿ”ฅ STARTING SAMPLING (INFERENCE MODE) ๐Ÿ”ฅ") + torch.cuda.empty_cache() net.cuda() - for i, (t_cur, t_next) in enumerate(tqdm(list(zip(t_steps[:-1], t_steps[1:])), desc="Sampling", total=total_steps)): - with torch.no_grad(): - v_pred = net(x_B_C_T_H_W=x.to(**tensor_kwargs), timesteps_B_T=(t_cur.float() * ones * 1000).to(**tensor_kwargs), **condition).to( - torch.float64 - ) - x = (1 - t_next) * (x - t_cur * v_pred) + t_next * torch.randn( - *x.shape, - dtype=torch.float32, - device=tensor_kwargs["device"], - generator=generator, - ) + print_memory_status("Tier 1: GPU Ready") + + # Compile? (Only if NOT disabled above) + if args.compile: + log.info("๐Ÿš€ Compiling model...") + try: + net = torch.compile(net, mode="reduce-overhead") + except Exception as e: + log.warning(f"Compile failed: {e}. Running eager.") + + failover = 0 + + with torch.inference_mode(): + for i, (t_cur, t_next) in enumerate(tqdm(zip(t_steps[:-1], t_steps[1:]), total=len(t_steps)-1)): + retry = True + while retry: + try: + t_cur_scalar = t_cur.item() + t_next_scalar = t_next.item() + + v_pred = net( + x_B_C_T_H_W=x.to(**tensor_kwargs), + timesteps_B_T=(t_cur * ones * 1000).to(**tensor_kwargs), + **cond + ).to(torch.float64) + + if args.offload_dit and i == len(t_steps)-2 and failover == 0: + net.cpu() + + noise = torch.randn(*x.shape, dtype=torch.float32, device=x.device, generator=gen).to(torch.float64) + term1 = x - (v_pred * t_cur_scalar) + x = term1 * (1.0 - t_next_scalar) + (noise * t_next_scalar) + + retry = False + + except torch.OutOfMemoryError: + log.warning(f"โš ๏ธ OOM at Step {i}. Recovering...") + try: net.cpu() + except: pass + del net + cleanup_memory() + failover += 1 + + if failover == 1: + print("โ™ป๏ธ Tier 2: Checkpointing") + net = load_dit_model(args, force_offload=True) + net.cuda() + # Retry compile with safer mode if first attempt was aggressive + if args.compile: + try: net = torch.compile(net, mode="default") + except: pass + elif failover == 2: + print("โ™ป๏ธ Tier 3: Manual Offload") + net = load_dit_model(args, force_offload=True) + apply_manual_offload(net) + else: + sys.exit("โŒ Critical OOM.") + samples = x.float() - net.cpu() - torch.cuda.empty_cache() + # 6. Decode + if 'net' in locals(): + try: net.cpu() + except: pass + del net + cleanup_memory("Pre-VAE") + + log.info("Decoding...") + tokenizer = Wan2pt1VAEInterface(vae_pth=args.vae_path) with torch.no_grad(): - video = tokenizer.decode(samples) - - to_show.append(video.float().cpu()) + try: + video = tokenizer.decode(samples) + except torch.OutOfMemoryError: + log.warning("Falling back to Tiled Decode...") + video = tiled_decode_gpu(tokenizer, samples) + to_show = [video.float().cpu()] to_show = (1.0 + torch.stack(to_show, dim=0).clamp(-1, 1)) / 2.0 - save_image_or_video(rearrange(to_show, "n b c t h w -> c t (n h) (b w)"), args.save_path, fps=16) + log.success(f"Saved: {args.save_path}") From 5b5f5dd658d8a6901f468640c444f0023afa586d Mon Sep 17 00:00:00 2001 From: MrEdwards007 <116316872+MrEdwards007@users.noreply.github.com> Date: Thu, 1 Jan 2026 07:45:10 -0500 Subject: [PATCH 3/6] cache_t5.py (Memory Utility) A utility script designed by John D. Pope to pre-compute and save text embeddings to a file. By "caching" these embeddings, you can run the main inference scripts without loading the heavy 11GB T5 Text Encoder into VRAM. This is essential for running the large 14B models on GPUs with limited memory, effectively "skipping" the most memory-intensive part of the initialization. --- turbodiffusion/inference/cache_t5.py | 117 +++++++++++++++++++++++++++ 1 file changed, 117 insertions(+) create mode 100644 turbodiffusion/inference/cache_t5.py diff --git a/turbodiffusion/inference/cache_t5.py b/turbodiffusion/inference/cache_t5.py new file mode 100644 index 0000000..a25e3cf --- /dev/null +++ b/turbodiffusion/inference/cache_t5.py @@ -0,0 +1,117 @@ +#!/usr/bin/env python +# ----------------------------------------------------------------------------------------- +# T5 EMBEDDING CACHE UTILITY +# +# Acknowledgments: +# - Work and creativity of: John D. Pope +# +# Description: +# Pre-computes text embeddings to allow running inference on GPUs with limited VRAM +# by removing the need to keep the 11GB T5 encoder loaded in memory. +# ----------------------------------------------------------------------------------------- +""" +Pre-cache T5 text embeddings to avoid loading the 11GB model during inference. + +Usage: + # Cache a single prompt + python scripts/cache_t5.py --prompt "slow head turn, cinematic" --output cached_embeddings.pt + + # Cache multiple prompts from file + python scripts/cache_t5.py --prompts_file prompts.txt --output cached_embeddings.pt + +Then use with inference: + python turbodiffusion/inference/wan2.2_i2v_infer.py \ + --cached_embedding cached_embeddings.pt \ + --skip_t5 \ + ... +""" +import os +import sys +import argparse +import torch + +# Add repo root to path for imports +SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) +REPO_ROOT = os.path.dirname(SCRIPT_DIR) +sys.path.insert(0, REPO_ROOT) + +def main(): + parser = argparse.ArgumentParser(description="Pre-cache T5 text embeddings") + parser.add_argument("--prompt", type=str, default=None, help="Single prompt to cache") + parser.add_argument("--prompts_file", type=str, default=None, help="File with prompts (one per line)") + parser.add_argument("--text_encoder_path", type=str, + default="/media/2TB/ComfyUI/models/text_encoders/models_t5_umt5-xxl-enc-bf16.pth", + help="Path to the umT5 text encoder") + parser.add_argument("--output", type=str, default="cached_t5_embeddings.pt", + help="Output path for cached embeddings") + parser.add_argument("--device", type=str, default="cuda", + help="Device to use for encoding (cuda is faster, memory freed after)") + args = parser.parse_args() + + # Collect prompts + prompts = [] + if args.prompt: + prompts.append(args.prompt) + if args.prompts_file and os.path.exists(args.prompts_file): + with open(args.prompts_file, 'r') as f: + prompts.extend([line.strip() for line in f if line.strip()]) + + if not prompts: + print("Error: Provide --prompt or --prompts_file") + sys.exit(1) + + print(f"Caching embeddings for {len(prompts)} prompt(s)") + print(f"Text encoder: {args.text_encoder_path}") + print(f"Device: {args.device}") + print() + + # Import after path setup + from rcm.utils.umt5 import get_umt5_embedding, clear_umt5_memory + + cache_data = { + 'prompts': prompts, + 'embeddings': [], + 'text_encoder_path': args.text_encoder_path, + } + + with torch.no_grad(): + for i, prompt in enumerate(prompts): + print(f"[{i+1}/{len(prompts)}] Encoding: '{prompt[:60]}...' " if len(prompt) > 60 else f"[{i+1}/{len(prompts)}] Encoding: '{prompt}'") + + # Get embedding (loads T5 if not already loaded) + embedding = get_umt5_embedding( + checkpoint_path=args.text_encoder_path, + prompts=prompt + ) + + # Move to CPU for storage + cache_data['embeddings'].append({ + 'prompt': prompt, + 'embedding': embedding.cpu(), + 'shape': list(embedding.shape), + }) + + print(f" Shape: {embedding.shape}, dtype: {embedding.dtype}") + + # Clear T5 from memory + print("\nClearing T5 from memory...") + clear_umt5_memory() + torch.cuda.empty_cache() + + # Save cache + print(f"\nSaving to: {args.output}") + torch.save(cache_data, args.output) + + # Summary + file_size = os.path.getsize(args.output) / (1024 * 1024) + print(f"Done! Cache file size: {file_size:.2f} MB") + print() + print("Usage:") + print(f" python turbodiffusion/inference/wan2.2_i2v_infer.py \\") + print(f" --cached_embedding {args.output} \\") + print(f" --skip_t5 \\") + print(f" ... (other args)") + + +if __name__ == "__main__": + main() From 95aef856f42f46c5607ee19ab88647bc2140e677 Mon Sep 17 00:00:00 2001 From: MrEdwards007 <116316872+MrEdwards007@users.noreply.github.com> Date: Thu, 1 Jan 2026 07:53:48 -0500 Subject: [PATCH 4/6] Gradio-based Interface (Unified Web UI) A Gradio-based dashboard that centralizes Text-to-Video and Image-to-Video workflows. It automates environment setup, provides real-time VRAM monitoring, and enforces the 4n+1 frame rule for VAE stability. It also supports automatic T5 caching and saves reproduction metadata for every video generated. --- .../turbo_diffusion_t5_cache_optimize_v6.py | 343 ++++++++++++++++++ 1 file changed, 343 insertions(+) create mode 100644 turbodiffusion/inference/turbo_diffusion_t5_cache_optimize_v6.py diff --git a/turbodiffusion/inference/turbo_diffusion_t5_cache_optimize_v6.py b/turbodiffusion/inference/turbo_diffusion_t5_cache_optimize_v6.py new file mode 100644 index 0000000..0525f4b --- /dev/null +++ b/turbodiffusion/inference/turbo_diffusion_t5_cache_optimize_v6.py @@ -0,0 +1,343 @@ +import os +import sys +import subprocess +import gradio as gr +import glob +import random +import time +import select +import torch +from datetime import datetime + +# --- 1. System Setup --- +PROJECT_ROOT = "/home/wedwards/Documents/Programs/TurboDiffusion" +os.chdir(PROJECT_ROOT) +os.system('clear' if os.name == 'posix' else 'cls') + +CHECKPOINT_DIR = os.path.join(PROJECT_ROOT, "checkpoints") +OUTPUT_DIR = os.path.join(PROJECT_ROOT, "output") +os.makedirs(OUTPUT_DIR, exist_ok=True) + +T2V_SCRIPT = "turbodiffusion/inference/wan2.1_t2v_infer.py" +I2V_SCRIPT = "turbodiffusion/inference/wan2.2_i2v_infer.py" +CACHE_SCRIPT = "turbodiffusion/inference/cache_t5.py" + +def get_gpu_status_original(): + """System-level GPU check.""" + try: + res = subprocess.check_output( + ["nvidia-smi", "--query-gpu=name,memory.used,memory.total", "--format=csv,nounits,noheader"], + encoding='utf-8' + ).strip().split(',') + return f"๐Ÿ–ฅ๏ธ {res[0]} | โšก VRAM: {res[1]}MB / {res[2]}MB" + except: + return "๐Ÿ–ฅ๏ธ GPU Monitor Active" + + +def get_gpu_status(): + """ + Check GPU status using PyTorch. + Returns system-wide VRAM usage without relying on nvidia-smi CLI. + """ + try: + # 1. Check for CUDA (NVIDIA) or ROCm (AMD) + if torch.cuda.is_available(): + # mem_get_info returns (free_bytes, total_bytes) + free_mem, total_mem = torch.cuda.mem_get_info() + + used_mem = total_mem - free_mem + + # Convert to MB for display + total_mb = int(total_mem / 1024**2) + used_mb = int(used_mem / 1024**2) + name = torch.cuda.get_device_name(0) + + return f"๐Ÿ–ฅ๏ธ {name} | โšก VRAM: {used_mb}MB / {total_mb}MB" + + # 2. Check for Apple Silicon (MPS) + # Note: Apple uses Unified Memory, so 'VRAM' is shared with System RAM. + elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): + return "๐Ÿ–ฅ๏ธ Apple Silicon (MPS) | โšก Unified Memory Active" + + # 3. Fallback to CPU + else: + return "๐Ÿ–ฅ๏ธ Running on CPU" + + except ImportError: + return "๐Ÿ–ฅ๏ธ GPU Monitor: PyTorch not installed" + except Exception as e: + return f"๐Ÿ–ฅ๏ธ GPU Monitor Error: {str(e)}" + +def save_debug_metadata(video_path, script_rel, cmd_list, cache_cmd_list=None): + """ + Saves a fully executable reproduction script with env vars. + """ + meta_path = video_path.replace(".mp4", "_metadata.txt") + with open(meta_path, "w") as f: + f.write(f"# Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n") + f.write("# Copy and paste the lines below to reproduce this video exactly:\n\n") + + # Environment Variables + f.write("export PYTHONPATH=turbodiffusion\n") + f.write("export PYTORCH_ALLOC_CONF=expandable_segments:True\n") + f.write("export TOKENIZERS_PARALLELISM=false\n\n") + + # Optional Cache Step + if cache_cmd_list: + f.write("# --- Step 1: Pre-Cache Embeddings ---\n") + f.write(f"python {CACHE_SCRIPT} \\\n") + c_args = cache_cmd_list[2:] + for i, arg in enumerate(c_args): + if arg.startswith("--"): + val = f'"{c_args[i+1]}"' if i+1 < len(c_args) and not c_args[i+1].startswith("--") else "" + f.write(f" {arg} {val} \\\n") + f.write("\n# --- Step 2: Run Inference ---\n") + + # Main Inference Command + f.write(f"python {script_rel} \\\n") + args_only = cmd_list[2:] + for i, arg in enumerate(args_only): + if arg.startswith("--"): + val = f'"{args_only[i+1]}"' if i+1 < len(args_only) and not args_only[i+1].startswith("--") else "" + f.write(f" {arg} {val} \\\n") + +def sync_path(scale): + fname = "TurboWan2.1-T2V-1.3B-480P-quant.pth" if "1.3B" in scale else "TurboWan2.1-T2V-14B-720P-quant.pth" + return os.path.join(CHECKPOINT_DIR, fname) + +# --- 2. Unified Generation Logic (With Safety Checks) --- + +def run_gen(mode, prompt, model, dit_path, i2v_high, i2v_low, image, res, ratio, steps, seed, quant, attn, top_k, frames, sigma, norm, adapt, ode, use_cache, cache_path, pr=gr.Progress()): + # --- PRE-FLIGHT SAFETY CHECK --- + error_msg = "" + if mode == "T2V": + if "quant" in dit_path.lower() and not quant: + error_msg = "โŒ CONFIG ERROR: Quantized model selected but '8-bit' disabled." + if attn == "original" and ("turbo" in dit_path.lower() or "quant" in dit_path.lower()): + error_msg = "โŒ COMPATIBILITY ERROR: 'Original' attention with Turbo/Quantized checkpoint." + else: + if ("quant" in i2v_high.lower() or "quant" in i2v_low.lower()) and not quant: + error_msg = "โŒ CONFIG ERROR: Quantized I2V model selected but '8-bit' disabled." + if attn == "original" and (("turbo" in i2v_high.lower() or "quant" in i2v_high.lower()) or ("turbo" in i2v_low.lower() or "quant" in i2v_low.lower())): + error_msg = "โŒ COMPATIBILITY ERROR: 'Original' attention with Turbo/Quantized checkpoints." + + if error_msg: + yield None, None, "โŒ Config Error", "๐Ÿ›‘ Aborted", error_msg + return + # ------------------------------- + + actual_seed = random.randint(1, 1000000) if seed <= 0 else int(seed) + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + start_time = time.time() + + full_log = f"๐Ÿš€ Starting Job: {timestamp}\n" + pr(0, desc="๐Ÿš€ Starting...") + + # --- FRAME SANITIZATION (4n+1 RULE) --- + # Wan2.1 VAE requires frames to be (4n + 1). If not, we sanitize. + target_frames = int(frames) + valid_frames = ((target_frames - 1) // 4) * 4 + 1 + + # If the user input (e.g., 32) became smaller (29) or changed, we log it. + # Note: We enforce a minimum of 1 frame just in case. + valid_frames = max(1, valid_frames) + + if valid_frames != target_frames: + warning_msg = f"โš ๏ธ AUTO-ADJUST: Frame count {target_frames} is incompatible with VAE (requires 4n+1).\n" + warning_msg += f" Adjusted {target_frames} -> {valid_frames} frames to prevent kernel crash.\n" + full_log += warning_msg + print(warning_msg) # Print to console as well + + # Use valid_frames for the rest of the logic + frames = valid_frames + # -------------------------------------- + + # --- AUTO-CACHE STEP --- + cache_cmd_list = None + if use_cache: + pr(0, desc="๐Ÿ’พ Auto-Caching T5 Embeddings...") + cache_script_full = os.path.join(PROJECT_ROOT, CACHE_SCRIPT) + encoder_path = os.path.join(CHECKPOINT_DIR, "models_t5_umt5-xxl-enc-bf16.pth") + + cache_cmd = [ + sys.executable, + cache_script_full, + "--prompt", prompt, + "--output", cache_path, + "--text_encoder_path", encoder_path + ] + cache_cmd_list = cache_cmd + + full_log += f"\n[System] Running Cache Script: {' '.join(cache_cmd)}\n" + yield None, None, f"Seed: {actual_seed}", "๐Ÿ’พ Caching...", full_log + + cache_process = subprocess.Popen(cache_cmd, cwd=PROJECT_ROOT, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, bufsize=1) + + while True: + if cache_process.poll() is not None: + rest = cache_process.stdout.read() + if rest: full_log += rest + break + line = cache_process.stdout.readline() + if line: + full_log += line + yield None, None, f"Seed: {actual_seed}", "๐Ÿ’พ Caching...", full_log + time.sleep(0.02) + + if cache_process.returncode != 0: + full_log += "\nโŒ CACHE FAILED. Aborting generation." + yield None, None, "โŒ Cache Failed", "๐Ÿ›‘ Aborted", full_log + return + + full_log += "\nโœ… Cache Complete. Starting Inference...\n" + # ----------------------------------------- + + # --- SETUP VIDEO GENERATION --- + if mode == "T2V": + save_path = os.path.join(OUTPUT_DIR, f"t2v_{timestamp}.mp4") + script_rel = T2V_SCRIPT + cmd = [sys.executable, os.path.join(PROJECT_ROOT, T2V_SCRIPT), "--model", model, "--dit_path", dit_path, "--prompt", prompt, "--resolution", res, "--aspect_ratio", ratio, "--num_steps", str(steps), "--seed", str(actual_seed), "--attention_type", attn, "--sla_topk", str(top_k), "--num_samples", "1", "--num_frames", str(frames), "--sigma_max", str(sigma)] + else: + save_path = os.path.join(OUTPUT_DIR, f"i2v_{timestamp}.mp4") + script_rel = I2V_SCRIPT + # Note: Added frames to I2V command in previous step, maintained here. + cmd = [sys.executable, os.path.join(PROJECT_ROOT, I2V_SCRIPT), "--prompt", prompt, "--image_path", image, "--high_noise_model_path", i2v_high, "--low_noise_model_path", i2v_low, "--resolution", res, "--aspect_ratio", ratio, "--num_steps", str(steps), "--seed", str(actual_seed), "--attention_type", attn, "--sla_topk", str(top_k), "--num_frames", str(frames)] + if adapt: cmd.append("--adaptive_resolution") + if ode: cmd.append("--ode") + + if quant: cmd.append("--quant_linear") + if norm: cmd.append("--default_norm") + + if use_cache: + cmd.extend(["--cached_embedding", cache_path, "--skip_t5"]) + + cmd.extend(["--save_path", save_path]) + + # Call the restored metadata saver + save_debug_metadata(save_path, script_rel, cmd, cache_cmd_list) + + env = os.environ.copy() + env["PYTHONPATH"] = os.path.join(PROJECT_ROOT, "turbodiffusion") + env["PYTORCH_ALLOC_CONF"] = "expandable_segments:True" + env["TOKENIZERS_PARALLELISM"] = "false" + env["PYTHONUNBUFFERED"] = "1" + + full_log += f"\n[System] Running Inference: {' '.join(cmd)}\n" + process = subprocess.Popen(cmd, cwd=PROJECT_ROOT, env=env, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, bufsize=1) + + last_ui_update = 0 + + while True: + if process.poll() is not None: + rest = process.stdout.read() + if rest: full_log += rest + break + + reads = [process.stdout.fileno()] + ret = select.select(reads, [], [], 0.1) + + if ret[0]: + line = process.stdout.readline() + full_log += line + + if "Loading DiT" in line: pr(0.1, desc="โšก Loading weights...") + if "Encoding" in line: pr(0.05, desc="๐Ÿ–ผ๏ธ VAE Encoding...") + if "Switching to CPU" in line: pr(0.1, desc="โš ๏ธ CPU Fallback...") + if "Sampling:" in line: + try: + pct = int(line.split('%')[0].split('|')[-1].strip()) + pr(0.2 + (pct/100 * 0.7), desc=f"๐ŸŽฌ Sampling: {pct}%") + except: pass + if "decoding" in line.lower(): pr(0.95, desc="๐ŸŽฅ Decoding VAE...") + + current_time = time.time() + if current_time - last_ui_update > 0.25: + last_ui_update = current_time + elapsed = f"{int(current_time - start_time)}s" + yield None, None, f"Seed: {actual_seed}", f"โฑ๏ธ Time: {elapsed}", full_log + + history = sorted(glob.glob(os.path.join(OUTPUT_DIR, "*.mp4")), key=os.path.getmtime, reverse=True) + total_time = f"{int(time.time() - start_time)}s" + + yield save_path, history, f"โœ… Done | Seed: {actual_seed}", f"๐Ÿ Finished in {total_time}", full_log + +# --- 3. UI Layout --- +with gr.Blocks(title="TurboDiffusion Studio") as demo: + with gr.Row(): + gr.HTML("

โšก TurboDiffusion Studio

") + with gr.Column(scale=1): + gpu_display = gr.Markdown(get_gpu_status()) + + gr.Timer(2).tick(get_gpu_status, outputs=gpu_display) + + with gr.Tabs(): + with gr.Tab("Text-to-Video"): + with gr.Row(): + with gr.Column(scale=4): + t2v_p = gr.Textbox(label="Prompt", lines=3, value="A stylish woman walks down a Tokyo street...") + with gr.Row(): + t2v_m = gr.Radio(["Wan2.1-1.3B", "Wan2.1-14B"], label="Model", value="Wan2.1-1.3B") + t2v_res = gr.Dropdown(["480p", "720p"], label="Resolution", value="480p") + t2v_ratio = gr.Dropdown(["16:9", "4:3", "1:1", "9:16"], label="Aspect Ratio", value="16:9") + t2v_dit = gr.Textbox(label="DiT Path", value=sync_path("Wan2.1-1.3B"), interactive=False) + t2v_btn = gr.Button("Generate Video", variant="primary") + with gr.Column(scale=3): + t2v_out = gr.Video(label="Result", height=320) + with gr.Row(): + t2v_stat = gr.Textbox(label="Status", interactive=False, scale=2) + t2v_time = gr.Textbox(label="Timer", value="โฑ๏ธ Ready", interactive=False, scale=1) + + with gr.Tab("Image-to-Video"): + with gr.Row(): + with gr.Column(scale=4): + with gr.Row(): + i2v_img = gr.Image(label="Source", type="filepath", height=200) + i2v_p = gr.Textbox(label="Motion Prompt", lines=7) + with gr.Row(): + i2v_res = gr.Dropdown(["480p", "720p"], label="Resolution", value="720p") + i2v_ratio = gr.Dropdown(["16:9", "4:3", "1:1", "9:16"], label="Aspect Ratio", value="16:9") + with gr.Row(): + i2v_adapt = gr.Checkbox(label="Adaptive Resolution", value=True) + i2v_ode = gr.Checkbox(label="Use ODE", value=False) + with gr.Accordion("I2V Path Overrides", open=False): + i2v_high = gr.Textbox(label="High-Noise", value=os.path.join(CHECKPOINT_DIR, "TurboWan2.2-I2V-A14B-high-720P-quant.pth")) + i2v_low = gr.Textbox(label="Low-Noise", value=os.path.join(CHECKPOINT_DIR, "TurboWan2.2-I2V-A14B-low-720P-quant.pth")) + i2v_btn = gr.Button("Animate Image", variant="primary") + with gr.Column(scale=3): + i2v_out = gr.Video(label="Result", height=320) + with gr.Row(): + i2v_stat_2 = gr.Textbox(label="Status", interactive=False, scale=2) + i2v_time_2 = gr.Textbox(label="Timer", value="โฑ๏ธ Ready", interactive=False, scale=1) + + console_out = gr.Textbox(label="Live CLI Console Output", lines=8, max_lines=8, interactive=False) + + with gr.Accordion("โš™๏ธ Precision & Advanced Settings", open=False): + with gr.Row(): + quant_opt = gr.Checkbox(label="Enable --quant_linear (8-bit)", value=True) + steps_opt = gr.Slider(1, 4, value=4, step=1, label="Steps") + seed_opt = gr.Number(label="Seed (0=Random)", value=0, precision=0) + with gr.Row(): + top_k_opt = gr.Slider(0.01, 0.5, value=0.15, step=0.01, label="SLA Top-K") + attn_opt = gr.Radio(["sagesla", "sla", "original"], label="Attention", value="sagesla") + sigma_opt = gr.Number(label="Sigma Max", value=80) + norm_opt = gr.Checkbox(label="Original Norms", value=False) + frames_opt = gr.Slider(1, 120, value=77, step=4, label="Frames (Steps of 4)") + with gr.Row(variant="panel"): + # --- T5 CACHE UI --- + use_cache_opt = gr.Checkbox(label="Use Cached T5 Embeddings (Auto-Run)", value=True) + cache_path_opt = gr.Textbox(label="Cache File Path", value="cached_t5_embeddings.pt", scale=2) + # ------------------- + + history_gal = gr.Gallery(value=sorted(glob.glob(os.path.join(OUTPUT_DIR, "*.mp4")), reverse=True), columns=6, height="auto") + + # --- 4. Logic Bindings --- + t2v_m.change(fn=sync_path, inputs=t2v_m, outputs=t2v_dit) + + t2v_args = [gr.State("T2V"), t2v_p, t2v_m, t2v_dit, gr.State(""), gr.State(""), gr.State(""), t2v_res, t2v_ratio, steps_opt, seed_opt, quant_opt, attn_opt, top_k_opt, frames_opt, sigma_opt, norm_opt, gr.State(False), gr.State(False), use_cache_opt, cache_path_opt] + t2v_btn.click(run_gen, t2v_args, [t2v_out, history_gal, t2v_stat, t2v_time, console_out], show_progress="hidden") + + i2v_args = [i2v_img, i2v_p, gr.State("Wan2.2-A14B"), gr.State(""), i2v_high, i2v_low, i2v_img, i2v_res, i2v_ratio, steps_opt, seed_opt, quant_opt, attn_opt, top_k_opt, frames_opt, gr.State(200), norm_opt, i2v_adapt, i2v_ode, use_cache_opt, cache_path_opt] + i2v_btn.click(run_gen, i2v_args, [i2v_out, history_gal, i2v_stat_2, i2v_time_2, console_out], show_progress="hidden") + +if __name__ == "__main__": + demo.launch(theme=gr.themes.Default(), allowed_paths=[OUTPUT_DIR]) \ No newline at end of file From 5461725c9dcd5130e7e2bb5f881a89dad5c9e9c5 Mon Sep 17 00:00:00 2001 From: MrEdwards007 <116316872+MrEdwards007@users.noreply.github.com> Date: Thu, 1 Jan 2026 08:35:18 -0500 Subject: [PATCH 5/6] TurboDiffusion Studio Readme Readme for TurboDiffusion Studio , a unified, high-performance Gradio interface for Wan2.1 and Wan2.2 video generation that automates memory management through real-time VRAM monitoring and integrated T5 embedding caching. --- TurboDiffusion_Studio.md | 123 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 123 insertions(+) create mode 100644 TurboDiffusion_Studio.md diff --git a/TurboDiffusion_Studio.md b/TurboDiffusion_Studio.md new file mode 100644 index 0000000..d0d7f4f --- /dev/null +++ b/TurboDiffusion_Studio.md @@ -0,0 +1,123 @@ + +## ๐Ÿš€ Scripts & Inference + +This repository contains optimized inference engines for the Wan2.1 and Wan2.2 models, specifically tuned for high-resolution output and robust memory management on consumer hardware. + +### ๐ŸŽฅ Inference Engines + +| Script | Function | Key Features | +| --- | --- | --- | +| **`wan2.2_i2v_infer.py`** | **Image-to-Video** | **Tiered Failover System**: Automatic recovery from OOM errors.
+ +
**Intelligent Model Switching**: Transitions between High and Low Noise models based on step boundaries.
+ +
**Tiled Processing**: Uses 4-chunk tiled encoding/decoding for 720p+ stability. | +| **`wan2.1_t2v_infer.py`** | **Text-to-Video** | **Hardware Auto-Detection**: Automatically selects TF32, BF16, or FP16 based on GPU capabilities.
+ +
**Quantization Safety**: Force-disables `torch.compile` for quantized models to prevent graph-break OOMs.
+ +
**3-Tier Recovery**: Escalates from GPU โž” Checkpointing โž” Manual CPU Offloading if memory is exceeded. | + +### ๐Ÿ› ๏ธ Utilities + +* **`cache_t5.py`** +* **Purpose**: Pre-computes and saves T5 text embeddings to disk. +* **VRAM Benefit**: Eliminates the need to load the **11GB T5 encoder** during the main inference run, allowing 14B models to fit on GPUs with lower VRAM. +* **Usage**: Run this first to generate a `.pt` file, then pass it to the inference scripts using the `--cached_embedding` flag. + + +--- + +## ๐Ÿš€ Getting Started with TurboDiffusion + +To run the large 14B models on consumer GPUs, it is recommended to use the **T5 Caching** workflow. This offloads the 11GB text encoder from VRAM, leaving more space for the DiT model and high-resolution video decoding. + +### **Step 1: Environment Setup** + +Ensure your project structure is organized as follows: + +* **Root**: `/your/path/to/TurboDiffusion` +* **Checkpoints**: Place your `.pth` models in the `checkpoints/` directory. +* **Output**: Generated videos and metadata will be saved to `output/`. + +### **Step 2: The Two Ways to Cache T5** + +#### **Option A: Manual Pre-Caching (Recommended for Batching)** + +If you have a list of prompts you want to use frequently, use the standalone utility: + +```bash +python turbodiffusion/inference/cache_t5.py --prompt "Your descriptive prompt here" --output cached_t5_embeddings.pt + +``` + +This saves the processed text into a small `.pt` file, allowing the inference scripts to "skip" the heavy T5 model entirely. + +#### **Option B: Automatic Caching via Web UI** + +For a more streamlined experience, use the **TurboDiffusion Studio**: + +1. Launch the UI: `python turbo_diffusion_t5_cache_optimize_v6.py`. +2. Open the **Precision & Advanced Settings** accordion. +3. Check **Use Cached T5 Embeddings (Auto-Run)**. +4. When you click generate, the UI will automatically run the caching script first, clear the T5 model from memory, and then start the video generation. + +### **Step 3: Running Inference** + +Once your UI is launched and caching is configured: + +1. **Select Mode**: Choose between **Text-to-Video** (Wan2.1) or **Image-to-Video** (Wan2.2). +2. **Apply Quantization**: For 24GB VRAM GPUs (like the RTX 3090/4090/5090), ensure **Enable --quant_linear (8-bit)** is checked to avoid OOM errors. +3. **Monitor Hardware**: Watch the **Live GPU Monitor** at the top of the UI to track real-time VRAM usage during the sampling process. +4. **Retrieve Results**: Your video and its reproduction metadata (containing the exact CLI command used) will appear in the `output/` gallery. + + +--- + +## ๐Ÿ–ฅ๏ธ TurboDiffusion Studio (Web UI) + +The `turbo_diffusion_t5_cache_optimize_v6.py` script provides a high-performance, unified **Gradio-based Web interface** for both Text-to-Video and Image-to-Video generation. It serves as a centralized "Studio" dashboard that automates complex environment setups and memory optimizations. + +### **Key Features** + +| Feature | Description | +| --- | --- | +| **Unified Interface** | Toggle between **Text-to-Video (Wan2.1)** and **Image-to-Video (Wan2.2)** workflows within a single dashboard. | +| **Real-time GPU Monitor** | Native PyTorch-based VRAM monitoring that displays current memory usage and hardware status directly in the UI. | +| **Auto-Cache T5 Integration** | Automatically runs the `cache_t5.py` utility before inference to offload the 11GB text encoder, significantly reducing peak VRAM usage. | +| **Frame Sanitization** | Automatically enforces the **4n + 1 rule** required by the Wan VAE to prevent kernel crashes during decoding. | +| **Reproduction Metadata** | Every generated video automatically saves a matching `_metadata.txt` file containing the exact CLI command and environment variables needed to reproduce the result. | +| **Live Console Output** | Pipes real-time CLI logs and progress bars directly into a "Live Console" window in the web browser. | + +### **Advanced Controls** + +The UI exposes granular controls for technical users: + +* **Precision & Quantization:** Toggle 8-bit `--quant_linear` mode for low-VRAM operation. +* **Attention Tuning:** Switch between `sagesla`, `sla`, and `original` attention mechanisms. +* **Adaptive I2V:** Enable adaptive resolution and ODE solvers for Image-to-Video workflows. +* **Integrated Gallery:** Browse and view your output history directly within the `output/` directory. + +--- + +## ๐Ÿ› ๏ธ Usage + +To launch the studio: + +```bash +python turbo_diffusion_t5_cache_optimize_v6.py + +``` + +> **Note:** The script defaults to `/your/path/to/TurboDiffusion`as the project root. Ensure your local paths are configured accordingly in the **System Setup** section of the code. + + +--- + +## ๐Ÿ’ณ Credits & Acknowledgments + +If you utilize, share, or build upon these optimized scripts, please include the following acknowledgments: + +* **Optimization & Development**: Co-developed by **Waverly Edwards** and **Google Gemini**. +* **T5 Caching Logic**: Original concept and utility implementation by **John D. Pope**. +* **Base Framework**: Built upon the NVIDIA Imaginaire and Wan-Video research. From 1cf0b7bb778db63b41324a6a973cd5a28e6c428f Mon Sep 17 00:00:00 2001 From: MrEdwards007 <116316872+MrEdwards007@users.noreply.github.com> Date: Thu, 1 Jan 2026 13:31:06 -0500 Subject: [PATCH 6/6] Update wan2.1_t2v_infer.py logging --- turbodiffusion/inference/wan2.1_t2v_infer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/turbodiffusion/inference/wan2.1_t2v_infer.py b/turbodiffusion/inference/wan2.1_t2v_infer.py index 8ecabf9..a45220d 100644 --- a/turbodiffusion/inference/wan2.1_t2v_infer.py +++ b/turbodiffusion/inference/wan2.1_t2v_infer.py @@ -224,7 +224,7 @@ def post_hook(module, args, output): # --- CREDIT PRINT --- log.info("----------------------------------------------------------------") log.info("๐Ÿš€ TurboDiffusion Optimized Inference") - log.info(" Co-developed by [Your Name/Handle] & Google Gemini") + log.info(" Co-developed by Waverly Edwards & Google Gemini") log.info("----------------------------------------------------------------") check_hardware_compatibility()