Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
7c72824
refactor: port MM probes to new api
psychedelicious Sep 23, 2025
8ae9716
feat(mm): port TIs to new API
psychedelicious Sep 23, 2025
8b6fe5c
tidy(mm): remove unused probes
psychedelicious Sep 23, 2025
cdcdecc
feat(mm): port spandrel to new API
psychedelicious Sep 23, 2025
12c3cbc
fix(mm): parsing for spandrel
psychedelicious Sep 23, 2025
7ab6042
fix(mm): loader for clip embed
psychedelicious Sep 23, 2025
1db1264
fix(mm): tis use existing weight_files method
psychedelicious Sep 23, 2025
82ffb58
feat(mm): port vae to new API
psychedelicious Sep 23, 2025
20a0231
fix(mm): vae class inheritance and config_path
psychedelicious Sep 23, 2025
c88fee6
tidy(mm): patcher types and import paths
psychedelicious Sep 23, 2025
5996e31
feat(mm): better errors when invalid model config found in db
psychedelicious Sep 23, 2025
8217fd9
feat(mm): port t5 to new API
psychedelicious Sep 23, 2025
1d3f6c4
feat(mm): make config_path optional
psychedelicious Sep 23, 2025
881f063
refactor(mm): simplify model classification process
psychedelicious Sep 24, 2025
049e9f2
refactor(mm): remove unused methods in config.py
psychedelicious Sep 24, 2025
8b6929b
refactor(mm): add model config parsing utils
psychedelicious Sep 24, 2025
4220657
fix(mm): abstractmethod bork
psychedelicious Sep 24, 2025
6c60e6d
tidy(mm): clarify that model id utils are private
psychedelicious Sep 24, 2025
b1780f9
fix(mm): fall back to UnknownModelConfig correctly
psychedelicious Sep 24, 2025
cfef478
feat(mm): port CLIPVisionDiffusersConfig to new api
psychedelicious Sep 24, 2025
4f4268e
feat(mm): port SigLIPDiffusersConfig to new api
psychedelicious Sep 24, 2025
01104f5
feat(mm): make match helpers more succint
psychedelicious Sep 24, 2025
6c66013
feat(mm): port flux redux to new api
psychedelicious Sep 24, 2025
20db2cb
feat(mm): port ip adapter to new api
psychedelicious Sep 24, 2025
f0e931c
tidy(mm): skip optimistic override handling for now
psychedelicious Sep 24, 2025
2813ec4
refactor(mm): continue iterating on config
psychedelicious Sep 25, 2025
e0d91ef
feat(mm): port flux "control lora" and t2i adapter to new api
psychedelicious Sep 25, 2025
5deb9bb
tidy(ui): use Extract to get model config types
psychedelicious Sep 25, 2025
07e99c9
fix(mm): t2i base determination
psychedelicious Sep 25, 2025
d27bef1
feat(mm): port cnet to new api
psychedelicious Sep 25, 2025
1268b23
refactor(mm): add config validation utils, make it all consistent and…
psychedelicious Sep 25, 2025
5f45a9c
feat(mm): wip port of main models to new api
psychedelicious Sep 25, 2025
7765c83
feat(mm): wip port of main models to new api
psychedelicious Sep 25, 2025
3a44fde
feat(mm): wip port of main models to new api
psychedelicious Sep 25, 2025
69efdc3
docs(mm): add todos
psychedelicious Sep 26, 2025
7765df4
tidy(mm): removed unused model merge class
psychedelicious Sep 29, 2025
9676cb8
feat(mm): wip port main models to new api
psychedelicious Sep 29, 2025
09449cf
tidy(mm): clean up model heuristic utils
psychedelicious Oct 1, 2025
d63348b
tidy(mm): clean up ModelOnDisk caching
psychedelicious Oct 1, 2025
bab7f62
tidy(mm): flux lora format util
psychedelicious Oct 1, 2025
935fafe
refactor(mm): make config classes narrow
psychedelicious Oct 1, 2025
17c5ad2
refactor(mm): diffusers loras
psychedelicious Oct 1, 2025
29087af
feat(mm): consistent naming for all model config classes
psychedelicious Oct 1, 2025
32a9ad1
fix(mm): tag generation & scattered probe fixes
psychedelicious Oct 1, 2025
508c488
tidy(mm): consistent class names
psychedelicious Oct 2, 2025
fb8b769
feat(mm): add abs method to load models to config
psychedelicious Oct 2, 2025
4c289e3
refactor(mm): migrate loaders to config classes
psychedelicious Oct 2, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions invokeai/app/api/routers/model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@
from invokeai.backend.model_manager import BaseModelType, ModelFormat, ModelType
from invokeai.backend.model_manager.config import (
AnyModelConfig,
MainCheckpointConfig,
Main_Checkpoint_SD1_Config,
Main_Checkpoint_SD2_Config,
Main_Checkpoint_SDXL_Config,
Main_Checkpoint_SDXLRefiner_Config,
)
from invokeai.backend.model_manager.load.model_cache.cache_stats import CacheStats
from invokeai.backend.model_manager.metadata.fetch.huggingface import HuggingFaceMetadataFetch
Expand Down Expand Up @@ -741,9 +744,18 @@ async def convert_model(
logger.error(str(e))
raise HTTPException(status_code=424, detail=str(e))

if not isinstance(model_config, MainCheckpointConfig):
logger.error(f"The model with key {key} is not a main checkpoint model.")
raise HTTPException(400, f"The model with key {key} is not a main checkpoint model.")
if isinstance(
model_config,
(
Main_Checkpoint_SD1_Config,
Main_Checkpoint_SD2_Config,
Main_Checkpoint_SDXL_Config,
Main_Checkpoint_SDXLRefiner_Config,
),
):
msg = f"The model with key {key} is not a main SD 1/2/XL checkpoint model."
logger.error(msg)
raise HTTPException(400, msg)

with TemporaryDirectory(dir=ApiDependencies.invoker.services.configuration.models_path) as tmpdir:
convert_path = pathlib.Path(tmpdir) / pathlib.Path(model_config.path).stem
Expand Down
4 changes: 2 additions & 2 deletions invokeai/app/invocations/create_gradient_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from invokeai.app.invocations.model import UNetField, VAEField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager import LoadedModel
from invokeai.backend.model_manager.config import MainConfigBase
from invokeai.backend.model_manager.config import Main_Config_Base
from invokeai.backend.model_manager.taxonomy import ModelVariantType
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor

Expand Down Expand Up @@ -182,7 +182,7 @@ def invoke(self, context: InvocationContext) -> GradientMaskOutput:
if self.unet is not None and self.vae is not None and self.image is not None:
# all three fields must be present at the same time
main_model_config = context.models.get_config(self.unet.unet.key)
assert isinstance(main_model_config, MainConfigBase)
assert isinstance(main_model_config, Main_Config_Base)
if main_model_config.variant is ModelVariantType.Inpaint:
mask = dilated_mask_tensor
vae_info: LoadedModel = context.models.load(self.vae.vae)
Expand Down
6 changes: 3 additions & 3 deletions invokeai/app/invocations/flux_denoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
unpack,
)
from invokeai.backend.flux.text_conditioning import FluxReduxConditioning, FluxTextConditioning
from invokeai.backend.model_manager.taxonomy import ModelFormat, ModelVariantType
from invokeai.backend.model_manager.taxonomy import FluxVariantType, ModelFormat
from invokeai.backend.patches.layer_patcher import LayerPatcher
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
Expand Down Expand Up @@ -232,7 +232,7 @@ def _run_diffusion(
)

