diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index 38f291f5203c..8b6d734f1e3f 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -17,7 +17,7 @@ from contextlib import contextmanager, nullcontext from dataclasses import dataclass from enum import Enum -from typing import Dict, List, Optional, Set, Tuple, Union +from typing import Callable, Dict, List, Optional, Set, Tuple, Union import safetensors.torch import torch @@ -59,6 +59,8 @@ class GroupOffloadingConfig: num_blocks_per_group: Optional[int] = None offload_to_disk_path: Optional[str] = None stream: Optional[Union[torch.cuda.Stream, torch.Stream]] = None + block_modules: Optional[List[str]] = None + pin_groups: Optional[Union[str, Callable]] = None class ModuleGroup: @@ -77,7 +79,7 @@ def __init__( low_cpu_mem_usage: bool = False, onload_self: bool = True, offload_to_disk_path: Optional[str] = None, - group_id: Optional[int] = None, + group_id: Optional[Union[int, str]] = None, ) -> None: self.modules = modules self.offload_device = offload_device @@ -91,6 +93,7 @@ def __init__( self.record_stream = record_stream self.onload_self = onload_self self.low_cpu_mem_usage = low_cpu_mem_usage + self.pinned = False self.offload_to_disk_path = offload_to_disk_path self._is_offloaded_to_disk = False @@ -296,6 +299,24 @@ def pre_forward(self, module: torch.nn.Module, *args, **kwargs): if self.group.onload_leader is None: self.group.onload_leader = module + if self.group.pinned: + if self.group.onload_leader == module and not self._is_group_on_device(): + self.group.onload_() + + should_onload_next_group = self.next_group is not None and not self.next_group.onload_self + if should_onload_next_group: + self.next_group.onload_() + + should_synchronize = ( + not self.group.onload_self and self.group.stream is not None and not should_onload_next_group + ) + if should_synchronize: + self.group.stream.synchronize() + + args = send_to_device(args, self.group.onload_device, non_blocking=self.group.non_blocking) + kwargs = send_to_device(kwargs, self.group.onload_device, non_blocking=self.group.non_blocking) + return args, kwargs + # If the current module is the onload_leader of the group, we onload the group if it is supposed # to onload itself. In the case of using prefetching with streams, we onload the next group if # it is not supposed to onload itself. @@ -324,10 +345,26 @@ def pre_forward(self, module: torch.nn.Module, *args, **kwargs): return args, kwargs def post_forward(self, module: torch.nn.Module, output): + if self.group.pinned: + return output + if self.group.offload_leader == module: self.group.offload_() return output + def _is_group_on_device(self) -> bool: + tensors = [] + for group_module in self.group.modules: + tensors.extend(list(group_module.parameters())) + tensors.extend(list(group_module.buffers())) + tensors.extend(self.group.parameters) + tensors.extend(self.group.buffers) + + if len(tensors) == 0: + return True + + return all(t.device == self.group.onload_device for t in tensors) + class LazyPrefetchGroupOffloadingHook(ModelHook): r""" @@ -423,6 +460,51 @@ def post_forward(self, module, output): group_offloading_hooks[i].next_group = group_offloading_hooks[i + 1].group group_offloading_hooks[i].next_group.onload_self = False + pin_groups = getattr(base_module_registry, "_group_offload_pin_groups", None) + if pin_groups is not None and num_executed > 0: + param_exec_info = [] + for idx, ((name, submodule), hook) in enumerate(zip(self.execution_order, group_offloading_hooks)): + if hook is None: + continue + if next(submodule.parameters(), None) is None and next(submodule.buffers(), None) is None: + continue + param_exec_info.append((name, submodule, hook)) + + num_param_modules = len(param_exec_info) + if num_param_modules > 0: + pinned_indices = set() + if isinstance(pin_groups, str): + if pin_groups == "all": + pinned_indices = set(range(num_param_modules)) + elif pin_groups == "first_last": + pinned_indices.add(0) + pinned_indices.add(num_param_modules - 1) + elif callable(pin_groups): + for idx, (name, submodule, _) in enumerate(param_exec_info): + should_pin = False + try: + should_pin = bool(pin_groups(submodule)) + except TypeError: + try: + should_pin = bool(pin_groups(name, submodule)) + except TypeError: + should_pin = bool(pin_groups(name, submodule, idx)) + if should_pin: + pinned_indices.add(idx) + + pinned_groups = set() + for idx in pinned_indices: + if idx >= num_param_modules: + continue + group = param_exec_info[idx][2].group + if group not in pinned_groups: + group.pinned = True + pinned_groups.add(group) + + for group in pinned_groups: + if group.offload_device != group.onload_device: + group.onload_() + return output @@ -453,6 +535,9 @@ def apply_group_offloading( record_stream: bool = False, low_cpu_mem_usage: bool = False, offload_to_disk_path: Optional[str] = None, + block_modules: Optional[List[str]] = None, + pin_groups: Optional[Union[str, Callable]] = None, + pin_first_last: bool = False, ) -> None: r""" Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is, and @@ -510,6 +595,15 @@ def apply_group_offloading( If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them. This option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when the CPU memory is a bottleneck but may counteract the benefits of using streams. + block_modules (`List[str]`, *optional*): + List of module names that should be treated as blocks for offloading. If provided, only these modules + will be considered for block-level offloading. If not provided, the default block detection logic will be used. + pin_groups (`"first_last"` or `"all"` or `Callable`, *optional*, defaults to `None`): + Optionally keeps selected groups on the onload device permanently. Use `"first_last"` to pin the first + and last parameter-bearing groups, `"all"` to pin every parameter-bearing group, or pass a callable that + receives a module (and optionally the module name and index) and returns `True` to pin that group. + pin_first_last (`bool`, *optional*, defaults to `False`): + Deprecated alias for `pin_groups="first_last"`. Example: ```python @@ -549,7 +643,24 @@ def apply_group_offloading( if offload_type == GroupOffloadingType.BLOCK_LEVEL and num_blocks_per_group is None: raise ValueError("`num_blocks_per_group` must be provided when using `offload_type='block_level'.") + if pin_first_last: + if pin_groups is not None and pin_groups != "first_last": + raise ValueError("`pin_first_last` cannot be combined with a different `pin_groups` setting.") + pin_groups = "first_last" + + normalized_pin_groups = pin_groups + if isinstance(pin_groups, str): + normalized_pin_groups = pin_groups.lower() + if normalized_pin_groups not in {"first_last", "all"}: + raise ValueError("`pin_groups` must be one of `None`, 'first_last', 'all', or a callable.") + elif pin_groups is not None and not callable(pin_groups): + raise ValueError("`pin_groups` must be one of `None`, 'first_last', 'all', or a callable.") + + pin_groups = normalized_pin_groups + _raise_error_if_accelerate_model_or_sequential_hook_present(module) + registry = HookRegistry.check_if_exists_or_initialize(module) + registry._group_offload_pin_groups = pin_groups config = GroupOffloadingConfig( onload_device=onload_device, @@ -561,11 +672,16 @@ def apply_group_offloading( record_stream=record_stream, low_cpu_mem_usage=low_cpu_mem_usage, offload_to_disk_path=offload_to_disk_path, + block_modules=block_modules, + pin_groups=pin_groups, ) _apply_group_offloading(module, config) def _apply_group_offloading(module: torch.nn.Module, config: GroupOffloadingConfig) -> None: + registry = HookRegistry.check_if_exists_or_initialize(module) + registry._group_offload_pin_groups = config.pin_groups + if config.offload_type == GroupOffloadingType.BLOCK_LEVEL: _apply_group_offloading_block_level(module, config) elif config.offload_type == GroupOffloadingType.LEAF_LEVEL: @@ -576,28 +692,123 @@ def _apply_group_offloading(module: torch.nn.Module, config: GroupOffloadingConf def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOffloadingConfig) -> None: r""" - This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks. In comparison to - the "leaf_level" offloading, which is more fine-grained, this offloading is done at the top-level blocks. - """ + This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks, and explicitly + defined block modules. In comparison to the "leaf_level" offloading, which is more fine-grained, this offloading + is done at the top-level blocks and modules specified in block_modules. + When block_modules is provided, only those modules will be treated as blocks for offloading. For each specified + module, we either offload the entire submodule or recursively apply block offloading to it. + """ if config.stream is not None and config.num_blocks_per_group != 1: logger.warning( f"Using streams is only supported for num_blocks_per_group=1. Got {config.num_blocks_per_group=}. Setting it to 1." ) config.num_blocks_per_group = 1 - # Create module groups for ModuleList and Sequential blocks + block_modules = set(config.block_modules) if config.block_modules is not None else set() + + # Create module groups for ModuleList and Sequential blocks, and explicitly defined block modules modules_with_group_offloading = set() unmatched_modules = [] matched_module_groups = [] + for name, submodule in module.named_children(): - if not isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)): + # Check if this is an explicitly defined block module + if name in block_modules: + # Apply block offloading to the specified submodule + _apply_block_offloading_to_submodule( + submodule, name, config, modules_with_group_offloading, matched_module_groups + ) + elif isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)): + # Handle ModuleList and Sequential blocks as before + for i in range(0, len(submodule), config.num_blocks_per_group): + current_modules = list(submodule[i : i + config.num_blocks_per_group]) + if len(current_modules) == 0: + continue + + group_id = f"{name}_{i}_{i + len(current_modules) - 1}" + group = ModuleGroup( + modules=current_modules, + offload_device=config.offload_device, + onload_device=config.onload_device, + offload_to_disk_path=config.offload_to_disk_path, + offload_leader=current_modules[-1], + onload_leader=current_modules[0], + non_blocking=config.non_blocking, + stream=config.stream, + record_stream=config.record_stream, + low_cpu_mem_usage=config.low_cpu_mem_usage, + onload_self=True, + group_id=group_id, + ) + matched_module_groups.append(group) + for j in range(i, i + len(current_modules)): + modules_with_group_offloading.add(f"{name}.{j}") + else: + # This is an unmatched module unmatched_modules.append((name, submodule)) - modules_with_group_offloading.add(name) - continue + # Apply group offloading hooks to the module groups + for i, group in enumerate(matched_module_groups): + for group_module in group.modules: + _apply_group_offloading_hook(group_module, group, config=config) + + # Parameters and Buffers of the top-level module need to be offloaded/onloaded separately + # when the forward pass of this module is called. This is because the top-level module is not + # part of any group (as doing so would lead to no VRAM savings). + parameters = _gather_parameters_with_no_group_offloading_parent(module, modules_with_group_offloading) + buffers = _gather_buffers_with_no_group_offloading_parent(module, modules_with_group_offloading) + parameters = [param for _, param in parameters] + buffers = [buffer for _, buffer in buffers] + + # Create a group for the remaining unmatched submodules of the top-level + # module so that they are on the correct device when the forward pass is called. + unmatched_modules = [unmatched_module for _, unmatched_module in unmatched_modules] + if len(unmatched_modules) > 0 or len(parameters) > 0 or len(buffers) > 0: + unmatched_group = ModuleGroup( + modules=unmatched_modules, + offload_device=config.offload_device, + onload_device=config.onload_device, + offload_to_disk_path=config.offload_to_disk_path, + offload_leader=module, + onload_leader=module, + parameters=parameters, + buffers=buffers, + non_blocking=False, + stream=None, + record_stream=False, + onload_self=True, + group_id=f"{module.__class__.__name__}_unmatched_group", + ) + if config.stream is None: + _apply_group_offloading_hook(module, unmatched_group, config=config) + else: + _apply_lazy_group_offloading_hook(module, unmatched_group, config=config) + + +def _apply_block_offloading_to_submodule( + submodule: torch.nn.Module, + name: str, + config: GroupOffloadingConfig, + modules_with_group_offloading: Set[str], + matched_module_groups: List[ModuleGroup], +) -> None: + r""" + Apply block offloading to a explicitly defined submodule. This function either: + 1. Offloads the entire submodule as a single group ( SIMPLE APPROACH) + 2. Recursively applies block offloading to the submodule + + For now, we use the simple approach - offload the entire submodule as a single group. + """ + # Simple approach: offload the entire submodule as a single group + # Since AEs are typically small, this is usually okay + if isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)): + # If it's a ModuleList or Sequential, apply the normal block-level logic for i in range(0, len(submodule), config.num_blocks_per_group): - current_modules = submodule[i : i + config.num_blocks_per_group] + current_modules = list(submodule[i : i + config.num_blocks_per_group]) + if len(current_modules) == 0: + continue + group_id = f"{name}_{i}_{i + len(current_modules) - 1}" group = ModuleGroup( modules=current_modules, @@ -616,42 +827,24 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf matched_module_groups.append(group) for j in range(i, i + len(current_modules)): modules_with_group_offloading.add(f"{name}.{j}") - - # Apply group offloading hooks to the module groups - for i, group in enumerate(matched_module_groups): - for group_module in group.modules: - _apply_group_offloading_hook(group_module, group, config=config) - - # Parameters and Buffers of the top-level module need to be offloaded/onloaded separately - # when the forward pass of this module is called. This is because the top-level module is not - # part of any group (as doing so would lead to no VRAM savings). - parameters = _gather_parameters_with_no_group_offloading_parent(module, modules_with_group_offloading) - buffers = _gather_buffers_with_no_group_offloading_parent(module, modules_with_group_offloading) - parameters = [param for _, param in parameters] - buffers = [buffer for _, buffer in buffers] - - # Create a group for the unmatched submodules of the top-level module so that they are on the correct - # device when the forward pass is called. - unmatched_modules = [unmatched_module for _, unmatched_module in unmatched_modules] - unmatched_group = ModuleGroup( - modules=unmatched_modules, - offload_device=config.offload_device, - onload_device=config.onload_device, - offload_to_disk_path=config.offload_to_disk_path, - offload_leader=module, - onload_leader=module, - parameters=parameters, - buffers=buffers, - non_blocking=False, - stream=None, - record_stream=False, - onload_self=True, - group_id=f"{module.__class__.__name__}_unmatched_group", - ) - if config.stream is None: - _apply_group_offloading_hook(module, unmatched_group, config=config) else: - _apply_lazy_group_offloading_hook(module, unmatched_group, config=config) + # For other modules, treat the entire submodule as a single group + group = ModuleGroup( + modules=[submodule], + offload_device=config.offload_device, + onload_device=config.onload_device, + offload_to_disk_path=config.offload_to_disk_path, + offload_leader=submodule, + onload_leader=submodule, + non_blocking=config.non_blocking, + stream=config.stream, + record_stream=config.record_stream, + low_cpu_mem_usage=config.low_cpu_mem_usage, + onload_self=True, + group_id=name, + ) + matched_module_groups.append(group) + modules_with_group_offloading.add(name) def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOffloadingConfig) -> None: diff --git a/src/diffusers/models/autoencoders/autoencoder_kl.py b/src/diffusers/models/autoencoders/autoencoder_kl.py index ffc8778e7aca..4096b7c07609 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl.py @@ -72,6 +72,7 @@ class AutoencoderKL(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModel _supports_gradient_checkpointing = True _no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D"] + _group_offload_block_modules = ["quant_conv", "post_quant_conv", "encoder", "decoder"] @register_to_config def __init__( diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py index b0b2960aaf18..6b29a6273cd9 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py @@ -964,6 +964,7 @@ class AutoencoderKLWan(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalMo # keys toignore when AlignDeviceHook moves inputs/outputs between devices # these are shared mutable state modified in-place _skip_keys = ["feat_cache", "feat_idx"] + _group_offload_block_modules = ["quant_conv", "post_quant_conv", "encoder", "decoder"] @register_to_config def __init__( diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index f06822c741ca..86d2024f0a95 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -531,6 +531,8 @@ def enable_group_offload( record_stream: bool = False, low_cpu_mem_usage=False, offload_to_disk_path: Optional[str] = None, + pin_groups: Optional[Union[str, Callable]] = None, + pin_first_last: bool = False, ) -> None: r""" Activates group offloading for the current model. @@ -570,6 +572,7 @@ def enable_group_offload( f"`_supports_group_offloading` to `True` in the class definition. If you believe this is a mistake, please " f"open an issue at https://github.com/huggingface/diffusers/issues." ) + block_modules = getattr(self, "_group_offload_block_modules", None) apply_group_offloading( module=self, onload_device=onload_device, @@ -581,6 +584,9 @@ def enable_group_offload( record_stream=record_stream, low_cpu_mem_usage=low_cpu_mem_usage, offload_to_disk_path=offload_to_disk_path, + block_modules=block_modules, + pin_groups=pin_groups, + pin_first_last=pin_first_last, ) def set_attention_backend(self, backend: str) -> None: diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 392d5fb3feb4..d0fab44a6187 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -1342,6 +1342,8 @@ def enable_group_offload( low_cpu_mem_usage=False, offload_to_disk_path: Optional[str] = None, exclude_modules: Optional[Union[str, List[str]]] = None, + pin_groups: Optional[Union[str, Callable]] = None, + pin_first_last: bool = False, ) -> None: r""" Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is, @@ -1402,6 +1404,11 @@ def enable_group_offload( This option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when the CPU memory is a bottleneck but may counteract the benefits of using streams. exclude_modules (`Union[str, List[str]]`, defaults to `None`): List of modules to exclude from offloading. + pin_groups (`\"first_last\"` | `\"all\"` | `Callable`, *optional*): + Optionally keep selected groups on the onload device permanently. See `ModelMixin.enable_group_offload` + for details. + pin_first_last (`bool`, *optional*, defaults to `False`): + Deprecated alias for `pin_groups=\"first_last\"`. Example: ```python @@ -1442,6 +1449,8 @@ def enable_group_offload( "record_stream": record_stream, "low_cpu_mem_usage": low_cpu_mem_usage, "offload_to_disk_path": offload_to_disk_path, + "pin_groups": pin_groups, + "pin_first_last": pin_first_last, } for name, component in self.components.items(): if name not in exclude_modules and isinstance(component, torch.nn.Module): diff --git a/tests/hooks/test_group_offloading.py b/tests/hooks/test_group_offloading.py index 96cbecfbf530..00b8f2df98e5 100644 --- a/tests/hooks/test_group_offloading.py +++ b/tests/hooks/test_group_offloading.py @@ -362,3 +362,84 @@ def apply_layer_output_tracker_hook(model: DummyModelWithLayerNorm): self.assertLess( cumulated_absmax, 1e-5, f"Output differences for {name} exceeded threshold: {cumulated_absmax:.5f}" ) + + def test_block_level_pin_first_last_groups_stay_on_device(self): + if torch.device(torch_device).type not in ["cuda", "xpu"]: + return + + def first_param_device(mod): + p = next(mod.parameters(), None) + self.assertIsNotNone(p, f"No parameters found for module {mod}") + return p.device + + def assert_all_modules_device(mods, expected_type: str, msg: str = ""): + bad = [] + for i, m in enumerate(mods): + dev_type = first_param_device(m).type + if dev_type != expected_type: + bad.append((i, m.__class__.__name__, dev_type)) + self.assertFalse( + bad, + (msg + "\n" if msg else "") + + f"Expected all modules on {expected_type}, but found mismatches: {bad}", + ) + + def get_param_modules_from_exec_order(model): + root_registry = HookRegistry.check_if_exists_or_initialize(model) + + lazy_hook = root_registry.get_hook("lazy_prefetch_group_offloading") + self.assertIsNotNone(lazy_hook, "lazy_prefetch_group_offloading hook was not registered") + + #record execution order with first forward + with torch.no_grad(): + model(self.input) + + mods = [m for _, m in lazy_hook.execution_order] + param_mods = [m for m in mods if next(m.parameters(), None) is not None] + self.assertGreaterEqual( + len(param_mods), 2, f"Expected >=2 param-bearing modules in execution_order, got {len(param_mods)}" + ) + + first = param_mods[0] + last = param_mods[-1] + middle_layers = param_mods[1:-1] + return first, middle_layers, last + + accel_type = torch.device(torch_device).type + + model_no_pin = self.get_model() + model_no_pin.enable_group_offload( + torch_device, + offload_type="block_level", + num_blocks_per_group=1, + use_stream=True, + ) + model_no_pin.eval() + first, middle, last = get_param_modules_from_exec_order(model_no_pin) + + self.assertEqual(first_param_device(first).type, "cpu") + self.assertEqual(first_param_device(last).type, "cpu") + assert_all_modules_device(middle, "cpu", msg="No-pin: expected ALL middle layers on CPU") + + model_pin = self.get_model() + model_pin.enable_group_offload( + torch_device, + offload_type="block_level", + num_blocks_per_group=1, + use_stream=True, + pin_first_last=True, + ) + model_pin.eval() + first, middle, last = get_param_modules_from_exec_order(model_pin) + + self.assertEqual(first_param_device(first).type, accel_type) + self.assertEqual(first_param_device(last).type, accel_type) + assert_all_modules_device(middle, "cpu", msg="Pin: expected ALL middle layers on CPU") + + # Should still hold after another invocation + with torch.no_grad(): + model_pin(self.input) + + self.assertEqual(first_param_device(first).type, accel_type) + self.assertEqual(first_param_device(last).type, accel_type) + assert_all_modules_device(middle, "cpu", msg="Pin (2nd forward): expected ALL middle layers on CPU")