1717from contextlib import contextmanager , nullcontext
1818from dataclasses import dataclass
1919from enum import Enum
20- from typing import Dict , List , Optional , Set , Tuple , Union
20+ from typing import Callable , Dict , List , Optional , Set , Tuple , Union
2121
2222import safetensors .torch
2323import torch
@@ -60,6 +60,7 @@ class GroupOffloadingConfig:
6060 offload_to_disk_path : Optional [str ] = None
6161 stream : Optional [Union [torch .cuda .Stream , torch .Stream ]] = None
6262 block_modules : Optional [List [str ]] = None
63+ pin_groups : Optional [Union [str , Callable ]] = None
6364
6465
6566class ModuleGroup :
@@ -92,6 +93,7 @@ def __init__(
9293 self .record_stream = record_stream
9394 self .onload_self = onload_self
9495 self .low_cpu_mem_usage = low_cpu_mem_usage
96+ self .pinned = False
9597
9698 self .offload_to_disk_path = offload_to_disk_path
9799 self ._is_offloaded_to_disk = False
@@ -297,6 +299,24 @@ def pre_forward(self, module: torch.nn.Module, *args, **kwargs):
297299 if self .group .onload_leader is None :
298300 self .group .onload_leader = module
299301
302+ if self .group .pinned :
303+ if self .group .onload_leader == module and not self ._is_group_on_device ():
304+ self .group .onload_ ()
305+
306+ should_onload_next_group = self .next_group is not None and not self .next_group .onload_self
307+ if should_onload_next_group :
308+ self .next_group .onload_ ()
309+
310+ should_synchronize = (
311+ not self .group .onload_self and self .group .stream is not None and not should_onload_next_group
312+ )
313+ if should_synchronize :
314+ self .group .stream .synchronize ()
315+
316+ args = send_to_device (args , self .group .onload_device , non_blocking = self .group .non_blocking )
317+ kwargs = send_to_device (kwargs , self .group .onload_device , non_blocking = self .group .non_blocking )
318+ return args , kwargs
319+
300320 # If the current module is the onload_leader of the group, we onload the group if it is supposed
301321 # to onload itself. In the case of using prefetching with streams, we onload the next group if
302322 # it is not supposed to onload itself.
@@ -325,10 +345,26 @@ def pre_forward(self, module: torch.nn.Module, *args, **kwargs):
325345 return args , kwargs
326346
327347 def post_forward (self , module : torch .nn .Module , output ):
348+ if self .group .pinned :
349+ return output
350+
328351 if self .group .offload_leader == module :
329352 self .group .offload_ ()
330353 return output
331354
355+ def _is_group_on_device (self ) -> bool :
356+ tensors = []
357+ for group_module in self .group .modules :
358+ tensors .extend (list (group_module .parameters ()))
359+ tensors .extend (list (group_module .buffers ()))
360+ tensors .extend (self .group .parameters )
361+ tensors .extend (self .group .buffers )
362+
363+ if len (tensors ) == 0 :
364+ return True
365+
366+ return all (t .device == self .group .onload_device for t in tensors )
367+
332368
333369class LazyPrefetchGroupOffloadingHook (ModelHook ):
334370 r"""
@@ -424,6 +460,51 @@ def post_forward(self, module, output):
424460 group_offloading_hooks [i ].next_group = group_offloading_hooks [i + 1 ].group
425461 group_offloading_hooks [i ].next_group .onload_self = False
426462
463+ pin_groups = getattr (base_module_registry , "_group_offload_pin_groups" , None )
464+ if pin_groups is not None and num_executed > 0 :
465+ param_exec_info = []
466+ for idx , ((name , submodule ), hook ) in enumerate (zip (self .execution_order , group_offloading_hooks )):
467+ if hook is None :
468+ continue
469+ if next (submodule .parameters (), None ) is None and next (submodule .buffers (), None ) is None :
470+ continue
471+ param_exec_info .append ((name , submodule , hook ))
472+
473+ num_param_modules = len (param_exec_info )
474+ if num_param_modules > 0 :
475+ pinned_indices = set ()
476+ if isinstance (pin_groups , str ):
477+ if pin_groups == "all" :
478+ pinned_indices = set (range (num_param_modules ))
479+ elif pin_groups == "first_last" :
480+ pinned_indices .add (0 )
481+ pinned_indices .add (num_param_modules - 1 )
482+ elif callable (pin_groups ):
483+ for idx , (name , submodule , _ ) in enumerate (param_exec_info ):
484+ should_pin = False
485+ try :
486+ should_pin = bool (pin_groups (submodule ))
487+ except TypeError :
488+ try :
489+ should_pin = bool (pin_groups (name , submodule ))
490+ except TypeError :
491+ should_pin = bool (pin_groups (name , submodule , idx ))
492+ if should_pin :
493+ pinned_indices .add (idx )
494+
495+ pinned_groups = set ()
496+ for idx in pinned_indices :
497+ if idx >= num_param_modules :
498+ continue
499+ group = param_exec_info [idx ][2 ].group
500+ if group not in pinned_groups :
501+ group .pinned = True
502+ pinned_groups .add (group )
503+
504+ for group in pinned_groups :
505+ if group .offload_device != group .onload_device :
506+ group .onload_ ()
507+
427508 return output
428509
429510
@@ -455,6 +536,8 @@ def apply_group_offloading(
455536 low_cpu_mem_usage : bool = False ,
456537 offload_to_disk_path : Optional [str ] = None ,
457538 block_modules : Optional [List [str ]] = None ,
539+ pin_groups : Optional [Union [str , Callable ]] = None ,
540+ pin_first_last : bool = False ,
458541) -> None :
459542 r"""
460543 Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is, and
@@ -515,6 +598,12 @@ def apply_group_offloading(
515598 block_modules (`List[str]`, *optional*):
516599 List of module names that should be treated as blocks for offloading. If provided, only these modules
517600 will be considered for block-level offloading. If not provided, the default block detection logic will be used.
601+ pin_groups (`"first_last"` or `"all"` or `Callable`, *optional*, defaults to `None`):
602+ Optionally keeps selected groups on the onload device permanently. Use `"first_last"` to pin the first
603+ and last parameter-bearing groups, `"all"` to pin every parameter-bearing group, or pass a callable that
604+ receives a module (and optionally the module name and index) and returns `True` to pin that group.
605+ pin_first_last (`bool`, *optional*, defaults to `False`):
606+ Deprecated alias for `pin_groups="first_last"`.
518607
519608 Example:
520609 ```python
@@ -554,7 +643,24 @@ def apply_group_offloading(
554643 if offload_type == GroupOffloadingType .BLOCK_LEVEL and num_blocks_per_group is None :
555644 raise ValueError ("`num_blocks_per_group` must be provided when using `offload_type='block_level'." )
556645
646+ if pin_first_last :
647+ if pin_groups is not None and pin_groups != "first_last" :
648+ raise ValueError ("`pin_first_last` cannot be combined with a different `pin_groups` setting." )
649+ pin_groups = "first_last"
650+
651+ normalized_pin_groups = pin_groups
652+ if isinstance (pin_groups , str ):
653+ normalized_pin_groups = pin_groups .lower ()
654+ if normalized_pin_groups not in {"first_last" , "all" }:
655+ raise ValueError ("`pin_groups` must be one of `None`, 'first_last', 'all', or a callable." )
656+ elif pin_groups is not None and not callable (pin_groups ):
657+ raise ValueError ("`pin_groups` must be one of `None`, 'first_last', 'all', or a callable." )
658+
659+ pin_groups = normalized_pin_groups
660+
557661 _raise_error_if_accelerate_model_or_sequential_hook_present (module )
662+ registry = HookRegistry .check_if_exists_or_initialize (module )
663+ registry ._group_offload_pin_groups = pin_groups
558664
559665 config = GroupOffloadingConfig (
560666 onload_device = onload_device ,
@@ -567,11 +673,15 @@ def apply_group_offloading(
567673 low_cpu_mem_usage = low_cpu_mem_usage ,
568674 offload_to_disk_path = offload_to_disk_path ,
569675 block_modules = block_modules ,
676+ pin_groups = pin_groups ,
570677 )
571678 _apply_group_offloading (module , config )
572679
573680
574681def _apply_group_offloading (module : torch .nn .Module , config : GroupOffloadingConfig ) -> None :
682+ registry = HookRegistry .check_if_exists_or_initialize (module )
683+ registry ._group_offload_pin_groups = config .pin_groups
684+
575685 if config .offload_type == GroupOffloadingType .BLOCK_LEVEL :
576686 _apply_group_offloading_block_level (module , config )
577687 elif config .offload_type == GroupOffloadingType .LEAF_LEVEL :
0 commit comments