Skip to content
Open
5 changes: 5 additions & 0 deletions docs/source/en/api/loaders/lora.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi
- [`AmusedLoraLoaderMixin`] is for the [`AmusedPipeline`].
- [`HiDreamImageLoraLoaderMixin`] provides similar functions for [HiDream Image](https://huggingface.co/docs/diffusers/main/en/api/pipelines/hidream)
- [`QwenImageLoraLoaderMixin`] provides similar functions for [Qwen Image](https://huggingface.co/docs/diffusers/main/en/api/pipelines/qwen).
- [`ZImageLoraLoaderMixin`] provides similar functions for [Z-Image](https://huggingface.co/docs/diffusers/main/en/api/pipelines/zimage).
- [`Flux2LoraLoaderMixin`] provides similar functions for [Flux2](https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux2).
- [`LoraBaseMixin`] provides a base class with several utility methods to fuse, unfuse, unload, LoRAs and more.

Expand Down Expand Up @@ -112,6 +113,10 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi

[[autodoc]] loaders.lora_pipeline.QwenImageLoraLoaderMixin

## ZImageLoraLoaderMixin

[[autodoc]] loaders.lora_pipeline.ZImageLoraLoaderMixin

## KandinskyLoraLoaderMixin
[[autodoc]] loaders.lora_pipeline.KandinskyLoraLoaderMixin

Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/loaders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def text_encoder_attn_modules(text_encoder):
"HiDreamImageLoraLoaderMixin",
"SkyReelsV2LoraLoaderMixin",
"QwenImageLoraLoaderMixin",
"ZImageLoraLoaderMixin",
"Flux2LoraLoaderMixin",
]
_import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"]
Expand Down Expand Up @@ -130,6 +131,7 @@ def text_encoder_attn_modules(text_encoder):
StableDiffusionLoraLoaderMixin,
StableDiffusionXLLoraLoaderMixin,
WanLoraLoaderMixin,
ZImageLoraLoaderMixin,
)
from .single_file import FromSingleFileMixin
from .textual_inversion import TextualInversionLoaderMixin
Expand Down
118 changes: 118 additions & 0 deletions src/diffusers/loaders/lora_conversion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2351,3 +2351,121 @@ def _convert_non_diffusers_flux2_lora_to_diffusers(state_dict):
converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)

return converted_state_dict


def _convert_non_diffusers_z_image_lora_to_diffusers(state_dict):
"""
Convert non-diffusers ZImage LoRA state dict to diffusers format.

Handles:
- `diffusion_model.` prefix removal
- `lora_unet_` prefix conversion with key mapping
- `default.` prefix removal
- `.lora_down.weight`/`.lora_up.weight` → `.lora_A.weight`/`.lora_B.weight` conversion with alpha scaling
"""
has_diffusion_model = any(k.startswith("diffusion_model.") for k in state_dict)
if has_diffusion_model:
state_dict = {k.removeprefix("diffusion_model."): v for k, v in state_dict.items()}

has_lora_unet = any(k.startswith("lora_unet_") for k in state_dict)
if has_lora_unet:
state_dict = {k.removeprefix("lora_unet_"): v for k, v in state_dict.items()}

def convert_key(key: str) -> str:
# ZImage has: layers, noise_refiner, context_refiner blocks
# Keys may be like: layers_0_attention_to_q.lora_down.weight

if "." in key:
base, suffix = key.rsplit(".", 1)
else:
base, suffix = key, ""

# Protected n-grams that must keep their internal underscores
protected = {
# pairs for attention
("to", "q"),
("to", "k"),
("to", "v"),
("to", "out"),
# feed_forward
("feed", "forward"),
}

prot_by_len = {}
for ng in protected:
prot_by_len.setdefault(len(ng), set()).add(ng)

parts = base.split("_")
merged = []
i = 0
lengths_desc = sorted(prot_by_len.keys(), reverse=True)

while i < len(parts):
matched = False
for L in lengths_desc:
if i + L <= len(parts) and tuple(parts[i : i + L]) in prot_by_len[L]:
merged.append("_".join(parts[i : i + L]))
i += L
matched = True
break
if not matched:
merged.append(parts[i])
i += 1

converted_base = ".".join(merged)
return converted_base + (("." + suffix) if suffix else "")

state_dict = {convert_key(k): v for k, v in state_dict.items()}

has_default = any("default." in k for k in state_dict)
if has_default:
state_dict = {k.replace("default.", ""): v for k, v in state_dict.items()}

converted_state_dict = {}
all_keys = list(state_dict.keys())
down_key = ".lora_down.weight"
up_key = ".lora_up.weight"
a_key = ".lora_A.weight"
b_key = ".lora_B.weight"

has_non_diffusers_lora_id = any(down_key in k or up_key in k for k in all_keys)
has_diffusers_lora_id = any(a_key in k or b_key in k for k in all_keys)

if has_non_diffusers_lora_id:

def get_alpha_scales(down_weight, alpha_key):
rank = down_weight.shape[0]
alpha = state_dict.pop(alpha_key).item()
scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
scale_down = scale
scale_up = 1.0
while scale_down * 2 < scale_up:
scale_down *= 2
scale_up /= 2
return scale_down, scale_up

for k in all_keys:
if k.endswith(down_key):
diffusers_down_key = k.replace(down_key, ".lora_A.weight")
diffusers_up_key = k.replace(down_key, up_key).replace(up_key, ".lora_B.weight")
alpha_key = k.replace(down_key, ".alpha")

