Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 129 additions & 2 deletions invokeai/backend/model_manager/configs/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,32 @@ def _get_flux_lora_format(mod: ModelOnDisk) -> FluxLoRAFormat | None:
_FLUX1_MLP_RATIO = 4


def _lokr_in_dim(state_dict: dict[str | int, Any], key_prefix: str) -> int | None:
"""Compute the input dimension of a LOKR layer: w1.shape[1] * w2.shape[1].

Supports both full LOKR (lokr_w1/lokr_w2) and factorized LOKR (lokr_w1_b/lokr_w2_b).
Returns None if the required keys are not present.
"""
if f"{key_prefix}.lokr_w1" in state_dict and f"{key_prefix}.lokr_w2" in state_dict:
return state_dict[f"{key_prefix}.lokr_w1"].shape[1] * state_dict[f"{key_prefix}.lokr_w2"].shape[1]
elif f"{key_prefix}.lokr_w1_b" in state_dict and f"{key_prefix}.lokr_w2_b" in state_dict:
return state_dict[f"{key_prefix}.lokr_w1_b"].shape[1] * state_dict[f"{key_prefix}.lokr_w2_b"].shape[1]
return None


def _lokr_out_dim(state_dict: dict[str | int, Any], key_prefix: str) -> int | None:
"""Compute the output dimension of a LOKR layer: w1.shape[0] * w2.shape[0].

Supports both full LOKR (lokr_w1/lokr_w2) and factorized LOKR (lokr_w1_a/lokr_w2_a).
Returns None if the required keys are not present.
"""
if f"{key_prefix}.lokr_w1" in state_dict and f"{key_prefix}.lokr_w2" in state_dict:
return state_dict[f"{key_prefix}.lokr_w1"].shape[0] * state_dict[f"{key_prefix}.lokr_w2"].shape[0]
elif f"{key_prefix}.lokr_w1_a" in state_dict and f"{key_prefix}.lokr_w2_a" in state_dict:
return state_dict[f"{key_prefix}.lokr_w1_a"].shape[0] * state_dict[f"{key_prefix}.lokr_w2_a"].shape[0]
return None


