Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 117 additions & 27 deletions src/diffusers/modular_pipelines/modular_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import importlib
import inspect
import os
import sys
import traceback
import warnings
from collections import OrderedDict
Expand All @@ -28,10 +29,16 @@
from typing_extensions import Self

from ..configuration_utils import ConfigMixin, FrozenDict
from ..pipelines.pipeline_loading_utils import _fetch_class_library_tuple, simple_get_class_obj
from ..pipelines.pipeline_loading_utils import (
LOADABLE_CLASSES,
_fetch_class_library_tuple,
_unwrap_model,
simple_get_class_obj,
)
from ..utils import PushToHubMixin, is_accelerate_available, logging
from ..utils.dynamic_modules_utils import get_class_from_dynamic_module, resolve_trust_remote_code
from ..utils.hub_utils import load_or_create_model_card, populate_model_card
from ..utils.torch_utils import is_compiled_module
from .components_manager import ComponentsManager
from .modular_pipeline_utils import (
MODULAR_MODEL_CARD_TEMPLATE,
Expand Down Expand Up @@ -1819,29 +1826,124 @@ def from_pretrained(
)
return pipeline

def save_pretrained(self, save_directory: str | os.PathLike, push_to_hub: bool = False, **kwargs):
def save_pretrained(
self,
save_directory: str | os.PathLike,
safe_serialization: bool = True,
variant: str | None = None,
max_shard_size: int | str | None = None,
push_to_hub: bool = False,
**kwargs,
):
"""
Save the pipeline to a directory. It does not save components, you need to save them separately.
Save the pipeline and all its components to a directory, so that it can be re-loaded using the
[`~ModularPipeline.from_pretrained`] class method.

Args:
save_directory (`str` or `os.PathLike`):
Path to the directory where the pipeline will be saved.
push_to_hub (`bool`, optional):
Whether to push the pipeline to the huggingface hub.
**kwargs: Additional arguments passed to `save_config()` method
"""
Directory to save the pipeline to. Will be created if it doesn't exist.
safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
variant (`str`, *optional*):
If specified, weights are saved in the format `pytorch_model.<variant>.bin`.
max_shard_size (`int` or `str`, defaults to `None`):
The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size
lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5GB"`).
If expressed as an integer, the unit is bytes.
push_to_hub (`bool`, *optional*, defaults to `False`):
Whether to push the pipeline to the Hugging Face model hub after saving it.
**kwargs: Additional keyword arguments:
- `overwrite_modular_index` (`bool`, *optional*, defaults to `False`):
When saving a Modular Pipeline, its components in `modular_model_index.json` may reference repos
different from the destination repo. Setting this to `True` updates all component references in
`modular_model_index.json` so they point to the repo specified by `repo_id`.
- `repo_id` (`str`, *optional*):
The repository ID to push the pipeline to. Defaults to the last component of `save_directory`.
- `commit_message` (`str`, *optional*):
Commit message for the push to hub operation.
- `private` (`bool`, *optional*):
Whether the repository should be private.
- `create_pr` (`bool`, *optional*, defaults to `False`):
Whether to create a pull request instead of pushing directly.
- `token` (`str`, *optional*):
The Hugging Face token to use for authentication.
"""
overwrite_modular_index = kwargs.pop("overwrite_modular_index", False)
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])

if push_to_hub:
commit_message = kwargs.pop("commit_message", None)
private = kwargs.pop("private", None)
create_pr = kwargs.pop("create_pr", False)
token = kwargs.pop("token", None)
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id

# Generate modular pipeline card content
card_content = generate_modular_model_card_content(self.blocks)
for component_name, component_spec in self._component_specs.items():
if component_spec.default_creation_method != "from_pretrained":
continue

component = getattr(self, component_name, None)
if component is None:
continue

model_cls = component.__class__
if is_compiled_module(component):
component = _unwrap_model(component)
model_cls = component.__class__

save_method_name = None
for library_name, library_classes in LOADABLE_CLASSES.items():
if library_name in sys.modules:
library = importlib.import_module(library_name)
else:
logger.info(
f"{library_name} is not installed. Cannot save {component_name} as {library_classes} from {library_name}"
)
continue

for base_class, save_load_methods in library_classes.items():
class_candidate = getattr(library, base_class, None)
if class_candidate is not None and issubclass(model_cls, class_candidate):
save_method_name = save_load_methods[0]
break
if save_method_name is not None:
break

if save_method_name is None:
logger.warning(f"self.{component_name}={component} of type {type(component)} cannot be saved.")
continue

save_method = getattr(component, save_method_name)
save_method_signature = inspect.signature(save_method)
save_method_accept_safe = "safe_serialization" in save_method_signature.parameters
save_method_accept_variant = "variant" in save_method_signature.parameters
save_method_accept_max_shard_size = "max_shard_size" in save_method_signature.parameters

# Create a new empty model card and eventually tag it
save_kwargs = {}
if save_method_accept_safe:
save_kwargs["safe_serialization"] = safe_serialization
if save_method_accept_variant:
save_kwargs["variant"] = variant
if save_method_accept_max_shard_size and max_shard_size is not None:
save_kwargs["max_shard_size"] = max_shard_size

component_save_path = os.path.join(save_directory, component_name)
save_method(component_save_path, **save_kwargs)

if component_name not in self.config:
continue

has_no_load_id = not hasattr(component, "_diffusers_load_id") or component._diffusers_load_id == "null"
if overwrite_modular_index or has_no_load_id:
library, class_name, component_spec_dict = self.config[component_name]
component_spec_dict["pretrained_model_name_or_path"] = repo_id if push_to_hub else save_directory
component_spec_dict["subfolder"] = component_name
self.register_to_config(**{component_name: (library, class_name, component_spec_dict)})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not too sure about the objective of this block. What happens if its corresponding model_cls doesn't have the save method we support through LOADABLE_CLASSES?

Or is this unrelated?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


self.save_config(save_directory=save_directory)

if push_to_hub:
card_content = generate_modular_model_card_content(self.blocks)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this conditioned on the above changes? If not, maybe we can keep it in the earlier position?

model_card = load_or_create_model_card(
repo_id,
token=token,
Expand All @@ -1850,13 +1952,8 @@ def save_pretrained(self, save_directory: str | os.PathLike, push_to_hub: bool =
is_modular=True,
)
model_card = populate_model_card(model_card, tags=card_content["tags"])

model_card.save(os.path.join(save_directory, "README.md"))

# YiYi TODO: maybe order the json file to make it more readable: configs first, then components
self.save_config(save_directory=save_directory)

if push_to_hub:
self._upload_folder(
save_directory,
repo_id,
Expand Down Expand Up @@ -2124,8 +2221,9 @@ def update_components(self, **kwargs):
```

Notes:
- Components with trained weights should be loaded with `AutoModel.from_pretrained()` or
`ComponentSpec.load()` so that loading specs are preserved for serialization.
- Components loaded with `AutoModel.from_pretrained()` or `ComponentSpec.load()` will have
loading specs preserved for serialization. Custom or locally loaded components without Hub references will
have their `modular_model_index.json` entries updated automatically during `save_pretrained()`.
- ConfigMixin objects without weights (e.g., schedulers, guiders) can be passed directly.
"""

Expand All @@ -2147,14 +2245,6 @@ def update_components(self, **kwargs):
new_component_spec = current_component_spec
if hasattr(self, name) and getattr(self, name) is not None:
logger.warning(f"ModularPipeline.update_components: setting {name} to None (spec unchanged)")
elif current_component_spec.default_creation_method == "from_pretrained" and not (
hasattr(component, "_diffusers_load_id") and component._diffusers_load_id is not None
):
logger.warning(
f"ModularPipeline.update_components: {name} has no valid _diffusers_load_id. "
f"This will result in empty loading spec, use ComponentSpec.load() for proper specs"
)
new_component_spec = ComponentSpec(name=name, type_hint=type(component))
else:
new_component_spec = ComponentSpec.from_component(name, component)

Expand Down
12 changes: 12 additions & 0 deletions src/diffusers/modular_pipelines/modular_pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,12 @@ def load(self, **kwargs) -> Any:
f"`type_hint` is required when loading a single file model but is missing for component: {self.name}"
)