down_weight = state_dict.pop(k)
up_weight = state_dict.pop(k.replace(down_key, up_key))
scale_down, scale_up = get_alpha_scales(down_weight, alpha_key)
converted_state_dict[diffusers_down_key] = down_weight * scale_down
converted_state_dict[diffusers_up_key] = up_weight * scale_up

# Already in diffusers format (lora_A/lora_B), just pop
elif has_diffusers_lora_id:
for k in all_keys:
if a_key in k or b_key in k:
converted_state_dict[k] = state_dict.pop(k)
elif ".alpha" in k:
state_dict.pop(k)

if len(state_dict) > 0:
raise ValueError(f"`state_dict` should be empty at this point but has {state_dict.keys()=}")

converted_state_dict = {f"transformer.{k}": v for k, v in converted_state_dict.items()}
return converted_state_dict
207 changes: 207 additions & 0 deletions src/diffusers/loaders/lora_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
_convert_non_diffusers_lumina2_lora_to_diffusers,
_convert_non_diffusers_qwen_lora_to_diffusers,
_convert_non_diffusers_wan_lora_to_diffusers,
_convert_non_diffusers_z_image_lora_to_diffusers,
_convert_xlabs_flux_lora_to_diffusers,
_maybe_map_sgm_blocks_to_diffusers,
)
Expand Down Expand Up @@ -5085,6 +5086,212 @@ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
super().unfuse_lora(components=components, **kwargs)


class ZImageLoraLoaderMixin(LoraBaseMixin):
r"""
Load LoRA layers into [`ZImageTransformer2DModel`]. Specific to [`ZImagePipeline`].
"""

_lora_loadable_modules = ["transformer"]
transformer_name = TRANSFORMER_NAME

@classmethod
@validate_hf_hub_args
def lora_state_dict(
cls,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
**kwargs,
):
r"""
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details.
"""
# Load the main state dict first which has the LoRA layers for either of
# transformer and text encoder or both.
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", None)
token = kwargs.pop("token", None)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None)
return_lora_metadata = kwargs.pop("return_lora_metadata", False)

allow_pickle = False
if use_safetensors is None:
use_safetensors = True
allow_pickle = True

user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}

state_dict, metadata = _fetch_state_dict(
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
weight_name=weight_name,
use_safetensors=use_safetensors,
local_files_only=local_files_only,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
allow_pickle=allow_pickle,
)

is_dora_scale_present = any("dora_scale" in k for k in state_dict)
if is_dora_scale_present:
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
logger.warning(warn_msg)
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}

has_alphas_in_sd = any(k.endswith(".alpha") for k in state_dict)
has_lora_unet = any(k.startswith("lora_unet_") for k in state_dict)
has_diffusion_model = any(k.startswith("diffusion_model.") for k in state_dict)
has_default = any("default." in k for k in state_dict)
if has_alphas_in_sd or has_lora_unet or has_diffusion_model or has_default:
state_dict = _convert_non_diffusers_z_image_lora_to_diffusers(state_dict)

out = (state_dict, metadata) if return_lora_metadata else state_dict
return out

# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
def load_lora_weights(
self,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
adapter_name: Optional[str] = None,
hotswap: bool = False,
**kwargs,
):
"""
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for more details.
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")

low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
)

# if a dict is passed, copy it instead of modifying it inplace
if isinstance(pretrained_model_name_or_path_or_dict, dict):
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()

# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
kwargs["return_lora_metadata"] = True
state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)

is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")

self.load_lora_into_transformer(
state_dict,
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
)

@classmethod
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->ZImageTransformer2DModel
def load_lora_into_transformer(
cls,
state_dict,
transformer,
adapter_name=None,
_pipeline=None,
low_cpu_mem_usage=False,
hotswap: bool = False,
metadata=None,
):
"""
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details.
"""
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
)

# Load the layers corresponding to transformer.
logger.info(f"Loading {cls.transformer_name}.")
transformer.load_lora_adapter(
state_dict,
network_alphas=None,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
)

@classmethod
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
def save_lora_weights(
cls,
save_directory: Union[str, os.PathLike],
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
is_main_process: bool = True,
weight_name: str = None,
save_function: Callable = None,
safe_serialization: bool = True,
transformer_lora_adapter_metadata: Optional[dict] = None,
):
r"""
See [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for more information.
"""
lora_layers = {}
lora_metadata = {}

if transformer_lora_layers:
lora_layers[cls.transformer_name] = transformer_lora_layers
lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata

if not lora_layers:
raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")

cls._save_lora_weights(
save_directory=save_directory,
lora_layers=lora_layers,
lora_metadata=lora_metadata,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
)

# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
def fuse_lora(
self,
components: List[str] = ["transformer"],
lora_scale: float = 1.0,
safe_fusing: bool = False,
adapter_names: Optional[List[str]] = None,
**kwargs,
):
r"""
See [`~loaders.StableDiffusionLoraLoaderMixin.fuse_lora`] for more details.
"""
super().fuse_lora(
components=components,
lora_scale=lora_scale,
safe_fusing=safe_fusing,
adapter_names=adapter_names,
**kwargs,
)

# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
r"""
See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details.
"""
super().unfuse_lora(components=components, **kwargs)


class Flux2LoraLoaderMixin(LoraBaseMixin):
r"""
Load LoRA layers into [`Flux2Transformer2DModel`]. Specific to [`Flux2Pipeline`].
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/loaders/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
"ChromaTransformer2DModel": lambda model_cls, weights: weights,
"QwenImageTransformer2DModel": lambda model_cls, weights: weights,
"Flux2Transformer2DModel": lambda model_cls, weights: weights,
"ZImageTransformer2DModel": lambda model_cls, weights: weights,
}


Expand Down
Loading
Loading