def _is_flux2_lora(mod: ModelOnDisk) -> bool:
"""Check if a FLUX-format LoRA is specifically for FLUX.2 (Klein) rather than FLUX.1.

Expand Down Expand Up @@ -147,7 +173,30 @@ def _is_flux2_lora_state_dict(state_dict: dict[str | int, Any]) -> bool:
elif "vector_in" in key and key.endswith("lora_A.weight"):
return state_dict[key].shape[1] in _FLUX2_VEC_IN_DIMS

# BFL PEFT: hidden_size matches FLUX.1. Check MLP ratio to distinguish Klein 4B.
# BFL LyCORIS (LoKR/LoHA): attention projection → check hidden_size via product of dims
elif key.endswith((".img_attn.proj.lokr_w1", ".img_attn.proj.lokr_w1_b")):
layer_prefix = key.rsplit(".", 1)[0]
in_dim = _lokr_in_dim(state_dict, layer_prefix)
if in_dim is not None:
if in_dim != _FLUX1_HIDDEN_SIZE:
return True
bfl_hidden_size = in_dim # ambiguous, keep checking

# BFL LyCORIS: context_embedder/txt_in
elif "txt_in" in key and key.endswith((".lokr_w1", ".lokr_w1_b")):
layer_prefix = key.rsplit(".", 1)[0]
in_dim = _lokr_in_dim(state_dict, layer_prefix)
if in_dim is not None:
return in_dim in _FLUX2_CONTEXT_IN_DIMS

# BFL LyCORIS: vector_in
elif "vector_in" in key and key.endswith((".lokr_w1", ".lokr_w1_b")):
layer_prefix = key.rsplit(".", 1)[0]
in_dim = _lokr_in_dim(state_dict, layer_prefix)
if in_dim is not None:
return in_dim in _FLUX2_VEC_IN_DIMS

# BFL PEFT/LyCORIS: hidden_size matches FLUX.1. Check MLP ratio to distinguish Klein 4B.
# Klein 4B uses mlp_ratio=6 (ffn_dim=18432), FLUX.1 uses mlp_ratio=4 (ffn_dim=12288).
if bfl_hidden_size == _FLUX1_HIDDEN_SIZE:
for key in state_dict:
Expand All @@ -158,6 +207,13 @@ def _is_flux2_lora_state_dict(state_dict: dict[str | int, Any]) -> bool:
if ffn_dim != bfl_hidden_size * _FLUX1_MLP_RATIO:
return True
break
# BFL LyCORIS: check output dim of img_mlp.0 via product of dims
if key.startswith(_bfl_prefixes) and key.endswith((".img_mlp.0.lokr_w1", ".img_mlp.0.lokr_w1_a")):
layer_prefix = key.rsplit(".", 1)[0]
out_dim = _lokr_out_dim(state_dict, layer_prefix)
if out_dim is not None and out_dim != bfl_hidden_size * _FLUX1_MLP_RATIO:
return True
break

# Check kohya format: look for context_embedder or vector_in keys
# Kohya format uses lora_unet_ prefix with underscores instead of dots
Expand All @@ -167,9 +223,21 @@ def _is_flux2_lora_state_dict(state_dict: dict[str | int, Any]) -> bool:
if key.startswith("lora_unet_txt_in.") or key.startswith("lora_unet_context_embedder."):
if key.endswith("lora_down.weight"):
return state_dict[key].shape[1] in _FLUX2_CONTEXT_IN_DIMS
# Kohya LyCORIS (LoKR)
elif key.endswith((".lokr_w1", ".lokr_w1_b")):
layer_prefix = key.rsplit(".", 1)[0]
in_dim = _lokr_in_dim(state_dict, layer_prefix)
if in_dim is not None:
return in_dim in _FLUX2_CONTEXT_IN_DIMS
if key.startswith("lora_unet_vector_in.") or key.startswith("lora_unet_time_text_embed_text_embedder_"):
if key.endswith("lora_down.weight"):
return state_dict[key].shape[1] in _FLUX2_VEC_IN_DIMS
# Kohya LyCORIS (LoKR)
elif key.endswith((".lokr_w1", ".lokr_w1_b")):
layer_prefix = key.rsplit(".", 1)[0]
in_dim = _lokr_in_dim(state_dict, layer_prefix)
if in_dim is not None:
return in_dim in _FLUX2_VEC_IN_DIMS

return False

Expand Down Expand Up @@ -244,7 +312,7 @@ def _get_flux2_lora_variant(state_dict: dict[str | int, Any]) -> Flux2VariantTyp
return Flux2VariantType.Klein9B
return None

# Check BFL PEFT format (diffusion_model.* or base_model.model.* prefix with BFL names)
# Check BFL PEFT/LyCORIS format (diffusion_model.* or base_model.model.* prefix with BFL names)
_bfl_prefixes = ("diffusion_model.", "base_model.model.")
for key in state_dict:
if not isinstance(key, str):
Expand Down Expand Up @@ -279,6 +347,39 @@ def _get_flux2_lora_variant(state_dict: dict[str | int, Any]) -> Flux2VariantTyp
return Flux2VariantType.Klein9B
return None

# BFL LyCORIS (LoKR): context embedder (txt_in)
if "txt_in" in key and key.endswith((".lokr_w1", ".lokr_w1_b")):
layer_prefix = key.rsplit(".", 1)[0]
in_dim = _lokr_in_dim(state_dict, layer_prefix)
if in_dim is not None:
if in_dim == KLEIN_4B_CONTEXT_DIM:
return Flux2VariantType.Klein4B
if in_dim == KLEIN_9B_CONTEXT_DIM:
return Flux2VariantType.Klein9B
return None

# BFL LyCORIS (LoKR): vector embedder (vector_in)
if "vector_in" in key and key.endswith((".lokr_w1", ".lokr_w1_b")):
layer_prefix = key.rsplit(".", 1)[0]
in_dim = _lokr_in_dim(state_dict, layer_prefix)
if in_dim is not None:
if in_dim == KLEIN_4B_VEC_DIM:
return Flux2VariantType.Klein4B
if in_dim == KLEIN_9B_VEC_DIM:
return Flux2VariantType.Klein9B
return None

# BFL LyCORIS (LoKR): attention projection
if key.endswith((".img_attn.proj.lokr_w1", ".img_attn.proj.lokr_w1_b")):
layer_prefix = key.rsplit(".", 1)[0]
in_dim = _lokr_in_dim(state_dict, layer_prefix)
if in_dim is not None:
if in_dim == KLEIN_4B_HIDDEN_SIZE:
return Flux2VariantType.Klein4B
if in_dim == KLEIN_9B_HIDDEN_SIZE:
return Flux2VariantType.Klein9B
return None

# Check kohya format
for key in state_dict:
if not isinstance(key, str):
Expand All @@ -291,6 +392,16 @@ def _get_flux2_lora_variant(state_dict: dict[str | int, Any]) -> Flux2VariantTyp
if dim == KLEIN_9B_CONTEXT_DIM:
return Flux2VariantType.Klein9B
return None
# Kohya LyCORIS (LoKR)
elif key.endswith((".lokr_w1", ".lokr_w1_b")):
layer_prefix = key.rsplit(".", 1)[0]
in_dim = _lokr_in_dim(state_dict, layer_prefix)
if in_dim is not None:
if in_dim == KLEIN_4B_CONTEXT_DIM:
return Flux2VariantType.Klein4B
if in_dim == KLEIN_9B_CONTEXT_DIM:
return Flux2VariantType.Klein9B
return None
if key.startswith("lora_unet_vector_in.") or key.startswith("lora_unet_time_text_embed_text_embedder_"):
if key.endswith("lora_down.weight"):
dim = state_dict[key].shape[1]
Expand All @@ -299,6 +410,16 @@ def _get_flux2_lora_variant(state_dict: dict[str | int, Any]) -> Flux2VariantTyp
if dim == KLEIN_9B_VEC_DIM:
return Flux2VariantType.Klein9B
return None
# Kohya LyCORIS (LoKR)
elif key.endswith((".lokr_w1", ".lokr_w1_b")):
layer_prefix = key.rsplit(".", 1)[0]
in_dim = _lokr_in_dim(state_dict, layer_prefix)
if in_dim is not None:
if in_dim == KLEIN_4B_VEC_DIM:
return Flux2VariantType.Klein4B
if in_dim == KLEIN_9B_VEC_DIM:
return Flux2VariantType.Klein9B
return None

return None

Expand Down Expand Up @@ -423,6 +544,12 @@ def _validate_looks_like_lora(cls, mod: ModelOnDisk) -> None:
"to_q_lora.down.weight",
"lora_A.weight",
"lora_B.weight",
# LyCORIS LoKR suffixes
"lokr_w1",
"lokr_w2",
# LyCORIS LoHA suffixes
"hada_w1_a",
"hada_w2_a",
},
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,27 @@ def is_state_dict_likely_in_flux_aitoolkit_format(
if not _has_flux_layer_structure(state_dict):
return False

# AIToolkit only produces standard PEFT LoRA (lora_A.weight / lora_B.weight).
# Exclude LyCORIS algorithm variants (LoKR, LoHA, etc.) which use different weight key suffixes.
# These are handled by the BFL PEFT converter instead.
_LYCORIS_SUFFIXES = (
"lokr_w1",
"lokr_w2",
"lokr_w1_a",
"lokr_w1_b",
"lokr_w2_a",
"lokr_w2_b",
"lokr_t2",
"hada_w1_a",
"hada_w1_b",
"hada_w2_a",
"hada_w2_b",
"hada_t1",
"hada_t2",
)
if any(k.endswith(_LYCORIS_SUFFIXES) for k in state_dict.keys() if isinstance(k, str)):
return False

if metadata:
try:
software = json.loads(metadata.get("software", "{}"))
Expand Down
Loading