From b646b2ca3335613b966c3d34bfb4bac2684317a7 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 24 Feb 2026 14:03:57 +0000 Subject: [PATCH] Add FLUX.2 LOKR model support (detection and loading) Co-authored-by: lstein <111189+lstein@users.noreply.github.com> Fix BFL LOKR models being misidentified as AIToolkit format Co-authored-by: lstein <111189+lstein@users.noreply.github.com> Fix alpha key warning in LOKR QKV split layers Co-authored-by: lstein <111189+lstein@users.noreply.github.com> --- .../backend/model_manager/configs/lora.py | 131 ++++++++++- .../flux_aitoolkit_lora_conversion_utils.py | 21 ++ .../flux_bfl_peft_lora_conversion_utils.py | 203 +++++++++++++++--- .../lora_state_dicts/flux_lokr_bfl_format.py | 22 ++ ...st_flux_aitoolkit_lora_conversion_utils.py | 8 +- .../__test_metadata__.json | 3 + .../model.safetensors | 3 + .../__test_metadata__.json | 3 + .../model.safetensors | 3 + .../__test_metadata__.json | 3 + .../model.safetensors | 3 + .../__test_metadata__.json | 3 + .../model.safetensors | 3 + 13 files changed, 375 insertions(+), 34 deletions(-) create mode 100644 tests/backend/patches/lora_conversions/lora_state_dicts/flux_lokr_bfl_format.py create mode 100644 tests/model_identification/stripped_models/flux1-bfl-lokr-attn-proj-mlp/__test_metadata__.json create mode 100644 tests/model_identification/stripped_models/flux1-bfl-lokr-attn-proj-mlp/model.safetensors create mode 100644 tests/model_identification/stripped_models/flux2-klein4b-bfl-lokr-mlp-ratio/__test_metadata__.json create mode 100644 tests/model_identification/stripped_models/flux2-klein4b-bfl-lokr-mlp-ratio/model.safetensors create mode 100644 tests/model_identification/stripped_models/flux2-klein4b-bfl-lokr-txt-in/__test_metadata__.json create mode 100644 tests/model_identification/stripped_models/flux2-klein4b-bfl-lokr-txt-in/model.safetensors create mode 100644 tests/model_identification/stripped_models/flux2-klein9b-bfl-lokr-attn-proj/__test_metadata__.json create mode 100644 tests/model_identification/stripped_models/flux2-klein9b-bfl-lokr-attn-proj/model.safetensors diff --git a/invokeai/backend/model_manager/configs/lora.py b/invokeai/backend/model_manager/configs/lora.py index f5b80a72f00..dcacf8d4928 100644 --- a/invokeai/backend/model_manager/configs/lora.py +++ b/invokeai/backend/model_manager/configs/lora.py @@ -79,6 +79,32 @@ def _get_flux_lora_format(mod: ModelOnDisk) -> FluxLoRAFormat | None: _FLUX1_MLP_RATIO = 4 +def _lokr_in_dim(state_dict: dict[str | int, Any], key_prefix: str) -> int | None: + """Compute the input dimension of a LOKR layer: w1.shape[1] * w2.shape[1]. + + Supports both full LOKR (lokr_w1/lokr_w2) and factorized LOKR (lokr_w1_b/lokr_w2_b). + Returns None if the required keys are not present. + """ + if f"{key_prefix}.lokr_w1" in state_dict and f"{key_prefix}.lokr_w2" in state_dict: + return state_dict[f"{key_prefix}.lokr_w1"].shape[1] * state_dict[f"{key_prefix}.lokr_w2"].shape[1] + elif f"{key_prefix}.lokr_w1_b" in state_dict and f"{key_prefix}.lokr_w2_b" in state_dict: + return state_dict[f"{key_prefix}.lokr_w1_b"].shape[1] * state_dict[f"{key_prefix}.lokr_w2_b"].shape[1] + return None + + +def _lokr_out_dim(state_dict: dict[str | int, Any], key_prefix: str) -> int | None: + """Compute the output dimension of a LOKR layer: w1.shape[0] * w2.shape[0]. + + Supports both full LOKR (lokr_w1/lokr_w2) and factorized LOKR (lokr_w1_a/lokr_w2_a). + Returns None if the required keys are not present. + """ + if f"{key_prefix}.lokr_w1" in state_dict and f"{key_prefix}.lokr_w2" in state_dict: + return state_dict[f"{key_prefix}.lokr_w1"].shape[0] * state_dict[f"{key_prefix}.lokr_w2"].shape[0] + elif f"{key_prefix}.lokr_w1_a" in state_dict and f"{key_prefix}.lokr_w2_a" in state_dict: + return state_dict[f"{key_prefix}.lokr_w1_a"].shape[0] * state_dict[f"{key_prefix}.lokr_w2_a"].shape[0] + return None + + def _is_flux2_lora(mod: ModelOnDisk) -> bool: """Check if a FLUX-format LoRA is specifically for FLUX.2 (Klein) rather than FLUX.1. @@ -147,7 +173,30 @@ def _is_flux2_lora_state_dict(state_dict: dict[str | int, Any]) -> bool: elif "vector_in" in key and key.endswith("lora_A.weight"): return state_dict[key].shape[1] in _FLUX2_VEC_IN_DIMS - # BFL PEFT: hidden_size matches FLUX.1. Check MLP ratio to distinguish Klein 4B. + # BFL LyCORIS (LoKR/LoHA): attention projection → check hidden_size via product of dims + elif key.endswith((".img_attn.proj.lokr_w1", ".img_attn.proj.lokr_w1_b")): + layer_prefix = key.rsplit(".", 1)[0] + in_dim = _lokr_in_dim(state_dict, layer_prefix) + if in_dim is not None: + if in_dim != _FLUX1_HIDDEN_SIZE: + return True + bfl_hidden_size = in_dim # ambiguous, keep checking + + # BFL LyCORIS: context_embedder/txt_in + elif "txt_in" in key and key.endswith((".lokr_w1", ".lokr_w1_b")): + layer_prefix = key.rsplit(".", 1)[0] + in_dim = _lokr_in_dim(state_dict, layer_prefix) + if in_dim is not None: + return in_dim in _FLUX2_CONTEXT_IN_DIMS + + # BFL LyCORIS: vector_in + elif "vector_in" in key and key.endswith((".lokr_w1", ".lokr_w1_b")): + layer_prefix = key.rsplit(".", 1)[0] + in_dim = _lokr_in_dim(state_dict, layer_prefix) + if in_dim is not None: + return in_dim in _FLUX2_VEC_IN_DIMS + + # BFL PEFT/LyCORIS: hidden_size matches FLUX.1. Check MLP ratio to distinguish Klein 4B. # Klein 4B uses mlp_ratio=6 (ffn_dim=18432), FLUX.1 uses mlp_ratio=4 (ffn_dim=12288). if bfl_hidden_size == _FLUX1_HIDDEN_SIZE: for key in state_dict: @@ -158,6 +207,13 @@ def _is_flux2_lora_state_dict(state_dict: dict[str | int, Any]) -> bool: if ffn_dim != bfl_hidden_size * _FLUX1_MLP_RATIO: return True break + # BFL LyCORIS: check output dim of img_mlp.0 via product of dims + if key.startswith(_bfl_prefixes) and key.endswith((".img_mlp.0.lokr_w1", ".img_mlp.0.lokr_w1_a")): + layer_prefix = key.rsplit(".", 1)[0] + out_dim = _lokr_out_dim(state_dict, layer_prefix) + if out_dim is not None and out_dim != bfl_hidden_size * _FLUX1_MLP_RATIO: + return True + break # Check kohya format: look for context_embedder or vector_in keys # Kohya format uses lora_unet_ prefix with underscores instead of dots @@ -167,9 +223,21 @@ def _is_flux2_lora_state_dict(state_dict: dict[str | int, Any]) -> bool: if key.startswith("lora_unet_txt_in.") or key.startswith("lora_unet_context_embedder."): if key.endswith("lora_down.weight"): return state_dict[key].shape[1] in _FLUX2_CONTEXT_IN_DIMS + # Kohya LyCORIS (LoKR) + elif key.endswith((".lokr_w1", ".lokr_w1_b")): + layer_prefix = key.rsplit(".", 1)[0] + in_dim = _lokr_in_dim(state_dict, layer_prefix) + if in_dim is not None: + return in_dim in _FLUX2_CONTEXT_IN_DIMS if key.startswith("lora_unet_vector_in.") or key.startswith("lora_unet_time_text_embed_text_embedder_"): if key.endswith("lora_down.weight"): return state_dict[key].shape[1] in _FLUX2_VEC_IN_DIMS + # Kohya LyCORIS (LoKR) + elif key.endswith((".lokr_w1", ".lokr_w1_b")): + layer_prefix = key.rsplit(".", 1)[0] + in_dim = _lokr_in_dim(state_dict, layer_prefix) + if in_dim is not None: + return in_dim in _FLUX2_VEC_IN_DIMS return False @@ -244,7 +312,7 @@ def _get_flux2_lora_variant(state_dict: dict[str | int, Any]) -> Flux2VariantTyp return Flux2VariantType.Klein9B return None - # Check BFL PEFT format (diffusion_model.* or base_model.model.* prefix with BFL names) + # Check BFL PEFT/LyCORIS format (diffusion_model.* or base_model.model.* prefix with BFL names) _bfl_prefixes = ("diffusion_model.", "base_model.model.") for key in state_dict: if not isinstance(key, str): @@ -279,6 +347,39 @@ def _get_flux2_lora_variant(state_dict: dict[str | int, Any]) -> Flux2VariantTyp return Flux2VariantType.Klein9B return None + # BFL LyCORIS (LoKR): context embedder (txt_in) + if "txt_in" in key and key.endswith((".lokr_w1", ".lokr_w1_b")): + layer_prefix = key.rsplit(".", 1)[0] + in_dim = _lokr_in_dim(state_dict, layer_prefix) + if in_dim is not None: + if in_dim == KLEIN_4B_CONTEXT_DIM: + return Flux2VariantType.Klein4B + if in_dim == KLEIN_9B_CONTEXT_DIM: + return Flux2VariantType.Klein9B + return None + + # BFL LyCORIS (LoKR): vector embedder (vector_in) + if "vector_in" in key and key.endswith((".lokr_w1", ".lokr_w1_b")): + layer_prefix = key.rsplit(".", 1)[0] + in_dim = _lokr_in_dim(state_dict, layer_prefix) + if in_dim is not None: + if in_dim == KLEIN_4B_VEC_DIM: + return Flux2VariantType.Klein4B + if in_dim == KLEIN_9B_VEC_DIM: + return Flux2VariantType.Klein9B + return None + + # BFL LyCORIS (LoKR): attention projection + if key.endswith((".img_attn.proj.lokr_w1", ".img_attn.proj.lokr_w1_b")): + layer_prefix = key.rsplit(".", 1)[0] + in_dim = _lokr_in_dim(state_dict, layer_prefix) + if in_dim is not None: + if in_dim == KLEIN_4B_HIDDEN_SIZE: + return Flux2VariantType.Klein4B + if in_dim == KLEIN_9B_HIDDEN_SIZE: + return Flux2VariantType.Klein9B + return None + # Check kohya format for key in state_dict: if not isinstance(key, str): @@ -291,6 +392,16 @@ def _get_flux2_lora_variant(state_dict: dict[str | int, Any]) -> Flux2VariantTyp if dim == KLEIN_9B_CONTEXT_DIM: return Flux2VariantType.Klein9B return None + # Kohya LyCORIS (LoKR) + elif key.endswith((".lokr_w1", ".lokr_w1_b")): + layer_prefix = key.rsplit(".", 1)[0] + in_dim = _lokr_in_dim(state_dict, layer_prefix) + if in_dim is not None: + if in_dim == KLEIN_4B_CONTEXT_DIM: + return Flux2VariantType.Klein4B + if in_dim == KLEIN_9B_CONTEXT_DIM: + return Flux2VariantType.Klein9B + return None if key.startswith("lora_unet_vector_in.") or key.startswith("lora_unet_time_text_embed_text_embedder_"): if key.endswith("lora_down.weight"): dim = state_dict[key].shape[1] @@ -299,6 +410,16 @@ def _get_flux2_lora_variant(state_dict: dict[str | int, Any]) -> Flux2VariantTyp if dim == KLEIN_9B_VEC_DIM: return Flux2VariantType.Klein9B return None + # Kohya LyCORIS (LoKR) + elif key.endswith((".lokr_w1", ".lokr_w1_b")): + layer_prefix = key.rsplit(".", 1)[0] + in_dim = _lokr_in_dim(state_dict, layer_prefix) + if in_dim is not None: + if in_dim == KLEIN_4B_VEC_DIM: + return Flux2VariantType.Klein4B + if in_dim == KLEIN_9B_VEC_DIM: + return Flux2VariantType.Klein9B + return None return None @@ -423,6 +544,12 @@ def _validate_looks_like_lora(cls, mod: ModelOnDisk) -> None: "to_q_lora.down.weight", "lora_A.weight", "lora_B.weight", + # LyCORIS LoKR suffixes + "lokr_w1", + "lokr_w2", + # LyCORIS LoHA suffixes + "hada_w1_a", + "hada_w2_a", }, ) diff --git a/invokeai/backend/patches/lora_conversions/flux_aitoolkit_lora_conversion_utils.py b/invokeai/backend/patches/lora_conversions/flux_aitoolkit_lora_conversion_utils.py index b5ae202fc10..f359e7caa32 100644 --- a/invokeai/backend/patches/lora_conversions/flux_aitoolkit_lora_conversion_utils.py +++ b/invokeai/backend/patches/lora_conversions/flux_aitoolkit_lora_conversion_utils.py @@ -30,6 +30,27 @@ def is_state_dict_likely_in_flux_aitoolkit_format( if not _has_flux_layer_structure(state_dict): return False + # AIToolkit only produces standard PEFT LoRA (lora_A.weight / lora_B.weight). + # Exclude LyCORIS algorithm variants (LoKR, LoHA, etc.) which use different weight key suffixes. + # These are handled by the BFL PEFT converter instead. + _LYCORIS_SUFFIXES = ( + "lokr_w1", + "lokr_w2", + "lokr_w1_a", + "lokr_w1_b", + "lokr_w2_a", + "lokr_w2_b", + "lokr_t2", + "hada_w1_a", + "hada_w1_b", + "hada_w2_a", + "hada_w2_b", + "hada_t1", + "hada_t2", + ) + if any(k.endswith(_LYCORIS_SUFFIXES) for k in state_dict.keys() if isinstance(k, str)): + return False + if metadata: try: software = json.loads(metadata.get("software", "{}")) diff --git a/invokeai/backend/patches/lora_conversions/flux_bfl_peft_lora_conversion_utils.py b/invokeai/backend/patches/lora_conversions/flux_bfl_peft_lora_conversion_utils.py index faef5fb200a..68f9eb7eaeb 100644 --- a/invokeai/backend/patches/lora_conversions/flux_bfl_peft_lora_conversion_utils.py +++ b/invokeai/backend/patches/lora_conversions/flux_bfl_peft_lora_conversion_utils.py @@ -2,12 +2,19 @@ This format uses BFL internal key names (double_blocks, single_blocks, etc.) with a 'diffusion_model.' prefix and PEFT-style LoRA suffixes (lora_A.weight, lora_B.weight). +LyCORIS variants (LoKR, LoHA, etc.) are also supported, using their respective weight key +suffixes (lokr_w1, lokr_w2, hada_w1_a, etc.) in place of the PEFT suffixes. -Example keys: +Example keys (LoRA PEFT): diffusion_model.double_blocks.0.img_attn.proj.lora_A.weight diffusion_model.double_blocks.0.img_attn.qkv.lora_B.weight diffusion_model.single_blocks.0.linear1.lora_A.weight +Example keys (LoKR): + diffusion_model.double_blocks.0.img_attn.proj.lokr_w1 + diffusion_model.double_blocks.0.img_attn.proj.lokr_w2 + diffusion_model.single_blocks.0.linear1.lokr_w1 + This format is used by some training tools (e.g. SimpleTuner, ComfyUI-based trainers) and is common for FLUX.2 Klein LoRAs. """ @@ -22,6 +29,9 @@ from invokeai.backend.patches.layers.utils import any_lora_layer_from_state_dict from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX from invokeai.backend.patches.model_patch_raw import ModelPatchRaw +from invokeai.backend.util.logging import InvokeAILogger + +logger = InvokeAILogger.get_logger(__name__) # The prefixes used in BFL PEFT format LoRAs. # Most commonly "diffusion_model.", but some PEFT-wrapped variants use "base_model.model.". @@ -42,6 +52,37 @@ _DOUBLE_BLOCK_RE = re.compile(r"^double_blocks\.(\d+)\.(.+)$") _SINGLE_BLOCK_RE = re.compile(r"^single_blocks\.(\d+)\.(.+)$") +# Weight key suffixes used by PEFT LoRA in BFL format. +_BFL_PEFT_LORA_SUFFIXES = ("lora_A.weight", "lora_B.weight") + +# Weight key suffixes used by LyCORIS algorithms (LoKR, LoHA, etc.) in BFL format. +# These are single-component suffixes (no dot), unlike the two-component PEFT suffixes. +_BFL_LYCORIS_WEIGHT_SUFFIXES = ( + # LoKR + "lokr_w1", + "lokr_w2", + "lokr_w1_a", + "lokr_w1_b", + "lokr_w2_a", + "lokr_w2_b", + "lokr_t2", + # LoHA + "hada_w1_a", + "hada_w1_b", + "hada_w2_a", + "hada_w2_b", + "hada_t1", + "hada_t2", + # Common to all LyCORIS algorithms + "alpha", + "dora_scale", + # Full/Diff + "diff", +) + +# All recognized BFL weight key suffixes (both PEFT and LyCORIS). +_BFL_ALL_WEIGHT_SUFFIXES = _BFL_PEFT_LORA_SUFFIXES + _BFL_LYCORIS_WEIGHT_SUFFIXES + # Mapping of BFL double block layer suffixes to diffusers equivalents (simple renames). _DOUBLE_BLOCK_RENAMES: dict[str, str] = { "img_attn.proj": "attn.to_out.0", @@ -60,19 +101,20 @@ def is_state_dict_likely_in_flux_bfl_peft_format(state_dict: dict[str | int, torch.Tensor]) -> bool: - """Checks if the provided state dict is likely in the BFL PEFT FLUX LoRA format. + """Checks if the provided state dict is likely in the BFL PEFT FLUX LoRA/LyCORIS format. - This format uses BFL key names (double_blocks, single_blocks, img_attn, etc.) with PEFT LoRA - suffixes (lora_A.weight, lora_B.weight). The keys may be prefixed with either 'diffusion_model.' + This format uses BFL key names (double_blocks, single_blocks, img_attn, etc.) with either + PEFT LoRA suffixes (lora_A.weight, lora_B.weight) or LyCORIS algorithm suffixes (lokr_w1, + lokr_w2, hada_w1_a, etc.). The keys may be prefixed with either 'diffusion_model.' (common for ComfyUI/SimpleTuner) or 'base_model.model.' (PEFT-wrapped variant). """ str_keys = [k for k in state_dict.keys() if isinstance(k, str)] if not str_keys: return False - # All keys must be in PEFT format (lora_A.weight / lora_B.weight) - all_peft = all(k.endswith(("lora_A.weight", "lora_B.weight")) for k in str_keys) - if not all_peft: + # All keys must use recognized weight suffixes (PEFT LoRA or LyCORIS). + all_valid_suffixes = all(k.endswith(_BFL_ALL_WEIGHT_SUFFIXES) for k in str_keys) + if not all_valid_suffixes: return False # Must have at least some keys with FLUX block structure (double_blocks/single_blocks) @@ -94,13 +136,31 @@ def _strip_bfl_peft_prefix(key: str) -> str: return key +def _split_bfl_key(key: str) -> tuple[str, str]: + """Split a BFL key (after prefix stripping) into (layer_name, weight_suffix). + + Handles: + - 2-component suffixes ending with '.weight': e.g., 'lora_A.weight', 'lora_B.weight' + - 1-component suffixes: e.g., 'lokr_w1', 'lokr_w2', 'alpha', 'dora_scale' + """ + if key.endswith(".weight"): + # 2-component suffix: e.g., 'lora_A.weight' → split at last 2 dots + parts = key.rsplit(".", maxsplit=2) + return parts[0], f"{parts[1]}.{parts[2]}" + else: + # 1-component suffix: e.g., 'lokr_w1', 'alpha' → split at last dot + parts = key.rsplit(".", maxsplit=1) + return parts[0], parts[1] + + def lora_model_from_flux_bfl_peft_state_dict( state_dict: Dict[str, torch.Tensor], alpha: float | None = None ) -> ModelPatchRaw: - """Convert a BFL PEFT format FLUX LoRA state dict to a ModelPatchRaw. + """Convert a BFL PEFT/LyCORIS format FLUX LoRA state dict to a ModelPatchRaw. The conversion is straightforward: strip the prefix ('diffusion_model.' or 'base_model.model.') to get the BFL internal key names, which are already the format used by InvokeAI internally. + Supports both PEFT LoRA (lora_A.weight / lora_B.weight) and LyCORIS algorithms (LoKR, LoHA, etc.). """ # Group keys by layer grouped_state_dict: dict[str, dict[str, torch.Tensor]] = {} @@ -109,15 +169,12 @@ def lora_model_from_flux_bfl_peft_state_dict( if isinstance(key, str): key = _strip_bfl_peft_prefix(key) - # Split off the lora_A.weight / lora_B.weight suffix - parts = key.rsplit(".", maxsplit=2) - layer_name = parts[0] - suffix = ".".join(parts[1:]) + layer_name, suffix = _split_bfl_key(key) if layer_name not in grouped_state_dict: grouped_state_dict[layer_name] = {} - # Convert PEFT naming to InvokeAI naming + # Convert PEFT naming to InvokeAI naming; LyCORIS keys pass through unchanged. if suffix == "lora_A.weight": grouped_state_dict[layer_name]["lora_down.weight"] = value elif suffix == "lora_B.weight": @@ -141,7 +198,7 @@ def lora_model_from_flux_bfl_peft_state_dict( def lora_model_from_flux2_bfl_peft_state_dict( state_dict: Dict[str, torch.Tensor], alpha: float | None = None ) -> ModelPatchRaw: - """Convert a BFL PEFT format FLUX LoRA state dict for use with FLUX.2 Klein (diffusers model). + """Convert a BFL PEFT/LyCORIS format FLUX LoRA state dict for use with FLUX.2 Klein (diffusers model). FLUX.2 Klein models are loaded as Flux2Transformer2DModel (diffusers), which uses different layer naming than BFL's internal format: @@ -149,7 +206,7 @@ def lora_model_from_flux2_bfl_peft_state_dict( - single_blocks.{i} → single_transformer_blocks.{i} - Fused QKV (img_attn.qkv) → separate Q/K/V (attn.to_q, attn.to_k, attn.to_v) - This function converts BFL PEFT keys to diffusers naming and splits fused QKV LoRAs + This function converts BFL PEFT/LyCORIS keys to diffusers naming and splits fused QKV LoRAs into separate Q/K/V LoRA layers. """ # First, strip the prefix and group by BFL layer name with PEFT→InvokeAI naming. @@ -158,9 +215,7 @@ def lora_model_from_flux2_bfl_peft_state_dict( if isinstance(key, str): key = _strip_bfl_peft_prefix(key) - parts = key.rsplit(".", maxsplit=2) - layer_name = parts[0] - suffix = ".".join(parts[1:]) + layer_name, suffix = _split_bfl_key(key) if layer_name not in grouped_state_dict: grouped_state_dict[layer_name] = {} @@ -189,7 +244,7 @@ def lora_model_from_flux2_bfl_peft_state_dict( def _convert_bfl_layer_to_diffusers( bfl_key: str, layer_sd: dict[str, torch.Tensor] ) -> list[tuple[str, dict[str, torch.Tensor]]]: - """Convert a single BFL-named LoRA layer to one or more diffusers-named layers. + """Convert a single BFL-named LoRA/LyCORIS layer to one or more diffusers-named layers. Returns a list of (diffusers_key, layer_state_dict) tuples. Most layers produce one entry, but fused QKV layers are split into three separate Q/K/V entries. @@ -202,20 +257,42 @@ def _convert_bfl_layer_to_diffusers( # Fused image QKV → split into separate Q, K, V if rest == "img_attn.qkv": - return _split_qkv_lora( - layer_sd, - q_key=f"{prefix}.attn.to_q", - k_key=f"{prefix}.attn.to_k", - v_key=f"{prefix}.attn.to_v", - ) + if "lora_down.weight" in layer_sd: + return _split_qkv_lora( + layer_sd, + q_key=f"{prefix}.attn.to_q", + k_key=f"{prefix}.attn.to_k", + v_key=f"{prefix}.attn.to_v", + ) + elif "lokr_w1" in layer_sd or "lokr_w1_a" in layer_sd: + return _split_qkv_lokr( + layer_sd, + q_key=f"{prefix}.attn.to_q", + k_key=f"{prefix}.attn.to_k", + v_key=f"{prefix}.attn.to_v", + ) + else: + logger.warning(f"Unsupported layer type for QKV split in {bfl_key}; layer will be skipped.") + return [] # Fused text QKV → split into separate Q, K, V if rest == "txt_attn.qkv": - return _split_qkv_lora( - layer_sd, - q_key=f"{prefix}.attn.add_q_proj", - k_key=f"{prefix}.attn.add_k_proj", - v_key=f"{prefix}.attn.add_v_proj", - ) + if "lora_down.weight" in layer_sd: + return _split_qkv_lora( + layer_sd, + q_key=f"{prefix}.attn.add_q_proj", + k_key=f"{prefix}.attn.add_k_proj", + v_key=f"{prefix}.attn.add_v_proj", + ) + elif "lokr_w1" in layer_sd or "lokr_w1_a" in layer_sd: + return _split_qkv_lokr( + layer_sd, + q_key=f"{prefix}.attn.add_q_proj", + k_key=f"{prefix}.attn.add_k_proj", + v_key=f"{prefix}.attn.add_v_proj", + ) + else: + logger.warning(f"Unsupported layer type for QKV split in {bfl_key}; layer will be skipped.") + return [] # Simple renames if rest in _DOUBLE_BLOCK_RENAMES: return [(f"{prefix}.{_DOUBLE_BLOCK_RENAMES[rest]}", layer_sd)] @@ -269,6 +346,70 @@ def _split_qkv_lora( return result +def _split_qkv_lokr( + layer_sd: dict[str, torch.Tensor], + q_key: str, + k_key: str, + v_key: str, +) -> list[tuple[str, dict[str, torch.Tensor]]]: + """Split a fused QKV LoKR layer into separate Q, K, V full (diff) layers. + + LoKR uses a Kronecker product which cannot be split cleanly, so we compute the full weight + matrix and store each third as a full weight update (diff). + + BFL uses fused QKV: full weight [3*hidden, hidden]. + Diffusers uses separate layers: each gets a [hidden, hidden] weight slice. + + For factorized LOKR (w1_a/w1_b), the alpha/rank scale is baked into the diff weights because + FullLayer always uses scale=1.0. + """ + w1 = layer_sd.get("lokr_w1") + w1_a = layer_sd.get("lokr_w1_a") + w1_b = layer_sd.get("lokr_w1_b") + w2 = layer_sd.get("lokr_w2") + w2_a = layer_sd.get("lokr_w2_a") + w2_b = layer_sd.get("lokr_w2_b") + t2 = layer_sd.get("lokr_t2") + alpha = layer_sd.get("alpha") + + # Compute rank for scaling (only valid for factorized LOKR). + if w1_b is not None: + rank: int | None = w1_b.shape[0] + elif w2_b is not None: + rank = w2_b.shape[0] + else: + rank = None + + if w1 is None: + assert w1_a is not None and w1_b is not None + w1 = w1_a @ w1_b + if w2 is None: + assert w2_a is not None and w2_b is not None + if t2 is not None: + w2 = torch.einsum("i j k l, i p, j r -> p r k l", t2, w2_a, w2_b) + else: + w2 = w2_a @ w2_b + + if len(w2.shape) == 4: + w1 = w1.unsqueeze(2).unsqueeze(2) + + full_weight = torch.kron(w1, w2) # [3*hidden, hidden] + + # For factorized LOKR, bake the alpha/rank scale into the weight because FullLayer.scale() + # always returns 1.0 (it has no alpha). For non-factorized LOKR, rank is None and scale is 1.0. + if rank is not None and alpha is not None: + scale = alpha.item() / rank + full_weight = full_weight * scale + + weight_q, weight_k, weight_v = full_weight.chunk(3, dim=0) + + result = [] + for key, weight_part in [(q_key, weight_q), (k_key, weight_k), (v_key, weight_v)]: + result.append((key, {"diff": weight_part})) + + return result + + def convert_bfl_lora_patch_to_diffusers(patch: ModelPatchRaw) -> ModelPatchRaw: """Convert a ModelPatchRaw with BFL-format layer keys to diffusers-format keys. diff --git a/tests/backend/patches/lora_conversions/lora_state_dicts/flux_lokr_bfl_format.py b/tests/backend/patches/lora_conversions/lora_state_dicts/flux_lokr_bfl_format.py new file mode 100644 index 00000000000..bff9470ec63 --- /dev/null +++ b/tests/backend/patches/lora_conversions/lora_state_dicts/flux_lokr_bfl_format.py @@ -0,0 +1,22 @@ +# A sample state dict in the BFL LOKR format (FLUX.1 hidden_size=3072). +# These keys represent a LOKR model using BFL internal key names with 'diffusion_model.' prefix. +state_dict_keys = { + "diffusion_model.double_blocks.0.img_attn.proj.lokr_w1": [32, 96], + "diffusion_model.double_blocks.0.img_attn.proj.lokr_w2": [32, 32], + "diffusion_model.double_blocks.0.img_attn.proj.alpha": [], + "diffusion_model.double_blocks.0.img_attn.qkv.lokr_w1": [32, 96], + "diffusion_model.double_blocks.0.img_attn.qkv.lokr_w2": [32, 288], + "diffusion_model.double_blocks.0.img_attn.qkv.alpha": [], + "diffusion_model.double_blocks.0.img_mlp.0.lokr_w1": [32, 96], + "diffusion_model.double_blocks.0.img_mlp.0.lokr_w2": [32, 128], + "diffusion_model.double_blocks.0.img_mlp.0.alpha": [], + "diffusion_model.double_blocks.0.img_mlp.2.lokr_w1": [32, 128], + "diffusion_model.double_blocks.0.img_mlp.2.lokr_w2": [32, 96], + "diffusion_model.double_blocks.0.img_mlp.2.alpha": [], + "diffusion_model.single_blocks.0.linear1.lokr_w1": [32, 128], + "diffusion_model.single_blocks.0.linear1.lokr_w2": [32, 128], + "diffusion_model.single_blocks.0.linear1.alpha": [], + "diffusion_model.single_blocks.0.linear2.lokr_w1": [32, 64], + "diffusion_model.single_blocks.0.linear2.lokr_w2": [32, 48], + "diffusion_model.single_blocks.0.linear2.alpha": [], +} diff --git a/tests/backend/patches/lora_conversions/test_flux_aitoolkit_lora_conversion_utils.py b/tests/backend/patches/lora_conversions/test_flux_aitoolkit_lora_conversion_utils.py index f9c20e82a5d..648f17438bb 100644 --- a/tests/backend/patches/lora_conversions/test_flux_aitoolkit_lora_conversion_utils.py +++ b/tests/backend/patches/lora_conversions/test_flux_aitoolkit_lora_conversion_utils.py @@ -12,6 +12,9 @@ from tests.backend.patches.lora_conversions.lora_state_dicts.flux_dora_onetrainer_format import ( state_dict_keys as flux_onetrainer_state_dict_keys, ) +from tests.backend.patches.lora_conversions.lora_state_dicts.flux_lokr_bfl_format import ( + state_dict_keys as flux_lokr_bfl_state_dict_keys, +) from tests.backend.patches.lora_conversions.lora_state_dicts.flux_lora_aitoolkit_format import ( state_dict_keys as flux_aitoolkit_state_dict_keys, ) @@ -26,7 +29,10 @@ def test_is_state_dict_likely_in_flux_aitoolkit_format(): assert is_state_dict_likely_in_flux_aitoolkit_format(state_dict) -@pytest.mark.parametrize("sd_keys", [flux_diffusers_state_dict_keys, flux_onetrainer_state_dict_keys]) +@pytest.mark.parametrize( + "sd_keys", + [flux_diffusers_state_dict_keys, flux_onetrainer_state_dict_keys, flux_lokr_bfl_state_dict_keys], +) def test_is_state_dict_likely_in_flux_kohya_format_false(sd_keys: dict[str, list[int]]): state_dict = keys_to_mock_state_dict(sd_keys) assert not is_state_dict_likely_in_flux_aitoolkit_format(state_dict) diff --git a/tests/model_identification/stripped_models/flux1-bfl-lokr-attn-proj-mlp/__test_metadata__.json b/tests/model_identification/stripped_models/flux1-bfl-lokr-attn-proj-mlp/__test_metadata__.json new file mode 100644 index 00000000000..12e61135b55 --- /dev/null +++ b/tests/model_identification/stripped_models/flux1-bfl-lokr-attn-proj-mlp/__test_metadata__.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fbc0e95aa9c9954422d9a08fe69a08c8aa290de30701375236f6511d5857c2a7 +size 308 diff --git a/tests/model_identification/stripped_models/flux1-bfl-lokr-attn-proj-mlp/model.safetensors b/tests/model_identification/stripped_models/flux1-bfl-lokr-attn-proj-mlp/model.safetensors new file mode 100644 index 00000000000..b5cdf0e0287 --- /dev/null +++ b/tests/model_identification/stripped_models/flux1-bfl-lokr-attn-proj-mlp/model.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:09401cf2cb3a6579498a92dc54d10d262e966523418ec53eee18a77a567875ed +size 769 diff --git a/tests/model_identification/stripped_models/flux2-klein4b-bfl-lokr-mlp-ratio/__test_metadata__.json b/tests/model_identification/stripped_models/flux2-klein4b-bfl-lokr-mlp-ratio/__test_metadata__.json new file mode 100644 index 00000000000..c8457316ae7 --- /dev/null +++ b/tests/model_identification/stripped_models/flux2-klein4b-bfl-lokr-mlp-ratio/__test_metadata__.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:16562b43278f6309c669077b44261c5cce1ca5aa1815dbd779bce0c1903d6a25 +size 367 diff --git a/tests/model_identification/stripped_models/flux2-klein4b-bfl-lokr-mlp-ratio/model.safetensors b/tests/model_identification/stripped_models/flux2-klein4b-bfl-lokr-mlp-ratio/model.safetensors new file mode 100644 index 00000000000..e25d4b21aeb --- /dev/null +++ b/tests/model_identification/stripped_models/flux2-klein4b-bfl-lokr-mlp-ratio/model.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e44472f3efa142982ddfa9efefd53aa00f742085c819b33fe4f4857fd25fb789 +size 769 diff --git a/tests/model_identification/stripped_models/flux2-klein4b-bfl-lokr-txt-in/__test_metadata__.json b/tests/model_identification/stripped_models/flux2-klein4b-bfl-lokr-txt-in/__test_metadata__.json new file mode 100644 index 00000000000..4490630bd68 --- /dev/null +++ b/tests/model_identification/stripped_models/flux2-klein4b-bfl-lokr-txt-in/__test_metadata__.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cda88f989dcb8523645a9aabece2f3537c46dfc8d06eb5d7c33955fc235b2268 +size 293 diff --git a/tests/model_identification/stripped_models/flux2-klein4b-bfl-lokr-txt-in/model.safetensors b/tests/model_identification/stripped_models/flux2-klein4b-bfl-lokr-txt-in/model.safetensors new file mode 100644 index 00000000000..f6f5cd72030 --- /dev/null +++ b/tests/model_identification/stripped_models/flux2-klein4b-bfl-lokr-txt-in/model.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ae8f2421999912107ea6eac7a2b7cee6a8873b96df45f64c798c2d9beda77e4e +size 842 diff --git a/tests/model_identification/stripped_models/flux2-klein9b-bfl-lokr-attn-proj/__test_metadata__.json b/tests/model_identification/stripped_models/flux2-klein9b-bfl-lokr-attn-proj/__test_metadata__.json new file mode 100644 index 00000000000..5c475f3f679 --- /dev/null +++ b/tests/model_identification/stripped_models/flux2-klein9b-bfl-lokr-attn-proj/__test_metadata__.json @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:20b432b2e42ff4781a4ba5e1974a6a7795c3bba45dc1377b8dcb6a7868eff1db +size 300 diff --git a/tests/model_identification/stripped_models/flux2-klein9b-bfl-lokr-attn-proj/model.safetensors b/tests/model_identification/stripped_models/flux2-klein9b-bfl-lokr-attn-proj/model.safetensors new file mode 100644 index 00000000000..c787e0f294d --- /dev/null +++ b/tests/model_identification/stripped_models/flux2-klein9b-bfl-lokr-attn-proj/model.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cd4555fa20a08bbaa486ed4b7916f49af05892bbb467d03796819b233a7f934a +size 457