@@ -1342,6 +1342,8 @@ def enable_group_offload(
13421342 low_cpu_mem_usage = False ,
13431343 offload_to_disk_path : Optional [str ] = None ,
13441344 exclude_modules : Optional [Union [str , List [str ]]] = None ,
1345+ pin_groups : Optional [Union [str , Callable ]] = None ,
1346+ pin_first_last : bool = False ,
13451347 ) -> None :
13461348 r"""
13471349 Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is,
@@ -1402,6 +1404,11 @@ def enable_group_offload(
14021404 This option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be
14031405 useful when the CPU memory is a bottleneck but may counteract the benefits of using streams.
14041406 exclude_modules (`Union[str, List[str]]`, defaults to `None`): List of modules to exclude from offloading.
1407+ pin_groups (`\"first_last\"` | `\"all\"` | `Callable`, *optional*):
1408+ Optionally keep selected groups on the onload device permanently. See `ModelMixin.enable_group_offload`
1409+ for details.
1410+ pin_first_last (`bool`, *optional*, defaults to `False`):
1411+ Deprecated alias for `pin_groups=\"first_last\"`.
14051412
14061413 Example:
14071414 ```python
@@ -1442,6 +1449,8 @@ def enable_group_offload(
14421449 "record_stream" : record_stream ,
14431450 "low_cpu_mem_usage" : low_cpu_mem_usage ,
14441451 "offload_to_disk_path" : offload_to_disk_path ,
1452+ "pin_groups" : pin_groups ,
1453+ "pin_first_last" : pin_first_last ,
14451454 }
14461455 for name , component in self .components .items ():
14471456 if name not in exclude_modules and isinstance (component , torch .nn .Module ):
0 commit comments