diff --git a/src/streamdiffusion/acceleration/tensorrt/builder.py b/src/streamdiffusion/acceleration/tensorrt/builder.py index 87abf90d..3c7db255 100644 --- a/src/streamdiffusion/acceleration/tensorrt/builder.py +++ b/src/streamdiffusion/acceleration/tensorrt/builder.py @@ -63,6 +63,9 @@ def build( del self.network gc.collect() torch.cuda.empty_cache() + # Additional cleanup for compile_engines_only mode + torch.cuda.synchronize() + torch.cuda.ipc_collect() if not force_onnx_optimize and os.path.exists(onnx_opt_path): print(f"Found cached model: {onnx_opt_path}") else: @@ -72,6 +75,9 @@ def build( onnx_opt_path=onnx_opt_path, model_data=self.model, ) + # Cleanup after ONNX optimization + gc.collect() + torch.cuda.empty_cache() self.model.min_latent_shape = min_image_resolution // 8 self.model.max_latent_shape = max_image_resolution // 8 if not force_engine_build and os.path.exists(engine_path): @@ -97,3 +103,6 @@ def build( gc.collect() torch.cuda.empty_cache() + # Additional cleanup after engine build + torch.cuda.synchronize() + torch.cuda.ipc_collect() diff --git a/src/streamdiffusion/modules/controlnet_module.py b/src/streamdiffusion/modules/controlnet_module.py index 984776a8..0f1ab183 100644 --- a/src/streamdiffusion/modules/controlnet_module.py +++ b/src/streamdiffusion/modules/controlnet_module.py @@ -93,7 +93,7 @@ def install(self, stream) -> None: def add_controlnet(self, cfg: ControlNetConfig, control_image: Optional[Union[str, Any, torch.Tensor]] = None) -> None: model = self._load_pytorch_controlnet_model(cfg.model_id) - model = model.to(dtype=self.dtype) + model = model.to(device=self.device, dtype=self.dtype) preproc = None if cfg.preprocessor: @@ -595,7 +595,7 @@ def _load_pytorch_controlnet_model(self, model_id: str) -> ControlNetModel: controlnet = ControlNetModel.from_pretrained( model_id, torch_dtype=self.dtype ) - controlnet = controlnet.to(dtype=self.dtype) + controlnet = controlnet.to(device=self.device, dtype=self.dtype) # Track model_id for updater diffing try: setattr(controlnet, 'model_id', model_id) diff --git a/src/streamdiffusion/wrapper.py b/src/streamdiffusion/wrapper.py index 58ebbcab..35b6ba87 100644 --- a/src/streamdiffusion/wrapper.py +++ b/src/streamdiffusion/wrapper.py @@ -1043,6 +1043,10 @@ def _load_model( from streamdiffusion.acceleration.tensorrt.export_wrappers.unet_controlnet_export import create_controlnet_wrapper from streamdiffusion.acceleration.tensorrt.export_wrappers.unet_ipadapter_export import create_ipadapter_wrapper + # Clean up VRAM before starting TensorRT compilation when only compiling engines + if compile_engines_only: + self.cleanup_gpu_memory_for_compilation() + # Legacy TensorRT implementation (fallback) # Initialize engine manager engine_manager = EngineManager(engine_dir) @@ -1334,6 +1338,10 @@ def _load_model( 'max_image_resolution': 1024, } ) + + # Clean up VRAM after VAE decoder compilation when only compiling engines + if compile_engines_only: + self.cleanup_gpu_memory_for_compilation() # Compile VAE encoder engine using EngineManager vae_encoder = TorchVAEEncoder(stream.vae) @@ -1359,6 +1367,10 @@ def _load_model( 'max_image_resolution': 1024, } ) + + # Clean up VRAM after VAE encoder compilation when only compiling engines + if compile_engines_only: + self.cleanup_gpu_memory_for_compilation() cuda_stream = cuda.Stream() @@ -1388,6 +1400,10 @@ def _load_model( if load_engine: logger.info("TensorRT UNet engine loaded successfully") + # Clean up VRAM after UNet compilation when only compiling engines + if compile_engines_only: + self.cleanup_gpu_memory_for_compilation() + except Exception as e: error_msg = str(e).lower() is_oom_error = ('out of memory' in error_msg or 'outofmemory' in error_msg or @@ -1506,6 +1522,10 @@ def _load_model( cuda_stream=None, load_engine=load_engine, ) + + # Clean up VRAM after safety checker compilation when only compiling engines + if compile_engines_only: + self.cleanup_gpu_memory_for_compilation() if load_engine: self.safety_checker = NSFWDetectorEngine( @@ -1585,6 +1605,10 @@ def _load_model( logger.info(f"Compiled/loaded {len(compiled_cn_engines)} ControlNet TensorRT engine(s)") except Exception: pass + + # Clean up VRAM after ControlNet compilation when only compiling engines + if compile_engines_only: + self.cleanup_gpu_memory_for_compilation() except Exception: import traceback traceback.print_exc() @@ -1859,6 +1883,27 @@ def cleanup_gpu_memory(self) -> None: logger.info(" Enhanced GPU memory cleanup complete") + def cleanup_gpu_memory_for_compilation(self) -> None: + """Lightweight GPU memory cleanup specifically for TensorRT engine compilation.""" + import gc + import torch + + logger.info("Cleaning up GPU memory for engine compilation...") + + # Force garbage collection + for i in range(2): + gc.collect() + + # Clear CUDA cache + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.synchronize() + + # Get memory info for logging + allocated = torch.cuda.memory_allocated() / (1024**3) # GB + cached = torch.cuda.memory_reserved() / (1024**3) # GB + logger.info(f" GPU Memory after compilation cleanup: {allocated:.2f}GB allocated, {cached:.2f}GB cached") + def check_gpu_memory_for_engine(self, engine_size_gb: float) -> bool: """ Check if there's enough GPU memory to load a TensorRT engine.