# `torch_dtype` is not an accepted parameter for tokenizers and processors.
# As a result, it gets stored in `init_kwargs`, which are written to the config
# during save. This causes JSON serialization to fail when saving the component.
if self.type_hint is not None and not issubclass(self.type_hint, torch.nn.Module):
kwargs.pop("torch_dtype", None)

if self.type_hint is None:
try:
from diffusers import AutoModel
Expand All @@ -328,6 +334,12 @@ def load(self, **kwargs) -> Any:
else getattr(self.type_hint, "from_pretrained")
)

# `torch_dtype` is not an accepted parameter for tokenizers and processors.
# As a result, it gets stored in `init_kwargs`, which are written to the config
# during save. This causes JSON serialization to fail when saving the component.
if not issubclass(self.type_hint, torch.nn.Module):
kwargs.pop("torch_dtype", None)

try:
component = load_method(pretrained_model_name_or_path, **load_kwargs, **kwargs)
except Exception as e:
Expand Down
77 changes: 77 additions & 0 deletions tests/modular_pipelines/test_modular_pipelines_common.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import gc
import os
import tempfile
from typing import Callable

Expand Down Expand Up @@ -699,3 +700,79 @@ def test_load_components_skips_invalid_pretrained_path(self):

# Verify test_component was not loaded
assert not hasattr(pipe, "test_component") or pipe.test_component is None


