Skip to content

Commit 7e50d90

Browse files
committed
Expose group offload pinning options in API
1 parent 83401fa commit 7e50d90

File tree

2 files changed

+13
-0
lines changed

2 files changed

+13
-0
lines changed

src/diffusers/models/modeling_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -531,6 +531,8 @@ def enable_group_offload(
531531
record_stream: bool = False,
532532
low_cpu_mem_usage=False,
533533
offload_to_disk_path: Optional[str] = None,
534+
pin_groups: Optional[Union[str, Callable]] = None,
535+
pin_first_last: bool = False,
534536
) -> None:
535537
r"""
536538
Activates group offloading for the current model.
@@ -583,6 +585,8 @@ def enable_group_offload(
583585
low_cpu_mem_usage=low_cpu_mem_usage,
584586
offload_to_disk_path=offload_to_disk_path,
585587
block_modules=block_modules,
588+
pin_groups=pin_groups,
589+
pin_first_last=pin_first_last,
586590
)
587591

588592
def set_attention_backend(self, backend: str) -> None:

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1342,6 +1342,8 @@ def enable_group_offload(
13421342
low_cpu_mem_usage=False,
13431343
offload_to_disk_path: Optional[str] = None,
13441344
exclude_modules: Optional[Union[str, List[str]]] = None,
1345+
pin_groups: Optional[Union[str, Callable]] = None,
1346+
pin_first_last: bool = False,
13451347
) -> None:
13461348
r"""
13471349
Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is,
@@ -1402,6 +1404,11 @@ def enable_group_offload(
14021404
This option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be
14031405
useful when the CPU memory is a bottleneck but may counteract the benefits of using streams.
14041406
exclude_modules (`Union[str, List[str]]`, defaults to `None`): List of modules to exclude from offloading.
1407+
pin_groups (`\"first_last\"` | `\"all\"` | `Callable`, *optional*):
1408+
Optionally keep selected groups on the onload device permanently. See `ModelMixin.enable_group_offload`
1409+
for details.
1410+
pin_first_last (`bool`, *optional*, defaults to `False`):
1411+
Deprecated alias for `pin_groups=\"first_last\"`.
14051412
14061413
Example:
14071414
```python
@@ -1442,6 +1449,8 @@ def enable_group_offload(
14421449
"record_stream": record_stream,
14431450
"low_cpu_mem_usage": low_cpu_mem_usage,
14441451
"offload_to_disk_path": offload_to_disk_path,
1452+
"pin_groups": pin_groups,
1453+
"pin_first_last": pin_first_last,
14451454
}
14461455
for name, component in self.components.items():
14471456
if name not in exclude_modules and isinstance(component, torch.nn.Module):

0 commit comments

Comments
 (0)