Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
283 changes: 238 additions & 45 deletions src/diffusers/hooks/group_offloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
from enum import Enum
from typing import Dict, List, Optional, Set, Tuple, Union
from typing import Callable, Dict, List, Optional, Set, Tuple, Union

import safetensors.torch
import torch
Expand Down Expand Up @@ -59,6 +59,8 @@ class GroupOffloadingConfig:
num_blocks_per_group: Optional[int] = None
offload_to_disk_path: Optional[str] = None
stream: Optional[Union[torch.cuda.Stream, torch.Stream]] = None
block_modules: Optional[List[str]] = None
pin_groups: Optional[Union[str, Callable]] = None


class ModuleGroup:
Expand All @@ -77,7 +79,7 @@ def __init__(
low_cpu_mem_usage: bool = False,
onload_self: bool = True,
offload_to_disk_path: Optional[str] = None,
group_id: Optional[int] = None,
group_id: Optional[Union[int, str]] = None,
) -> None:
self.modules = modules
self.offload_device = offload_device
Expand All @@ -91,6 +93,7 @@ def __init__(
self.record_stream = record_stream
self.onload_self = onload_self
self.low_cpu_mem_usage = low_cpu_mem_usage
self.pinned = False

self.offload_to_disk_path = offload_to_disk_path
self._is_offloaded_to_disk = False
Expand Down Expand Up @@ -296,6 +299,24 @@ def pre_forward(self, module: torch.nn.Module, *args, **kwargs):
if self.group.onload_leader is None:
self.group.onload_leader = module

if self.group.pinned:
if self.group.onload_leader == module and not self._is_group_on_device():
self.group.onload_()

should_onload_next_group = self.next_group is not None and not self.next_group.onload_self
if should_onload_next_group:
self.next_group.onload_()

should_synchronize = (
not self.group.onload_self and self.group.stream is not None and not should_onload_next_group
)
if should_synchronize:
self.group.stream.synchronize()

args = send_to_device(args, self.group.onload_device, non_blocking=self.group.non_blocking)
kwargs = send_to_device(kwargs, self.group.onload_device, non_blocking=self.group.non_blocking)
return args, kwargs

# If the current module is the onload_leader of the group, we onload the group if it is supposed
# to onload itself. In the case of using prefetching with streams, we onload the next group if
# it is not supposed to onload itself.
Expand Down Expand Up @@ -324,10 +345,26 @@ def pre_forward(self, module: torch.nn.Module, *args, **kwargs):
return args, kwargs

def post_forward(self, module: torch.nn.Module, output):
if self.group.pinned:
return output

if self.group.offload_leader == module:
self.group.offload_()
return output

def _is_group_on_device(self) -> bool:
tensors = []
for group_module in self.group.modules:
tensors.extend(list(group_module.parameters()))
tensors.extend(list(group_module.buffers()))
tensors.extend(self.group.parameters)
tensors.extend(self.group.buffers)

if len(tensors) == 0:
return True

return all(t.device == self.group.onload_device for t in tensors)


class LazyPrefetchGroupOffloadingHook(ModelHook):
r"""
Expand Down Expand Up @@ -423,6 +460,51 @@ def post_forward(self, module, output):
group_offloading_hooks[i].next_group = group_offloading_hooks[i + 1].group
group_offloading_hooks[i].next_group.onload_self = False

pin_groups = getattr(base_module_registry, "_group_offload_pin_groups", None)
if pin_groups is not None and num_executed > 0:
param_exec_info = []
for idx, ((name, submodule), hook) in enumerate(zip(self.execution_order, group_offloading_hooks)):
if hook is None:
continue
if next(submodule.parameters(), None) is None and next(submodule.buffers(), None) is None:
continue
param_exec_info.append((name, submodule, hook))

num_param_modules = len(param_exec_info)
if num_param_modules > 0:
pinned_indices = set()
if isinstance(pin_groups, str):
if pin_groups == "all":
pinned_indices = set(range(num_param_modules))
elif pin_groups == "first_last":
pinned_indices.add(0)
pinned_indices.add(num_param_modules - 1)
elif callable(pin_groups):
for idx, (name, submodule, _) in enumerate(param_exec_info):
should_pin = False
try:
should_pin = bool(pin_groups(submodule))
except TypeError:
try:
should_pin = bool(pin_groups(name, submodule))
except TypeError:
should_pin = bool(pin_groups(name, submodule, idx))
if should_pin:
pinned_indices.add(idx)

pinned_groups = set()
for idx in pinned_indices:
if idx >= num_param_modules:
continue
group = param_exec_info[idx][2].group
if group not in pinned_groups:
group.pinned = True
pinned_groups.add(group)

for group in pinned_groups:
if group.offload_device != group.onload_device:
group.onload_()

return output


Expand Down Expand Up @@ -453,6 +535,9 @@ def apply_group_offloading(
record_stream: bool = False,
low_cpu_mem_usage: bool = False,
offload_to_disk_path: Optional[str] = None,
block_modules: Optional[List[str]] = None,
pin_groups: Optional[Union[str, Callable]] = None,
pin_first_last: bool = False,
) -> None:
r"""
Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is, and
Expand Down Expand Up @@ -510,6 +595,15 @@ def apply_group_offloading(
If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them. This
option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when
the CPU memory is a bottleneck but may counteract the benefits of using streams.
block_modules (`List[str]`, *optional*):
List of module names that should be treated as blocks for offloading. If provided, only these modules
will be considered for block-level offloading. If not provided, the default block detection logic will be used.
pin_groups (`"first_last"` or `"all"` or `Callable`, *optional*, defaults to `None`):
Optionally keeps selected groups on the onload device permanently. Use `"first_last"` to pin the first
and last parameter-bearing groups, `"all"` to pin every parameter-bearing group, or pass a callable that
receives a module (and optionally the module name and index) and returns `True` to pin that group.
pin_first_last (`bool`, *optional*, defaults to `False`):
Deprecated alias for `pin_groups="first_last"`.

Example:
```python
Expand Down Expand Up @@ -549,7 +643,24 @@ def apply_group_offloading(
if offload_type == GroupOffloadingType.BLOCK_LEVEL and num_blocks_per_group is None:
raise ValueError("`num_blocks_per_group` must be provided when using `offload_type='block_level'.")

if pin_first_last:
if pin_groups is not None and pin_groups != "first_last":
raise ValueError("`pin_first_last` cannot be combined with a different `pin_groups` setting.")
pin_groups = "first_last"

normalized_pin_groups = pin_groups
if isinstance(pin_groups, str):
normalized_pin_groups = pin_groups.lower()
if normalized_pin_groups not in {"first_last", "all"}:
raise ValueError("`pin_groups` must be one of `None`, 'first_last', 'all', or a callable.")
elif pin_groups is not None and not callable(pin_groups):
raise ValueError("`pin_groups` must be one of `None`, 'first_last', 'all', or a callable.")

pin_groups = normalized_pin_groups

_raise_error_if_accelerate_model_or_sequential_hook_present(module)
registry = HookRegistry.check_if_exists_or_initialize(module)
registry._group_offload_pin_groups = pin_groups

config = GroupOffloadingConfig(
onload_device=onload_device,
Expand All @@ -561,11 +672,16 @@ def apply_group_offloading(
record_stream=record_stream,
low_cpu_mem_usage=low_cpu_mem_usage,
offload_to_disk_path=offload_to_disk_path,
block_modules=block_modules,
pin_groups=pin_groups,
)
_apply_group_offloading(module, config)


def _apply_group_offloading(module: torch.nn.Module, config: GroupOffloadingConfig) -> None:
registry = HookRegistry.check_if_exists_or_initialize(module)
registry._group_offload_pin_groups = config.pin_groups

if config.offload_type == GroupOffloadingType.BLOCK_LEVEL:
_apply_group_offloading_block_level(module, config)
elif config.offload_type == GroupOffloadingType.LEAF_LEVEL:
Expand All @@ -576,28 +692,123 @@ def _apply_group_offloading(module: torch.nn.Module, config: GroupOffloadingConf

def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOffloadingConfig) -> None:
r"""
This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks. In comparison to
the "leaf_level" offloading, which is more fine-grained, this offloading is done at the top-level blocks.
"""
This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks, and explicitly
defined block modules. In comparison to the "leaf_level" offloading, which is more fine-grained, this offloading
is done at the top-level blocks and modules specified in block_modules.

When block_modules is provided, only those modules will be treated as blocks for offloading. For each specified
module, we either offload the entire submodule or recursively apply block offloading to it.
"""
if config.stream is not None and config.num_blocks_per_group != 1:
logger.warning(
f"Using streams is only supported for num_blocks_per_group=1. Got {config.num_blocks_per_group=}. Setting it to 1."
)
config.num_blocks_per_group = 1

# Create module groups for ModuleList and Sequential blocks
block_modules = set(config.block_modules) if config.block_modules is not None else set()

# Create module groups for ModuleList and Sequential blocks, and explicitly defined block modules
modules_with_group_offloading = set()
unmatched_modules = []
matched_module_groups = []

for name, submodule in module.named_children():
if not isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)):
# Check if this is an explicitly defined block module
if name in block_modules:
# Apply block offloading to the specified submodule
_apply_block_offloading_to_submodule(
submodule, name, config, modules_with_group_offloading, matched_module_groups
)
elif isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)):
# Handle ModuleList and Sequential blocks as before
for i in range(0, len(submodule), config.num_blocks_per_group):
current_modules = list(submodule[i : i + config.num_blocks_per_group])
if len(current_modules) == 0:
continue

group_id = f"{name}_{i}_{i + len(current_modules) - 1}"
group = ModuleGroup(
modules=current_modules,
offload_device=config.offload_device,
onload_device=config.onload_device,
offload_to_disk_path=config.offload_to_disk_path,
offload_leader=current_modules[-1],
onload_leader=current_modules[0],
non_blocking=config.non_blocking,
stream=config.stream,
record_stream=config.record_stream,
low_cpu_mem_usage=config.low_cpu_mem_usage,
onload_self=True,
group_id=group_id,
)
matched_module_groups.append(group)
for j in range(i, i + len(current_modules)):
modules_with_group_offloading.add(f"{name}.{j}")
else:
# This is an unmatched module
unmatched_modules.append((name, submodule))
modules_with_group_offloading.add(name)
continue

# Apply group offloading hooks to the module groups
for i, group in enumerate(matched_module_groups):
for group_module in group.modules:
_apply_group_offloading_hook(group_module, group, config=config)

# Parameters and Buffers of the top-level module need to be offloaded/onloaded separately
# when the forward pass of this module is called. This is because the top-level module is not
# part of any group (as doing so would lead to no VRAM savings).
parameters = _gather_parameters_with_no_group_offloading_parent(module, modules_with_group_offloading)
buffers = _gather_buffers_with_no_group_offloading_parent(module, modules_with_group_offloading)
parameters = [param for _, param in parameters]
buffers = [buffer for _, buffer in buffers]

# Create a group for the remaining unmatched submodules of the top-level
# module so that they are on the correct device when the forward pass is called.
unmatched_modules = [unmatched_module for _, unmatched_module in unmatched_modules]
if len(unmatched_modules) > 0 or len(parameters) > 0 or len(buffers) > 0:
unmatched_group = ModuleGroup(
modules=unmatched_modules,
offload_device=config.offload_device,
onload_device=config.onload_device,
offload_to_disk_path=config.offload_to_disk_path,
offload_leader=module,
onload_leader=module,
parameters=parameters,
buffers=buffers,
non_blocking=False,
stream=None,
record_stream=False,
onload_self=True,
group_id=f"{module.__class__.__name__}_unmatched_group",
)
if config.stream is None:
_apply_group_offloading_hook(module, unmatched_group, config=config)
else:
_apply_lazy_group_offloading_hook(module, unmatched_group, config=config)


def _apply_block_offloading_to_submodule(
submodule: torch.nn.Module,
name: str,
config: GroupOffloadingConfig,
modules_with_group_offloading: Set[str],
matched_module_groups: List[ModuleGroup],
) -> None:
r"""
Apply block offloading to a explicitly defined submodule. This function either:
1. Offloads the entire submodule as a single group ( SIMPLE APPROACH)
2. Recursively applies block offloading to the submodule

For now, we use the simple approach - offload the entire submodule as a single group.
"""
# Simple approach: offload the entire submodule as a single group
# Since AEs are typically small, this is usually okay
if isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)):
# If it's a ModuleList or Sequential, apply the normal block-level logic
for i in range(0, len(submodule), config.num_blocks_per_group):
current_modules = submodule[i : i + config.num_blocks_per_group]
current_modules = list(submodule[i : i + config.num_blocks_per_group])
if len(current_modules) == 0:
continue

group_id = f"{name}_{i}_{i + len(current_modules) - 1}"
group = ModuleGroup(
modules=current_modules,
Expand All @@ -616,42 +827,24 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf
matched_module_groups.append(group)
for j in range(i, i + len(current_modules)):
modules_with_group_offloading.add(f"{name}.{j}")

# Apply group offloading hooks to the module groups
for i, group in enumerate(matched_module_groups):
for group_module in group.modules:
_apply_group_offloading_hook(group_module, group, config=config)

# Parameters and Buffers of the top-level module need to be offloaded/onloaded separately
# when the forward pass of this module is called. This is because the top-level module is not
# part of any group (as doing so would lead to no VRAM savings).
parameters = _gather_parameters_with_no_group_offloading_parent(module, modules_with_group_offloading)
buffers = _gather_buffers_with_no_group_offloading_parent(module, modules_with_group_offloading)
parameters = [param for _, param in parameters]
buffers = [buffer for _, buffer in buffers]

# Create a group for the unmatched submodules of the top-level module so that they are on the correct
# device when the forward pass is called.
unmatched_modules = [unmatched_module for _, unmatched_module in unmatched_modules]
unmatched_group = ModuleGroup(
modules=unmatched_modules,
offload_device=config.offload_device,
onload_device=config.onload_device,
offload_to_disk_path=config.offload_to_disk_path,
offload_leader=module,
onload_leader=module,
parameters=parameters,
buffers=buffers,
non_blocking=False,
stream=None,
record_stream=False,
onload_self=True,
group_id=f"{module.__class__.__name__}_unmatched_group",
)
if config.stream is None:
_apply_group_offloading_hook(module, unmatched_group, config=config)
else:
_apply_lazy_group_offloading_hook(module, unmatched_group, config=config)
# For other modules, treat the entire submodule as a single group
group = ModuleGroup(
modules=[submodule],
offload_device=config.offload_device,
onload_device=config.onload_device,
offload_to_disk_path=config.offload_to_disk_path,
offload_leader=submodule,
onload_leader=submodule,
non_blocking=config.non_blocking,
stream=config.stream,
record_stream=config.record_stream,
low_cpu_mem_usage=config.low_cpu_mem_usage,
onload_self=True,
group_id=name,
)
matched_module_groups.append(group)
modules_with_group_offloading.add(name)


def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOffloadingConfig) -> None:
Expand Down
Loading