Skip to content

Commit 9c3c14f

Browse files
committed
Add pinning support to group offloading hooks
1 parent 3455019 commit 9c3c14f

File tree

1 file changed

+111
-1
lines changed

1 file changed

+111
-1
lines changed

src/diffusers/hooks/group_offloading.py

Lines changed: 111 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from contextlib import contextmanager, nullcontext
1818
from dataclasses import dataclass
1919
from enum import Enum
20-
from typing import Dict, List, Optional, Set, Tuple, Union
20+
from typing import Callable, Dict, List, Optional, Set, Tuple, Union
2121

2222
import safetensors.torch
2323
import torch
@@ -60,6 +60,7 @@ class GroupOffloadingConfig:
6060
offload_to_disk_path: Optional[str] = None
6161
stream: Optional[Union[torch.cuda.Stream, torch.Stream]] = None
6262
block_modules: Optional[List[str]] = None
63+
pin_groups: Optional[Union[str, Callable]] = None
6364

6465

6566
class ModuleGroup:
@@ -92,6 +93,7 @@ def __init__(
9293
self.record_stream = record_stream
9394
self.onload_self = onload_self
9495
self.low_cpu_mem_usage = low_cpu_mem_usage
96+
self.pinned = False
9597

9698
self.offload_to_disk_path = offload_to_disk_path
9799
self._is_offloaded_to_disk = False
@@ -297,6 +299,24 @@ def pre_forward(self, module: torch.nn.Module, *args, **kwargs):
297299
if self.group.onload_leader is None:
298300
self.group.onload_leader = module
299301

302+
if self.group.pinned:
303+
if self.group.onload_leader == module and not self._is_group_on_device():
304+
self.group.onload_()
305+
306+
should_onload_next_group = self.next_group is not None and not self.next_group.onload_self
307+
if should_onload_next_group:
308+
self.next_group.onload_()
309+
310+
should_synchronize = (
311+
not self.group.onload_self and self.group.stream is not None and not should_onload_next_group
312+
)
313+
if should_synchronize:
314+
self.group.stream.synchronize()
315+
316+
args = send_to_device(args, self.group.onload_device, non_blocking=self.group.non_blocking)
317+
kwargs = send_to_device(kwargs, self.group.onload_device, non_blocking=self.group.non_blocking)
318+
return args, kwargs
319+
300320
# If the current module is the onload_leader of the group, we onload the group if it is supposed
301321
# to onload itself. In the case of using prefetching with streams, we onload the next group if
302322
# it is not supposed to onload itself.
@@ -325,10 +345,26 @@ def pre_forward(self, module: torch.nn.Module, *args, **kwargs):
325345
return args, kwargs
326346

327347
def post_forward(self, module: torch.nn.Module, output):
348+
if self.group.pinned:
349+
return output
350+
328351
if self.group.offload_leader == module:
329352
self.group.offload_()
330353
return output
331354

355+
def _is_group_on_device(self) -> bool:
356+
tensors = []
357+
for group_module in self.group.modules:
358+
tensors.extend(list(group_module.parameters()))
359+
tensors.extend(list(group_module.buffers()))
360+
tensors.extend(self.group.parameters)
361+
tensors.extend(self.group.buffers)
362+
363+
if len(tensors) == 0:
364+
return True
365+
366+
return all(t.device == self.group.onload_device for t in tensors)
367+
332368

333369
class LazyPrefetchGroupOffloadingHook(ModelHook):
334370
r"""
@@ -424,6 +460,51 @@ def post_forward(self, module, output):
424460
group_offloading_hooks[i].next_group = group_offloading_hooks[i + 1].group
425461
group_offloading_hooks[i].next_group.onload_self = False
426462

463+
pin_groups = getattr(base_module_registry, "_group_offload_pin_groups", None)
464+
if pin_groups is not None and num_executed > 0:
465+
param_exec_info = []
466+
for idx, ((name, submodule), hook) in enumerate(zip(self.execution_order, group_offloading_hooks)):
467+
if hook is None:
468+
continue
469+
if next(submodule.parameters(), None) is None and next(submodule.buffers(), None) is None:
470+
continue
471+
param_exec_info.append((name, submodule, hook))
472+
473+
num_param_modules = len(param_exec_info)
474+
if num_param_modules > 0:
475+
pinned_indices = set()
476+
if isinstance(pin_groups, str):
477+
if pin_groups == "all":
478+
pinned_indices = set(range(num_param_modules))
479+
elif pin_groups == "first_last":
480+
pinned_indices.add(0)
481+
pinned_indices.add(num_param_modules - 1)
482+
elif callable(pin_groups):
483+
for idx, (name, submodule, _) in enumerate(param_exec_info):
484+
should_pin = False
485+
try:
486+
should_pin = bool(pin_groups(submodule))
487+
except TypeError:
488+
try:
489+
should_pin = bool(pin_groups(name, submodule))
490+
except TypeError:
491+
should_pin = bool(pin_groups(name, submodule, idx))
492+
if should_pin:
493+
pinned_indices.add(idx)
494+
495+
pinned_groups = set()
496+
for idx in pinned_indices:
497+
if idx >= num_param_modules:
498+
continue
499+
group = param_exec_info[idx][2].group
500+
if group not in pinned_groups:
501+
group.pinned = True
502+
pinned_groups.add(group)
503+
504+
for group in pinned_groups:
505+
if group.offload_device != group.onload_device:
506+
group.onload_()
507+
427508
return output
428509

429510

@@ -455,6 +536,8 @@ def apply_group_offloading(
455536
low_cpu_mem_usage: bool = False,
456537
offload_to_disk_path: Optional[str] = None,
457538
block_modules: Optional[List[str]] = None,
539+
pin_groups: Optional[Union[str, Callable]] = None,
540+
pin_first_last: bool = False,
458541
) -> None:
459542
r"""
460543
Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is, and
@@ -515,6 +598,12 @@ def apply_group_offloading(
515598
block_modules (`List[str]`, *optional*):
516599
List of module names that should be treated as blocks for offloading. If provided, only these modules
517600
will be considered for block-level offloading. If not provided, the default block detection logic will be used.
601+
pin_groups (`"first_last"` or `"all"` or `Callable`, *optional*, defaults to `None`):
602+
Optionally keeps selected groups on the onload device permanently. Use `"first_last"` to pin the first
603+
and last parameter-bearing groups, `"all"` to pin every parameter-bearing group, or pass a callable that
604+
receives a module (and optionally the module name and index) and returns `True` to pin that group.
605+
pin_first_last (`bool`, *optional*, defaults to `False`):
606+
Deprecated alias for `pin_groups="first_last"`.
518607
519608
Example:
520609
```python
@@ -554,7 +643,24 @@ def apply_group_offloading(
554643
if offload_type == GroupOffloadingType.BLOCK_LEVEL and num_blocks_per_group is None:
555644
raise ValueError("`num_blocks_per_group` must be provided when using `offload_type='block_level'.")
556645

646+
if pin_first_last:
647+
if pin_groups is not None and pin_groups != "first_last":
648+
raise ValueError("`pin_first_last` cannot be combined with a different `pin_groups` setting.")
649+
pin_groups = "first_last"
650+
651+
normalized_pin_groups = pin_groups
652+
if isinstance(pin_groups, str):
653+
normalized_pin_groups = pin_groups.lower()
654+
if normalized_pin_groups not in {"first_last", "all"}:
655+
raise ValueError("`pin_groups` must be one of `None`, 'first_last', 'all', or a callable.")
656+
elif pin_groups is not None and not callable(pin_groups):
657+
raise ValueError("`pin_groups` must be one of `None`, 'first_last', 'all', or a callable.")
658+
659+
pin_groups = normalized_pin_groups
660+
557661
_raise_error_if_accelerate_model_or_sequential_hook_present(module)
662+
registry = HookRegistry.check_if_exists_or_initialize(module)
663+
registry._group_offload_pin_groups = pin_groups
558664

559665
config = GroupOffloadingConfig(
560666
onload_device=onload_device,
@@ -567,11 +673,15 @@ def apply_group_offloading(
567673
low_cpu_mem_usage=low_cpu_mem_usage,
568674
offload_to_disk_path=offload_to_disk_path,
569675
block_modules=block_modules,
676+
pin_groups=pin_groups,
570677
)
571678
_apply_group_offloading(module, config)
572679

573680

574681
def _apply_group_offloading(module: torch.nn.Module, config: GroupOffloadingConfig) -> None:
682+
registry = HookRegistry.check_if_exists_or_initialize(module)
683+
registry._group_offload_pin_groups = config.pin_groups
684+
575685
if config.offload_type == GroupOffloadingType.BLOCK_LEVEL:
576686
_apply_group_offloading_block_level(module, config)
577687
elif config.offload_type == GroupOffloadingType.LEAF_LEVEL:

0 commit comments

Comments
 (0)