Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/streamdiffusion/acceleration/tensorrt/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ def build(
del self.network
gc.collect()
torch.cuda.empty_cache()
# Additional cleanup for compile_engines_only mode
torch.cuda.synchronize()
torch.cuda.ipc_collect()
if not force_onnx_optimize and os.path.exists(onnx_opt_path):
print(f"Found cached model: {onnx_opt_path}")
else:
Expand All @@ -72,6 +75,9 @@ def build(
onnx_opt_path=onnx_opt_path,
model_data=self.model,
)
# Cleanup after ONNX optimization
gc.collect()
torch.cuda.empty_cache()
self.model.min_latent_shape = min_image_resolution // 8
self.model.max_latent_shape = max_image_resolution // 8
if not force_engine_build and os.path.exists(engine_path):
Expand All @@ -97,3 +103,6 @@ def build(

gc.collect()
torch.cuda.empty_cache()
# Additional cleanup after engine build
torch.cuda.synchronize()
torch.cuda.ipc_collect()
4 changes: 2 additions & 2 deletions src/streamdiffusion/modules/controlnet_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def install(self, stream) -> None:

def add_controlnet(self, cfg: ControlNetConfig, control_image: Optional[Union[str, Any, torch.Tensor]] = None) -> None:
model = self._load_pytorch_controlnet_model(cfg.model_id)
model = model.to(dtype=self.dtype)
model = model.to(device=self.device, dtype=self.dtype)

preproc = None
if cfg.preprocessor:
Expand Down Expand Up @@ -595,7 +595,7 @@ def _load_pytorch_controlnet_model(self, model_id: str) -> ControlNetModel:
controlnet = ControlNetModel.from_pretrained(
model_id, torch_dtype=self.dtype
)
controlnet = controlnet.to(dtype=self.dtype)
controlnet = controlnet.to(device=self.device, dtype=self.dtype)
# Track model_id for updater diffing
try:
setattr(controlnet, 'model_id', model_id)
Expand Down
45 changes: 45 additions & 0 deletions src/streamdiffusion/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -1043,6 +1043,10 @@ def _load_model(
from streamdiffusion.acceleration.tensorrt.export_wrappers.unet_controlnet_export import create_controlnet_wrapper
from streamdiffusion.acceleration.tensorrt.export_wrappers.unet_ipadapter_export import create_ipadapter_wrapper

# Clean up VRAM before starting TensorRT compilation when only compiling engines
if compile_engines_only:
self.cleanup_gpu_memory_for_compilation()

# Legacy TensorRT implementation (fallback)
# Initialize engine manager
engine_manager = EngineManager(engine_dir)
Expand Down Expand Up @@ -1334,6 +1338,10 @@ def _load_model(
'max_image_resolution': 1024,
}
)

# Clean up VRAM after VAE decoder compilation when only compiling engines
if compile_engines_only:
self.cleanup_gpu_memory_for_compilation()

# Compile VAE encoder engine using EngineManager
vae_encoder = TorchVAEEncoder(stream.vae)
Expand All @@ -1359,6 +1367,10 @@ def _load_model(
'max_image_resolution': 1024,
}
)

# Clean up VRAM after VAE encoder compilation when only compiling engines
if compile_engines_only:
self.cleanup_gpu_memory_for_compilation()

cuda_stream = cuda.Stream()

Expand Down Expand Up @@ -1388,6 +1400,10 @@ def _load_model(
if load_engine:
logger.info("TensorRT UNet engine loaded successfully")

# Clean up VRAM after UNet compilation when only compiling engines
if compile_engines_only:
self.cleanup_gpu_memory_for_compilation()

except Exception as e:
error_msg = str(e).lower()
is_oom_error = ('out of memory' in error_msg or 'outofmemory' in error_msg or
Expand Down Expand Up @@ -1506,6 +1522,10 @@ def _load_model(
cuda_stream=None,
load_engine=load_engine,
)

# Clean up VRAM after safety checker compilation when only compiling engines
if compile_engines_only:
self.cleanup_gpu_memory_for_compilation()

if load_engine:
self.safety_checker = NSFWDetectorEngine(
Expand Down Expand Up @@ -1585,6 +1605,10 @@ def _load_model(
logger.info(f"Compiled/loaded {len(compiled_cn_engines)} ControlNet TensorRT engine(s)")
except Exception:
pass

# Clean up VRAM after ControlNet compilation when only compiling engines
if compile_engines_only:
self.cleanup_gpu_memory_for_compilation()
except Exception:
import traceback
traceback.print_exc()
Expand Down Expand Up @@ -1859,6 +1883,27 @@ def cleanup_gpu_memory(self) -> None:

logger.info(" Enhanced GPU memory cleanup complete")

def cleanup_gpu_memory_for_compilation(self) -> None:
"""Lightweight GPU memory cleanup specifically for TensorRT engine compilation."""
import gc
import torch

logger.info("Cleaning up GPU memory for engine compilation...")

# Force garbage collection
for i in range(2):
gc.collect()

# Clear CUDA cache
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()

# Get memory info for logging
allocated = torch.cuda.memory_allocated() / (1024**3) # GB
cached = torch.cuda.memory_reserved() / (1024**3) # GB
logger.info(f" GPU Memory after compilation cleanup: {allocated:.2f}GB allocated, {cached:.2f}GB cached")

def check_gpu_memory_for_engine(self, engine_size_gb: float) -> bool:
"""
Check if there's enough GPU memory to load a TensorRT engine.
Expand Down