transformer_config = context.models.get_config(self.transformer.transformer)
is_schnell = "schnell" in getattr(transformer_config, "config_path", "")
is_schnell = transformer_config.variant is FluxVariantType.Schnell

# Calculate the timestep schedule.
timesteps = get_schedule(
Expand Down Expand Up @@ -277,7 +277,7 @@ def _run_diffusion(

# Prepare the extra image conditioning tensor (img_cond) for either FLUX structural control or FLUX Fill.
img_cond: torch.Tensor | None = None
is_flux_fill = transformer_config.variant == ModelVariantType.Inpaint # type: ignore
is_flux_fill = transformer_config.variant is FluxVariantType.DevFill
if is_flux_fill:
img_cond = self._prep_flux_fill_img_cond(
context, device=TorchDevice.choose_torch_device(), dtype=inference_dtype
Expand Down
5 changes: 2 additions & 3 deletions invokeai/app/invocations/flux_ip_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.config import (
IPAdapterCheckpointConfig,
IPAdapterInvokeAIConfig,
IPAdapter_Checkpoint_FLUX_Config,
)
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType

Expand Down Expand Up @@ -68,7 +67,7 @@ def validate_begin_end_step_percent(self) -> Self:
def invoke(self, context: InvocationContext) -> IPAdapterOutput:
# Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model.
ip_adapter_info = context.models.get_config(self.ip_adapter_model.key)
assert isinstance(ip_adapter_info, (IPAdapterInvokeAIConfig, IPAdapterCheckpointConfig))
assert isinstance(ip_adapter_info, IPAdapter_Checkpoint_FLUX_Config)

# Note: There is a IPAdapterInvokeAIConfig.image_encoder_model_id field, but it isn't trustworthy.
image_encoder_starter_model = CLIP_VISION_MODEL_MAP[self.clip_vision_model]
Expand Down
8 changes: 4 additions & 4 deletions invokeai/app/invocations/flux_model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
preprocess_t5_encoder_model_identifier,
preprocess_t5_tokenizer_model_identifier,
)
from invokeai.backend.flux.util import max_seq_lengths
from invokeai.backend.flux.util import get_flux_max_seq_length
from invokeai.backend.model_manager.config import (
CheckpointConfigBase,
Checkpoint_Config_Base,
)
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType, SubModelType

Expand Down Expand Up @@ -87,12 +87,12 @@ def invoke(self, context: InvocationContext) -> FluxModelLoaderOutput:
t5_encoder = preprocess_t5_encoder_model_identifier(self.t5_encoder_model)

transformer_config = context.models.get_config(transformer)
assert isinstance(transformer_config, CheckpointConfigBase)
assert isinstance(transformer_config, Checkpoint_Config_Base)

return FluxModelLoaderOutput(
transformer=TransformerField(transformer=transformer, loras=[]),
clip=CLIPField(tokenizer=tokenizer, text_encoder=clip_encoder, loras=[], skipped_layers=0),
t5_encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder, loras=[]),
vae=VAEField(vae=vae),
max_seq_len=max_seq_lengths[transformer_config.config_path],
max_seq_len=get_flux_max_seq_length(transformer_config.variant),
)
8 changes: 4 additions & 4 deletions invokeai/app/invocations/ip_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.config import (
AnyModelConfig,
IPAdapterCheckpointConfig,
IPAdapterInvokeAIConfig,
IPAdapter_Checkpoint_Config_Base,
IPAdapter_InvokeAI_Config_Base,
)
from invokeai.backend.model_manager.starter_models import (
StarterModel,
Expand Down Expand Up @@ -123,9 +123,9 @@ def validate_begin_end_step_percent(self) -> Self:
def invoke(self, context: InvocationContext) -> IPAdapterOutput:
# Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model.
ip_adapter_info = context.models.get_config(self.ip_adapter_model.key)
assert isinstance(ip_adapter_info, (IPAdapterInvokeAIConfig, IPAdapterCheckpointConfig))
assert isinstance(ip_adapter_info, (IPAdapter_InvokeAI_Config_Base, IPAdapter_Checkpoint_Config_Base))

if isinstance(ip_adapter_info, IPAdapterInvokeAIConfig):
if isinstance(ip_adapter_info, IPAdapter_InvokeAI_Config_Base):
image_encoder_model_id = ip_adapter_info.image_encoder_model_id
image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip()
else:
Expand Down
5 changes: 3 additions & 2 deletions invokeai/app/invocations/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@ class ModelIdentifierField(BaseModel):
name: str = Field(description="The model's name")
base: BaseModelType = Field(description="The model's base model type")
type: ModelType = Field(description="The model's type")
submodel_type: Optional[SubModelType] = Field(
description="The submodel to load, if this is a main model", default=None
submodel_type: SubModelType | None = Field(
description="The submodel to load, if this is a main model",
default=None,
)

@classmethod
Expand Down
27 changes: 11 additions & 16 deletions invokeai/app/services/model_install/model_install_default.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import re
import threading
import time
from copy import deepcopy
from pathlib import Path
from queue import Empty, Queue
from shutil import move, rmtree
Expand Down Expand Up @@ -36,11 +37,10 @@
from invokeai.app.services.model_records.model_records_base import ModelRecordChanges
from invokeai.backend.model_manager.config import (
AnyModelConfig,
CheckpointConfigBase,
Checkpoint_Config_Base,
InvalidModelConfigException,
ModelConfigBase,
ModelConfigFactory,
)
from invokeai.backend.model_manager.legacy_probe import ModelProbe
from invokeai.backend.model_manager.metadata import (
AnyModelRepoMetadata,
HuggingFaceMetadataFetch,
Expand Down Expand Up @@ -370,6 +370,8 @@ def unconditionally_delete(self, key: str) -> None: # noqa D102
model_path = self.app_config.models_path / model.path
if model_path.is_file() or model_path.is_symlink():
model_path.unlink()
assert model_path.parent != self.app_config.models_path
os.rmdir(model_path.parent)
elif model_path.is_dir():
rmtree(model_path)
self.unregister(key)
Expand Down Expand Up @@ -598,18 +600,11 @@ def _probe(self, model_path: Path, config: Optional[ModelRecordChanges] = None):
hash_algo = self._app_config.hashing_algorithm
fields = config.model_dump()

# WARNING!
# The legacy probe relies on the implicit order of tests to determine model classification.
# This can lead to regressions between the legacy and new probes.
# Do NOT change the order of `probe` and `classify` without implementing one of the following fixes:
# Short-term fix: `classify` tests `matches` in the same order as the legacy probe.
# Long-term fix: Improve `matches` to be more specific so that only one config matches
# any given model - eliminating ambiguity and removing reliance on order.
# After implementing either of these fixes, remove @pytest.mark.xfail from `test_regression_against_model_probe`
try:
return ModelProbe.probe(model_path=model_path, fields=fields, hash_algo=hash_algo) # type: ignore
except InvalidModelConfigException:
return ModelConfigBase.classify(model_path, hash_algo, **fields)
return ModelConfigFactory.from_model_on_disk(
mod=model_path,
overrides=deepcopy(fields),
hash_algo=hash_algo,
)

def _register(
self, model_path: Path, config: Optional[ModelRecordChanges] = None, info: Optional[AnyModelConfig] = None
Expand All @@ -630,7 +625,7 @@ def _register(

info.path = model_path.as_posix()

if isinstance(info, CheckpointConfigBase):
if isinstance(info, Checkpoint_Config_Base) and info.config_path is not None:
# Checkpoints have a config file needed for conversion. Same handling as the model weights - if it's in the
# invoke-managed legacy config dir, we use a relative path.
legacy_config_path = self.app_config.legacy_conf_path / info.config_path
Expand Down
93 changes: 68 additions & 25 deletions invokeai/app/services/model_load/model_load_default.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"""Implementation of model loader service."""

from pathlib import Path
from typing import Callable, Optional, Type
from typing import Any, Callable, Optional

from picklescan.scanner import scan_file_path
from safetensors.torch import load_file as safetensors_load_file
Expand All @@ -11,36 +11,29 @@
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.model_load.model_load_base import ModelLoadServiceBase
from invokeai.backend.model_manager.config import AnyModelConfig
from invokeai.backend.model_manager.load import (
LoadedModel,
LoadedModelWithoutConfig,
ModelLoaderRegistry,
ModelLoaderRegistryBase,
)
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
from invokeai.backend.model_manager.config import AnyModelConfig, Diffusers_Config_Base
from invokeai.backend.model_manager.load import LoadedModel, LoadedModelWithoutConfig
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache, get_model_cache_key
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_fs
from invokeai.backend.model_manager.taxonomy import AnyModel, SubModelType
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.logging import InvokeAILogger


class ModelLoadService(ModelLoadServiceBase):
"""Wrapper around ModelLoaderRegistry."""
"""Model loading service using config-based loading."""

def __init__(
self,
app_config: InvokeAIAppConfig,
ram_cache: ModelCache,
registry: Optional[Type[ModelLoaderRegistryBase]] = ModelLoaderRegistry,
):
"""Initialize the model load service."""
logger = InvokeAILogger.get_logger(self.__class__.__name__)
logger.setLevel(app_config.log_level.upper())
self._logger = logger
self._app_config = app_config
self._ram_cache = ram_cache
self._registry = registry

def start(self, invoker: Invoker) -> None:
self._invoker = invoker
Expand All @@ -63,18 +56,49 @@ def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubMo
if hasattr(self, "_invoker"):
self._invoker.services.events.emit_model_load_started(model_config, submodel_type)

implementation, model_config, submodel_type = self._registry.get_implementation(model_config, submodel_type) # type: ignore
loaded_model: LoadedModel = implementation(
app_config=self._app_config,
logger=self._logger,
ram_cache=self._ram_cache,
).load_model(model_config, submodel_type)
loaded_model = self._load_model_from_config(model_config, submodel_type)

if hasattr(self, "_invoker"):
self._invoker.services.events.emit_model_load_complete(model_config, submodel_type)

return loaded_model

def _load_model_from_config(
self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None
) -> LoadedModel:
"""Load a model using the config's load_model method."""
model_path = Path(model_config.path)
stats_name = ":".join([model_config.base, model_config.type, model_config.name, (submodel_type or "")])

# Check if model is already in cache
try:
cache_record = self._ram_cache.get(key=get_model_cache_key(model_config.key, submodel_type), stats_name=stats_name)
return LoadedModel(config=model_config, cache_record=cache_record, cache=self._ram_cache)
except IndexError:
pass

# Make room in cache
variant = model_config.repo_variant if isinstance(model_config, Diffusers_Config_Base) else None
model_size = calc_model_size_by_fs(
model_path=model_path,
subfolder=submodel_type.value if submodel_type else None,
variant=variant,
)
self._ram_cache.make_room(model_size)

# Load the model using the config's load_model method
raw_model = model_config.load_model(submodel_type)

# Cache the loaded model
self._ram_cache.put(
get_model_cache_key(model_config.key, submodel_type),
model=raw_model,
)

# Retrieve from cache and return
cache_record = self._ram_cache.get(key=get_model_cache_key(model_config.key, submodel_type), stats_name=stats_name)
return LoadedModel(config=model_config, cache_record=cache_record, cache=self._ram_cache)

def load_model_from_path(
self, model_path: Path, loader: Optional[Callable[[Path], AnyModel]] = None
) -> LoadedModelWithoutConfig:
Expand Down Expand Up @@ -107,12 +131,31 @@ def torch_load_file(checkpoint: Path) -> AnyModel:
return result

def diffusers_load_directory(directory: Path) -> AnyModel:
load_class = GenericDiffusersLoader(
app_config=self._app_config,
logger=self._logger,
ram_cache=self._ram_cache,
convert_cache=self.convert_cache,
).get_hf_load_class(directory)
from diffusers.configuration_utils import ConfigMixin

class ConfigLoader(ConfigMixin):
"""Subclass of ConfigMixin for loading diffusers configuration files."""

@classmethod
def load_config(cls, *args: Any, **kwargs: Any) -> dict[str, Any]: # type: ignore
"""Load a diffusers ConfigMixin configuration."""
cls.config_name = kwargs.pop("config_name")
return super().load_config(*args, **kwargs) # type: ignore

config = ConfigLoader.load_config(directory, config_name="config.json")
if class_name := config.get("_class_name"):
import sys

res_type = sys.modules["diffusers"]
load_class = getattr(res_type, class_name)
elif class_name := config.get("architectures"):
import sys

res_type = sys.modules["transformers"]
load_class = getattr(res_type, class_name[0])
else:
raise Exception("Unable to determine load class from config.json")

return load_class.from_pretrained(model_path, torch_dtype=TorchDevice.choose_torch_dtype())

loader = loader or (
Expand Down
5 changes: 4 additions & 1 deletion invokeai/app/services/model_records/model_records_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from invokeai.backend.model_manager.taxonomy import (
BaseModelType,
ClipVariantType,
FluxVariantType,
ModelFormat,
ModelSourceType,
ModelType,
Expand Down Expand Up @@ -90,7 +91,9 @@ class ModelRecordChanges(BaseModelExcludeNull):

# Checkpoint-specific changes
# TODO(MM2): Should we expose these? Feels footgun-y...
variant: Optional[ModelVariantType | ClipVariantType] = Field(description="The variant of the model.", default=None)
variant: Optional[ModelVariantType | ClipVariantType | FluxVariantType] = Field(
description="The variant of the model.", default=None
)
prediction_type: Optional[SchedulerPredictionType] = Field(
description="The prediction type of the model.", default=None
)
Expand Down
Loading
Loading