class TestCustomModelSavePretrained:
def test_save_pretrained_updates_index_for_local_model(self, tmp_path):
"""When a component without _diffusers_load_id (custom/local model) is saved,
modular_model_index.json should point to the save directory."""
import json

pipe = ModularPipeline.from_pretrained("hf-internal-testing/tiny-stable-diffusion-xl-pipe")
pipe.load_components(torch_dtype=torch.float32)

pipe.unet._diffusers_load_id = "null"

save_dir = str(tmp_path / "my-pipeline")
pipe.save_pretrained(save_dir)

with open(os.path.join(save_dir, "modular_model_index.json")) as f:
index = json.load(f)

_library, _cls, unet_spec = index["unet"]
assert unet_spec["pretrained_model_name_or_path"] == save_dir
assert unet_spec["subfolder"] == "unet"

_library, _cls, vae_spec = index["vae"]
assert vae_spec["pretrained_model_name_or_path"] == "hf-internal-testing/tiny-stable-diffusion-xl-pipe"

def test_save_pretrained_roundtrip_with_local_model(self, tmp_path):
"""A pipeline with a custom/local model should be saveable and re-loadable with identical outputs."""
pipe = ModularPipeline.from_pretrained("hf-internal-testing/tiny-stable-diffusion-xl-pipe")
pipe.load_components(torch_dtype=torch.float32)

pipe.unet._diffusers_load_id = "null"

original_state_dict = pipe.unet.state_dict()

save_dir = str(tmp_path / "my-pipeline")
pipe.save_pretrained(save_dir)

loaded_pipe = ModularPipeline.from_pretrained(save_dir)
loaded_pipe.load_components(torch_dtype=torch.float32)

assert loaded_pipe.unet is not None
assert loaded_pipe.unet.__class__.__name__ == pipe.unet.__class__.__name__

loaded_state_dict = loaded_pipe.unet.state_dict()
assert set(original_state_dict.keys()) == set(loaded_state_dict.keys())
for key in original_state_dict:
assert torch.equal(original_state_dict[key], loaded_state_dict[key]), f"Mismatch in {key}"

def test_save_pretrained_overwrite_modular_index(self, tmp_path):
"""With overwrite_modular_index=True, all component references should point to the save directory."""
import json

pipe = ModularPipeline.from_pretrained("hf-internal-testing/tiny-stable-diffusion-xl-pipe")
pipe.load_components(torch_dtype=torch.float32)

save_dir = str(tmp_path / "my-pipeline")
pipe.save_pretrained(save_dir, overwrite_modular_index=True)

with open(os.path.join(save_dir, "modular_model_index.json")) as f:
index = json.load(f)

for component_name in ["unet", "vae", "text_encoder", "text_encoder_2"]:
if component_name not in index:
continue
_library, _cls, spec = index[component_name]
assert spec["pretrained_model_name_or_path"] == save_dir, (
f"{component_name} should point to save dir but got {spec['pretrained_model_name_or_path']}"
)
assert spec["subfolder"] == component_name

loaded_pipe = ModularPipeline.from_pretrained(save_dir)
loaded_pipe.load_components(torch_dtype=torch.float32)

assert loaded_pipe.unet is not None
assert loaded_pipe.vae is not None
Loading