diff --git a/invokeai/app/api/routers/model_manager.py b/invokeai/app/api/routers/model_manager.py index 6142239cf65..84db65252e1 100644 --- a/invokeai/app/api/routers/model_manager.py +++ b/invokeai/app/api/routers/model_manager.py @@ -31,7 +31,10 @@ from invokeai.backend.model_manager import BaseModelType, ModelFormat, ModelType from invokeai.backend.model_manager.config import ( AnyModelConfig, - MainCheckpointConfig, + Main_Checkpoint_SD1_Config, + Main_Checkpoint_SD2_Config, + Main_Checkpoint_SDXL_Config, + Main_Checkpoint_SDXLRefiner_Config, ) from invokeai.backend.model_manager.load.model_cache.cache_stats import CacheStats from invokeai.backend.model_manager.metadata.fetch.huggingface import HuggingFaceMetadataFetch @@ -741,9 +744,18 @@ async def convert_model( logger.error(str(e)) raise HTTPException(status_code=424, detail=str(e)) - if not isinstance(model_config, MainCheckpointConfig): - logger.error(f"The model with key {key} is not a main checkpoint model.") - raise HTTPException(400, f"The model with key {key} is not a main checkpoint model.") + if isinstance( + model_config, + ( + Main_Checkpoint_SD1_Config, + Main_Checkpoint_SD2_Config, + Main_Checkpoint_SDXL_Config, + Main_Checkpoint_SDXLRefiner_Config, + ), + ): + msg = f"The model with key {key} is not a main SD 1/2/XL checkpoint model." + logger.error(msg) + raise HTTPException(400, msg) with TemporaryDirectory(dir=ApiDependencies.invoker.services.configuration.models_path) as tmpdir: convert_path = pathlib.Path(tmpdir) / pathlib.Path(model_config.path).stem diff --git a/invokeai/app/invocations/create_gradient_mask.py b/invokeai/app/invocations/create_gradient_mask.py index b232fbbc932..f6e046d096e 100644 --- a/invokeai/app/invocations/create_gradient_mask.py +++ b/invokeai/app/invocations/create_gradient_mask.py @@ -21,7 +21,7 @@ from invokeai.app.invocations.model import UNetField, VAEField from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.model_manager import LoadedModel -from invokeai.backend.model_manager.config import MainConfigBase +from invokeai.backend.model_manager.config import Main_Config_Base from invokeai.backend.model_manager.taxonomy import ModelVariantType from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor @@ -182,7 +182,7 @@ def invoke(self, context: InvocationContext) -> GradientMaskOutput: if self.unet is not None and self.vae is not None and self.image is not None: # all three fields must be present at the same time main_model_config = context.models.get_config(self.unet.unet.key) - assert isinstance(main_model_config, MainConfigBase) + assert isinstance(main_model_config, Main_Config_Base) if main_model_config.variant is ModelVariantType.Inpaint: mask = dilated_mask_tensor vae_info: LoadedModel = context.models.load(self.vae.vae) diff --git a/invokeai/app/invocations/flux_denoise.py b/invokeai/app/invocations/flux_denoise.py index 35d095e2799..1599e8428cb 100644 --- a/invokeai/app/invocations/flux_denoise.py +++ b/invokeai/app/invocations/flux_denoise.py @@ -48,7 +48,7 @@ unpack, ) from invokeai.backend.flux.text_conditioning import FluxReduxConditioning, FluxTextConditioning -from invokeai.backend.model_manager.taxonomy import ModelFormat, ModelVariantType +from invokeai.backend.model_manager.taxonomy import FluxVariantType, ModelFormat from invokeai.backend.patches.layer_patcher import LayerPatcher from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX from invokeai.backend.patches.model_patch_raw import ModelPatchRaw @@ -232,7 +232,7 @@ def _run_diffusion( ) transformer_config = context.models.get_config(self.transformer.transformer) - is_schnell = "schnell" in getattr(transformer_config, "config_path", "") + is_schnell = transformer_config.variant is FluxVariantType.Schnell # Calculate the timestep schedule. timesteps = get_schedule( @@ -277,7 +277,7 @@ def _run_diffusion( # Prepare the extra image conditioning tensor (img_cond) for either FLUX structural control or FLUX Fill. img_cond: torch.Tensor | None = None - is_flux_fill = transformer_config.variant == ModelVariantType.Inpaint # type: ignore + is_flux_fill = transformer_config.variant is FluxVariantType.DevFill if is_flux_fill: img_cond = self._prep_flux_fill_img_cond( context, device=TorchDevice.choose_torch_device(), dtype=inference_dtype diff --git a/invokeai/app/invocations/flux_ip_adapter.py b/invokeai/app/invocations/flux_ip_adapter.py index db5754ee2b0..c564023a3a0 100644 --- a/invokeai/app/invocations/flux_ip_adapter.py +++ b/invokeai/app/invocations/flux_ip_adapter.py @@ -17,8 +17,7 @@ from invokeai.app.invocations.util import validate_begin_end_step, validate_weights from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.model_manager.config import ( - IPAdapterCheckpointConfig, - IPAdapterInvokeAIConfig, + IPAdapter_Checkpoint_FLUX_Config, ) from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType @@ -68,7 +67,7 @@ def validate_begin_end_step_percent(self) -> Self: def invoke(self, context: InvocationContext) -> IPAdapterOutput: # Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model. ip_adapter_info = context.models.get_config(self.ip_adapter_model.key) - assert isinstance(ip_adapter_info, (IPAdapterInvokeAIConfig, IPAdapterCheckpointConfig)) + assert isinstance(ip_adapter_info, IPAdapter_Checkpoint_FLUX_Config) # Note: There is a IPAdapterInvokeAIConfig.image_encoder_model_id field, but it isn't trustworthy. image_encoder_starter_model = CLIP_VISION_MODEL_MAP[self.clip_vision_model] diff --git a/invokeai/app/invocations/flux_model_loader.py b/invokeai/app/invocations/flux_model_loader.py index e5a1966c659..2803db48e02 100644 --- a/invokeai/app/invocations/flux_model_loader.py +++ b/invokeai/app/invocations/flux_model_loader.py @@ -13,9 +13,9 @@ preprocess_t5_encoder_model_identifier, preprocess_t5_tokenizer_model_identifier, ) -from invokeai.backend.flux.util import max_seq_lengths +from invokeai.backend.flux.util import get_flux_max_seq_length from invokeai.backend.model_manager.config import ( - CheckpointConfigBase, + Checkpoint_Config_Base, ) from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType, SubModelType @@ -87,12 +87,12 @@ def invoke(self, context: InvocationContext) -> FluxModelLoaderOutput: t5_encoder = preprocess_t5_encoder_model_identifier(self.t5_encoder_model) transformer_config = context.models.get_config(transformer) - assert isinstance(transformer_config, CheckpointConfigBase) + assert isinstance(transformer_config, Checkpoint_Config_Base) return FluxModelLoaderOutput( transformer=TransformerField(transformer=transformer, loras=[]), clip=CLIPField(tokenizer=tokenizer, text_encoder=clip_encoder, loras=[], skipped_layers=0), t5_encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder, loras=[]), vae=VAEField(vae=vae), - max_seq_len=max_seq_lengths[transformer_config.config_path], + max_seq_len=get_flux_max_seq_length(transformer_config.variant), ) diff --git a/invokeai/app/invocations/ip_adapter.py b/invokeai/app/invocations/ip_adapter.py index 35a98ff6ba0..7c3234bdc71 100644 --- a/invokeai/app/invocations/ip_adapter.py +++ b/invokeai/app/invocations/ip_adapter.py @@ -13,8 +13,8 @@ from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.model_manager.config import ( AnyModelConfig, - IPAdapterCheckpointConfig, - IPAdapterInvokeAIConfig, + IPAdapter_Checkpoint_Config_Base, + IPAdapter_InvokeAI_Config_Base, ) from invokeai.backend.model_manager.starter_models import ( StarterModel, @@ -123,9 +123,9 @@ def validate_begin_end_step_percent(self) -> Self: def invoke(self, context: InvocationContext) -> IPAdapterOutput: # Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model. ip_adapter_info = context.models.get_config(self.ip_adapter_model.key) - assert isinstance(ip_adapter_info, (IPAdapterInvokeAIConfig, IPAdapterCheckpointConfig)) + assert isinstance(ip_adapter_info, (IPAdapter_InvokeAI_Config_Base, IPAdapter_Checkpoint_Config_Base)) - if isinstance(ip_adapter_info, IPAdapterInvokeAIConfig): + if isinstance(ip_adapter_info, IPAdapter_InvokeAI_Config_Base): image_encoder_model_id = ip_adapter_info.image_encoder_model_id image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip() else: diff --git a/invokeai/app/invocations/model.py b/invokeai/app/invocations/model.py index 2d338c677d2..327de6ac700 100644 --- a/invokeai/app/invocations/model.py +++ b/invokeai/app/invocations/model.py @@ -24,8 +24,9 @@ class ModelIdentifierField(BaseModel): name: str = Field(description="The model's name") base: BaseModelType = Field(description="The model's base model type") type: ModelType = Field(description="The model's type") - submodel_type: Optional[SubModelType] = Field( - description="The submodel to load, if this is a main model", default=None + submodel_type: SubModelType | None = Field( + description="The submodel to load, if this is a main model", + default=None, ) @classmethod diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index 454697ea5a1..10a954a5636 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -5,6 +5,7 @@ import re import threading import time +from copy import deepcopy from pathlib import Path from queue import Empty, Queue from shutil import move, rmtree @@ -36,11 +37,10 @@ from invokeai.app.services.model_records.model_records_base import ModelRecordChanges from invokeai.backend.model_manager.config import ( AnyModelConfig, - CheckpointConfigBase, + Checkpoint_Config_Base, InvalidModelConfigException, - ModelConfigBase, + ModelConfigFactory, ) -from invokeai.backend.model_manager.legacy_probe import ModelProbe from invokeai.backend.model_manager.metadata import ( AnyModelRepoMetadata, HuggingFaceMetadataFetch, @@ -370,6 +370,8 @@ def unconditionally_delete(self, key: str) -> None: # noqa D102 model_path = self.app_config.models_path / model.path if model_path.is_file() or model_path.is_symlink(): model_path.unlink() + assert model_path.parent != self.app_config.models_path + os.rmdir(model_path.parent) elif model_path.is_dir(): rmtree(model_path) self.unregister(key) @@ -598,18 +600,11 @@ def _probe(self, model_path: Path, config: Optional[ModelRecordChanges] = None): hash_algo = self._app_config.hashing_algorithm fields = config.model_dump() - # WARNING! - # The legacy probe relies on the implicit order of tests to determine model classification. - # This can lead to regressions between the legacy and new probes. - # Do NOT change the order of `probe` and `classify` without implementing one of the following fixes: - # Short-term fix: `classify` tests `matches` in the same order as the legacy probe. - # Long-term fix: Improve `matches` to be more specific so that only one config matches - # any given model - eliminating ambiguity and removing reliance on order. - # After implementing either of these fixes, remove @pytest.mark.xfail from `test_regression_against_model_probe` - try: - return ModelProbe.probe(model_path=model_path, fields=fields, hash_algo=hash_algo) # type: ignore - except InvalidModelConfigException: - return ModelConfigBase.classify(model_path, hash_algo, **fields) + return ModelConfigFactory.from_model_on_disk( + mod=model_path, + overrides=deepcopy(fields), + hash_algo=hash_algo, + ) def _register( self, model_path: Path, config: Optional[ModelRecordChanges] = None, info: Optional[AnyModelConfig] = None @@ -630,7 +625,7 @@ def _register( info.path = model_path.as_posix() - if isinstance(info, CheckpointConfigBase): + if isinstance(info, Checkpoint_Config_Base) and info.config_path is not None: # Checkpoints have a config file needed for conversion. Same handling as the model weights - if it's in the # invoke-managed legacy config dir, we use a relative path. legacy_config_path = self.app_config.legacy_conf_path / info.config_path diff --git a/invokeai/app/services/model_load/model_load_default.py b/invokeai/app/services/model_load/model_load_default.py index ad4ad97a02c..a8b24aba125 100644 --- a/invokeai/app/services/model_load/model_load_default.py +++ b/invokeai/app/services/model_load/model_load_default.py @@ -2,7 +2,7 @@ """Implementation of model loader service.""" from pathlib import Path -from typing import Callable, Optional, Type +from typing import Any, Callable, Optional from picklescan.scanner import scan_file_path from safetensors.torch import load_file as safetensors_load_file @@ -11,28 +11,22 @@ from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.invoker import Invoker from invokeai.app.services.model_load.model_load_base import ModelLoadServiceBase -from invokeai.backend.model_manager.config import AnyModelConfig -from invokeai.backend.model_manager.load import ( - LoadedModel, - LoadedModelWithoutConfig, - ModelLoaderRegistry, - ModelLoaderRegistryBase, -) -from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache -from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader +from invokeai.backend.model_manager.config import AnyModelConfig, Diffusers_Config_Base +from invokeai.backend.model_manager.load import LoadedModel, LoadedModelWithoutConfig +from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache, get_model_cache_key +from invokeai.backend.model_manager.load.model_util import calc_model_size_by_fs from invokeai.backend.model_manager.taxonomy import AnyModel, SubModelType from invokeai.backend.util.devices import TorchDevice from invokeai.backend.util.logging import InvokeAILogger class ModelLoadService(ModelLoadServiceBase): - """Wrapper around ModelLoaderRegistry.""" + """Model loading service using config-based loading.""" def __init__( self, app_config: InvokeAIAppConfig, ram_cache: ModelCache, - registry: Optional[Type[ModelLoaderRegistryBase]] = ModelLoaderRegistry, ): """Initialize the model load service.""" logger = InvokeAILogger.get_logger(self.__class__.__name__) @@ -40,7 +34,6 @@ def __init__( self._logger = logger self._app_config = app_config self._ram_cache = ram_cache - self._registry = registry def start(self, invoker: Invoker) -> None: self._invoker = invoker @@ -63,18 +56,49 @@ def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubMo if hasattr(self, "_invoker"): self._invoker.services.events.emit_model_load_started(model_config, submodel_type) - implementation, model_config, submodel_type = self._registry.get_implementation(model_config, submodel_type) # type: ignore - loaded_model: LoadedModel = implementation( - app_config=self._app_config, - logger=self._logger, - ram_cache=self._ram_cache, - ).load_model(model_config, submodel_type) + loaded_model = self._load_model_from_config(model_config, submodel_type) if hasattr(self, "_invoker"): self._invoker.services.events.emit_model_load_complete(model_config, submodel_type) return loaded_model + def _load_model_from_config( + self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None + ) -> LoadedModel: + """Load a model using the config's load_model method.""" + model_path = Path(model_config.path) + stats_name = ":".join([model_config.base, model_config.type, model_config.name, (submodel_type or "")]) + + # Check if model is already in cache + try: + cache_record = self._ram_cache.get(key=get_model_cache_key(model_config.key, submodel_type), stats_name=stats_name) + return LoadedModel(config=model_config, cache_record=cache_record, cache=self._ram_cache) + except IndexError: + pass + + # Make room in cache + variant = model_config.repo_variant if isinstance(model_config, Diffusers_Config_Base) else None + model_size = calc_model_size_by_fs( + model_path=model_path, + subfolder=submodel_type.value if submodel_type else None, + variant=variant, + ) + self._ram_cache.make_room(model_size) + + # Load the model using the config's load_model method + raw_model = model_config.load_model(submodel_type) + + # Cache the loaded model + self._ram_cache.put( + get_model_cache_key(model_config.key, submodel_type), + model=raw_model, + ) + + # Retrieve from cache and return + cache_record = self._ram_cache.get(key=get_model_cache_key(model_config.key, submodel_type), stats_name=stats_name) + return LoadedModel(config=model_config, cache_record=cache_record, cache=self._ram_cache) + def load_model_from_path( self, model_path: Path, loader: Optional[Callable[[Path], AnyModel]] = None ) -> LoadedModelWithoutConfig: @@ -107,12 +131,31 @@ def torch_load_file(checkpoint: Path) -> AnyModel: return result def diffusers_load_directory(directory: Path) -> AnyModel: - load_class = GenericDiffusersLoader( - app_config=self._app_config, - logger=self._logger, - ram_cache=self._ram_cache, - convert_cache=self.convert_cache, - ).get_hf_load_class(directory) + from diffusers.configuration_utils import ConfigMixin + + class ConfigLoader(ConfigMixin): + """Subclass of ConfigMixin for loading diffusers configuration files.""" + + @classmethod + def load_config(cls, *args: Any, **kwargs: Any) -> dict[str, Any]: # type: ignore + """Load a diffusers ConfigMixin configuration.""" + cls.config_name = kwargs.pop("config_name") + return super().load_config(*args, **kwargs) # type: ignore + + config = ConfigLoader.load_config(directory, config_name="config.json") + if class_name := config.get("_class_name"): + import sys + + res_type = sys.modules["diffusers"] + load_class = getattr(res_type, class_name) + elif class_name := config.get("architectures"): + import sys + + res_type = sys.modules["transformers"] + load_class = getattr(res_type, class_name[0]) + else: + raise Exception("Unable to determine load class from config.json") + return load_class.from_pretrained(model_path, torch_dtype=TorchDevice.choose_torch_dtype()) loader = loader or ( diff --git a/invokeai/app/services/model_records/model_records_base.py b/invokeai/app/services/model_records/model_records_base.py index 740d548a4a3..48f53175364 100644 --- a/invokeai/app/services/model_records/model_records_base.py +++ b/invokeai/app/services/model_records/model_records_base.py @@ -21,6 +21,7 @@ from invokeai.backend.model_manager.taxonomy import ( BaseModelType, ClipVariantType, + FluxVariantType, ModelFormat, ModelSourceType, ModelType, @@ -90,7 +91,9 @@ class ModelRecordChanges(BaseModelExcludeNull): # Checkpoint-specific changes # TODO(MM2): Should we expose these? Feels footgun-y... - variant: Optional[ModelVariantType | ClipVariantType] = Field(description="The variant of the model.", default=None) + variant: Optional[ModelVariantType | ClipVariantType | FluxVariantType] = Field( + description="The variant of the model.", default=None + ) prediction_type: Optional[SchedulerPredictionType] = Field( description="The prediction type of the model.", default=None ) diff --git a/invokeai/app/services/model_records/model_records_sql.py b/invokeai/app/services/model_records/model_records_sql.py index e3b24a6e626..7fad1761cce 100644 --- a/invokeai/app/services/model_records/model_records_sql.py +++ b/invokeai/app/services/model_records/model_records_sql.py @@ -141,10 +141,25 @@ def update_model(self, key: str, changes: ModelRecordChanges) -> AnyModelConfig: with self._db.transaction() as cursor: record = self.get_model(key) - # Model configs use pydantic's `validate_assignment`, so each change is validated by pydantic. + # The changes may mean the model config class changes. So we need to: + # + # 1. convert the existing record to a dict + # 2. apply the changes to the dict + # 3. create a new model config from the updated dict + # + # This way we ensure that the update does not inadvertently create an invalid model config. + + # 1. convert the existing record to a dict + record_as_dict = record.model_dump() + + # 2. apply the changes to the dict for field_name in changes.model_fields_set: - setattr(record, field_name, getattr(changes, field_name)) + record_as_dict[field_name] = getattr(changes, field_name) + # 3. create a new model config from the updated dict + record = ModelConfigFactory.make_config(record_as_dict) + + # If we get this far, the updated model config is valid, so we can save it to the database. json_serialized = record.model_dump_json() cursor.execute( @@ -277,14 +292,19 @@ def search_by_attr( for row in result: try: model_config = ModelConfigFactory.make_config(json.loads(row[0]), timestamp=row[1]) - except pydantic.ValidationError: + except pydantic.ValidationError as e: # We catch this error so that the app can still run if there are invalid model configs in the database. # One reason that an invalid model config might be in the database is if someone had to rollback from a # newer version of the app that added a new model type. row_data = f"{row[0][:64]}..." if len(row[0]) > 64 else row[0] + try: + name = json.loads(row[0]).get("name", "") + except Exception: + name = "" self._logger.warning( - f"Found an invalid model config in the database. Ignoring this model. ({row_data})" + f"Skipping invalid model config in the database with name {name}. Ignoring this model. ({row_data})" ) + self._logger.warning(f"Validation error: {e}") else: results.append(model_config) diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index 743b6208ead..16aacbb9855 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -21,7 +21,7 @@ from invokeai.app.util.step_callback import diffusion_step_callback from invokeai.backend.model_manager.config import ( AnyModelConfig, - ModelConfigBase, + Config_Base, ) from invokeai.backend.model_manager.load.load_base import LoadedModel, LoadedModelWithoutConfig from invokeai.backend.model_manager.taxonomy import AnyModel, BaseModelType, ModelFormat, ModelType, SubModelType @@ -558,7 +558,7 @@ def get_absolute_path(self, config_or_path: AnyModelConfig | Path | str) -> Path The absolute path to the model. """ - model_path = Path(config_or_path.path) if isinstance(config_or_path, ModelConfigBase) else Path(config_or_path) + model_path = Path(config_or_path.path) if isinstance(config_or_path, Config_Base) else Path(config_or_path) if model_path.is_absolute(): return model_path.resolve() diff --git a/invokeai/app/services/shared/sqlite_migrator/migrations/migration_22.py b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_22.py index c79b58bf2ad..08b0e760686 100644 --- a/invokeai/app/services/shared/sqlite_migrator/migrations/migration_22.py +++ b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_22.py @@ -8,7 +8,7 @@ from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration -from invokeai.backend.model_manager.config import AnyModelConfig, AnyModelConfigValidator +from invokeai.backend.model_manager.config import AnyModelConfigValidator class NormalizeResult(NamedTuple): @@ -30,7 +30,7 @@ def __call__(self, cursor: sqlite3.Cursor) -> None: for model_id, config_json in rows: try: # Get the model config as a pydantic object - config = self._load_model_config(config_json) + config = AnyModelConfigValidator.validate_json(config_json) except ValidationError: # This could happen if the config schema changed in a way that makes old configs invalid. Unlikely # for users, more likely for devs testing out migration paths. @@ -216,11 +216,6 @@ def _prune_empty_directories(self) -> None: self._logger.info("Pruned %d empty directories under %s", len(removed_dirs), self._models_dir) - def _load_model_config(self, config_json: str) -> AnyModelConfig: - # The typing of the validator says it returns Unknown, but it's really a AnyModelConfig. This utility function - # just makes that clear. - return AnyModelConfigValidator.validate_json(config_json) - def build_migration_22(app_config: InvokeAIAppConfig, logger: Logger) -> Migration: """Builds the migration object for migrating from version 21 to version 22. diff --git a/invokeai/app/util/custom_openapi.py b/invokeai/app/util/custom_openapi.py index d6b8f3786f1..2e07622530d 100644 --- a/invokeai/app/util/custom_openapi.py +++ b/invokeai/app/util/custom_openapi.py @@ -12,6 +12,7 @@ from invokeai.app.invocations.model import ModelIdentifierField from invokeai.app.services.events.events_common import EventBase from invokeai.app.services.session_processor.session_processor_common import ProgressImage +from invokeai.backend.model_manager.config import AnyModelConfigValidator from invokeai.backend.util.logging import InvokeAILogger logger = InvokeAILogger.get_logger() @@ -115,6 +116,13 @@ def openapi() -> dict[str, Any]: # additional_schemas[1] is a dict of $defs that we need to add to the top level of the schema move_defs_to_top_level(openapi_schema, additional_schemas[1]) + any_model_config_schema = AnyModelConfigValidator.json_schema( + mode="serialization", + ref_template="#/components/schemas/{model}", + ) + move_defs_to_top_level(openapi_schema, any_model_config_schema) + openapi_schema["components"]["schemas"]["AnyModelConfig"] = any_model_config_schema + if post_transform is not None: openapi_schema = post_transform(openapi_schema) diff --git a/invokeai/backend/flux/controlnet/state_dict_utils.py b/invokeai/backend/flux/controlnet/state_dict_utils.py index aa44e6c10f0..87eae5a96bc 100644 --- a/invokeai/backend/flux/controlnet/state_dict_utils.py +++ b/invokeai/backend/flux/controlnet/state_dict_utils.py @@ -5,7 +5,7 @@ from invokeai.backend.flux.model import FluxParams -def is_state_dict_xlabs_controlnet(sd: Dict[str, Any]) -> bool: +def is_state_dict_xlabs_controlnet(sd: dict[str | int, Any]) -> bool: """Is the state dict for an XLabs ControlNet model? This is intended to be a reasonably high-precision detector, but it is not guaranteed to have perfect precision. @@ -25,7 +25,7 @@ def is_state_dict_xlabs_controlnet(sd: Dict[str, Any]) -> bool: return False -def is_state_dict_instantx_controlnet(sd: Dict[str, Any]) -> bool: +def is_state_dict_instantx_controlnet(sd: dict[str | int, Any]) -> bool: """Is the state dict for an InstantX ControlNet model? This is intended to be a reasonably high-precision detector, but it is not guaranteed to have perfect precision. diff --git a/invokeai/backend/flux/flux_state_dict_utils.py b/invokeai/backend/flux/flux_state_dict_utils.py index 8ffab54c688..c306c88f965 100644 --- a/invokeai/backend/flux/flux_state_dict_utils.py +++ b/invokeai/backend/flux/flux_state_dict_utils.py @@ -1,10 +1,7 @@ -from typing import TYPE_CHECKING +from typing import Any -if TYPE_CHECKING: - from invokeai.backend.model_manager.legacy_probe import CkptType - -def get_flux_in_channels_from_state_dict(state_dict: "CkptType") -> int | None: +def get_flux_in_channels_from_state_dict(state_dict: dict[str | int, Any]) -> int | None: """Gets the in channels from the state dict.""" # "Standard" FLUX models use "img_in.weight", but some community fine tunes use diff --git a/invokeai/backend/flux/ip_adapter/state_dict_utils.py b/invokeai/backend/flux/ip_adapter/state_dict_utils.py index 90f11ff642b..24ac53550f9 100644 --- a/invokeai/backend/flux/ip_adapter/state_dict_utils.py +++ b/invokeai/backend/flux/ip_adapter/state_dict_utils.py @@ -1,11 +1,11 @@ -from typing import Any, Dict +from typing import Any import torch from invokeai.backend.flux.ip_adapter.xlabs_ip_adapter_flux import XlabsIpAdapterParams -def is_state_dict_xlabs_ip_adapter(sd: Dict[str, Any]) -> bool: +def is_state_dict_xlabs_ip_adapter(sd: dict[str | int, Any]) -> bool: """Is the state dict for an XLabs FLUX IP-Adapter model? This is intended to be a reasonably high-precision detector, but it is not guaranteed to have perfect precision. @@ -27,7 +27,7 @@ def is_state_dict_xlabs_ip_adapter(sd: Dict[str, Any]) -> bool: return False -def infer_xlabs_ip_adapter_params_from_state_dict(state_dict: dict[str, torch.Tensor]) -> XlabsIpAdapterParams: +def infer_xlabs_ip_adapter_params_from_state_dict(state_dict: dict[str | int, torch.Tensor]) -> XlabsIpAdapterParams: num_double_blocks = 0 context_dim = 0 hidden_dim = 0 diff --git a/invokeai/backend/flux/redux/flux_redux_state_dict_utils.py b/invokeai/backend/flux/redux/flux_redux_state_dict_utils.py index a5a13b402d3..83e96d38451 100644 --- a/invokeai/backend/flux/redux/flux_redux_state_dict_utils.py +++ b/invokeai/backend/flux/redux/flux_redux_state_dict_utils.py @@ -1,7 +1,7 @@ -from typing import Any, Dict +from typing import Any -def is_state_dict_likely_flux_redux(state_dict: Dict[str, Any]) -> bool: +def is_state_dict_likely_flux_redux(state_dict: dict[str | int, Any]) -> bool: """Checks if the provided state dict is likely a FLUX Redux model.""" expected_keys = {"redux_down.bias", "redux_down.weight", "redux_up.bias", "redux_up.weight"} diff --git a/invokeai/backend/flux/util.py b/invokeai/backend/flux/util.py index 2a5261cb5c6..2cf52b6ec11 100644 --- a/invokeai/backend/flux/util.py +++ b/invokeai/backend/flux/util.py @@ -1,10 +1,11 @@ # Initially pulled from https://github.com/black-forest-labs/flux from dataclasses import dataclass -from typing import Dict, Literal +from typing import Literal from invokeai.backend.flux.model import FluxParams from invokeai.backend.flux.modules.autoencoder import AutoEncoderParams +from invokeai.backend.model_manager.taxonomy import AnyVariant, FluxVariantType @dataclass @@ -41,30 +42,39 @@ class ModelSpec: ] -max_seq_lengths: Dict[str, Literal[256, 512]] = { - "flux-dev": 512, - "flux-dev-fill": 512, - "flux-schnell": 256, +_flux_max_seq_lengths: dict[AnyVariant, Literal[256, 512]] = { + FluxVariantType.Dev: 512, + FluxVariantType.DevFill: 512, + FluxVariantType.Schnell: 256, } -ae_params = { - "flux": AutoEncoderParams( - resolution=256, - in_channels=3, - ch=128, - out_ch=3, - ch_mult=[1, 2, 4, 4], - num_res_blocks=2, - z_channels=16, - scale_factor=0.3611, - shift_factor=0.1159, - ) -} +def get_flux_max_seq_length(variant: AnyVariant): + try: + return _flux_max_seq_lengths[variant] + except KeyError: + raise ValueError(f"Unknown variant for FLUX max seq len: {variant}") + + +_flux_ae_params = AutoEncoderParams( + resolution=256, + in_channels=3, + ch=128, + out_ch=3, + ch_mult=[1, 2, 4, 4], + num_res_blocks=2, + z_channels=16, + scale_factor=0.3611, + shift_factor=0.1159, +) + +def get_flux_ae_params() -> AutoEncoderParams: + return _flux_ae_params -params = { - "flux-dev": FluxParams( + +_flux_transformer_params: dict[AnyVariant, FluxParams] = { + FluxVariantType.Dev: FluxParams( in_channels=64, vec_in_dim=768, context_in_dim=4096, @@ -78,7 +88,7 @@ class ModelSpec: qkv_bias=True, guidance_embed=True, ), - "flux-schnell": FluxParams( + FluxVariantType.Schnell: FluxParams( in_channels=64, vec_in_dim=768, context_in_dim=4096, @@ -92,7 +102,7 @@ class ModelSpec: qkv_bias=True, guidance_embed=False, ), - "flux-dev-fill": FluxParams( + FluxVariantType.DevFill: FluxParams( in_channels=384, out_channels=64, vec_in_dim=768, @@ -108,3 +118,10 @@ class ModelSpec: guidance_embed=True, ), } + + +def get_flux_transformers_params(variant: AnyVariant): + try: + return _flux_transformer_params[variant] + except KeyError: + raise ValueError(f"Unknown variant for FLUX transformer params: {variant}") diff --git a/invokeai/backend/model_manager/__init__.py b/invokeai/backend/model_manager/__init__.py index dca72f170e0..7d2667dae08 100644 --- a/invokeai/backend/model_manager/__init__.py +++ b/invokeai/backend/model_manager/__init__.py @@ -3,10 +3,9 @@ from invokeai.backend.model_manager.config import ( AnyModelConfig, InvalidModelConfigException, - ModelConfigBase, + Config_Base, ModelConfigFactory, ) -from invokeai.backend.model_manager.legacy_probe import ModelProbe from invokeai.backend.model_manager.load import LoadedModel from invokeai.backend.model_manager.search import ModelSearch from invokeai.backend.model_manager.taxonomy import ( @@ -28,9 +27,8 @@ "InvalidModelConfigException", "LoadedModel", "ModelConfigFactory", - "ModelProbe", "ModelSearch", - "ModelConfigBase", + "Config_Base", "AnyModel", "AnyVariant", "BaseModelType", diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index 83f0c1d2bf5..c7bea5429bc 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -20,30 +20,82 @@ """ -# pyright: reportIncompatibleVariableOverride=false import json import logging +import re import time from abc import ABC, abstractmethod from enum import Enum +from functools import cache from inspect import isabstract from pathlib import Path -from typing import ClassVar, Literal, Optional, Type, TypeAlias, Union +from typing import ( + ClassVar, + Literal, + Optional, + Self, + Type, + Union, +) -from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapter +import accelerate +import torch +from diffusers import AutoencoderKL, ControlNetModel +from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline +from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline +from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import StableDiffusionXLPipeline +from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint import ( + StableDiffusionXLInpaintPipeline, +) +from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapter, ValidationError +from pydantic_core import CoreSchema, PydanticUndefined, SchemaValidator +from safetensors.torch import load_file +from transformers import ( + AutoConfig, + AutoModelForTextEncoding, + CLIPTextModel, + CLIPTokenizer, + CLIPVisionModelWithProjection, + LlavaOnevisionForConditionalGeneration, + SiglipVisionModel, + T5EncoderModel, + T5TokenizerFast, +) from typing_extensions import Annotated, Any, Dict from invokeai.app.services.config.config_default import get_config from invokeai.app.util.misc import uuid_string +from invokeai.backend.flux.controlnet.instantx_controlnet_flux import InstantXControlNetFlux +from invokeai.backend.flux.controlnet.state_dict_utils import ( + convert_diffusers_instantx_state_dict_to_bfl_format, + infer_flux_params_from_state_dict, + infer_instantx_num_control_modes_from_state_dict, + is_state_dict_instantx_controlnet, + is_state_dict_xlabs_controlnet, +) +from invokeai.backend.flux.controlnet.xlabs_controlnet_flux import XLabsControlNetFlux +from invokeai.backend.ip_adapter.ip_adapter import build_ip_adapter +from invokeai.backend.flux.ip_adapter.state_dict_utils import ( + infer_xlabs_ip_adapter_params_from_state_dict, + is_state_dict_xlabs_ip_adapter, +) +from invokeai.backend.flux.ip_adapter.xlabs_ip_adapter_flux import XlabsIpAdapterFlux +from invokeai.backend.flux.modules.autoencoder import AutoEncoder +from invokeai.backend.flux.redux.flux_redux_model import FluxReduxModel +from invokeai.backend.flux.redux.flux_redux_state_dict_utils import is_state_dict_likely_flux_redux +from invokeai.backend.flux.util import get_flux_ae_params, get_flux_transformers_params +from invokeai.backend.flux.model import Flux from invokeai.backend.model_hash.hash_validator import validate_hash from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS from invokeai.backend.model_manager.model_on_disk import ModelOnDisk from invokeai.backend.model_manager.omi import flux_dev_1_lora, stable_diffusion_xl_1_lora from invokeai.backend.model_manager.taxonomy import ( + AnyModel, AnyVariant, BaseModelType, ClipVariantType, FluxLoRAFormat, + FluxVariantType, ModelFormat, ModelRepoVariant, ModelSourceType, @@ -51,9 +103,47 @@ ModelVariantType, SchedulerPredictionType, SubModelType, + variant_type_adapter, +) +from invokeai.backend.model_manager.omi.omi import convert_from_omi +from invokeai.backend.model_manager.util.model_util import lora_token_vector_length, convert_bundle_to_flux_transformer_checkpoint +from invokeai.backend.patches.lora_conversions.flux_aitoolkit_lora_conversion_utils import ( + is_state_dict_likely_in_flux_aitoolkit_format, + lora_model_from_flux_aitoolkit_state_dict, +) +from invokeai.backend.patches.lora_conversions.flux_control_lora_utils import ( + is_state_dict_likely_flux_control, + lora_model_from_flux_control_state_dict, +) +from invokeai.backend.patches.lora_conversions.flux_diffusers_lora_conversion_utils import ( + lora_model_from_flux_diffusers_state_dict, ) -from invokeai.backend.model_manager.util.model_util import lora_token_vector_length +from invokeai.backend.patches.lora_conversions.flux_kohya_lora_conversion_utils import ( + is_state_dict_likely_in_flux_kohya_format, + lora_model_from_flux_kohya_state_dict, +) +from invokeai.backend.patches.lora_conversions.flux_onetrainer_lora_conversion_utils import ( + is_state_dict_likely_in_flux_onetrainer_format, + lora_model_from_flux_onetrainer_state_dict, +) +from invokeai.backend.patches.lora_conversions.sd_lora_conversion_utils import lora_model_from_sd_state_dict +from invokeai.backend.patches.lora_conversions.sdxl_lora_conversion_utils import convert_sdxl_keys_to_diffusers_format +from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor +from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader +from invokeai.backend.quantization.gguf.utils import TORCH_COMPATIBLE_QTYPES +from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES +from invokeai.backend.textual_inversion import TextualInversionModelRaw +from invokeai.backend.util.devices import TorchDevice +from invokeai.backend.util.silence_warnings import SilenceWarnings + +try: + from invokeai.backend.quantization.bnb_llm_int8 import quantize_model_llm_int8 + from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4 + + bnb_available = True +except ImportError: + bnb_available = False logger = logging.getLogger(__name__) app_config = get_config() @@ -65,13 +155,285 @@ class InvalidModelConfigException(Exception): pass +class NotAMatch(Exception): + """Exception for when a model does not match a config class. + + Args: + config_class: The config class that was being tested. + reason: The reason why the model did not match. + """ + + def __init__( + self, + config_class: type, + reason: str, + ): + super().__init__(f"{config_class.__name__}: {reason}") + + DEFAULTS_PRECISION = Literal["fp16", "fp32"] +class FieldValidator: + """Utility class for validating individual fields of a Pydantic model without instantiating the whole model. + + See: https://github.com/pydantic/pydantic/discussions/7367#discussioncomment-14213144 + """ + + @staticmethod + def find_field_schema(model: type[BaseModel], field_name: str) -> CoreSchema: + """Find the Pydantic core schema for a specific field in a model.""" + schema: CoreSchema = model.__pydantic_core_schema__.copy() + # we shallow copied, be careful not to mutate the original schema! + + assert schema["type"] in ["definitions", "model"] + + # find the field schema + field_schema = schema["schema"] # type: ignore + while "fields" not in field_schema: + field_schema = field_schema["schema"] # type: ignore + + field_schema = field_schema["fields"][field_name]["schema"] # type: ignore + + # if the original schema is a definition schema, replace the model schema with the field schema + if schema["type"] == "definitions": + schema["schema"] = field_schema + return schema + else: + return field_schema + + @cache + @staticmethod + def get_validator(model: type[BaseModel], field_name: str) -> SchemaValidator: + """Get a SchemaValidator for a specific field in a model.""" + return SchemaValidator(FieldValidator.find_field_schema(model, field_name)) + + @staticmethod + def validate_field(model: type[BaseModel], field_name: str, value: Any) -> Any: + """Validate a value for a specific field in a model.""" + return FieldValidator.get_validator(model, field_name).validate_python(value) + + +def has_any_keys(state_dict: dict[str | int, Any], keys: str | set[str]) -> bool: + """Returns true if the state dict has any of the specified keys.""" + _keys = {keys} if isinstance(keys, str) else keys + return any(key in state_dict for key in _keys) + + +def has_any_keys_starting_with(state_dict: dict[str | int, Any], prefixes: str | set[str]) -> bool: + """Returns true if the state dict has any keys starting with any of the specified prefixes.""" + _prefixes = {prefixes} if isinstance(prefixes, str) else prefixes + return any(any(key.startswith(prefix) for prefix in _prefixes) for key in state_dict.keys() if isinstance(key, str)) + + +def has_any_keys_ending_with(state_dict: dict[str | int, Any], suffixes: str | set[str]) -> bool: + """Returns true if the state dict has any keys ending with any of the specified suffixes.""" + _suffixes = {suffixes} if isinstance(suffixes, str) else suffixes + return any(any(key.endswith(suffix) for suffix in _suffixes) for key in state_dict.keys() if isinstance(key, str)) + + +def common_config_paths(path: Path) -> set[Path]: + """Returns common config file paths for models stored in directories.""" + return {path / "config.json", path / "model_index.json"} + + +def _hf_definition_to_type(module: str, class_name: str) -> type: + """Convert a HuggingFace module and class name to a Python type. + + Args: + module: The module name (e.g. 'diffusers', 'transformers') + class_name: The class name (e.g. 'T2IAdapter') + + Returns: + The Python class type. + """ + import sys + + if module in [ + "diffusers", + "transformers", + "invokeai.backend.quantization.fast_quantized_transformers_model", + "invokeai.backend.quantization.fast_quantized_diffusion_model", + ]: + res_type = sys.modules[module] + else: + import diffusers + res_type = diffusers.pipelines + result: type = getattr(res_type, class_name) + return result + + +def _get_hf_load_class_from_config(config: dict[str, Any]) -> type: + """Get the HuggingFace model class to use for loading from a config dict. + + Args: + config: The config dictionary loaded from config.json + + Returns: + The model class to use for loading. + + Raises: + InvalidModelConfigException: If unable to determine the load class. + """ + if class_name := config.get("_class_name"): + return _hf_definition_to_type(module="diffusers", class_name=class_name) + elif class_name := config.get("architectures"): + return _hf_definition_to_type(module="transformers", class_name=class_name[0]) + else: + raise InvalidModelConfigException("Unable to decipher Load Class based on given config.json") + + +def _get_sd_checkpoint_pipeline_class(base: BaseModelType, variant: ModelVariantType) -> type: + """Get the appropriate pipeline class for SD checkpoint loading.""" + load_classes = { + BaseModelType.StableDiffusion1: { + ModelVariantType.Normal: StableDiffusionPipeline, + ModelVariantType.Inpaint: StableDiffusionInpaintPipeline, + }, + BaseModelType.StableDiffusion2: { + ModelVariantType.Normal: StableDiffusionPipeline, + ModelVariantType.Inpaint: StableDiffusionInpaintPipeline, + }, + BaseModelType.StableDiffusionXL: { + ModelVariantType.Normal: StableDiffusionXLPipeline, + ModelVariantType.Inpaint: StableDiffusionXLInpaintPipeline, + }, + BaseModelType.StableDiffusionXLRefiner: { + ModelVariantType.Normal: StableDiffusionXLPipeline, + }, + } + try: + return load_classes[base][variant] + except KeyError as e: + raise ValueError(f"No diffusers pipeline known for base={base}, variant={variant}") from e + + +def _cache_sd_submodels(config_key: str, pipeline: Any, submodel_type: SubModelType) -> None: + """Proactively cache all submodels from an SD pipeline except the one being returned.""" + from invokeai.backend.model_manager.load.model_cache.model_cache import get_model_cache_key + + for subtype in SubModelType: + if subtype == submodel_type: + continue + if submodel := getattr(pipeline, subtype.value, None): + app_config.ram_cache.put(get_model_cache_key(config_key, subtype), model=submodel) + + +# These utility functions are tightly coupled to the config classes below in order to make the process of raising +# NotAMatch exceptions as easy and consistent as possible. + + +def _get_config_or_raise( + config_class: type, + config_path: Path | set[Path], +) -> dict[str, Any]: + """Load the config file at the given path, or raise NotAMatch if it cannot be loaded.""" + paths_to_check = config_path if isinstance(config_path, set) else {config_path} + + problems: dict[Path, str] = {} + + for p in paths_to_check: + if not p.exists(): + problems[p] = "file does not exist" + continue + + try: + with open(p, "r") as file: + config = json.load(file) + + return config + except Exception as e: + problems[p] = str(e) + continue + + raise NotAMatch(config_class, f"unable to load config file(s): {problems}") + + +def _get_class_name_from_config( + config_class: type, + config_path: Path | set[Path], +) -> str: + """Load the config file and return the class name. + + Raises: + NotAMatch if the config file is missing or does not contain a valid class name. + """ + + config = _get_config_or_raise(config_class, config_path) + + try: + if "_class_name" in config: + config_class_name = config["_class_name"] + elif "architectures" in config: + config_class_name = config["architectures"][0] + else: + raise ValueError("missing _class_name or architectures field") + except Exception as e: + raise NotAMatch(config_class, f"unable to determine class name from config file: {config_path}") from e + + if not isinstance(config_class_name, str): + raise NotAMatch(config_class, f"_class_name or architectures field is not a string: {config_class_name}") + + return config_class_name + + +def _validate_class_name(config_class: type[BaseModel], config_path: Path | set[Path], expected: set[str]) -> None: + """Check if the class name in the config file matches the expected class names. + + Args: + config_class: The config class that is being tested. + config_path: The path to the config file. + expected: The expected class names.""" + + class_name = _get_class_name_from_config(config_class, config_path) + if class_name not in expected: + raise NotAMatch(config_class, f"invalid class name from config: {class_name}") + + +def _validate_override_fields( + config_class: type[BaseModel], + override_fields: dict[str, Any], +) -> None: + """Check if the provided override fields are valid for the config class. + + Args: + config_class: The config class that is being tested. + override_fields: The override fields provided by the user. + + Raises: + NotAMatch if any override field is invalid for the config. + """ + for field_name, override_value in override_fields.items(): + if field_name not in config_class.model_fields: + raise NotAMatch(config_class, f"unknown override field: {field_name}") + try: + FieldValidator.validate_field(config_class, field_name, override_value) + except ValidationError as e: + raise NotAMatch(config_class, f"invalid override for field '{field_name}': {e}") from e + + +def _validate_is_file( + config_class: type, + mod: ModelOnDisk, +) -> None: + """Raise NotAMatch if the model path is not a file.""" + if not mod.path.is_file(): + raise NotAMatch(config_class, "model path is not a file") + + +def _validate_is_dir( + config_class: type, + mod: ModelOnDisk, +) -> None: + """Raise NotAMatch if the model path is not a directory.""" + if not mod.path.is_dir(): + raise NotAMatch(config_class, "model path is not a directory") + + class SubmodelDefinition(BaseModel): path_or_prefix: str model_type: ModelType - variant: AnyVariant = None + variant: AnyVariant | None = None model_config = ConfigDict(protected_namespaces=()) @@ -103,692 +465,3440 @@ class ControlAdapterDefaultSettings(BaseModel): model_config = ConfigDict(extra="forbid") -class MatchSpeed(int, Enum): - """Represents the estimated runtime speed of a config's 'matches' method.""" - - FAST = 0 - MED = 1 - SLOW = 2 - - class LegacyProbeMixin: """Mixin for classes using the legacy probe for model classification.""" - @classmethod - def matches(cls, *args, **kwargs): - raise NotImplementedError(f"Method 'matches' not implemented for {cls.__name__}") - - @classmethod - def parse(cls, *args, **kwargs): - raise NotImplementedError(f"Method 'parse' not implemented for {cls.__name__}") + pass -class ModelConfigBase(ABC, BaseModel): +class Config_Base(ABC, BaseModel): """ - Abstract Base class for model configurations. + Abstract base class for model configurations. A model config describes a specific combination of model base, type and + format, along with other metadata about the model. For example, a Stable Diffusion 1.x main model in checkpoint format + would have base=sd-1, type=main, format=checkpoint. To create a new config type, inherit from this class and implement its interface: - - (mandatory) override methods 'matches' and 'parse' - - (mandatory) define fields 'type' and 'format' as class attributes + - Define method 'from_model_on_disk' that returns an instance of the class or raises NotAMatch. This method will be + called during model installation to determine the correct config class for a model. + - Define fields 'type', 'base' and 'format' as pydantic fields. These should be Literals with a single value. A + default must be provided for each of these fields. - - (optional) override method 'get_tag' - - (optional) override field _MATCH_SPEED + If multiple combinations of base, type and format need to be supported, create a separate subclass for each. See MinimalConfigExample in test_model_probe.py for an example implementation. """ - @staticmethod - def json_schema_extra(schema: dict[str, Any]) -> None: - schema["required"].extend(["key", "base", "type", "format"]) - - model_config = ConfigDict(validate_assignment=True, json_schema_extra=json_schema_extra) - - key: str = Field(description="A unique key for this model.", default_factory=uuid_string) - hash: str = Field(description="The hash of the model file(s).") + key: str = Field( + description="A unique key for this model.", + default_factory=uuid_string, + ) + hash: str = Field( + description="The hash of the model file(s).", + ) path: str = Field( - description="Path to the model on the filesystem. Relative paths are relative to the Invoke root directory." + description="Path to the model on the filesystem. Relative paths are relative to the Invoke root directory.", + ) + file_size: int = Field( + description="The size of the model in bytes.", + ) + name: str = Field( + description="Name of the model.", + ) + description: str | None = Field( + description="Model description", + default=None, ) - file_size: int = Field(description="The size of the model in bytes.") - name: str = Field(description="Name of the model.") - type: ModelType = Field(description="Model type") - format: ModelFormat = Field(description="Model format") - base: BaseModelType = Field(description="The base model.") - source: str = Field(description="The original source of the model (path, URL or repo_id).") - source_type: ModelSourceType = Field(description="The type of source") - - description: Optional[str] = Field(description="Model description", default=None) - source_api_response: Optional[str] = Field( - description="The original API response from the source, as stringified JSON.", default=None + source: str = Field( + description="The original source of the model (path, URL or repo_id).", ) - cover_image: Optional[str] = Field(description="Url for image to preview model", default=None) - submodels: Optional[Dict[SubModelType, SubmodelDefinition]] = Field( - description="Loadable submodels in this model", default=None + source_type: ModelSourceType = Field( + description="The type of source", + ) + source_api_response: str | None = Field( + description="The original API response from the source, as stringified JSON.", + default=None, + ) + cover_image: str | None = Field( + description="Url for image to preview model", + default=None, + ) + submodels: dict[SubModelType, SubmodelDefinition] | None = Field( + description="Loadable submodels in this model", + default=None, + ) + usage_info: str | None = Field( + default=None, + description="Usage information for this model", ) - usage_info: Optional[str] = Field(default=None, description="Usage information for this model") - USING_LEGACY_PROBE: ClassVar[set[Type["ModelConfigBase"]]] = set() - USING_CLASSIFY_API: ClassVar[set[Type["ModelConfigBase"]]] = set() - _MATCH_SPEED: ClassVar[MatchSpeed] = MatchSpeed.MED + CONFIG_CLASSES: ClassVar[set[Type["AnyModelConfig"]]] = set() + model_config = ConfigDict( + validate_assignment=True, + json_schema_serialization_defaults_required=True, + json_schema_mode_override="serialization", + ) + + @classmethod def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) - if issubclass(cls, LegacyProbeMixin): - ModelConfigBase.USING_LEGACY_PROBE.add(cls) - # Cannot use `elif isinstance(cls, UnknownModelConfig)` because UnknownModelConfig is not defined yet - else: - ModelConfigBase.USING_CLASSIFY_API.add(cls) + # Register non-abstract subclasses so we can iterate over them later during model probing. + if not isabstract(cls): + cls.CONFIG_CLASSES.add(cls) - @staticmethod - def all_config_classes(): - subclasses = ModelConfigBase.USING_LEGACY_PROBE | ModelConfigBase.USING_CLASSIFY_API - concrete = {cls for cls in subclasses if not isabstract(cls)} - return concrete + @classmethod + def __pydantic_init_subclass__(cls, **kwargs): + # Ensure that subclasses define 'base', 'type' and 'format' fields and provide defaults for them. Each subclass + # is expected to represent a single combination of base, type and format. + for name in ("type", "base", "format"): + assert name in cls.model_fields, f"{cls.__name__} must define a '{name}' field" + assert cls.model_fields[name].default is not PydanticUndefined, ( + f"{cls.__name__} must define a default for the '{name}' field" + ) + + @classmethod + def get_tag(cls) -> Tag: + """Constructs a pydantic discriminated union tag for this model config class. When a config is deserialized, + pydantic uses the tag to determine which subclass to instantiate. + + The tag is a dot-separated string of the type, format, base and variant (if applicable). + """ + tag_strings: list[str] = [] + for name in ("type", "format", "base", "variant"): + if field := cls.model_fields.get(name): + if field.default is not PydanticUndefined: + # We expect each of these fields has an Enum for its default; we want the value of the enum. + tag_strings.append(field.default.value) + return Tag(".".join(tag_strings)) @staticmethod - def classify( - mod: str | Path | ModelOnDisk, hash_algo: HASHING_ALGORITHMS = "blake3_single", **overrides - ) -> "AnyModelConfig": + def get_model_discriminator_value(v: Any) -> str: """ - Returns the best matching ModelConfig instance from a model's file/folder path. - Raises InvalidModelConfigException if no valid configuration is found. - Created to deprecate ModelProbe.probe + Computes the discriminator value for a model config. + https://docs.pydantic.dev/latest/concepts/unions/#discriminated-unions-with-callable-discriminator """ - if isinstance(mod, Path | str): - mod = ModelOnDisk(mod, hash_algo) - - candidates = ModelConfigBase.USING_CLASSIFY_API - sorted_by_match_speed = sorted(candidates, key=lambda cls: (cls._MATCH_SPEED, cls.__name__)) + if isinstance(v, Config_Base): + # We have an instance of a ModelConfigBase subclass - use its tag directly. + return v.get_tag().tag + if isinstance(v, dict): + # We have a dict - compute the tag from its fields. + tag_strings: list[str] = [] + if type_ := v.get("type"): + if isinstance(type_, Enum): + type_ = type_.value + tag_strings.append(type_) + + if format_ := v.get("format"): + if isinstance(format_, Enum): + format_ = format_.value + tag_strings.append(format_) + + if base_ := v.get("base"): + if isinstance(base_, Enum): + base_ = base_.value + tag_strings.append(base_) + + # Special case: CLIP Embed models also need the variant to distinguish them. + if ( + type_ == ModelType.CLIPEmbed.value + and format_ == ModelFormat.Diffusers.value + and base_ == BaseModelType.Any.value + ): + if variant_value := v.get("variant"): + if isinstance(variant_value, Enum): + variant_value = variant_value.value + tag_strings.append(variant_value) + else: + raise ValueError("CLIP Embed model config dict must include a 'variant' field") + + return ".".join(tag_strings) + else: + raise TypeError("Model config discriminator value must be computed from a dict or ModelConfigBase instance") - for config_cls in sorted_by_match_speed: - try: - if not config_cls.matches(mod): - continue - except Exception as e: - logger.warning(f"Unexpected exception while matching {mod.name} to '{config_cls.__name__}': {e}") - continue - else: - return config_cls.from_model_on_disk(mod, **overrides) + @abstractmethod + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: + """Given the model on disk and any overrides, return an instance of this config class. - if app_config.allow_unknown_models: - try: - return UnknownModelConfig.from_model_on_disk(mod, **overrides) - except Exception: - # Fall through to raising the exception below - pass + Implementations should raise NotAMatch if the model does not match this config class.""" + pass - raise InvalidModelConfigException("Unable to determine model type") + @abstractmethod + def load_model(self, submodel_type: Optional[SubModelType] = None) -> AnyModel: + """Load the model described by this config. - @classmethod - def get_tag(cls) -> Tag: - type = cls.model_fields["type"].default.value - format = cls.model_fields["format"].default.value - return Tag(f"{type}.{format}") + Args: + submodel_type: If specified, load the specified submodel instead of the main model. - @classmethod - @abstractmethod - def parse(cls, mod: ModelOnDisk) -> dict[str, Any]: - """Returns a dictionary with the fields needed to construct the model. - Raises InvalidModelConfigException if the model is invalid. + Returns: + The loaded model object. """ pass + +class Unknown_Config(Config_Base): + """Model config for unknown models, used as a fallback when we cannot identify a model.""" + + base: Literal[BaseModelType.Unknown] = Field(default=BaseModelType.Unknown) + type: Literal[ModelType.Unknown] = Field(default=ModelType.Unknown) + format: Literal[ModelFormat.Unknown] = Field(default=ModelFormat.Unknown) + @classmethod - @abstractmethod - def matches(cls, mod: ModelOnDisk) -> bool: - """Performs a quick check to determine if the config matches the model. - This doesn't need to be a perfect test - the aim is to eliminate unlikely matches quickly before parsing.""" - pass + def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: + raise NotAMatch(cls, "unknown model config cannot match any model") - @staticmethod - def cast_overrides(overrides: dict[str, Any]): - """Casts user overrides from str to Enum""" - if "type" in overrides: - overrides["type"] = ModelType(overrides["type"]) + def load_model(self, submodel_type: Optional[SubModelType] = None) -> AnyModel: + raise InvalidModelConfigException("Cannot load model with unknown config") + + +class Checkpoint_Config_Base(ABC, BaseModel): + """Base class for checkpoint-style models.""" - if "format" in overrides: - overrides["format"] = ModelFormat(overrides["format"]) + config_path: str | None = Field( + description="Path to the config for this model, if any.", + default=None, + ) + converted_at: float | None = Field( + description="When this model was last converted to diffusers", + default_factory=time.time, + ) - if "base" in overrides: - overrides["base"] = BaseModelType(overrides["base"]) - if "source_type" in overrides: - overrides["source_type"] = ModelSourceType(overrides["source_type"]) +class Diffusers_Config_Base(ABC, BaseModel): + """Base class for diffusers-style models.""" - if "variant" in overrides: - overrides["variant"] = ModelVariantType(overrides["variant"]) + format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers) + repo_variant: Optional[ModelRepoVariant] = Field(ModelRepoVariant.Default) @classmethod - def from_model_on_disk(cls, mod: ModelOnDisk, **overrides): - """Creates an instance of this config or raises InvalidModelConfigException.""" - fields = cls.parse(mod) - cls.cast_overrides(overrides) - fields.update(overrides) + def _get_repo_variant_or_raise(cls, mod: ModelOnDisk) -> ModelRepoVariant: + # get all files ending in .bin or .safetensors + weight_files = list(mod.path.glob("**/*.safetensors")) + weight_files.extend(list(mod.path.glob("**/*.bin"))) + for x in weight_files: + if ".fp16" in x.suffixes: + return ModelRepoVariant.FP16 + if "openvino_model" in x.name: + return ModelRepoVariant.OpenVINO + if "flax_model" in x.name: + return ModelRepoVariant.Flax + if x.suffix == ".onnx": + return ModelRepoVariant.ONNX + return ModelRepoVariant.Default + + +class T5Encoder_T5Encoder_Config(Config_Base): + base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any) + type: Literal[ModelType.T5Encoder] = Field(default=ModelType.T5Encoder) + format: Literal[ModelFormat.T5Encoder] = Field(default=ModelFormat.T5Encoder) - fields["path"] = mod.path.as_posix() - fields["source"] = fields.get("source") or fields["path"] - fields["source_type"] = fields.get("source_type") or ModelSourceType.Path - fields["name"] = fields.get("name") or mod.name - fields["hash"] = fields.get("hash") or mod.hash() - fields["key"] = fields.get("key") or uuid_string() - fields["description"] = fields.get("description") - fields["repo_variant"] = fields.get("repo_variant") or mod.repo_variant() - fields["file_size"] = fields.get("file_size") or mod.size() + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: + _validate_is_dir(cls, mod) - return cls(**fields) + _validate_override_fields(cls, fields) + + _validate_class_name( + cls, + common_config_paths(mod.path), + { + "T5EncoderModel", + }, + ) + cls._validate_has_unquantized_config_file(mod) -class UnknownModelConfig(ModelConfigBase): - base: Literal[BaseModelType.Unknown] = BaseModelType.Unknown - type: Literal[ModelType.Unknown] = ModelType.Unknown - format: Literal[ModelFormat.Unknown] = ModelFormat.Unknown + return cls(**fields) @classmethod - def matches(cls, mod: ModelOnDisk) -> bool: - return False + def _validate_has_unquantized_config_file(cls, mod: ModelOnDisk) -> None: + has_unquantized_config = (mod.path / "text_encoder_2" / "model.safetensors.index.json").exists() + + if not has_unquantized_config: + raise NotAMatch(cls, "missing text_encoder_2/model.safetensors.index.json") + + def load_model(self, submodel_type: Optional[SubModelType] = None) -> AnyModel: + match submodel_type: + case SubModelType.Tokenizer2 | SubModelType.Tokenizer3: + return T5TokenizerFast.from_pretrained(Path(self.path) / "tokenizer_2", max_length=512) + case SubModelType.TextEncoder2 | SubModelType.TextEncoder3: + return T5EncoderModel.from_pretrained( + Path(self.path) / "text_encoder_2", torch_dtype="auto", low_cpu_mem_usage=True + ) + case _: + raise ValueError( + f"Only Tokenizer2, Tokenizer3, TextEncoder2, and TextEncoder3 submodels are supported. Received: {submodel_type.value if submodel_type else 'None'}" + ) + + +class T5Encoder_BnBLLMint8_Config(Config_Base): + base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any) + type: Literal[ModelType.T5Encoder] = Field(default=ModelType.T5Encoder) + format: Literal[ModelFormat.BnbQuantizedLlmInt8b] = Field(default=ModelFormat.BnbQuantizedLlmInt8b) @classmethod - def parse(cls, mod: ModelOnDisk) -> dict[str, Any]: - return {} + def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: + _validate_is_dir(cls, mod) + _validate_override_fields(cls, fields) -class CheckpointConfigBase(ABC, BaseModel): - """Base class for checkpoint-style models.""" + _validate_class_name( + cls, + common_config_paths(mod.path), + { + "T5EncoderModel", + }, + ) - format: Literal[ModelFormat.Checkpoint, ModelFormat.BnbQuantizednf4b, ModelFormat.GGUFQuantized] = Field( - description="Format of the provided checkpoint model", default=ModelFormat.Checkpoint - ) - config_path: str = Field(description="path to the checkpoint model config file") - converted_at: Optional[float] = Field( - description="When this model was last converted to diffusers", default_factory=time.time - ) + cls._validate_filename_looks_like_bnb_quantized(mod) + cls._validate_model_looks_like_bnb_quantized(mod) -class DiffusersConfigBase(ABC, BaseModel): - """Base class for diffusers-style models.""" + return cls(**fields) + + @classmethod + def _validate_filename_looks_like_bnb_quantized(cls, mod: ModelOnDisk) -> None: + filename_looks_like_bnb = any(x for x in mod.weight_files() if "llm_int8" in x.as_posix()) + if not filename_looks_like_bnb: + raise NotAMatch(cls, "filename does not look like bnb quantized llm_int8") - format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers - repo_variant: Optional[ModelRepoVariant] = ModelRepoVariant.Default + @classmethod + def _validate_model_looks_like_bnb_quantized(cls, mod: ModelOnDisk) -> None: + has_scb_key_suffix = has_any_keys_ending_with(mod.load_state_dict(), "SCB") + if not has_scb_key_suffix: + raise NotAMatch(cls, "state dict does not look like bnb quantized llm_int8") + + def load_model(self, submodel_type: Optional[SubModelType] = None) -> AnyModel: + if not bnb_available: + raise ImportError( + "The bnb modules are not available. Please install bitsandbytes if available on your platform." + ) + match submodel_type: + case SubModelType.Tokenizer2 | SubModelType.Tokenizer3: + return T5TokenizerFast.from_pretrained(Path(self.path) / "tokenizer_2", max_length=512) + case SubModelType.TextEncoder2 | SubModelType.TextEncoder3: + te2_model_path = Path(self.path) / "text_encoder_2" + model_config = AutoConfig.from_pretrained(te2_model_path) + with accelerate.init_empty_weights(): + model = AutoModelForTextEncoding.from_config(model_config) + model = quantize_model_llm_int8(model, modules_to_not_convert=set()) + + state_dict_path = te2_model_path / "bnb_llm_int8_model.safetensors" + state_dict = load_file(state_dict_path) + self._load_state_dict_into_t5(model, state_dict) + + return model + case _: + raise ValueError( + f"Only Tokenizer2, Tokenizer3, TextEncoder2, and TextEncoder3 submodels are supported. Received: {submodel_type.value if submodel_type else 'None'}" + ) -class LoRAConfigBase(ABC, BaseModel): + @staticmethod + def _load_state_dict_into_t5(model: T5EncoderModel, state_dict: dict[str, torch.Tensor]) -> None: + # There is a shared reference to a single weight tensor in the model. + # Both "encoder.embed_tokens.weight" and "shared.weight" refer to the same tensor, so only the latter should + # be present in the state_dict. + missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False, assign=True) + assert len(unexpected_keys) == 0 + assert set(missing_keys) == {"encoder.embed_tokens.weight"} + # Assert that the layers we expect to be shared are actually shared. + assert model.encoder.embed_tokens.weight is model.shared.weight + + +class LoRA_Config_Base(ABC, BaseModel): """Base class for LoRA models.""" - type: Literal[ModelType.LoRA] = ModelType.LoRA + type: Literal[ModelType.LoRA] = Field(default=ModelType.LoRA) trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None) default_settings: Optional[LoraModelDefaultSettings] = Field( description="Default settings for this model", default=None ) - @classmethod - def flux_lora_format(cls, mod: ModelOnDisk): - key = "FLUX_LORA_FORMAT" - if key in mod.cache: - return mod.cache[key] - - from invokeai.backend.patches.lora_conversions.formats import flux_format_from_state_dict - - sd = mod.load_state_dict(mod.path) - value = flux_format_from_state_dict(sd, mod.metadata()) - mod.cache[key] = value - return value - - @classmethod - def base_model(cls, mod: ModelOnDisk) -> BaseModelType: - if cls.flux_lora_format(mod): - return BaseModelType.Flux - state_dict = mod.load_state_dict() - # If we've gotten here, we assume that the model is a Stable Diffusion model - token_vector_length = lora_token_vector_length(state_dict) - if token_vector_length == 768: - return BaseModelType.StableDiffusion1 - elif token_vector_length == 1024: - return BaseModelType.StableDiffusion2 - elif token_vector_length == 1280: - return BaseModelType.StableDiffusionXL # recognizes format at https://civitai.com/models/224641 - elif token_vector_length == 2048: - return BaseModelType.StableDiffusionXL - else: - raise InvalidModelConfigException("Unknown LoRA type") +def _get_flux_lora_format(mod: ModelOnDisk) -> FluxLoRAFormat | None: + # TODO(psyche): Moving this import to the function to avoid circular imports. Refactor later. + from invokeai.backend.patches.lora_conversions.formats import flux_format_from_state_dict + state_dict = mod.load_state_dict(mod.path) + value = flux_format_from_state_dict(state_dict, mod.metadata()) + return value -class T5EncoderConfigBase(ABC, BaseModel): - """Base class for diffusers-style models.""" - type: Literal[ModelType.T5Encoder] = ModelType.T5Encoder +class LoRA_OMI_Config_Base(LoRA_Config_Base): + format: Literal[ModelFormat.OMI] = Field(default=ModelFormat.OMI) + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: + _validate_is_file(cls, mod) -class T5EncoderConfig(T5EncoderConfigBase, LegacyProbeMixin, ModelConfigBase): - format: Literal[ModelFormat.T5Encoder] = ModelFormat.T5Encoder + _validate_override_fields(cls, fields) + cls._validate_looks_like_omi_lora(mod) -class T5EncoderBnbQuantizedLlmInt8bConfig(T5EncoderConfigBase, LegacyProbeMixin, ModelConfigBase): - format: Literal[ModelFormat.BnbQuantizedLlmInt8b] = ModelFormat.BnbQuantizedLlmInt8b + cls._validate_base(mod) + return cls(**fields) -class LoRAOmiConfig(LoRAConfigBase, ModelConfigBase): - format: Literal[ModelFormat.OMI] = ModelFormat.OMI + @classmethod + def _validate_base(cls, mod: ModelOnDisk) -> None: + """Raise `NotAMatch` if the model base does not match this config class.""" + expected_base = cls.model_fields["base"].default + recognized_base = cls._get_base_or_raise(mod) + if expected_base is not recognized_base: + raise NotAMatch(cls, f"base is {recognized_base}, not {expected_base}") @classmethod - def matches(cls, mod: ModelOnDisk) -> bool: - if mod.path.is_dir(): - return False + def _validate_looks_like_omi_lora(cls, mod: ModelOnDisk) -> None: + """Raise `NotAMatch` if the model metadata does not look like an OMI LoRA.""" + flux_format = _get_flux_lora_format(mod) + if flux_format in [FluxLoRAFormat.Control, FluxLoRAFormat.Diffusers]: + raise NotAMatch(cls, "model looks like ControlLoRA or Diffusers LoRA") metadata = mod.metadata() - return ( + + metadata_looks_like_omi_lora = ( bool(metadata.get("modelspec.sai_model_spec")) and metadata.get("ot_branch") == "omi_format" - and metadata["modelspec.architecture"].split("/")[1].lower() == "lora" + and metadata.get("modelspec.architecture", "").split("/")[1].lower() == "lora" ) + if not metadata_looks_like_omi_lora: + raise NotAMatch(cls, "metadata does not look like OMI LoRA") + @classmethod - def parse(cls, mod: ModelOnDisk) -> dict[str, Any]: + def _get_base_or_raise(cls, mod: ModelOnDisk) -> Literal[BaseModelType.Flux, BaseModelType.StableDiffusionXL]: metadata = mod.metadata() architecture = metadata["modelspec.architecture"] if architecture == stable_diffusion_xl_1_lora: - base = BaseModelType.StableDiffusionXL + return BaseModelType.StableDiffusionXL elif architecture == flux_dev_1_lora: - base = BaseModelType.Flux + return BaseModelType.Flux else: - raise InvalidModelConfigException(f"Unrecognised/unsupported architecture for OMI LoRA: {architecture}") + raise NotAMatch(cls, f"unrecognised/unsupported architecture for OMI LoRA: {architecture}") - return {"base": base} +class LoRA_OMI_SDXL_Config(LoRA_OMI_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL) -class LoRALyCORISConfig(LoRAConfigBase, ModelConfigBase): - """Model config for LoRA/Lycoris models.""" + def load_model(self, submodel_type: Optional[SubModelType] = None) -> AnyModel: + if submodel_type is not None: + raise ValueError("LoRA models have no submodels.") - format: Literal[ModelFormat.LyCORIS] = ModelFormat.LyCORIS + model_path = Path(self.path) - @classmethod - def matches(cls, mod: ModelOnDisk) -> bool: - if mod.path.is_dir(): - return False + # Load the state dict from the model file. + if model_path.suffix == ".safetensors": + state_dict = load_file(model_path.absolute().as_posix(), device="cpu") + else: + state_dict = torch.load(model_path, map_location="cpu") - # Avoid false positive match against ControlLoRA and Diffusers - if cls.flux_lora_format(mod) in [FluxLoRAFormat.Control, FluxLoRAFormat.Diffusers]: - return False + # Strip 'bundle_emb' keys - these are unused and currently cause downstream errors. + state_dict = {k: v for k, v in state_dict.items() if not k.startswith("bundle_emb")} - state_dict = mod.load_state_dict() - for key in state_dict.keys(): - if isinstance(key, int): - continue + # Convert from OMI format + state_dict = convert_from_omi(state_dict, self.base) - if key.startswith(("lora_te_", "lora_unet_", "lora_te1_", "lora_te2_", "lora_transformer_")): - return True - # "lora_A.weight" and "lora_B.weight" are associated with models in PEFT format. We don't support all PEFT - # LoRA models, but as of the time of writing, we support Diffusers FLUX PEFT LoRA models. - if key.endswith(("to_k_lora.up.weight", "to_q_lora.down.weight", "lora_A.weight", "lora_B.weight")): - return True + # Apply SDXL-specific key conversions + state_dict = convert_sdxl_keys_to_diffusers_format(state_dict) + model = lora_model_from_sd_state_dict(state_dict=state_dict) - return False + model.to(dtype=TorchDevice.choose_torch_dtype()) + return model - @classmethod - def parse(cls, mod: ModelOnDisk) -> dict[str, Any]: - return { - "base": cls.base_model(mod), - } +class LoRA_OMI_FLUX_Config(LoRA_OMI_Config_Base, Config_Base): + base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux) -class ControlAdapterConfigBase(ABC, BaseModel): - default_settings: Optional[ControlAdapterDefaultSettings] = Field( - description="Default settings for this model", default=None - ) + def load_model(self, submodel_type: Optional[SubModelType] = None) -> AnyModel: + if submodel_type is not None: + raise ValueError("LoRA models have no submodels.") + model_path = Path(self.path) -class ControlLoRALyCORISConfig(ControlAdapterConfigBase, LegacyProbeMixin, ModelConfigBase): - """Model config for Control LoRA models.""" + # Load the state dict from the model file. + if model_path.suffix == ".safetensors": + state_dict = load_file(model_path.absolute().as_posix(), device="cpu") + else: + state_dict = torch.load(model_path, map_location="cpu") - type: Literal[ModelType.ControlLoRa] = ModelType.ControlLoRa - trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None) - format: Literal[ModelFormat.LyCORIS] = ModelFormat.LyCORIS + # Strip 'bundle_emb' keys - these are unused and currently cause downstream errors. + state_dict = {k: v for k, v in state_dict.items() if not k.startswith("bundle_emb")} + # Convert from OMI format + state_dict = convert_from_omi(state_dict, self.base) -class ControlLoRADiffusersConfig(ControlAdapterConfigBase, LegacyProbeMixin, ModelConfigBase): - """Model config for Control LoRA models.""" + # HACK(ryand): We set alpha=None for diffusers PEFT format models. These models are typically + # distributed as a single file without the associated metadata containing the alpha value. We chose + # alpha=None, because this is treated as alpha=rank internally in `LoRALayerBase.scale()`. alpha=rank + # is a popular choice. For example, in the diffusers training scripts: + # https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth_lora_flux.py#L1194 + model = lora_model_from_flux_diffusers_state_dict(state_dict=state_dict, alpha=None) - type: Literal[ModelType.ControlLoRa] = ModelType.ControlLoRa - trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None) - format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers + model.to(dtype=TorchDevice.choose_torch_dtype()) + return model -class LoRADiffusersConfig(LoRAConfigBase, ModelConfigBase): - """Model config for LoRA/Diffusers models.""" +class LoRA_LyCORIS_Config_Base(LoRA_Config_Base): + """Model config for LoRA/Lycoris models.""" - format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers + type: Literal[ModelType.LoRA] = Field(default=ModelType.LoRA) + format: Literal[ModelFormat.LyCORIS] = Field(default=ModelFormat.LyCORIS) @classmethod - def matches(cls, mod: ModelOnDisk) -> bool: - if mod.path.is_file(): - return cls.flux_lora_format(mod) == FluxLoRAFormat.Diffusers + def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: + _validate_is_file(cls, mod) - suffixes = ["bin", "safetensors"] - weight_files = [mod.path / f"pytorch_lora_weights.{sfx}" for sfx in suffixes] - return any(wf.exists() for wf in weight_files) + _validate_override_fields(cls, fields) - @classmethod - def parse(cls, mod: ModelOnDisk) -> dict[str, Any]: - return { - "base": cls.base_model(mod), - } + cls._validate_looks_like_lora(mod) + cls._validate_base(mod) -class VAECheckpointConfig(CheckpointConfigBase, LegacyProbeMixin, ModelConfigBase): - """Model config for standalone VAE models.""" + return cls(**fields) - type: Literal[ModelType.VAE] = ModelType.VAE + @classmethod + def _validate_base(cls, mod: ModelOnDisk) -> None: + """Raise `NotAMatch` if the model base does not match this config class.""" + expected_base = cls.model_fields["base"].default + recognized_base = cls._get_base_or_raise(mod) + if expected_base is not recognized_base: + raise NotAMatch(cls, f"base is {recognized_base}, not {expected_base}") + @classmethod + def _validate_looks_like_lora(cls, mod: ModelOnDisk) -> None: + # First rule out ControlLoRA and Diffusers LoRA + flux_format = _get_flux_lora_format(mod) + if flux_format in [FluxLoRAFormat.Control, FluxLoRAFormat.Diffusers]: + raise NotAMatch(cls, "model looks like ControlLoRA or Diffusers LoRA") + + # Note: Existence of these key prefixes/suffixes does not guarantee that this is a LoRA. + # Some main models have these keys, likely due to the creator merging in a LoRA. + has_key_with_lora_prefix = has_any_keys_starting_with( + mod.load_state_dict(), + { + "lora_te_", + "lora_unet_", + "lora_te1_", + "lora_te2_", + "lora_transformer_", + }, + ) -class VAEDiffusersConfig(LegacyProbeMixin, ModelConfigBase): - """Model config for standalone VAE models (diffusers version).""" + has_key_with_lora_suffix = has_any_keys_ending_with( + mod.load_state_dict(), + { + "to_k_lora.up.weight", + "to_q_lora.down.weight", + "lora_A.weight", + "lora_B.weight", + }, + ) - type: Literal[ModelType.VAE] = ModelType.VAE - format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers + if not has_key_with_lora_prefix and not has_key_with_lora_suffix: + raise NotAMatch(cls, "model does not match LyCORIS LoRA heuristics") + @classmethod + def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType: + if _get_flux_lora_format(mod): + return BaseModelType.Flux -class ControlNetDiffusersConfig(DiffusersConfigBase, ControlAdapterConfigBase, LegacyProbeMixin, ModelConfigBase): - """Model config for ControlNet models (diffusers version).""" + state_dict = mod.load_state_dict() + # If we've gotten here, we assume that the model is a Stable Diffusion model + token_vector_length = lora_token_vector_length(state_dict) + if token_vector_length == 768: + return BaseModelType.StableDiffusion1 + elif token_vector_length == 1024: + return BaseModelType.StableDiffusion2 + elif token_vector_length == 1280: + return BaseModelType.StableDiffusionXL # recognizes format at https://civitai.com/models/224641 + elif token_vector_length == 2048: + return BaseModelType.StableDiffusionXL + else: + raise NotAMatch(cls, f"unrecognized token vector length {token_vector_length}") - type: Literal[ModelType.ControlNet] = ModelType.ControlNet - format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers +class LoRA_LyCORIS_SD1_Config(LoRA_LyCORIS_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1) -class ControlNetCheckpointConfig(CheckpointConfigBase, ControlAdapterConfigBase, LegacyProbeMixin, ModelConfigBase): - """Model config for ControlNet models (diffusers version).""" + def load_model(self, submodel_type: Optional[SubModelType] = None) -> AnyModel: + if submodel_type is not None: + raise ValueError("LoRA models have no submodels.") - type: Literal[ModelType.ControlNet] = ModelType.ControlNet + model_path = Path(self.path) + # Load the state dict from the model file. + if model_path.suffix == ".safetensors": + state_dict = load_file(model_path.absolute().as_posix(), device="cpu") + else: + state_dict = torch.load(model_path, map_location="cpu") -class TextualInversionFileConfig(LegacyProbeMixin, ModelConfigBase): - """Model config for textual inversion embeddings.""" + # Strip 'bundle_emb' keys - these are unused and currently cause downstream errors. + state_dict = {k: v for k, v in state_dict.items() if not k.startswith("bundle_emb")} - type: Literal[ModelType.TextualInversion] = ModelType.TextualInversion - format: Literal[ModelFormat.EmbeddingFile] = ModelFormat.EmbeddingFile + # SD1 models don't need any key conversions + model = lora_model_from_sd_state_dict(state_dict=state_dict) + model.to(dtype=TorchDevice.choose_torch_dtype()) + return model -class TextualInversionFolderConfig(LegacyProbeMixin, ModelConfigBase): - """Model config for textual inversion embeddings.""" - type: Literal[ModelType.TextualInversion] = ModelType.TextualInversion - format: Literal[ModelFormat.EmbeddingFolder] = ModelFormat.EmbeddingFolder +class LoRA_LyCORIS_SD2_Config(LoRA_LyCORIS_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2) + def load_model(self, submodel_type: Optional[SubModelType] = None) -> AnyModel: + if submodel_type is not None: + raise ValueError("LoRA models have no submodels.") -class MainConfigBase(ABC, BaseModel): - type: Literal[ModelType.Main] = ModelType.Main - trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None) - default_settings: Optional[MainModelDefaultSettings] = Field( - description="Default settings for this model", default=None - ) - variant: AnyVariant = ModelVariantType.Normal + model_path = Path(self.path) + # Load the state dict from the model file. + if model_path.suffix == ".safetensors": + state_dict = load_file(model_path.absolute().as_posix(), device="cpu") + else: + state_dict = torch.load(model_path, map_location="cpu") -class VideoConfigBase(ABC, BaseModel): - type: Literal[ModelType.Video] = ModelType.Video - trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None) - default_settings: Optional[MainModelDefaultSettings] = Field( - description="Default settings for this model", default=None - ) - variant: AnyVariant = ModelVariantType.Normal + # Strip 'bundle_emb' keys - these are unused and currently cause downstream errors. + state_dict = {k: v for k, v in state_dict.items() if not k.startswith("bundle_emb")} + # SD2 models don't need any key conversions + model = lora_model_from_sd_state_dict(state_dict=state_dict) -class MainCheckpointConfig(CheckpointConfigBase, MainConfigBase, LegacyProbeMixin, ModelConfigBase): - """Model config for main checkpoint models.""" + model.to(dtype=TorchDevice.choose_torch_dtype()) + return model - prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon - upcast_attention: bool = False +class LoRA_LyCORIS_SDXL_Config(LoRA_LyCORIS_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL) -class MainBnbQuantized4bCheckpointConfig(CheckpointConfigBase, MainConfigBase, LegacyProbeMixin, ModelConfigBase): - """Model config for main checkpoint models.""" + def load_model(self, submodel_type: Optional[SubModelType] = None) -> AnyModel: + if submodel_type is not None: + raise ValueError("LoRA models have no submodels.") - format: Literal[ModelFormat.BnbQuantizednf4b] = ModelFormat.BnbQuantizednf4b - prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon - upcast_attention: bool = False + model_path = Path(self.path) + # Load the state dict from the model file. + if model_path.suffix == ".safetensors": + state_dict = load_file(model_path.absolute().as_posix(), device="cpu") + else: + state_dict = torch.load(model_path, map_location="cpu") -class MainGGUFCheckpointConfig(CheckpointConfigBase, MainConfigBase, LegacyProbeMixin, ModelConfigBase): - """Model config for main checkpoint models.""" + # Strip 'bundle_emb' keys - these are unused and currently cause downstream errors. + state_dict = {k: v for k, v in state_dict.items() if not k.startswith("bundle_emb")} - format: Literal[ModelFormat.GGUFQuantized] = ModelFormat.GGUFQuantized - prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon - upcast_attention: bool = False + # Apply SDXL-specific key conversions + state_dict = convert_sdxl_keys_to_diffusers_format(state_dict) + model = lora_model_from_sd_state_dict(state_dict=state_dict) + model.to(dtype=TorchDevice.choose_torch_dtype()) + return model -class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase, LegacyProbeMixin, ModelConfigBase): - """Model config for main diffusers models.""" - pass +class LoRA_LyCORIS_FLUX_Config(LoRA_LyCORIS_Config_Base, Config_Base): + base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux) + def load_model(self, submodel_type: Optional[SubModelType] = None) -> AnyModel: + if submodel_type is not None: + raise ValueError("LoRA models have no submodels.") -class IPAdapterConfigBase(ABC, BaseModel): - type: Literal[ModelType.IPAdapter] = ModelType.IPAdapter + model_path = Path(self.path) + # Load the state dict from the model file. + if model_path.suffix == ".safetensors": + state_dict = load_file(model_path.absolute().as_posix(), device="cpu") + else: + state_dict = torch.load(model_path, map_location="cpu") + + # Strip 'bundle_emb' keys - these are unused and currently cause downstream errors. + state_dict = {k: v for k, v in state_dict.items() if not k.startswith("bundle_emb")} + + # Detect and convert FLUX LoRA format + if is_state_dict_likely_in_flux_kohya_format(state_dict=state_dict): + model = lora_model_from_flux_kohya_state_dict(state_dict=state_dict) + elif is_state_dict_likely_in_flux_onetrainer_format(state_dict=state_dict): + model = lora_model_from_flux_onetrainer_state_dict(state_dict=state_dict) + elif is_state_dict_likely_flux_control(state_dict=state_dict): + model = lora_model_from_flux_control_state_dict(state_dict=state_dict) + elif is_state_dict_likely_in_flux_aitoolkit_format(state_dict=state_dict): + model = lora_model_from_flux_aitoolkit_state_dict(state_dict=state_dict) + else: + raise ValueError("LoRA model is in unsupported FLUX format") -class IPAdapterInvokeAIConfig(IPAdapterConfigBase, LegacyProbeMixin, ModelConfigBase): - """Model config for IP Adapter diffusers format models.""" + model.to(dtype=TorchDevice.choose_torch_dtype()) + return model - # TODO(ryand): Should we deprecate this field? From what I can tell, it hasn't been probed correctly for a long - # time. Need to go through the history to make sure I'm understanding this fully. - image_encoder_model_id: str - format: Literal[ModelFormat.InvokeAI] = ModelFormat.InvokeAI +class ControlAdapter_Config_Base(ABC, BaseModel): + default_settings: ControlAdapterDefaultSettings | None = Field(None) -class IPAdapterCheckpointConfig(IPAdapterConfigBase, LegacyProbeMixin, ModelConfigBase): - """Model config for IP Adapter checkpoint format models.""" - format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint +class ControlLoRA_LyCORIS_FLUX_Config(ControlAdapter_Config_Base, Config_Base): + """Model config for Control LoRA models.""" + + base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux) + type: Literal[ModelType.ControlLoRa] = Field(default=ModelType.ControlLoRa) + format: Literal[ModelFormat.LyCORIS] = Field(default=ModelFormat.LyCORIS) + + trigger_phrases: set[str] | None = Field(None) + + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: + _validate_is_file(cls, mod) + + _validate_override_fields(cls, fields) + + cls._validate_looks_like_control_lora(mod) + + return cls(**fields) + + @classmethod + def _validate_looks_like_control_lora(cls, mod: ModelOnDisk) -> None: + state_dict = mod.load_state_dict() + + if not is_state_dict_likely_flux_control(state_dict): + raise NotAMatch(cls, "model state dict does not look like a Flux Control LoRA") + + def load_model(self, submodel_type: Optional[SubModelType] = None) -> AnyModel: + if submodel_type is not None: + raise ValueError("Control LoRA models have no submodels.") + + model_path = Path(self.path) + + # Load the state dict from the model file. + if model_path.suffix == ".safetensors": + state_dict = load_file(model_path.absolute().as_posix(), device="cpu") + else: + state_dict = torch.load(model_path, map_location="cpu") + + # Strip 'bundle_emb' keys - these are unused and currently cause downstream errors. + state_dict = {k: v for k, v in state_dict.items() if not k.startswith("bundle_emb")} + + # Load as Flux Control LoRA + model = lora_model_from_flux_control_state_dict(state_dict=state_dict) + + model.to(dtype=TorchDevice.choose_torch_dtype()) + return model + + +class LoRA_Diffusers_Config_Base(LoRA_Config_Base): + """Model config for LoRA/Diffusers models.""" + + # TODO(psyche): Needs base handling. For FLUX, the Diffusers format does not indicate a folder model; it indicates + # the weights format. FLUX Diffusers LoRAs are single files. + + format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers) + + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: + _validate_is_dir(cls, mod) + + _validate_override_fields(cls, fields) + + cls._validate_base(mod) + + return cls(**fields) + + @classmethod + def _validate_base(cls, mod: ModelOnDisk) -> None: + """Raise `NotAMatch` if the model base does not match this config class.""" + expected_base = cls.model_fields["base"].default + recognized_base = cls._get_base_or_raise(mod) + if expected_base is not recognized_base: + raise NotAMatch(cls, f"base is {recognized_base}, not {expected_base}") + + @classmethod + def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType: + if _get_flux_lora_format(mod): + return BaseModelType.Flux + + # If we've gotten here, we assume that the LoRA is a Stable Diffusion LoRA + path_to_weight_file = cls._get_weight_file_or_raise(mod) + state_dict = mod.load_state_dict(path_to_weight_file) + token_vector_length = lora_token_vector_length(state_dict) + + match token_vector_length: + case 768: + return BaseModelType.StableDiffusion1 + case 1024: + return BaseModelType.StableDiffusion2 + case 1280: + return BaseModelType.StableDiffusionXL # recognizes format at https://civitai.com/models/224641 + case 2048: + return BaseModelType.StableDiffusionXL + case _: + raise NotAMatch(cls, f"unrecognized token vector length {token_vector_length}") + + @classmethod + def _get_weight_file_or_raise(cls, mod: ModelOnDisk) -> Path: + suffixes = ["bin", "safetensors"] + weight_files = [mod.path / f"pytorch_lora_weights.{sfx}" for sfx in suffixes] + for wf in weight_files: + if wf.exists(): + return wf + raise NotAMatch(cls, "missing pytorch_lora_weights.bin or pytorch_lora_weights.safetensors") + + +class LoRA_Diffusers_SD1_Config(LoRA_Diffusers_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1) + + def load_model(self, submodel_type: Optional[SubModelType] = None) -> AnyModel: + if submodel_type is not None: + raise ValueError("LoRA models have no submodels.") + + # For Diffusers format, resolve to the actual weight file inside the directory + model_base_path = Path(self.path) + model_path = model_base_path + for ext in ["safetensors", "bin"]: + path = model_base_path / f"pytorch_lora_weights.{ext}" + if path.exists(): + model_path = path + break + + # Load the state dict from the model file. + if model_path.suffix == ".safetensors": + state_dict = load_file(model_path.absolute().as_posix(), device="cpu") + else: + state_dict = torch.load(model_path, map_location="cpu") + + # Strip 'bundle_emb' keys - these are unused and currently cause downstream errors. + state_dict = {k: v for k, v in state_dict.items() if not k.startswith("bundle_emb")} + + # SD1 models don't need any key conversions + model = lora_model_from_sd_state_dict(state_dict=state_dict) + + model.to(dtype=TorchDevice.choose_torch_dtype()) + return model + + +class LoRA_Diffusers_SD2_Config(LoRA_Diffusers_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2) + + def load_model(self, submodel_type: Optional[SubModelType] = None) -> AnyModel: + if submodel_type is not None: + raise ValueError("LoRA models have no submodels.") + + # For Diffusers format, resolve to the actual weight file inside the directory + model_base_path = Path(self.path) + model_path = model_base_path + for ext in ["safetensors", "bin"]: + path = model_base_path / f"pytorch_lora_weights.{ext}" + if path.exists(): + model_path = path + break + + # Load the state dict from the model file. + if model_path.suffix == ".safetensors": + state_dict = load_file(model_path.absolute().as_posix(), device="cpu") + else: + state_dict = torch.load(model_path, map_location="cpu") + + # Strip 'bundle_emb' keys - these are unused and currently cause downstream errors. + state_dict = {k: v for k, v in state_dict.items() if not k.startswith("bundle_emb")} + + # SD2 models don't need any key conversions + model = lora_model_from_sd_state_dict(state_dict=state_dict) + + model.to(dtype=TorchDevice.choose_torch_dtype()) + return model + + +class LoRA_Diffusers_SDXL_Config(LoRA_Diffusers_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL) + + def load_model(self, submodel_type: Optional[SubModelType] = None) -> AnyModel: + if submodel_type is not None: + raise ValueError("LoRA models have no submodels.") + + # For Diffusers format, resolve to the actual weight file inside the directory + model_base_path = Path(self.path) + model_path = model_base_path + for ext in ["safetensors", "bin"]: + path = model_base_path / f"pytorch_lora_weights.{ext}" + if path.exists(): + model_path = path + break + + # Load the state dict from the model file. + if model_path.suffix == ".safetensors": + state_dict = load_file(model_path.absolute().as_posix(), device="cpu") + else: + state_dict = torch.load(model_path, map_location="cpu") + + # Strip 'bundle_emb' keys - these are unused and currently cause downstream errors. + state_dict = {k: v for k, v in state_dict.items() if not k.startswith("bundle_emb")} + + # Apply SDXL-specific key conversions + state_dict = convert_sdxl_keys_to_diffusers_format(state_dict) + model = lora_model_from_sd_state_dict(state_dict=state_dict) + + model.to(dtype=TorchDevice.choose_torch_dtype()) + return model + + +class LoRA_Diffusers_FLUX_Config(LoRA_Diffusers_Config_Base, Config_Base): + base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux) + + def load_model(self, submodel_type: Optional[SubModelType] = None) -> AnyModel: + if submodel_type is not None: + raise ValueError("LoRA models have no submodels.") + + # For FLUX Diffusers format, resolve to the actual weight file inside the directory + # (though these are typically single files, we still handle directory format for consistency) + model_base_path = Path(self.path) + model_path = model_base_path + for ext in ["safetensors", "bin"]: + path = model_base_path / f"pytorch_lora_weights.{ext}" + if path.exists(): + model_path = path + break + + # Load the state dict from the model file. + if model_path.suffix == ".safetensors": + state_dict = load_file(model_path.absolute().as_posix(), device="cpu") + else: + state_dict = torch.load(model_path, map_location="cpu") + + # Strip 'bundle_emb' keys - these are unused and currently cause downstream errors. + state_dict = {k: v for k, v in state_dict.items() if not k.startswith("bundle_emb")} + + # HACK(ryand): We set alpha=None for diffusers PEFT format models. These models are typically + # distributed as a single file without the associated metadata containing the alpha value. We chose + # alpha=None, because this is treated as alpha=rank internally in `LoRALayerBase.scale()`. alpha=rank + # is a popular choice. For example, in the diffusers training scripts: + # https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth_lora_flux.py#L1194 + model = lora_model_from_flux_diffusers_state_dict(state_dict=state_dict, alpha=None) + + model.to(dtype=TorchDevice.choose_torch_dtype()) + return model + + +class VAE_Checkpoint_Config_Base(Checkpoint_Config_Base): + """Model config for standalone VAE models.""" + + type: Literal[ModelType.VAE] = Field(default=ModelType.VAE) + format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint) + + REGEX_TO_BASE: ClassVar[dict[str, BaseModelType]] = { + r"xl": BaseModelType.StableDiffusionXL, + r"sd2": BaseModelType.StableDiffusion2, + r"vae": BaseModelType.StableDiffusion1, + r"FLUX.1-schnell_ae": BaseModelType.Flux, + } + + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: + _validate_is_file(cls, mod) + + _validate_override_fields(cls, fields) + + cls._validate_looks_like_vae(mod) + + cls._validate_base(mod) + + return cls(**fields) + + @classmethod + def _validate_base(cls, mod: ModelOnDisk) -> None: + """Raise `NotAMatch` if the model base does not match this config class.""" + expected_base = cls.model_fields["base"].default + recognized_base = cls._get_base_or_raise(mod) + if expected_base is not recognized_base: + raise NotAMatch(cls, f"base is {recognized_base}, not {expected_base}") + + @classmethod + def _validate_looks_like_vae(cls, mod: ModelOnDisk) -> None: + if not has_any_keys_starting_with( + mod.load_state_dict(), + { + "encoder.conv_in", + "decoder.conv_in", + }, + ): + raise NotAMatch(cls, "model does not match Checkpoint VAE heuristics") + + @classmethod + def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType: + # Heuristic: VAEs of all architectures have a similar structure; the best we can do is guess based on name + for regexp, base in cls.REGEX_TO_BASE.items(): + if re.search(regexp, mod.path.name, re.IGNORECASE): + return base + + raise NotAMatch(cls, "cannot determine base type") + + +class VAE_Checkpoint_SD1_Config(VAE_Checkpoint_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1) + + def load_model(self, submodel_type: Optional[SubModelType] = None) -> AnyModel: + if submodel_type is not None: + raise ValueError("VAE models have no submodels.") + + model_path = Path(self.path) + model = AutoencoderKL.from_single_file( + model_path, + torch_dtype=TorchDevice.choose_torch_dtype(), + ) + assert isinstance(model, AutoencoderKL) + return model + + +class VAE_Checkpoint_SD2_Config(VAE_Checkpoint_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2) + + def load_model(self, submodel_type: Optional[SubModelType] = None) -> AnyModel: + if submodel_type is not None: + raise ValueError("VAE models have no submodels.") + + model_path = Path(self.path) + model = AutoencoderKL.from_single_file( + model_path, + torch_dtype=TorchDevice.choose_torch_dtype(), + ) + assert isinstance(model, AutoencoderKL) + return model + + +class VAE_Checkpoint_SDXL_Config(VAE_Checkpoint_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL) + + def load_model(self, submodel_type: Optional[SubModelType] = None) -> AnyModel: + if submodel_type is not None: + raise ValueError("VAE models have no submodels.") + + model_path = Path(self.path) + model = AutoencoderKL.from_single_file( + model_path, + torch_dtype=TorchDevice.choose_torch_dtype(), + ) + assert isinstance(model, AutoencoderKL) + return model + + +class VAE_Checkpoint_FLUX_Config(VAE_Checkpoint_Config_Base, Config_Base): + base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux) + + def load_model(self, submodel_type: Optional[SubModelType] = None) -> AnyModel: + if submodel_type is not None: + raise ValueError("FLUX VAE models have no submodels.") + + model_path = Path(self.path) + + with accelerate.init_empty_weights(): + model = AutoEncoder(get_flux_ae_params()) + + sd = load_file(model_path) + model.load_state_dict(sd, assign=True) + + # VAE is broken in float16, which mps defaults to + torch_dtype = TorchDevice.choose_torch_dtype() + if torch_dtype == torch.float16: + try: + vae_dtype = torch.tensor([1.0], dtype=torch.bfloat16, device=TorchDevice.choose_torch_device()).dtype + except TypeError: + vae_dtype = torch.float32 + else: + vae_dtype = torch_dtype + model.to(vae_dtype) + + return model + + +class VAE_Diffusers_Config_Base(Diffusers_Config_Base): + """Model config for standalone VAE models (diffusers version).""" + + type: Literal[ModelType.VAE] = Field(default=ModelType.VAE) + format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers) + + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: + _validate_is_dir(cls, mod) + + _validate_override_fields(cls, fields) + + _validate_class_name( + cls, + common_config_paths(mod.path), + { + "AutoencoderKL", + "AutoencoderTiny", + }, + ) + + cls._validate_base(mod) + + return cls(**fields) + + @classmethod + def _validate_base(cls, mod: ModelOnDisk) -> None: + """Raise `NotAMatch` if the model base does not match this config class.""" + expected_base = cls.model_fields["base"].default + recognized_base = cls._get_base_or_raise(mod) + if expected_base is not recognized_base: + raise NotAMatch(cls, f"base is {recognized_base}, not {expected_base}") + + @classmethod + def _config_looks_like_sdxl(cls, config: dict[str, Any]) -> bool: + # Heuristic: These config values that distinguish Stability's SD 1.x VAE from their SDXL VAE. + return config.get("scaling_factor", 0) == 0.13025 and config.get("sample_size") in [512, 1024] + + @classmethod + def _name_looks_like_sdxl(cls, mod: ModelOnDisk) -> bool: + # Heuristic: SD and SDXL VAE are the same shape (3-channel RGB to 4-channel float scaled down + # by a factor of 8), so we can't necessarily tell them apart by config hyperparameters. Best + # we can do is guess based on name. + return bool(re.search(r"xl\b", cls._guess_name(mod), re.IGNORECASE)) + + @classmethod + def _guess_name(cls, mod: ModelOnDisk) -> str: + name = mod.path.name + if name == "vae": + name = mod.path.parent.name + return name + + @classmethod + def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType: + config = _get_config_or_raise(cls, common_config_paths(mod.path)) + if cls._config_looks_like_sdxl(config): + return BaseModelType.StableDiffusionXL + elif cls._name_looks_like_sdxl(mod): + return BaseModelType.StableDiffusionXL + else: + # TODO(psyche): Figure out how to positively identify SD1 here, and raise if we can't. Until then, YOLO. + return BaseModelType.StableDiffusion1 + + +class VAE_Diffusers_SD1_Config(VAE_Diffusers_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1) + + def load_model(self, submodel_type: Optional[SubModelType] = None) -> AnyModel: + if submodel_type is not None: + raise ValueError("VAE models have no submodels.") + + model_path = Path(self.path) + model = AutoencoderKL.from_pretrained( + model_path, + torch_dtype=TorchDevice.choose_torch_dtype(), + ) + assert isinstance(model, AutoencoderKL) + return model + + +class VAE_Diffusers_SDXL_Config(VAE_Diffusers_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL) + + def load_model(self, submodel_type: Optional[SubModelType] = None) -> AnyModel: + if submodel_type is not None: + raise ValueError("VAE models have no submodels.") + + model_path = Path(self.path) + model = AutoencoderKL.from_pretrained( + model_path, + torch_dtype=TorchDevice.choose_torch_dtype(), + ) + assert isinstance(model, AutoencoderKL) + return model + + +class ControlNet_Diffusers_Config_Base(Diffusers_Config_Base, ControlAdapter_Config_Base): + """Model config for ControlNet models (diffusers version).""" + + type: Literal[ModelType.ControlNet] = Field(default=ModelType.ControlNet) + format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers) + + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: + _validate_is_dir(cls, mod) + + _validate_override_fields(cls, fields) + + _validate_class_name( + cls, + common_config_paths(mod.path), + { + "ControlNetModel", + "FluxControlNetModel", + }, + ) + + cls._validate_base(mod) + + return cls(**fields) + + @classmethod + def _validate_base(cls, mod: ModelOnDisk) -> None: + """Raise `NotAMatch` if the model base does not match this config class.""" + expected_base = cls.model_fields["base"].default + recognized_base = cls._get_base_or_raise(mod) + if expected_base is not recognized_base: + raise NotAMatch(cls, f"base is {recognized_base}, not {expected_base}") + + @classmethod + def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType: + config = _get_config_or_raise(cls, common_config_paths(mod.path)) + + if config.get("_class_name") == "FluxControlNetModel": + return BaseModelType.Flux + + dimension = config.get("cross_attention_dim") + + match dimension: + case 768: + return BaseModelType.StableDiffusion1 + case 1024: + # No obvious way to distinguish between sd2-base and sd2-768, but we don't really differentiate them + # anyway. + return BaseModelType.StableDiffusion2 + case 2048: + return BaseModelType.StableDiffusionXL + case _: + raise NotAMatch(cls, f"unrecognized cross_attention_dim {dimension}") + + +class ControlNet_Diffusers_SD1_Config(ControlNet_Diffusers_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1) + + def load_model(self, submodel_type: Optional[SubModelType] = None) -> AnyModel: + if submodel_type is not None: + raise ValueError("ControlNet models have no submodels.") + + model_path = Path(self.path) + config = _get_config_or_raise(type(self), model_path / "config.json") + model_class = _get_hf_load_class_from_config(config) + + variant = self.repo_variant.value if self.repo_variant else None + try: + model: AnyModel = model_class.from_pretrained( + model_path, + torch_dtype=TorchDevice.choose_torch_dtype(), + variant=variant, + ) + except OSError as e: + if variant and "no file named" in str(e): + # try without the variant, just in case user's preferences changed + model = model_class.from_pretrained( + model_path, + torch_dtype=TorchDevice.choose_torch_dtype(), + ) + else: + raise e + return model + + +class ControlNet_Diffusers_SD2_Config(ControlNet_Diffusers_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2) + + def load_model(self, submodel_type: Optional[SubModelType] = None) -> AnyModel: + if submodel_type is not None: + raise ValueError("ControlNet models have no submodels.") + + model_path = Path(self.path) + config = _get_config_or_raise(type(self), model_path / "config.json") + model_class = _get_hf_load_class_from_config(config) + + variant = self.repo_variant.value if self.repo_variant else None + try: + model: AnyModel = model_class.from_pretrained( + model_path, + torch_dtype=TorchDevice.choose_torch_dtype(), + variant=variant, + ) + except OSError as e: + if variant and "no file named" in str(e): + # try without the variant, just in case user's preferences changed + model = model_class.from_pretrained( + model_path, + torch_dtype=TorchDevice.choose_torch_dtype(), + ) + else: + raise e + return model + + +class ControlNet_Diffusers_SDXL_Config(ControlNet_Diffusers_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL) + + def load_model(self, submodel_type: Optional[SubModelType] = None) -> AnyModel: + if submodel_type is not None: + raise ValueError("ControlNet models have no submodels.") + + model_path = Path(self.path) + config = _get_config_or_raise(type(self), model_path / "config.json") + model_class = _get_hf_load_class_from_config(config) + + variant = self.repo_variant.value if self.repo_variant else None + try: + model: AnyModel = model_class.from_pretrained( + model_path, + torch_dtype=TorchDevice.choose_torch_dtype(), + variant=variant, + ) + except OSError as e: + if variant and "no file named" in str(e): + # try without the variant, just in case user's preferences changed + model = model_class.from_pretrained( + model_path, + torch_dtype=TorchDevice.choose_torch_dtype(), + ) + else: + raise e + return model + + +class ControlNet_Diffusers_FLUX_Config(ControlNet_Diffusers_Config_Base, Config_Base): + base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux) + + def load_model(self, submodel_type: Optional[SubModelType] = None) -> AnyModel: + if submodel_type is not None: + raise ValueError("ControlNet models have no submodels.") + + # Load from the diffusers directory weight file + model_path = Path(self.path) / "diffusion_pytorch_model.safetensors" + sd = load_file(model_path) + + # Detect the FLUX ControlNet model type from the state dict + if is_state_dict_xlabs_controlnet(sd): + with accelerate.init_empty_weights(): + # HACK(ryand): Is it safe to assume dev here? + model = XLabsControlNetFlux(get_flux_transformers_params(FluxVariantType.Dev)) + + model.load_state_dict(sd, assign=True) + return model + elif is_state_dict_instantx_controlnet(sd): + sd = convert_diffusers_instantx_state_dict_to_bfl_format(sd) + flux_params = infer_flux_params_from_state_dict(sd) + num_control_modes = infer_instantx_num_control_modes_from_state_dict(sd) + + with accelerate.init_empty_weights(): + model = InstantXControlNetFlux(flux_params, num_control_modes) + + model.load_state_dict(sd, assign=True) + return model + else: + raise ValueError("Do not recognize the state dict as an XLabs or InstantX ControlNet model.") + + +class ControlNet_Checkpoint_Config_Base(Checkpoint_Config_Base, ControlAdapter_Config_Base): + """Model config for ControlNet models (diffusers version).""" + + type: Literal[ModelType.ControlNet] = Field(default=ModelType.ControlNet) + format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint) + + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: + _validate_is_file(cls, mod) + + _validate_override_fields(cls, fields) + + cls._validate_looks_like_controlnet(mod) + + cls._validate_base(mod) + + return cls(**fields) + + @classmethod + def _validate_base(cls, mod: ModelOnDisk) -> None: + """Raise `NotAMatch` if the model base does not match this config class.""" + expected_base = cls.model_fields["base"].default + recognized_base = cls._get_base_or_raise(mod) + if expected_base is not recognized_base: + raise NotAMatch(cls, f"base is {recognized_base}, not {expected_base}") + + @classmethod + def _validate_looks_like_controlnet(cls, mod: ModelOnDisk) -> None: + if has_any_keys_starting_with( + mod.load_state_dict(), + { + "controlnet", + "control_model", + "input_blocks", + # XLabs FLUX ControlNet models have keys starting with "controlnet_blocks." + # For example: https://huggingface.co/XLabs-AI/flux-controlnet-collections/blob/86ab1e915a389d5857135c00e0d350e9e38a9048/flux-canny-controlnet_v2.safetensors + # TODO(ryand): This is very fragile. XLabs FLUX ControlNet models also contain keys starting with + # "double_blocks.", which we check for above. But, I'm afraid to modify this logic because it is so + # delicate. + "controlnet_blocks", + }, + ): + raise NotAMatch(cls, "state dict does not look like a ControlNet checkpoint") + + @classmethod + def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType: + state_dict = mod.load_state_dict() + + if is_state_dict_xlabs_controlnet(state_dict) or is_state_dict_instantx_controlnet(state_dict): + # TODO(ryand): Should I distinguish between XLabs, InstantX and other ControlNet models by implementing + # get_format()? + return BaseModelType.Flux + + for key in ( + "control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight", + "controlnet_mid_block.bias", + "input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight", + "down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_k.weight", + ): + if key not in state_dict: + continue + width = state_dict[key].shape[-1] + match width: + case 768: + return BaseModelType.StableDiffusion1 + case 1024: + return BaseModelType.StableDiffusion2 + case 2048: + return BaseModelType.StableDiffusionXL + case 1280: + return BaseModelType.StableDiffusionXL + case _: + pass + + raise NotAMatch(cls, "unable to determine base type from state dict") + + +class ControlNet_Checkpoint_SD1_Config(ControlNet_Checkpoint_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1) + + def load_model(self, submodel_type: Optional[SubModelType] = None) -> AnyModel: + if submodel_type is not None: + raise ValueError("ControlNet models have no submodels.") + + model: AnyModel = ControlNetModel.from_single_file( + self.path, + torch_dtype=TorchDevice.choose_torch_dtype(), + ) + return model + + +class ControlNet_Checkpoint_SD2_Config(ControlNet_Checkpoint_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2) + + def load_model(self, submodel_type: Optional[SubModelType] = None) -> AnyModel: + if submodel_type is not None: + raise ValueError("ControlNet models have no submodels.") + + model: AnyModel = ControlNetModel.from_single_file( + self.path, + torch_dtype=TorchDevice.choose_torch_dtype(), + ) + return model + + +class ControlNet_Checkpoint_SDXL_Config(ControlNet_Checkpoint_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL) + + def load_model(self, submodel_type: Optional[SubModelType] = None) -> AnyModel: + if submodel_type is not None: + raise ValueError("ControlNet models have no submodels.") + + model: AnyModel = ControlNetModel.from_single_file( + self.path, + torch_dtype=TorchDevice.choose_torch_dtype(), + ) + return model + + +class ControlNet_Checkpoint_FLUX_Config(ControlNet_Checkpoint_Config_Base, Config_Base): + base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux) + + def load_model(self, submodel_type: Optional[SubModelType] = None) -> AnyModel: + if submodel_type is not None: + raise ValueError("ControlNet models have no submodels.") + + model_path = Path(self.path) + sd = load_file(model_path) + + # Detect the FLUX ControlNet model type from the state dict + if is_state_dict_xlabs_controlnet(sd): + with accelerate.init_empty_weights(): + # HACK(ryand): Is it safe to assume dev here? + model = XLabsControlNetFlux(get_flux_transformers_params(FluxVariantType.Dev)) + + model.load_state_dict(sd, assign=True) + return model + elif is_state_dict_instantx_controlnet(sd): + sd = convert_diffusers_instantx_state_dict_to_bfl_format(sd) + flux_params = infer_flux_params_from_state_dict(sd) + num_control_modes = infer_instantx_num_control_modes_from_state_dict(sd) + + with accelerate.init_empty_weights(): + model = InstantXControlNetFlux(flux_params, num_control_modes) + + model.load_state_dict(sd, assign=True) + return model + else: + raise ValueError("Do not recognize the state dict as an XLabs or InstantX ControlNet model.") + + +class TI_Config_Base(ABC, BaseModel): + type: Literal[ModelType.TextualInversion] = Field(default=ModelType.TextualInversion) + + @classmethod + def _validate_base(cls, mod: ModelOnDisk, path: Path | None = None) -> None: + """Raise `NotAMatch` if the model base does not match this config class.""" + expected_base = cls.model_fields["base"].default + recognized_base = cls._get_base_or_raise(mod, path) + if expected_base is not recognized_base: + raise NotAMatch(cls, f"base is {recognized_base}, not {expected_base}") + + @classmethod + def _file_looks_like_embedding(cls, mod: ModelOnDisk, path: Path | None = None) -> bool: + try: + p = path or mod.path + + if not p.exists(): + return False + + if p.is_dir(): + return False + + if p.name in [f"learned_embeds.{s}" for s in mod.weight_files()]: + return True + + state_dict = mod.load_state_dict(p) + + # Heuristic: textual inversion embeddings have these keys + if any(key in {"string_to_param", "emb_params", "clip_g"} for key in state_dict.keys()): + return True + + # Heuristic: small state dict with all tensor values + if (len(state_dict)) < 10 and all(isinstance(v, torch.Tensor) for v in state_dict.values()): + return True + + return False + except Exception: + return False + + @classmethod + def _get_base_or_raise(cls, mod: ModelOnDisk, path: Path | None = None) -> BaseModelType: + p = path or mod.path + + try: + state_dict = mod.load_state_dict(p) + except Exception as e: + raise NotAMatch(cls, f"unable to load state dict from {p}: {e}") from e + + try: + if "string_to_token" in state_dict: + token_dim = list(state_dict["string_to_param"].values())[0].shape[-1] + elif "emb_params" in state_dict: + token_dim = state_dict["emb_params"].shape[-1] + elif "clip_g" in state_dict: + token_dim = state_dict["clip_g"].shape[-1] + else: + token_dim = list(state_dict.values())[0].shape[0] + except Exception as e: + raise NotAMatch(cls, f"unable to determine token dimension from state dict in {p}: {e}") from e + + match token_dim: + case 768: + return BaseModelType.StableDiffusion1 + case 1024: + return BaseModelType.StableDiffusion2 + case 1280: + return BaseModelType.StableDiffusionXL + case _: + raise NotAMatch(cls, f"unrecognized token dimension {token_dim}") + + +class TI_File_Config_Base(TI_Config_Base): + """Model config for textual inversion embeddings.""" + + format: Literal[ModelFormat.EmbeddingFile] = Field(default=ModelFormat.EmbeddingFile) + + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: + _validate_is_file(cls, mod) + + _validate_override_fields(cls, fields) + + if not cls._file_looks_like_embedding(mod): + raise NotAMatch(cls, "model does not look like a textual inversion embedding file") + + cls._validate_base(mod) + + return cls(**fields) + + +class TI_File_SD1_Config(TI_File_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1) + + def load_model(self, submodel_type: Optional[SubModelType] = None) -> AnyModel: + if submodel_type is not None: + raise ValueError("Textual Inversion models have no submodels.") + + model_path = Path(self.path) + model = TextualInversionModelRaw.from_checkpoint( + file_path=model_path, + dtype=TorchDevice.choose_torch_dtype(), + ) + return model + + +class TI_File_SD2_Config(TI_File_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2) + + def load_model(self, submodel_type: Optional[SubModelType] = None) -> AnyModel: + if submodel_type is not None: + raise ValueError("Textual Inversion models have no submodels.") + + model_path = Path(self.path) + model = TextualInversionModelRaw.from_checkpoint( + file_path=model_path, + dtype=TorchDevice.choose_torch_dtype(), + ) + return model + + +class TI_File_SDXL_Config(TI_File_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL) + + def load_model(self, submodel_type: Optional[SubModelType] = None) -> AnyModel: + if submodel_type is not None: + raise ValueError("Textual Inversion models have no submodels.") + + model_path = Path(self.path) + model = TextualInversionModelRaw.from_checkpoint( + file_path=model_path, + dtype=TorchDevice.choose_torch_dtype(), + ) + return model + + +class TI_Folder_Config_Base(TI_Config_Base): + """Model config for textual inversion embeddings.""" + + format: Literal[ModelFormat.EmbeddingFolder] = Field(default=ModelFormat.EmbeddingFolder) + + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: + _validate_is_dir(cls, mod) + + _validate_override_fields(cls, fields) + + for p in mod.weight_files(): + if cls._file_looks_like_embedding(mod, p): + cls._validate_base(mod, p) + return cls(**fields) + + raise NotAMatch(cls, "model does not look like a textual inversion embedding folder") + + +class TI_Folder_SD1_Config(TI_Folder_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1) + + def load_model(self, submodel_type: Optional[SubModelType] = None) -> AnyModel: + if submodel_type is not None: + raise ValueError("Textual Inversion models have no submodels.") + + model_path = Path(self.path) / "learned_embeds.bin" + if not model_path.exists(): + raise OSError(f"The embedding file at {model_path} was not found") + + model = TextualInversionModelRaw.from_checkpoint( + file_path=model_path, + dtype=TorchDevice.choose_torch_dtype(), + ) + return model + + +class TI_Folder_SD2_Config(TI_Folder_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2) + + def load_model(self, submodel_type: Optional[SubModelType] = None) -> AnyModel: + if submodel_type is not None: + raise ValueError("Textual Inversion models have no submodels.") + + model_path = Path(self.path) / "learned_embeds.bin" + if not model_path.exists(): + raise OSError(f"The embedding file at {model_path} was not found") + + model = TextualInversionModelRaw.from_checkpoint( + file_path=model_path, + dtype=TorchDevice.choose_torch_dtype(), + ) + return model + + +class TI_Folder_SDXL_Config(TI_Folder_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL) + + def load_model(self, submodel_type: Optional[SubModelType] = None) -> AnyModel: + if submodel_type is not None: + raise ValueError("Textual Inversion models have no submodels.") + + model_path = Path(self.path) / "learned_embeds.bin" + if not model_path.exists(): + raise OSError(f"The embedding file at {model_path} was not found") + + model = TextualInversionModelRaw.from_checkpoint( + file_path=model_path, + dtype=TorchDevice.choose_torch_dtype(), + ) + return model + + +class Main_Config_Base(ABC, BaseModel): + type: Literal[ModelType.Main] = Field(default=ModelType.Main) + trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None) + default_settings: Optional[MainModelDefaultSettings] = Field( + description="Default settings for this model", default=None + ) + + +def _has_bnb_nf4_keys(state_dict: dict[str | int, Any]) -> bool: + bnb_nf4_keys = { + "double_blocks.0.img_attn.proj.weight.quant_state.bitsandbytes__nf4", + "model.diffusion_model.double_blocks.0.img_attn.proj.weight.quant_state.bitsandbytes__nf4", + } + return any(key in state_dict for key in bnb_nf4_keys) + + +def _has_ggml_tensors(state_dict: dict[str | int, Any]) -> bool: + return any(isinstance(v, GGMLTensor) for v in state_dict.values()) + + +def _has_main_keys(state_dict: dict[str | int, Any]) -> bool: + for key in state_dict.keys(): + if isinstance(key, int): + continue + elif key.startswith( + ( + "cond_stage_model.", + "first_stage_model.", + "model.diffusion_model.", + # Some FLUX checkpoint files contain transformer keys prefixed with "model.diffusion_model". + # This prefix is typically used to distinguish between multiple models bundled in a single file. + "model.diffusion_model.double_blocks.", + ) + ): + return True + elif key.startswith("double_blocks.") and "ip_adapter" not in key: + # FLUX models in the official BFL format contain keys with the "double_blocks." prefix, but we must be + # careful to avoid false positives on XLabs FLUX IP-Adapter models. + return True + return False + + +class Main_Checkpoint_Config_Base(Checkpoint_Config_Base, Main_Config_Base): + """Model config for main checkpoint models.""" + + format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint) + + prediction_type: SchedulerPredictionType = Field() + variant: ModelVariantType = Field() + + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: + _validate_is_file(cls, mod) + + _validate_override_fields(cls, fields) + + cls._validate_looks_like_main_model(mod) + + cls._validate_base(mod) + + prediction_type = fields.get("prediction_type") or cls._get_scheduler_prediction_type_or_raise(mod) + + variant = fields.get("variant") or cls._get_variant_or_raise(mod) + + return cls(**fields, prediction_type=prediction_type, variant=variant) + + @classmethod + def _validate_base(cls, mod: ModelOnDisk) -> None: + """Raise `NotAMatch` if the model base does not match this config class.""" + expected_base = cls.model_fields["base"].default + recognized_base = cls._get_base_or_raise(mod) + if expected_base is not recognized_base: + raise NotAMatch(cls, f"base is {recognized_base}, not {expected_base}") + + @classmethod + def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType: + state_dict = mod.load_state_dict() + + key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight" + if key_name in state_dict and state_dict[key_name].shape[-1] == 768: + return BaseModelType.StableDiffusion1 + if key_name in state_dict and state_dict[key_name].shape[-1] == 1024: + return BaseModelType.StableDiffusion2 + + key_name = "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight" + if key_name in state_dict and state_dict[key_name].shape[-1] == 2048: + return BaseModelType.StableDiffusionXL + elif key_name in state_dict and state_dict[key_name].shape[-1] == 1280: + return BaseModelType.StableDiffusionXLRefiner + + raise NotAMatch(cls, "unable to determine base type from state dict") + + @classmethod + def _get_scheduler_prediction_type_or_raise(cls, mod: ModelOnDisk) -> SchedulerPredictionType: + base = cls.model_fields["base"].default + + if base is BaseModelType.StableDiffusion2: + state_dict = mod.load_state_dict() + key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight" + if key_name in state_dict and state_dict[key_name].shape[-1] == 1024: + if "global_step" in state_dict: + if state_dict["global_step"] == 220000: + return SchedulerPredictionType.Epsilon + elif state_dict["global_step"] == 110000: + return SchedulerPredictionType.VPrediction + return SchedulerPredictionType.VPrediction + else: + return SchedulerPredictionType.Epsilon + + @classmethod + def _get_variant_or_raise(cls, mod: ModelOnDisk) -> ModelVariantType: + base = cls.model_fields["base"].default + + state_dict = mod.load_state_dict() + key_name = "model.diffusion_model.input_blocks.0.0.weight" + + if key_name not in state_dict: + raise NotAMatch(cls, "unable to determine model variant from state dict") + + in_channels = state_dict["model.diffusion_model.input_blocks.0.0.weight"].shape[1] + + match in_channels: + case 4: + return ModelVariantType.Normal + case 5: + # Only SD2 has a depth variant + assert base is BaseModelType.StableDiffusion2, f"unexpected unet in_channels 5 for base '{base}'" + return ModelVariantType.Depth + case 9: + return ModelVariantType.Inpaint + case _: + raise NotAMatch(cls, f"unrecognized unet in_channels {in_channels} for base '{base}'") + + @classmethod + def _validate_looks_like_main_model(cls, mod: ModelOnDisk) -> None: + has_main_model_keys = _has_main_keys(mod.load_state_dict()) + if not has_main_model_keys: + raise NotAMatch(cls, "state dict does not look like a main model") + + +class Main_Checkpoint_SD1_Config(Main_Checkpoint_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1) + + def load_model(self, submodel_type: Optional[SubModelType] = None) -> AnyModel: + load_class = _get_sd_checkpoint_pipeline_class(self.base, self.variant) + + with SilenceWarnings(): + pipeline = load_class.from_single_file(self.path, torch_dtype=TorchDevice.choose_torch_dtype()) + + if not submodel_type: + return pipeline + + _cache_sd_submodels(self.key, pipeline, submodel_type) + return getattr(pipeline, submodel_type.value) + + +class Main_Checkpoint_SD2_Config(Main_Checkpoint_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2) + + def load_model(self, submodel_type: Optional[SubModelType] = None) -> AnyModel: + load_class = _get_sd_checkpoint_pipeline_class(self.base, self.variant) + + with SilenceWarnings(): + pipeline = load_class.from_single_file(self.path, torch_dtype=TorchDevice.choose_torch_dtype()) + + if not submodel_type: + return pipeline + + _cache_sd_submodels(self.key, pipeline, submodel_type) + return getattr(pipeline, submodel_type.value) + + +class Main_Checkpoint_SDXL_Config(Main_Checkpoint_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL) + + def load_model(self, submodel_type: Optional[SubModelType] = None) -> AnyModel: + load_class = _get_sd_checkpoint_pipeline_class(self.base, self.variant) + + with SilenceWarnings(): + pipeline = load_class.from_single_file(self.path, torch_dtype=TorchDevice.choose_torch_dtype()) + + if not submodel_type: + return pipeline + + _cache_sd_submodels(self.key, pipeline, submodel_type) + return getattr(pipeline, submodel_type.value) + + +class Main_Checkpoint_SDXLRefiner_Config(Main_Checkpoint_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusionXLRefiner] = Field(default=BaseModelType.StableDiffusionXLRefiner) + + def load_model(self, submodel_type: Optional[SubModelType] = None) -> AnyModel: + load_class = _get_sd_checkpoint_pipeline_class(self.base, self.variant) + + with SilenceWarnings(): + pipeline = load_class.from_single_file(self.path, torch_dtype=TorchDevice.choose_torch_dtype()) + + if not submodel_type: + return pipeline + + _cache_sd_submodels(self.key, pipeline, submodel_type) + return getattr(pipeline, submodel_type.value) + + +def _get_flux_variant(state_dict: dict[str | int, Any]) -> FluxVariantType | None: + # FLUX Model variant types are distinguished by input channels and the presence of certain keys. + + # Input channels are derived from the shape of either "img_in.weight" or "model.diffusion_model.img_in.weight". + # + # Known models that use the latter key: + # - https://civitai.com/models/885098?modelVersionId=990775 + # - https://civitai.com/models/1018060?modelVersionId=1596255 + # - https://civitai.com/models/978314/ultrareal-fine-tune?modelVersionId=1413133 + # + # Input channels for known FLUX models: + # - Unquantized Dev and Schnell have in_channels=64 + # - BNB-NF4 Dev and Schnell have in_channels=1 + # - FLUX Fill has in_channels=384 + # - Unsure of quantized FLUX Fill models + # - Unsure of GGUF-quantized models + + in_channels = None + for key in {"img_in.weight", "model.diffusion_model.img_in.weight"}: + if key in state_dict: + in_channels = state_dict[key].shape[1] + break + + if in_channels is None: + # TODO(psyche): Should we have a graceful fallback here? Previously we fell back to the "normal" variant, + # but this variant is no longer used for FLUX models. If we get here, but the model is definitely a FLUX + # model, we should figure out a good fallback value. + return None + + # Because FLUX Dev and Schnell models have the same in_channels, we need to check for the presence of + # certain keys to distinguish between them. + is_flux_dev = ( + "guidance_in.out_layer.weight" in state_dict + or "model.diffusion_model.guidance_in.out_layer.weight" in state_dict + ) + + if is_flux_dev and in_channels == 384: + return FluxVariantType.DevFill + elif is_flux_dev: + return FluxVariantType.Dev + else: + # Must be a Schnell model...? + return FluxVariantType.Schnell + + +class Main_Checkpoint_FLUX_Config(Checkpoint_Config_Base, Main_Config_Base, Config_Base): + """Model config for main checkpoint models.""" + + format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint) + base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux) + + variant: FluxVariantType = Field() + + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: + _validate_is_file(cls, mod) + + _validate_override_fields(cls, fields) + + cls._validate_looks_like_main_model(mod) + + cls._validate_is_flux(mod) + + cls._validate_does_not_look_like_bnb_quantized(mod) + + cls._validate_does_not_look_like_gguf_quantized(mod) + + variant = fields.get("variant") or cls._get_variant_or_raise(mod) + + return cls(**fields, variant=variant) + + @classmethod + def _validate_is_flux(cls, mod: ModelOnDisk) -> None: + if not has_any_keys( + mod.load_state_dict(), + { + "double_blocks.0.img_attn.norm.key_norm.scale", + "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale", + }, + ): + raise NotAMatch(cls, "state dict does not look like a FLUX checkpoint") + + @classmethod + def _get_variant_or_raise(cls, mod: ModelOnDisk) -> FluxVariantType: + # FLUX Model variant types are distinguished by input channels and the presence of certain keys. + state_dict = mod.load_state_dict() + variant = _get_flux_variant(state_dict) + + if variant is None: + # TODO(psyche): Should we have a graceful fallback here? Previously we fell back to the "normal" variant, + # but this variant is no longer used for FLUX models. If we get here, but the model is definitely a FLUX + # model, we should figure out a good fallback value. + raise NotAMatch(cls, "unable to determine model variant from state dict") + + return variant + + @classmethod + def _validate_looks_like_main_model(cls, mod: ModelOnDisk) -> None: + has_main_model_keys = _has_main_keys(mod.load_state_dict()) + if not has_main_model_keys: + raise NotAMatch(cls, "state dict does not look like a main model") + + @classmethod + def _validate_does_not_look_like_bnb_quantized(cls, mod: ModelOnDisk) -> None: + has_bnb_nf4_keys = _has_bnb_nf4_keys(mod.load_state_dict()) + if has_bnb_nf4_keys: + raise NotAMatch(cls, "state dict looks like bnb quantized nf4") + + @classmethod + def _validate_does_not_look_like_gguf_quantized(cls, mod: ModelOnDisk): + has_ggml_tensors = _has_ggml_tensors(mod.load_state_dict()) + if has_ggml_tensors: + raise NotAMatch(cls, "state dict looks like GGUF quantized") + + def load_model(self, submodel_type: Optional[SubModelType] = None) -> AnyModel: + if submodel_type != SubModelType.Transformer: + raise ValueError( + f"Only Transformer submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}" + ) + + model_path = Path(self.path) + + with accelerate.init_empty_weights(): + model = Flux(get_flux_transformers_params(self.variant)) + + sd = load_file(model_path) + if "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale" in sd: + sd = convert_bundle_to_flux_transformer_checkpoint(sd) + new_sd_size = sum([ten.nelement() * torch.bfloat16.itemsize for ten in sd.values()]) + app_config.ram_cache.make_room(new_sd_size) + for k in sd.keys(): + # We need to cast to bfloat16 due to it being the only currently supported dtype for inference + sd[k] = sd[k].to(torch.bfloat16) + model.load_state_dict(sd, assign=True) + return model + + +class Main_BnBNF4_FLUX_Config(Checkpoint_Config_Base, Main_Config_Base, Config_Base): + """Model config for main checkpoint models.""" + + base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux) + format: Literal[ModelFormat.BnbQuantizednf4b] = Field(default=ModelFormat.BnbQuantizednf4b) + + variant: FluxVariantType = Field() + + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: + _validate_is_file(cls, mod) + + _validate_override_fields(cls, fields) + + cls._validate_looks_like_main_model(mod) + + cls._validate_model_looks_like_bnb_quantized(mod) + + variant = fields.get("variant") or cls._get_variant_or_raise(mod) + + return cls(**fields, variant=variant) + + @classmethod + def _get_variant_or_raise(cls, mod: ModelOnDisk) -> FluxVariantType: + # FLUX Model variant types are distinguished by input channels and the presence of certain keys. + state_dict = mod.load_state_dict() + variant = _get_flux_variant(state_dict) + + if variant is None: + # TODO(psyche): Should we have a graceful fallback here? Previously we fell back to the "normal" variant, + # but this variant is no longer used for FLUX models. If we get here, but the model is definitely a FLUX + # model, we should figure out a good fallback value. + raise NotAMatch(cls, "unable to determine model variant from state dict") + + return variant + + @classmethod + def _validate_looks_like_main_model(cls, mod: ModelOnDisk) -> None: + has_main_model_keys = _has_main_keys(mod.load_state_dict()) + if not has_main_model_keys: + raise NotAMatch(cls, "state dict does not look like a main model") + + @classmethod + def _validate_model_looks_like_bnb_quantized(cls, mod: ModelOnDisk) -> None: + has_bnb_nf4_keys = _has_bnb_nf4_keys(mod.load_state_dict()) + if not has_bnb_nf4_keys: + raise NotAMatch(cls, "state dict does not look like bnb quantized nf4") + + def load_model(self, submodel_type: Optional[SubModelType] = None) -> AnyModel: + if submodel_type != SubModelType.Transformer: + raise ValueError( + f"Only Transformer submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}" + ) + + if not bnb_available: + raise ImportError( + "The bnb modules are not available. Please install bitsandbytes if available on your platform." + ) + + model_path = Path(self.path) + + with SilenceWarnings(): + with accelerate.init_empty_weights(): + model = Flux(get_flux_transformers_params(self.variant)) + model = quantize_model_nf4(model, modules_to_not_convert=set(), compute_dtype=torch.bfloat16) + sd = load_file(model_path) + if "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale" in sd: + sd = convert_bundle_to_flux_transformer_checkpoint(sd) + model.load_state_dict(sd, assign=True) + return model + + +class Main_GGUF_FLUX_Config(Checkpoint_Config_Base, Main_Config_Base, Config_Base): + """Model config for main checkpoint models.""" + + base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux) + format: Literal[ModelFormat.GGUFQuantized] = Field(default=ModelFormat.GGUFQuantized) + + variant: FluxVariantType = Field() + + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: + _validate_is_file(cls, mod) + + _validate_override_fields(cls, fields) + + cls._validate_looks_like_main_model(mod) + + cls._validate_looks_like_gguf_quantized(mod) + + variant = fields.get("variant") or cls._get_variant_or_raise(mod) + + return cls(**fields, variant=variant) + + @classmethod + def _get_variant_or_raise(cls, mod: ModelOnDisk) -> FluxVariantType: + # FLUX Model variant types are distinguished by input channels and the presence of certain keys. + state_dict = mod.load_state_dict() + variant = _get_flux_variant(state_dict) + + if variant is None: + # TODO(psyche): Should we have a graceful fallback here? Previously we fell back to the "normal" variant, + # but this variant is no longer used for FLUX models. If we get here, but the model is definitely a FLUX + # model, we should figure out a good fallback value. + raise NotAMatch(cls, "unable to determine model variant from state dict") + + return variant + + @classmethod + def _validate_looks_like_main_model(cls, mod: ModelOnDisk) -> None: + has_main_model_keys = _has_main_keys(mod.load_state_dict()) + if not has_main_model_keys: + raise NotAMatch(cls, "state dict does not look like a main model") + + @classmethod + def _validate_looks_like_gguf_quantized(cls, mod: ModelOnDisk) -> None: + has_ggml_tensors = _has_ggml_tensors(mod.load_state_dict()) + if not has_ggml_tensors: + raise NotAMatch(cls, "state dict does not look like GGUF quantized") + + def load_model(self, submodel_type: Optional[SubModelType] = None) -> AnyModel: + if submodel_type != SubModelType.Transformer: + raise ValueError( + f"Only Transformer submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}" + ) + + model_path = Path(self.path) + + with accelerate.init_empty_weights(): + model = Flux(get_flux_transformers_params(self.variant)) + + # HACK(ryand): We shouldn't be hard-coding the compute_dtype here. + sd = gguf_sd_loader(model_path, compute_dtype=torch.bfloat16) + + # HACK(ryand): There are some broken GGUF models in circulation that have the wrong shape for img_in.weight. + # We override the shape here to fix the issue. + # Example model with this issue (Q4_K_M): https://civitai.com/models/705823/ggufk-flux-unchained-km-quants + img_in_weight = sd.get("img_in.weight", None) + if img_in_weight is not None and img_in_weight._ggml_quantization_type in TORCH_COMPATIBLE_QTYPES: + expected_img_in_weight_shape = model.img_in.weight.shape + img_in_weight.quantized_data = img_in_weight.quantized_data.view(expected_img_in_weight_shape) + img_in_weight.tensor_shape = expected_img_in_weight_shape + + model.load_state_dict(sd, assign=True) + return model + + +class Main_Diffusers_Config_Base(Diffusers_Config_Base, Main_Config_Base): + prediction_type: SchedulerPredictionType = Field() + variant: ModelVariantType = Field() + + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: + _validate_is_dir(cls, mod) + + _validate_override_fields(cls, fields) + + _validate_class_name( + cls, + common_config_paths(mod.path), + { + # SD 1.x and 2.x + "StableDiffusionPipeline", + "StableDiffusionInpaintPipeline", + # SDXL + "StableDiffusionXLPipeline", + "StableDiffusionXLInpaintPipeline", + # SDXL Refiner + "StableDiffusionXLImg2ImgPipeline", + # TODO(psyche): Do we actually support LCM models? I don't see using this class anywhere in the codebase. + "LatentConsistencyModelPipeline", + }, + ) + + cls._validate_base(mod) + + variant = fields.get("variant") or cls._get_variant_or_raise(mod) + + prediction_type = fields.get("prediction_type") or cls._get_scheduler_prediction_type_or_raise(mod) + + repo_variant = fields.get("repo_variant") or cls._get_repo_variant_or_raise(mod) + + return cls( + **fields, + variant=variant, + prediction_type=prediction_type, + repo_variant=repo_variant, + ) + + @classmethod + def _validate_base(cls, mod: ModelOnDisk) -> None: + """Raise `NotAMatch` if the model base does not match this config class.""" + expected_base = cls.model_fields["base"].default + recognized_base = cls._get_base_or_raise(mod) + if expected_base is not recognized_base: + raise NotAMatch(cls, f"base is {recognized_base}, not {expected_base}") + + @classmethod + def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType: + # Handle pipelines with a UNet (i.e SD 1.x, SD2.x, SDXL). + unet_config_path = mod.path / "unet" / "config.json" + if unet_config_path.exists(): + with open(unet_config_path) as file: + unet_conf = json.load(file) + cross_attention_dim = unet_conf.get("cross_attention_dim") + match cross_attention_dim: + case 768: + return BaseModelType.StableDiffusion1 + case 1024: + return BaseModelType.StableDiffusion2 + case 1280: + return BaseModelType.StableDiffusionXLRefiner + case 2048: + return BaseModelType.StableDiffusionXL + case _: + raise NotAMatch(cls, f"unrecognized cross_attention_dim {cross_attention_dim}") + + raise NotAMatch(cls, "unable to determine base type") + + @classmethod + def _get_scheduler_prediction_type_or_raise(cls, mod: ModelOnDisk) -> SchedulerPredictionType: + scheduler_conf = _get_config_or_raise(cls, mod.path / "scheduler" / "scheduler_config.json") + + # TODO(psyche): Is epsilon the right default or should we raise if it's not present? + prediction_type = scheduler_conf.get("prediction_type", "epsilon") + + match prediction_type: + case "v_prediction": + return SchedulerPredictionType.VPrediction + case "epsilon": + return SchedulerPredictionType.Epsilon + case _: + raise NotAMatch(cls, f"unrecognized scheduler prediction_type {prediction_type}") + + @classmethod + def _get_variant_or_raise(cls, mod: ModelOnDisk) -> ModelVariantType: + base = cls.model_fields["base"].default + unet_config = _get_config_or_raise(cls, mod.path / "unet" / "config.json") + in_channels = unet_config.get("in_channels") + + match in_channels: + case 4: + return ModelVariantType.Normal + case 5: + # Only SD2 has a depth variant + assert base is BaseModelType.StableDiffusion2, f"unexpected unet in_channels 5 for base '{base}'" + return ModelVariantType.Depth + case 9: + return ModelVariantType.Inpaint + case _: + raise NotAMatch(cls, f"unrecognized unet in_channels {in_channels} for base '{base}'") + + +class Main_Diffusers_SD1_Config(Main_Diffusers_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusion1] = Field(BaseModelType.StableDiffusion1) + + def load_model(self, submodel_type: Optional[SubModelType] = None) -> AnyModel: + if submodel_type is None: + raise ValueError("A submodel type must be provided when loading main pipelines.") + + model_path = Path(self.path) + config = _get_config_or_raise(self.__class__, {model_path / "model_index.json"}) + load_class = _hf_definition_to_type(module=config[submodel_type.value][0], class_name=config[submodel_type.value][1]) + + variant = self.repo_variant.value if self.repo_variant else None + model_path = model_path / submodel_type.value + try: + result: AnyModel = load_class.from_pretrained( + model_path, + torch_dtype=TorchDevice.choose_torch_dtype(), + variant=variant, + ) + except OSError as e: + if variant and "no file named" in str(e): + result = load_class.from_pretrained(model_path, torch_dtype=TorchDevice.choose_torch_dtype()) + else: + raise e + + return result + + +class Main_Diffusers_SD2_Config(Main_Diffusers_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusion2] = Field(BaseModelType.StableDiffusion2) + + def load_model(self, submodel_type: Optional[SubModelType] = None) -> AnyModel: + if submodel_type is None: + raise ValueError("A submodel type must be provided when loading main pipelines.") + + model_path = Path(self.path) + config = _get_config_or_raise(self.__class__, {model_path / "model_index.json"}) + load_class = _hf_definition_to_type(module=config[submodel_type.value][0], class_name=config[submodel_type.value][1]) + + variant = self.repo_variant.value if self.repo_variant else None + model_path = model_path / submodel_type.value + try: + result: AnyModel = load_class.from_pretrained( + model_path, + torch_dtype=TorchDevice.choose_torch_dtype(), + variant=variant, + ) + except OSError as e: + if variant and "no file named" in str(e): + result = load_class.from_pretrained(model_path, torch_dtype=TorchDevice.choose_torch_dtype()) + else: + raise e + + return result + + +class Main_Diffusers_SDXL_Config(Main_Diffusers_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusionXL] = Field(BaseModelType.StableDiffusionXL) + + def load_model(self, submodel_type: Optional[SubModelType] = None) -> AnyModel: + if submodel_type is None: + raise ValueError("A submodel type must be provided when loading main pipelines.") + + model_path = Path(self.path) + config = _get_config_or_raise(self.__class__, {model_path / "model_index.json"}) + load_class = _hf_definition_to_type(module=config[submodel_type.value][0], class_name=config[submodel_type.value][1]) + + variant = self.repo_variant.value if self.repo_variant else None + model_path = model_path / submodel_type.value + try: + result: AnyModel = load_class.from_pretrained( + model_path, + torch_dtype=TorchDevice.choose_torch_dtype(), + variant=variant, + ) + except OSError as e: + if variant and "no file named" in str(e): + result = load_class.from_pretrained(model_path, torch_dtype=TorchDevice.choose_torch_dtype()) + else: + raise e + + return result + +class Main_Diffusers_SDXLRefiner_Config(Main_Diffusers_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusionXLRefiner] = Field(BaseModelType.StableDiffusionXLRefiner) -class CLIPEmbedDiffusersConfig(DiffusersConfigBase): - """Model config for Clip Embeddings.""" + def load_model(self, submodel_type: Optional[SubModelType] = None) -> AnyModel: + if submodel_type is None: + raise ValueError("A submodel type must be provided when loading main pipelines.") - variant: ClipVariantType = Field(description="Clip variant for this model") - type: Literal[ModelType.CLIPEmbed] = ModelType.CLIPEmbed - format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers + model_path = Path(self.path) + config = _get_config_or_raise(self.__class__, {model_path / "model_index.json"}) + load_class = _hf_definition_to_type(module=config[submodel_type.value][0], class_name=config[submodel_type.value][1]) + + variant = self.repo_variant.value if self.repo_variant else None + model_path = model_path / submodel_type.value + try: + result: AnyModel = load_class.from_pretrained( + model_path, + torch_dtype=TorchDevice.choose_torch_dtype(), + variant=variant, + ) + except OSError as e: + if variant and "no file named" in str(e): + result = load_class.from_pretrained(model_path, torch_dtype=TorchDevice.choose_torch_dtype()) + else: + raise e + return result -class CLIPGEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, LegacyProbeMixin, ModelConfigBase): - """Model config for CLIP-G Embeddings.""" - variant: Literal[ClipVariantType.G] = ClipVariantType.G +class Main_Diffusers_SD3_Config(Diffusers_Config_Base, Main_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusion3] = Field(BaseModelType.StableDiffusion3) @classmethod - def get_tag(cls) -> Tag: - return Tag(f"{ModelType.CLIPEmbed.value}.{ModelFormat.Diffusers.value}.{ClipVariantType.G.value}") + def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: + _validate_is_dir(cls, mod) + + _validate_override_fields(cls, fields) + + # This check implies the base type - no further validation needed. + _validate_class_name( + cls, + common_config_paths(mod.path), + { + "StableDiffusion3Pipeline", + "SD3Transformer2DModel", + }, + ) + submodels = fields.get("submodels") or cls._get_submodels_or_raise(mod) -class CLIPLEmbedDiffusersConfig(CLIPEmbedDiffusersConfig, LegacyProbeMixin, ModelConfigBase): - """Model config for CLIP-L Embeddings.""" + repo_variant = fields.get("repo_variant") or cls._get_repo_variant_or_raise(mod) - variant: Literal[ClipVariantType.L] = ClipVariantType.L + return cls( + **fields, + submodels=submodels, + repo_variant=repo_variant, + ) @classmethod - def get_tag(cls) -> Tag: - return Tag(f"{ModelType.CLIPEmbed.value}.{ModelFormat.Diffusers.value}.{ClipVariantType.L.value}") + def _get_submodels_or_raise(cls, mod: ModelOnDisk) -> dict[SubModelType, SubmodelDefinition]: + # Example: https://huggingface.co/stabilityai/stable-diffusion-3.5-medium/blob/main/model_index.json + config = _get_config_or_raise(cls, common_config_paths(mod.path)) + + submodels: dict[SubModelType, SubmodelDefinition] = {} + + for key, value in config.items(): + # Anything that starts with an underscore is top-level metadata, not a submodel + if key.startswith("_") or not (isinstance(value, list) and len(value) == 2): + continue + # The key is something like "transformer" and is a submodel - it will be in a dir of the same name. + # The value value is something like ["diffusers", "SD3Transformer2DModel"] + _library_name, class_name = value + + match class_name: + case "CLIPTextModelWithProjection": + model_type = ModelType.CLIPEmbed + path_or_prefix = (mod.path / key).resolve().as_posix() + + # We need to read the config to determine the variant of the CLIP model. + clip_embed_config = _get_config_or_raise( + cls, {mod.path / key / "config.json", mod.path / key / "model_index.json"} + ) + variant = _get_clip_variant_type_from_config(clip_embed_config) + submodels[SubModelType(key)] = SubmodelDefinition( + path_or_prefix=path_or_prefix, + model_type=model_type, + variant=variant, + ) + case "SD3Transformer2DModel": + model_type = ModelType.Main + path_or_prefix = (mod.path / key).resolve().as_posix() + variant = None + submodels[SubModelType(key)] = SubmodelDefinition( + path_or_prefix=path_or_prefix, + model_type=model_type, + variant=variant, + ) + case _: + pass + + return submodels + + def load_model(self, submodel_type: Optional[SubModelType] = None) -> AnyModel: + if submodel_type is None: + raise ValueError("A submodel type must be provided when loading main pipelines.") + + model_path = Path(self.path) + config = _get_config_or_raise(self.__class__, {model_path / "model_index.json"}) + load_class = _hf_definition_to_type(module=config[submodel_type.value][0], class_name=config[submodel_type.value][1]) + + variant = self.repo_variant.value if self.repo_variant else None + model_path = model_path / submodel_type.value + try: + result: AnyModel = load_class.from_pretrained( + model_path, + torch_dtype=TorchDevice.choose_torch_dtype(), + variant=variant, + ) + except OSError as e: + if variant and "no file named" in str(e): + result = load_class.from_pretrained(model_path, torch_dtype=TorchDevice.choose_torch_dtype()) + else: + raise e + + return result + + +class Main_Diffusers_CogView4_Config(Diffusers_Config_Base, Main_Config_Base, Config_Base): + base: Literal[BaseModelType.CogView4] = Field(BaseModelType.CogView4) + + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: + _validate_is_dir(cls, mod) + + _validate_override_fields(cls, fields) + + # This check implies the base type - no further validation needed. + _validate_class_name( + cls, + common_config_paths(mod.path), + { + "CogView4Pipeline", + }, + ) + + repo_variant = fields.get("repo_variant") or cls._get_repo_variant_or_raise(mod) + + return cls( + **fields, + repo_variant=repo_variant, + ) + + def load_model(self, submodel_type: Optional[SubModelType] = None) -> AnyModel: + if submodel_type is None: + raise ValueError("A submodel type must be provided when loading main pipelines.") + + model_path = Path(self.path) + config = _get_config_or_raise(self.__class__, {model_path / "model_index.json"}) + load_class = _hf_definition_to_type(module=config[submodel_type.value][0], class_name=config[submodel_type.value][1]) + + variant = self.repo_variant.value if self.repo_variant else None + model_path = model_path / submodel_type.value + + # We force bfloat16 for CogView4 models. It produces black images with float16. + dtype = torch.bfloat16 + try: + result: AnyModel = load_class.from_pretrained( + model_path, + torch_dtype=dtype, + variant=variant, + ) + except OSError as e: + if variant and "no file named" in str(e): + result = load_class.from_pretrained(model_path, torch_dtype=dtype) + else: + raise e + + return result + + +class IPAdapter_Config_Base(ABC, BaseModel): + type: Literal[ModelType.IPAdapter] = Field(default=ModelType.IPAdapter) + + +class IPAdapter_InvokeAI_Config_Base(IPAdapter_Config_Base): + """Model config for IP Adapter diffusers format models.""" + + format: Literal[ModelFormat.InvokeAI] = Field(default=ModelFormat.InvokeAI) + + # TODO(ryand): Should we deprecate this field? From what I can tell, it hasn't been probed correctly for a long + # time. Need to go through the history to make sure I'm understanding this fully. + image_encoder_model_id: str = Field() + + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: + _validate_is_dir(cls, mod) + + _validate_override_fields(cls, fields) + + cls._validate_has_weights_file(mod) + + cls._validate_has_image_encoder_metadata_file(mod) + + cls._validate_base(mod) + + return cls(**fields) + + @classmethod + def _validate_base(cls, mod: ModelOnDisk) -> None: + """Raise `NotAMatch` if the model base does not match this config class.""" + expected_base = cls.model_fields["base"].default + recognized_base = cls._get_base_or_raise(mod) + if expected_base is not recognized_base: + raise NotAMatch(cls, f"base is {recognized_base}, not {expected_base}") + + @classmethod + def _validate_has_weights_file(cls, mod: ModelOnDisk) -> None: + weights_file = mod.path / "ip_adapter.bin" + if not weights_file.exists(): + raise NotAMatch(cls, "missing ip_adapter.bin weights file") + + @classmethod + def _validate_has_image_encoder_metadata_file(cls, mod: ModelOnDisk) -> None: + image_encoder_metadata_file = mod.path / "image_encoder.txt" + if not image_encoder_metadata_file.exists(): + raise NotAMatch(cls, "missing image_encoder.txt metadata file") + + @classmethod + def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType: + state_dict = mod.load_state_dict() + + try: + cross_attention_dim = state_dict["ip_adapter"]["1.to_k_ip.weight"].shape[-1] + except Exception as e: + raise NotAMatch(cls, f"unable to determine cross attention dimension: {e}") from e + + match cross_attention_dim: + case 1280: + return BaseModelType.StableDiffusionXL + case 768: + return BaseModelType.StableDiffusion1 + case 1024: + return BaseModelType.StableDiffusion2 + case _: + raise NotAMatch(cls, f"unrecognized cross attention dimension {cross_attention_dim}") + + +class IPAdapter_InvokeAI_SD1_Config(IPAdapter_InvokeAI_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1) + + def load_model(self, submodel_type: Optional[SubModelType] = None) -> AnyModel: + if submodel_type is not None: + raise ValueError("IP-Adapter models have no submodels.") + + model_path = Path(self.path) + model = build_ip_adapter( + ip_adapter_ckpt_path=model_path, + device=torch.device("cpu"), + dtype=TorchDevice.choose_torch_dtype(), + ) + return model + + +class IPAdapter_InvokeAI_SD2_Config(IPAdapter_InvokeAI_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2) + + def load_model(self, submodel_type: Optional[SubModelType] = None) -> AnyModel: + if submodel_type is not None: + raise ValueError("IP-Adapter models have no submodels.") + + model_path = Path(self.path) + model = build_ip_adapter( + ip_adapter_ckpt_path=model_path, + device=torch.device("cpu"), + dtype=TorchDevice.choose_torch_dtype(), + ) + return model + + +class IPAdapter_InvokeAI_SDXL_Config(IPAdapter_InvokeAI_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL) + + def load_model(self, submodel_type: Optional[SubModelType] = None) -> AnyModel: + if submodel_type is not None: + raise ValueError("IP-Adapter models have no submodels.") + + model_path = Path(self.path) + model = build_ip_adapter( + ip_adapter_ckpt_path=model_path, + device=torch.device("cpu"), + dtype=TorchDevice.choose_torch_dtype(), + ) + return model + + +class IPAdapter_Checkpoint_Config_Base(IPAdapter_Config_Base): + """Model config for IP Adapter checkpoint format models.""" + + format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint) + + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: + _validate_is_file(cls, mod) + + _validate_override_fields(cls, fields) + + cls._validate_looks_like_ip_adapter(mod) + + cls._validate_base(mod) + + return cls(**fields) + + @classmethod + def _validate_base(cls, mod: ModelOnDisk) -> None: + """Raise `NotAMatch` if the model base does not match this config class.""" + expected_base = cls.model_fields["base"].default + recognized_base = cls._get_base_or_raise(mod) + if expected_base is not recognized_base: + raise NotAMatch(cls, f"base is {recognized_base}, not {expected_base}") + + @classmethod + def _validate_looks_like_ip_adapter(cls, mod: ModelOnDisk) -> None: + if not has_any_keys_starting_with( + mod.load_state_dict(), + { + "image_proj.", + "ip_adapter.", + # XLabs FLUX IP-Adapter models have keys startinh with "ip_adapter_proj_model.". + "ip_adapter_proj_model.", + }, + ): + raise NotAMatch(cls, "model does not match Checkpoint IP Adapter heuristics") + + @classmethod + def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType: + state_dict = mod.load_state_dict() + + if is_state_dict_xlabs_ip_adapter(state_dict): + return BaseModelType.Flux + + try: + cross_attention_dim = state_dict["ip_adapter.1.to_k_ip.weight"].shape[-1] + except Exception as e: + raise NotAMatch(cls, f"unable to determine cross attention dimension: {e}") from e + + match cross_attention_dim: + case 1280: + return BaseModelType.StableDiffusionXL + case 768: + return BaseModelType.StableDiffusion1 + case 1024: + return BaseModelType.StableDiffusion2 + case _: + raise NotAMatch(cls, f"unrecognized cross attention dimension {cross_attention_dim}") + + +class IPAdapter_Checkpoint_SD1_Config(IPAdapter_Checkpoint_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1) + + def load_model(self, submodel_type: Optional[SubModelType] = None) -> AnyModel: + if submodel_type is not None: + raise ValueError("IP-Adapter models have no submodels.") + + model_path = Path(self.path) + model = build_ip_adapter( + ip_adapter_ckpt_path=model_path, + device=torch.device("cpu"), + dtype=TorchDevice.choose_torch_dtype(), + ) + return model + + +class IPAdapter_Checkpoint_SD2_Config(IPAdapter_Checkpoint_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusion2] = Field(default=BaseModelType.StableDiffusion2) + + def load_model(self, submodel_type: Optional[SubModelType] = None) -> AnyModel: + if submodel_type is not None: + raise ValueError("IP-Adapter models have no submodels.") + + model_path = Path(self.path) + model = build_ip_adapter( + ip_adapter_ckpt_path=model_path, + device=torch.device("cpu"), + dtype=TorchDevice.choose_torch_dtype(), + ) + return model + + +class IPAdapter_Checkpoint_SDXL_Config(IPAdapter_Checkpoint_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL) + + def load_model(self, submodel_type: Optional[SubModelType] = None) -> AnyModel: + if submodel_type is not None: + raise ValueError("IP-Adapter models have no submodels.") + + model_path = Path(self.path) + model = build_ip_adapter( + ip_adapter_ckpt_path=model_path, + device=torch.device("cpu"), + dtype=TorchDevice.choose_torch_dtype(), + ) + return model + + +class IPAdapter_Checkpoint_FLUX_Config(IPAdapter_Checkpoint_Config_Base, Config_Base): + base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux) + + def load_model(self, submodel_type: Optional[SubModelType] = None) -> AnyModel: + if submodel_type is not None: + raise ValueError("FLUX IP-Adapter models have no submodels.") + + model_path = Path(self.path) + sd = load_file(model_path) + + params = infer_xlabs_ip_adapter_params_from_state_dict(sd) + + with accelerate.init_empty_weights(): + model = XlabsIpAdapterFlux(params=params) + + model.load_xlabs_state_dict(sd, assign=True) + return model + + +def _get_clip_variant_type_from_config(config: dict[str, Any]) -> ClipVariantType | None: + try: + hidden_size = config.get("hidden_size") + match hidden_size: + case 1280: + return ClipVariantType.G + case 768: + return ClipVariantType.L + case _: + return None + except Exception: + return None + + +class CLIPEmbed_Diffusers_Config_Base(Diffusers_Config_Base): + base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any) + type: Literal[ModelType.CLIPEmbed] = Field(default=ModelType.CLIPEmbed) + format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers) + + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: + _validate_is_dir(cls, mod) + + _validate_override_fields(cls, fields) + + _validate_class_name( + cls, + { + mod.path / "config.json", + mod.path / "text_encoder" / "config.json", + }, + { + "CLIPModel", + "CLIPTextModel", + "CLIPTextModelWithProjection", + }, + ) + + cls._validate_variant(mod) + + return cls(**fields) + + @classmethod + def _validate_variant(cls, mod: ModelOnDisk) -> None: + """Raise `NotAMatch` if the model variant does not match this config class.""" + expected_variant = cls.model_fields["variant"].default + config = _get_config_or_raise( + cls, + { + mod.path / "config.json", + mod.path / "text_encoder" / "config.json", + }, + ) + recognized_variant = _get_clip_variant_type_from_config(config) + + if recognized_variant is None: + raise NotAMatch(cls, "unable to determine CLIP variant from config") + if expected_variant is not recognized_variant: + raise NotAMatch(cls, f"variant is {recognized_variant}, not {expected_variant}") -class CLIPVisionDiffusersConfig(DiffusersConfigBase, LegacyProbeMixin, ModelConfigBase): + +class CLIPEmbed_Diffusers_G_Config(CLIPEmbed_Diffusers_Config_Base, Config_Base): + variant: Literal[ClipVariantType.G] = Field(default=ClipVariantType.G) + + def load_model(self, submodel_type: Optional[SubModelType] = None) -> AnyModel: + match submodel_type: + case SubModelType.Tokenizer: + return CLIPTokenizer.from_pretrained(Path(self.path) / "tokenizer") + case SubModelType.TextEncoder: + return CLIPTextModel.from_pretrained(Path(self.path) / "text_encoder") + case _: + raise ValueError( + f"Only Tokenizer and TextEncoder submodels are supported. Received: {submodel_type.value if submodel_type else 'None'}" + ) + + +class CLIPEmbed_Diffusers_L_Config(CLIPEmbed_Diffusers_Config_Base, Config_Base): + variant: Literal[ClipVariantType.L] = Field(default=ClipVariantType.L) + + def load_model(self, submodel_type: Optional[SubModelType] = None) -> AnyModel: + match submodel_type: + case SubModelType.Tokenizer: + return CLIPTokenizer.from_pretrained(Path(self.path) / "tokenizer") + case SubModelType.TextEncoder: + return CLIPTextModel.from_pretrained(Path(self.path) / "text_encoder") + case _: + raise ValueError( + f"Only Tokenizer and TextEncoder submodels are supported. Received: {submodel_type.value if submodel_type else 'None'}" + ) + + +class CLIPVision_Diffusers_Config(Diffusers_Config_Base, Config_Base): """Model config for CLIPVision.""" - type: Literal[ModelType.CLIPVision] = ModelType.CLIPVision - format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers + base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any) + type: Literal[ModelType.CLIPVision] = Field(default=ModelType.CLIPVision) + format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers) + + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: + _validate_is_dir(cls, mod) + + _validate_override_fields(cls, fields) + + _validate_class_name( + cls, + common_config_paths(mod.path), + { + "CLIPVisionModelWithProjection", + }, + ) + + return cls(**fields) + + def load_model(self, submodel_type: Optional[SubModelType] = None) -> AnyModel: + if submodel_type is not None: + raise ValueError("CLIPVision models have no submodels.") + + model_path = Path(self.path) + model = CLIPVisionModelWithProjection.from_pretrained( + model_path, torch_dtype="auto", local_files_only=True + ) + assert isinstance(model, CLIPVisionModelWithProjection) + + return model -class T2IAdapterConfig(DiffusersConfigBase, ControlAdapterConfigBase, LegacyProbeMixin, ModelConfigBase): + +class T2IAdapter_Diffusers_Config_Base(Diffusers_Config_Base, ControlAdapter_Config_Base): """Model config for T2I.""" - type: Literal[ModelType.T2IAdapter] = ModelType.T2IAdapter - format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers + type: Literal[ModelType.T2IAdapter] = Field(default=ModelType.T2IAdapter) + format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers) + + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: + _validate_is_dir(cls, mod) + + _validate_override_fields(cls, fields) + + _validate_class_name( + cls, + common_config_paths(mod.path), + { + "T2IAdapter", + }, + ) + + cls._validate_base(mod) + + return cls(**fields) + + @classmethod + def _validate_base(cls, mod: ModelOnDisk) -> None: + """Raise `NotAMatch` if the model base does not match this config class.""" + expected_base = cls.model_fields["base"].default + recognized_base = cls._get_base_or_raise(mod) + if expected_base is not recognized_base: + raise NotAMatch(cls, f"base is {recognized_base}, not {expected_base}") + + @classmethod + def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType: + config = _get_config_or_raise(cls, common_config_paths(mod.path)) + + adapter_type = config.get("adapter_type") + match adapter_type: + case "full_adapter_xl": + return BaseModelType.StableDiffusionXL + case "full_adapter" | "light_adapter": + return BaseModelType.StableDiffusion1 + case _: + raise NotAMatch(cls, f"unrecognized adapter_type '{adapter_type}'") -class SpandrelImageToImageConfig(LegacyProbeMixin, ModelConfigBase): + +class T2IAdapter_Diffusers_SD1_Config(T2IAdapter_Diffusers_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusion1] = Field(default=BaseModelType.StableDiffusion1) + + def load_model(self, submodel_type: Optional[SubModelType] = None) -> AnyModel: + if submodel_type is not None: + raise ValueError("T2IAdapter models have no submodels.") + + model_path = Path(self.path) + config = _get_config_or_raise(type(self), model_path / "config.json") + model_class = _get_hf_load_class_from_config(config) + + variant = self.repo_variant.value if self.repo_variant else None + try: + model: AnyModel = model_class.from_pretrained( + model_path, + torch_dtype=TorchDevice.choose_torch_dtype(), + variant=variant, + ) + except OSError as e: + if variant and "no file named" in str(e): + # try without the variant, just in case user's preferences changed + model = model_class.from_pretrained( + model_path, + torch_dtype=TorchDevice.choose_torch_dtype(), + ) + else: + raise e + return model + + +class T2IAdapter_Diffusers_SDXL_Config(T2IAdapter_Diffusers_Config_Base, Config_Base): + base: Literal[BaseModelType.StableDiffusionXL] = Field(default=BaseModelType.StableDiffusionXL) + + def load_model(self, submodel_type: Optional[SubModelType] = None) -> AnyModel: + if submodel_type is not None: + raise ValueError("T2IAdapter models have no submodels.") + + model_path = Path(self.path) + config = _get_config_or_raise(type(self), model_path / "config.json") + model_class = _get_hf_load_class_from_config(config) + + variant = self.repo_variant.value if self.repo_variant else None + try: + model: AnyModel = model_class.from_pretrained( + model_path, + torch_dtype=TorchDevice.choose_torch_dtype(), + variant=variant, + ) + except OSError as e: + if variant and "no file named" in str(e): + # try without the variant, just in case user's preferences changed + model = model_class.from_pretrained( + model_path, + torch_dtype=TorchDevice.choose_torch_dtype(), + ) + else: + raise e + return model + + +class Spandrel_Checkpoint_Config(Config_Base): """Model config for Spandrel Image to Image models.""" - _MATCH_SPEED: ClassVar[MatchSpeed] = MatchSpeed.SLOW # requires loading the model from disk + base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any) + type: Literal[ModelType.SpandrelImageToImage] = Field(default=ModelType.SpandrelImageToImage) + format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint) + + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: + _validate_is_file(cls, mod) + + _validate_override_fields(cls, fields) + + cls._validate_spandrel_loads_model(mod) + + return cls(**fields) + + def load_model(self, submodel_type: Optional[SubModelType] = None) -> AnyModel: + if submodel_type is not None: + raise ValueError("Unexpected submodel requested for Spandrel model.") - type: Literal[ModelType.SpandrelImageToImage] = ModelType.SpandrelImageToImage - format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint + model_path = Path(self.path) + model = SpandrelImageToImageModel.load_from_file(model_path) + torch_dtype = TorchDevice.choose_torch_dtype() + if not model.supports_dtype(torch_dtype): + logger.warning( + f"The configured dtype ('{torch_dtype}') is not supported by the {model.get_model_type_name()} " + "model. Falling back to 'float32'." + ) + torch_dtype = torch.float32 + model.to(dtype=torch_dtype) + + return model -class SigLIPConfig(DiffusersConfigBase, LegacyProbeMixin, ModelConfigBase): + @classmethod + def _validate_spandrel_loads_model(cls, mod: ModelOnDisk) -> None: + try: + # It would be nice to avoid having to load the Spandrel model from disk here. A couple of options were + # explored to avoid this: + # 1. Call `SpandrelImageToImageModel.load_from_state_dict(ckpt)`, where `ckpt` is a state_dict on the meta + # device. Unfortunately, some Spandrel models perform operations during initialization that are not + # supported on meta tensors. + # 2. Spandrel has internal logic to determine a model's type from its state_dict before loading the model. + # This logic is not exposed in spandrel's public API. We could copy the logic here, but then we have to + # maintain it, and the risk of false positive detections is higher. + SpandrelImageToImageModel.load_from_file(mod.path) + except Exception as e: + raise NotAMatch(cls, "model does not match SpandrelImageToImage heuristics") from e + + +class SigLIP_Diffusers_Config(Diffusers_Config_Base, Config_Base): """Model config for SigLIP.""" - type: Literal[ModelType.SigLIP] = ModelType.SigLIP - format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers + type: Literal[ModelType.SigLIP] = Field(default=ModelType.SigLIP) + format: Literal[ModelFormat.Diffusers] = Field(default=ModelFormat.Diffusers) + base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any) + + def load_model(self, submodel_type: Optional[SubModelType] = None) -> AnyModel: + if submodel_type is not None: + raise ValueError("SigLIP models have no submodels.") + + model_path = Path(self.path) + model = SiglipVisionModel.from_pretrained(model_path, local_files_only=True, torch_dtype="auto") + assert isinstance(model, SiglipVisionModel) + return model + + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: + _validate_is_dir(cls, mod) + + _validate_override_fields(cls, fields) + + _validate_class_name( + cls, + common_config_paths(mod.path), + { + "SiglipModel", + }, + ) + + return cls(**fields) -class FluxReduxConfig(LegacyProbeMixin, ModelConfigBase): +class FLUXRedux_Checkpoint_Config(Config_Base): """Model config for FLUX Tools Redux model.""" - type: Literal[ModelType.FluxRedux] = ModelType.FluxRedux - format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint + type: Literal[ModelType.FluxRedux] = Field(default=ModelType.FluxRedux) + format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint) + base: Literal[BaseModelType.Flux] = Field(default=BaseModelType.Flux) + def load_model(self, submodel_type: Optional[SubModelType] = None) -> AnyModel: + if submodel_type is not None: + raise ValueError("FLUX Redux models have no submodels.") -class LlavaOnevisionConfig(DiffusersConfigBase, ModelConfigBase): - """Model config for Llava Onevision models.""" + model_path = Path(self.path) + sd = load_file(model_path) - type: Literal[ModelType.LlavaOnevision] = ModelType.LlavaOnevision - format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers + with accelerate.init_empty_weights(): + model = FluxReduxModel() + + model.load_state_dict(sd, assign=True) + model.to(dtype=torch.bfloat16) + return model @classmethod - def matches(cls, mod: ModelOnDisk) -> bool: - if mod.path.is_file(): - return False + def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: + _validate_is_file(cls, mod) - config_path = mod.path / "config.json" - try: - with open(config_path, "r") as file: - config = json.load(file) - except FileNotFoundError: - return False + _validate_override_fields(cls, fields) + + if not is_state_dict_likely_flux_redux(mod.load_state_dict()): + raise NotAMatch(cls, "model does not match FLUX Tools Redux heuristics") + + return cls(**fields) + + +class LlavaOnevision_Diffusers_Config(Diffusers_Config_Base, Config_Base): + """Model config for Llava Onevision models.""" + + type: Literal[ModelType.LlavaOnevision] = Field(default=ModelType.LlavaOnevision) + base: Literal[BaseModelType.Any] = Field(default=BaseModelType.Any) + variant: Literal[ModelVariantType.Normal] = Field(default=ModelVariantType.Normal) - architectures = config.get("architectures") - return architectures and architectures[0] == "LlavaOnevisionForConditionalGeneration" + def load_model(self, submodel_type: Optional[SubModelType] = None) -> AnyModel: + if submodel_type is not None: + raise ValueError("LlavaOnevision models have no submodels.") + + model_path = Path(self.path) + model = LlavaOnevisionForConditionalGeneration.from_pretrained( + model_path, local_files_only=True, torch_dtype="auto" + ) + assert isinstance(model, LlavaOnevisionForConditionalGeneration) + return model @classmethod - def parse(cls, mod: ModelOnDisk) -> dict[str, Any]: - return { - "base": BaseModelType.Any, - "variant": ModelVariantType.Normal, - } + def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: + _validate_is_dir(cls, mod) + + _validate_override_fields(cls, fields) + + _validate_class_name( + cls, + common_config_paths(mod.path), + { + "LlavaOnevisionForConditionalGeneration", + }, + ) + + return cls(**fields) -class ApiModelConfig(MainConfigBase, ModelConfigBase): +class ExternalAPI_Config_Base(ABC, BaseModel): """Model config for API-based models.""" - format: Literal[ModelFormat.Api] = ModelFormat.Api + format: Literal[ModelFormat.Api] = Field(default=ModelFormat.Api) @classmethod - def matches(cls, mod: ModelOnDisk) -> bool: - # API models are not stored on disk, so we can't match them. - return False + def from_model_on_disk(cls, mod: ModelOnDisk, fields: dict[str, Any]) -> Self: + raise NotAMatch(cls, "External API models cannot be built from disk") - @classmethod - def parse(cls, mod: ModelOnDisk) -> dict[str, Any]: - raise NotImplementedError("API models are not parsed from disk.") +class ExternalAPI_ChatGPT4o_Config(ExternalAPI_Config_Base, Main_Config_Base, Config_Base): + base: Literal[BaseModelType.ChatGPT4o] = Field(default=BaseModelType.ChatGPT4o) -class VideoApiModelConfig(VideoConfigBase, ModelConfigBase): - """Model config for API-based video models.""" - format: Literal[ModelFormat.Api] = ModelFormat.Api +class ExternalAPI_Gemini2_5_Config(ExternalAPI_Config_Base, Main_Config_Base, Config_Base): + base: Literal[BaseModelType.Gemini2_5] = Field(default=BaseModelType.Gemini2_5) - @classmethod - def matches(cls, mod: ModelOnDisk) -> bool: - # API models are not stored on disk, so we can't match them. - return False - @classmethod - def parse(cls, mod: ModelOnDisk) -> dict[str, Any]: - raise NotImplementedError("API models are not parsed from disk.") +class ExternalAPI_Imagen3_Config(ExternalAPI_Config_Base, Main_Config_Base, Config_Base): + base: Literal[BaseModelType.Imagen3] = Field(default=BaseModelType.Imagen3) -def get_model_discriminator_value(v: Any) -> str: - """ - Computes the discriminator value for a model config. - https://docs.pydantic.dev/latest/concepts/unions/#discriminated-unions-with-callable-discriminator - """ - format_ = type_ = variant_ = None +class ExternalAPI_Imagen4_Config(ExternalAPI_Config_Base, Main_Config_Base, Config_Base): + base: Literal[BaseModelType.Imagen4] = Field(default=BaseModelType.Imagen4) - if isinstance(v, dict): - format_ = v.get("format") - if isinstance(format_, Enum): - format_ = format_.value - type_ = v.get("type") - if isinstance(type_, Enum): - type_ = type_.value +class ExternalAPI_FluxKontext_Config(ExternalAPI_Config_Base, Main_Config_Base, Config_Base): + base: Literal[BaseModelType.FluxKontext] = Field(default=BaseModelType.FluxKontext) + + +class VideoConfigBase(ABC, BaseModel): + type: Literal[ModelType.Video] = Field(default=ModelType.Video) + trigger_phrases: set[str] | None = Field(description="Set of trigger phrases for this model", default=None) + default_settings: MainModelDefaultSettings | None = Field( + description="Default settings for this model", default=None + ) - variant_ = v.get("variant") - if isinstance(variant_, Enum): - variant_ = variant_.value - else: - format_ = v.format.value - type_ = v.type.value - variant_ = getattr(v, "variant", None) - if variant_: - variant_ = variant_.value - # Ideally, each config would be uniquely identified with a combination of fields - # i.e. (type, format, variant) without any special cases. Alas... +class ExternalAPI_Veo3_Config(ExternalAPI_Config_Base, VideoConfigBase, Config_Base): + base: Literal[BaseModelType.FluxKontext] = Field(default=BaseModelType.FluxKontext) - # Previously, CLIPEmbed did not have any variants, meaning older database entries lack a variant field. - # To maintain compatibility, we default to ClipVariantType.L in this case. - if type_ == ModelType.CLIPEmbed.value and format_ == ModelFormat.Diffusers.value: - variant_ = variant_ or ClipVariantType.L.value - return f"{type_}.{format_}.{variant_}" - return f"{type_}.{format_}" + +class ExternalAPI_Runway_Config(ExternalAPI_Config_Base, VideoConfigBase, Config_Base): + base: Literal[BaseModelType.FluxKontext] = Field(default=BaseModelType.FluxKontext) # The types are listed explicitly because IDEs/LSPs can't identify the correct types # when AnyModelConfig is constructed dynamically using ModelConfigBase.all_config_classes AnyModelConfig = Annotated[ Union[ - Annotated[MainDiffusersConfig, MainDiffusersConfig.get_tag()], - Annotated[MainCheckpointConfig, MainCheckpointConfig.get_tag()], - Annotated[MainBnbQuantized4bCheckpointConfig, MainBnbQuantized4bCheckpointConfig.get_tag()], - Annotated[MainGGUFCheckpointConfig, MainGGUFCheckpointConfig.get_tag()], - Annotated[VAEDiffusersConfig, VAEDiffusersConfig.get_tag()], - Annotated[VAECheckpointConfig, VAECheckpointConfig.get_tag()], - Annotated[ControlNetDiffusersConfig, ControlNetDiffusersConfig.get_tag()], - Annotated[ControlNetCheckpointConfig, ControlNetCheckpointConfig.get_tag()], - Annotated[LoRALyCORISConfig, LoRALyCORISConfig.get_tag()], - Annotated[LoRAOmiConfig, LoRAOmiConfig.get_tag()], - Annotated[ControlLoRALyCORISConfig, ControlLoRALyCORISConfig.get_tag()], - Annotated[ControlLoRADiffusersConfig, ControlLoRADiffusersConfig.get_tag()], - Annotated[LoRADiffusersConfig, LoRADiffusersConfig.get_tag()], - Annotated[T5EncoderConfig, T5EncoderConfig.get_tag()], - Annotated[T5EncoderBnbQuantizedLlmInt8bConfig, T5EncoderBnbQuantizedLlmInt8bConfig.get_tag()], - Annotated[TextualInversionFileConfig, TextualInversionFileConfig.get_tag()], - Annotated[TextualInversionFolderConfig, TextualInversionFolderConfig.get_tag()], - Annotated[IPAdapterInvokeAIConfig, IPAdapterInvokeAIConfig.get_tag()], - Annotated[IPAdapterCheckpointConfig, IPAdapterCheckpointConfig.get_tag()], - Annotated[T2IAdapterConfig, T2IAdapterConfig.get_tag()], - Annotated[SpandrelImageToImageConfig, SpandrelImageToImageConfig.get_tag()], - Annotated[CLIPVisionDiffusersConfig, CLIPVisionDiffusersConfig.get_tag()], - Annotated[CLIPLEmbedDiffusersConfig, CLIPLEmbedDiffusersConfig.get_tag()], - Annotated[CLIPGEmbedDiffusersConfig, CLIPGEmbedDiffusersConfig.get_tag()], - Annotated[SigLIPConfig, SigLIPConfig.get_tag()], - Annotated[FluxReduxConfig, FluxReduxConfig.get_tag()], - Annotated[LlavaOnevisionConfig, LlavaOnevisionConfig.get_tag()], - Annotated[ApiModelConfig, ApiModelConfig.get_tag()], - Annotated[VideoApiModelConfig, VideoApiModelConfig.get_tag()], - Annotated[UnknownModelConfig, UnknownModelConfig.get_tag()], + # Main (Pipeline) - diffusers format + Annotated[Main_Diffusers_SD1_Config, Main_Diffusers_SD1_Config.get_tag()], + Annotated[Main_Diffusers_SD2_Config, Main_Diffusers_SD2_Config.get_tag()], + Annotated[Main_Diffusers_SDXL_Config, Main_Diffusers_SDXL_Config.get_tag()], + Annotated[Main_Diffusers_SDXLRefiner_Config, Main_Diffusers_SDXLRefiner_Config.get_tag()], + Annotated[Main_Diffusers_SD3_Config, Main_Diffusers_SD3_Config.get_tag()], + Annotated[Main_Diffusers_CogView4_Config, Main_Diffusers_CogView4_Config.get_tag()], + # Main (Pipeline) - checkpoint format + Annotated[Main_Checkpoint_SD1_Config, Main_Checkpoint_SD1_Config.get_tag()], + Annotated[Main_Checkpoint_SD2_Config, Main_Checkpoint_SD2_Config.get_tag()], + Annotated[Main_Checkpoint_SDXL_Config, Main_Checkpoint_SDXL_Config.get_tag()], + Annotated[Main_Checkpoint_SDXLRefiner_Config, Main_Checkpoint_SDXLRefiner_Config.get_tag()], + Annotated[Main_Checkpoint_FLUX_Config, Main_Checkpoint_FLUX_Config.get_tag()], + # Main (Pipeline) - quantized formats + Annotated[Main_BnBNF4_FLUX_Config, Main_BnBNF4_FLUX_Config.get_tag()], + Annotated[Main_GGUF_FLUX_Config, Main_GGUF_FLUX_Config.get_tag()], + # VAE - checkpoint format + Annotated[VAE_Checkpoint_SD1_Config, VAE_Checkpoint_SD1_Config.get_tag()], + Annotated[VAE_Checkpoint_SD2_Config, VAE_Checkpoint_SD2_Config.get_tag()], + Annotated[VAE_Checkpoint_SDXL_Config, VAE_Checkpoint_SDXL_Config.get_tag()], + Annotated[VAE_Checkpoint_FLUX_Config, VAE_Checkpoint_FLUX_Config.get_tag()], + # VAE - diffusers format + Annotated[VAE_Diffusers_SD1_Config, VAE_Diffusers_SD1_Config.get_tag()], + Annotated[VAE_Diffusers_SDXL_Config, VAE_Diffusers_SDXL_Config.get_tag()], + # ControlNet - checkpoint format + Annotated[ControlNet_Checkpoint_SD1_Config, ControlNet_Checkpoint_SD1_Config.get_tag()], + Annotated[ControlNet_Checkpoint_SD2_Config, ControlNet_Checkpoint_SD2_Config.get_tag()], + Annotated[ControlNet_Checkpoint_SDXL_Config, ControlNet_Checkpoint_SDXL_Config.get_tag()], + Annotated[ControlNet_Checkpoint_FLUX_Config, ControlNet_Checkpoint_FLUX_Config.get_tag()], + # ControlNet - diffusers format + Annotated[ControlNet_Diffusers_SD1_Config, ControlNet_Diffusers_SD1_Config.get_tag()], + Annotated[ControlNet_Diffusers_SD2_Config, ControlNet_Diffusers_SD2_Config.get_tag()], + Annotated[ControlNet_Diffusers_SDXL_Config, ControlNet_Diffusers_SDXL_Config.get_tag()], + Annotated[ControlNet_Diffusers_FLUX_Config, ControlNet_Diffusers_FLUX_Config.get_tag()], + # LoRA - LyCORIS format + Annotated[LoRA_LyCORIS_SD1_Config, LoRA_LyCORIS_SD1_Config.get_tag()], + Annotated[LoRA_LyCORIS_SD2_Config, LoRA_LyCORIS_SD2_Config.get_tag()], + Annotated[LoRA_LyCORIS_SDXL_Config, LoRA_LyCORIS_SDXL_Config.get_tag()], + Annotated[LoRA_LyCORIS_FLUX_Config, LoRA_LyCORIS_FLUX_Config.get_tag()], + # LoRA - OMI format + Annotated[LoRA_OMI_SDXL_Config, LoRA_OMI_SDXL_Config.get_tag()], + Annotated[LoRA_OMI_FLUX_Config, LoRA_OMI_FLUX_Config.get_tag()], + # LoRA - diffusers format + Annotated[LoRA_Diffusers_SD1_Config, LoRA_Diffusers_SD1_Config.get_tag()], + Annotated[LoRA_Diffusers_SD2_Config, LoRA_Diffusers_SD2_Config.get_tag()], + Annotated[LoRA_Diffusers_SDXL_Config, LoRA_Diffusers_SDXL_Config.get_tag()], + Annotated[LoRA_Diffusers_FLUX_Config, LoRA_Diffusers_FLUX_Config.get_tag()], + # ControlLoRA - diffusers format + Annotated[ControlLoRA_LyCORIS_FLUX_Config, ControlLoRA_LyCORIS_FLUX_Config.get_tag()], + # T5 Encoder - all formats + Annotated[T5Encoder_T5Encoder_Config, T5Encoder_T5Encoder_Config.get_tag()], + Annotated[T5Encoder_BnBLLMint8_Config, T5Encoder_BnBLLMint8_Config.get_tag()], + # TI - file format + Annotated[TI_File_SD1_Config, TI_File_SD1_Config.get_tag()], + Annotated[TI_File_SD2_Config, TI_File_SD2_Config.get_tag()], + Annotated[TI_File_SDXL_Config, TI_File_SDXL_Config.get_tag()], + # TI - folder format + Annotated[TI_Folder_SD1_Config, TI_Folder_SD1_Config.get_tag()], + Annotated[TI_Folder_SD2_Config, TI_Folder_SD2_Config.get_tag()], + Annotated[TI_Folder_SDXL_Config, TI_Folder_SDXL_Config.get_tag()], + # IP Adapter - InvokeAI format + Annotated[IPAdapter_InvokeAI_SD1_Config, IPAdapter_InvokeAI_SD1_Config.get_tag()], + Annotated[IPAdapter_InvokeAI_SD2_Config, IPAdapter_InvokeAI_SD2_Config.get_tag()], + Annotated[IPAdapter_InvokeAI_SDXL_Config, IPAdapter_InvokeAI_SDXL_Config.get_tag()], + # IP Adapter - checkpoint format + Annotated[IPAdapter_Checkpoint_SD1_Config, IPAdapter_Checkpoint_SD1_Config.get_tag()], + Annotated[IPAdapter_Checkpoint_SD2_Config, IPAdapter_Checkpoint_SD2_Config.get_tag()], + Annotated[IPAdapter_Checkpoint_SDXL_Config, IPAdapter_Checkpoint_SDXL_Config.get_tag()], + Annotated[IPAdapter_Checkpoint_FLUX_Config, IPAdapter_Checkpoint_FLUX_Config.get_tag()], + # T2I Adapter - diffusers format + Annotated[T2IAdapter_Diffusers_SD1_Config, T2IAdapter_Diffusers_SD1_Config.get_tag()], + Annotated[T2IAdapter_Diffusers_SDXL_Config, T2IAdapter_Diffusers_SDXL_Config.get_tag()], + # Misc models + Annotated[Spandrel_Checkpoint_Config, Spandrel_Checkpoint_Config.get_tag()], + Annotated[CLIPEmbed_Diffusers_G_Config, CLIPEmbed_Diffusers_G_Config.get_tag()], + Annotated[CLIPEmbed_Diffusers_L_Config, CLIPEmbed_Diffusers_L_Config.get_tag()], + Annotated[CLIPVision_Diffusers_Config, CLIPVision_Diffusers_Config.get_tag()], + Annotated[SigLIP_Diffusers_Config, SigLIP_Diffusers_Config.get_tag()], + Annotated[FLUXRedux_Checkpoint_Config, FLUXRedux_Checkpoint_Config.get_tag()], + Annotated[LlavaOnevision_Diffusers_Config, LlavaOnevision_Diffusers_Config.get_tag()], + # API models + Annotated[ExternalAPI_ChatGPT4o_Config, ExternalAPI_ChatGPT4o_Config.get_tag()], + Annotated[ExternalAPI_Gemini2_5_Config, ExternalAPI_Gemini2_5_Config.get_tag()], + Annotated[ExternalAPI_Imagen3_Config, ExternalAPI_Imagen3_Config.get_tag()], + Annotated[ExternalAPI_Imagen4_Config, ExternalAPI_Imagen4_Config.get_tag()], + Annotated[ExternalAPI_FluxKontext_Config, ExternalAPI_FluxKontext_Config.get_tag()], + Annotated[ExternalAPI_Veo3_Config, ExternalAPI_Veo3_Config.get_tag()], + Annotated[ExternalAPI_Runway_Config, ExternalAPI_Runway_Config.get_tag()], + # Unknown model (fallback) + Annotated[Unknown_Config, Unknown_Config.get_tag()], ], - Discriminator(get_model_discriminator_value), + Discriminator(Config_Base.get_model_discriminator_value), ] -AnyModelConfigValidator = TypeAdapter(AnyModelConfig) -AnyDefaultSettings: TypeAlias = Union[MainModelDefaultSettings, LoraModelDefaultSettings, ControlAdapterDefaultSettings] +AnyModelConfigValidator = TypeAdapter[AnyModelConfig](AnyModelConfig) class ModelConfigFactory: @staticmethod def make_config(model_data: Dict[str, Any], timestamp: Optional[float] = None) -> AnyModelConfig: """Return the appropriate config object from raw dict values.""" - model = AnyModelConfigValidator.validate_python(model_data) # type: ignore - if isinstance(model, CheckpointConfigBase) and timestamp: + model = AnyModelConfigValidator.validate_python(model_data) + if isinstance(model, Checkpoint_Config_Base) and timestamp: model.converted_at = timestamp validate_hash(model.hash) - return model # type: ignore + return model + + @staticmethod + def build_common_fields( + mod: ModelOnDisk, + overrides: dict[str, Any] | None = None, + ) -> dict[str, Any]: + """Builds the common fields for all model configs. + + Args: + mod: The model on disk to extract fields from. + overrides: A optional dictionary of fields to override. These fields will take precedence over the values + extracted from the model on disk. + + - Casts string fields to their Enum types. + - Does not validate the fields against the model config schema. + """ + + _overrides: dict[str, Any] = overrides or {} + fields: dict[str, Any] = {} + + if "type" in _overrides: + fields["type"] = ModelType(_overrides["type"]) + + if "format" in _overrides: + fields["format"] = ModelFormat(_overrides["format"]) + + if "base" in _overrides: + fields["base"] = BaseModelType(_overrides["base"]) + + if "source_type" in _overrides: + fields["source_type"] = ModelSourceType(_overrides["source_type"]) + + if "variant" in _overrides: + fields["variant"] = variant_type_adapter.validate_strings(_overrides["variant"]) + + fields["path"] = mod.path.as_posix() + fields["source"] = _overrides.get("source") or fields["path"] + fields["source_type"] = _overrides.get("source_type") or ModelSourceType.Path + fields["name"] = _overrides.get("name") or mod.name + fields["hash"] = _overrides.get("hash") or mod.hash() + fields["key"] = _overrides.get("key") or uuid_string() + fields["description"] = _overrides.get("description") + fields["file_size"] = _overrides.get("file_size") or mod.size() + + return fields + + @staticmethod + def from_model_on_disk( + mod: str | Path | ModelOnDisk, + overrides: dict[str, Any] | None = None, + hash_algo: HASHING_ALGORITHMS = "blake3_single", + ) -> AnyModelConfig: + """ + Returns the best matching ModelConfig instance from a model's file/folder path. + Raises InvalidModelConfigException if no valid configuration is found. + Created to deprecate ModelProbe.probe + """ + if isinstance(mod, Path | str): + mod = ModelOnDisk(Path(mod), hash_algo) + + # We will always need these fields to build any model config. + fields = ModelConfigFactory.build_common_fields(mod, overrides) + + # Store results as a mapping of config class to either an instance of that class or an exception + # that was raised when trying to build it. + results: dict[str, AnyModelConfig | Exception] = {} + + # Try to build an instance of each model config class that uses the classify API. + # Each class will either return an instance of itself or raise NotAMatch if it doesn't match. + # Other exceptions may be raised if something unexpected happens during matching or building. + for config_class in Config_Base.CONFIG_CLASSES: + class_name = config_class.__name__ + try: + instance = config_class.from_model_on_disk(mod, fields) + results[class_name] = instance + except NotAMatch as e: + results[class_name] = e + logger.debug(f"No match for {config_class.__name__} on model {mod.name}") + except ValidationError as e: + # This means the model matched, but we couldn't create the pydantic model instance for the config. + # Maybe invalid overrides were provided? + results[class_name] = e + logger.warning(f"Schema validation error for {config_class.__name__} on model {mod.name}: {e}") + except Exception as e: + results[class_name] = e + logger.warning(f"Unexpected exception while matching {mod.name} to {config_class.__name__}: {e}") + + matches = [r for r in results.values() if isinstance(r, Config_Base)] + + if not matches and app_config.allow_unknown_models: + logger.warning(f"Unable to identify model {mod.name}, falling back to Unknown_Config") + return Unknown_Config(**fields) + + if len(matches) > 1: + # We have multiple matches, in which case at most 1 is correct. We need to pick one. + # + # Known cases: + # - SD main models can look like a LoRA when they have merged in LoRA weights. Prefer the main model. + # - SD main models in diffusers format can look like a CLIP Embed; they have a text_encoder folder with + # a config.json file. Prefer the main model. + + # Sort the matching according to known special cases. + def sort_key(m: AnyModelConfig) -> int: + match m.type: + case ModelType.Main: + return 0 + case ModelType.LoRA: + return 1 + case ModelType.CLIPEmbed: + return 2 + case _: + return 3 + + matches.sort(key=sort_key) + logger.warning( + f"Multiple model config classes matched for model {mod.name}: {[type(m).__name__ for m in matches]}. Using {type(matches[0]).__name__}." + ) + + instance = matches[0] + logger.info(f"Model {mod.name} classified as {type(instance).__name__}") + return instance diff --git a/invokeai/backend/model_manager/legacy_probe.py b/invokeai/backend/model_manager/legacy_probe.py deleted file mode 100644 index 36fd82667d7..00000000000 --- a/invokeai/backend/model_manager/legacy_probe.py +++ /dev/null @@ -1,1169 +0,0 @@ -import json -import re -from pathlib import Path -from typing import Any, Callable, Dict, Literal, Optional, Union - -import picklescan.scanner as pscan -import safetensors.torch -import spandrel -import torch - -import invokeai.backend.util.logging as logger -from invokeai.app.services.config.config_default import get_config -from invokeai.app.util.misc import uuid_string -from invokeai.backend.flux.controlnet.state_dict_utils import ( - is_state_dict_instantx_controlnet, - is_state_dict_xlabs_controlnet, -) -from invokeai.backend.flux.flux_state_dict_utils import get_flux_in_channels_from_state_dict -from invokeai.backend.flux.ip_adapter.state_dict_utils import is_state_dict_xlabs_ip_adapter -from invokeai.backend.flux.redux.flux_redux_state_dict_utils import is_state_dict_likely_flux_redux -from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash -from invokeai.backend.model_manager.config import ( - AnyModelConfig, - ControlAdapterDefaultSettings, - InvalidModelConfigException, - LoraModelDefaultSettings, - MainModelDefaultSettings, - ModelConfigFactory, - SubmodelDefinition, -) -from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import ConfigLoader -from invokeai.backend.model_manager.model_on_disk import ModelOnDisk -from invokeai.backend.model_manager.taxonomy import ( - AnyVariant, - BaseModelType, - ModelFormat, - ModelRepoVariant, - ModelSourceType, - ModelType, - ModelVariantType, - SchedulerPredictionType, - SubModelType, -) -from invokeai.backend.model_manager.util.model_util import ( - get_clip_variant_type, - lora_token_vector_length, - read_checkpoint_meta, -) -from invokeai.backend.patches.lora_conversions.flux_control_lora_utils import is_state_dict_likely_flux_control -from invokeai.backend.patches.lora_conversions.flux_diffusers_lora_conversion_utils import ( - is_state_dict_likely_in_flux_diffusers_format, -) -from invokeai.backend.patches.lora_conversions.flux_kohya_lora_conversion_utils import ( - is_state_dict_likely_in_flux_kohya_format, -) -from invokeai.backend.patches.lora_conversions.flux_onetrainer_lora_conversion_utils import ( - is_state_dict_likely_in_flux_onetrainer_format, -) -from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor -from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader -from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel -from invokeai.backend.util.silence_warnings import SilenceWarnings - -CkptType = Dict[str | int, Any] - -LEGACY_CONFIGS: Dict[BaseModelType, Dict[ModelVariantType, Union[str, Dict[SchedulerPredictionType, str]]]] = { - BaseModelType.StableDiffusion1: { - ModelVariantType.Normal: { - SchedulerPredictionType.Epsilon: "v1-inference.yaml", - SchedulerPredictionType.VPrediction: "v1-inference-v.yaml", - }, - ModelVariantType.Inpaint: "v1-inpainting-inference.yaml", - }, - BaseModelType.StableDiffusion2: { - ModelVariantType.Normal: { - SchedulerPredictionType.Epsilon: "v2-inference.yaml", - SchedulerPredictionType.VPrediction: "v2-inference-v.yaml", - }, - ModelVariantType.Inpaint: { - SchedulerPredictionType.Epsilon: "v2-inpainting-inference.yaml", - SchedulerPredictionType.VPrediction: "v2-inpainting-inference-v.yaml", - }, - ModelVariantType.Depth: "v2-midas-inference.yaml", - }, - BaseModelType.StableDiffusionXL: { - ModelVariantType.Normal: "sd_xl_base.yaml", - ModelVariantType.Inpaint: "sd_xl_inpaint.yaml", - }, - BaseModelType.StableDiffusionXLRefiner: { - ModelVariantType.Normal: "sd_xl_refiner.yaml", - }, -} - - -class ProbeBase(object): - """Base class for probes.""" - - def __init__(self, model_path: Path): - self.model_path = model_path - - def get_base_type(self) -> BaseModelType: - """Get model base type.""" - raise NotImplementedError - - def get_format(self) -> ModelFormat: - """Get model file format.""" - raise NotImplementedError - - def get_variant_type(self) -> Optional[ModelVariantType]: - """Get model variant type.""" - return None - - def get_scheduler_prediction_type(self) -> Optional[SchedulerPredictionType]: - """Get model scheduler prediction type.""" - return None - - def get_image_encoder_model_id(self) -> Optional[str]: - """Get image encoder (IP adapters only).""" - return None - - -class ModelProbe(object): - PROBES: Dict[str, Dict[ModelType, type[ProbeBase]]] = { - "diffusers": {}, - "checkpoint": {}, - "onnx": {}, - } - - CLASS2TYPE = { - "FluxPipeline": ModelType.Main, - "StableDiffusionPipeline": ModelType.Main, - "StableDiffusionInpaintPipeline": ModelType.Main, - "StableDiffusionXLPipeline": ModelType.Main, - "StableDiffusionXLImg2ImgPipeline": ModelType.Main, - "StableDiffusionXLInpaintPipeline": ModelType.Main, - "StableDiffusion3Pipeline": ModelType.Main, - "LatentConsistencyModelPipeline": ModelType.Main, - "AutoencoderKL": ModelType.VAE, - "AutoencoderTiny": ModelType.VAE, - "ControlNetModel": ModelType.ControlNet, - "CLIPVisionModelWithProjection": ModelType.CLIPVision, - "T2IAdapter": ModelType.T2IAdapter, - "CLIPModel": ModelType.CLIPEmbed, - "CLIPTextModel": ModelType.CLIPEmbed, - "T5EncoderModel": ModelType.T5Encoder, - "FluxControlNetModel": ModelType.ControlNet, - "SD3Transformer2DModel": ModelType.Main, - "CLIPTextModelWithProjection": ModelType.CLIPEmbed, - "SiglipModel": ModelType.SigLIP, - "LlavaOnevisionForConditionalGeneration": ModelType.LlavaOnevision, - "CogView4Pipeline": ModelType.Main, - } - - TYPE2VARIANT: Dict[ModelType, Callable[[str], Optional[AnyVariant]]] = {ModelType.CLIPEmbed: get_clip_variant_type} - - @classmethod - def register_probe( - cls, format: Literal["diffusers", "checkpoint", "onnx"], model_type: ModelType, probe_class: type[ProbeBase] - ) -> None: - cls.PROBES[format][model_type] = probe_class - - @classmethod - def probe( - cls, model_path: Path, fields: Optional[Dict[str, Any]] = None, hash_algo: HASHING_ALGORITHMS = "blake3_single" - ) -> AnyModelConfig: - """ - Probe the model at model_path and return its configuration record. - - :param model_path: Path to the model file (checkpoint) or directory (diffusers). - :param fields: An optional dictionary that can be used to override probed - fields. Typically used for fields that don't probe well, such as prediction_type. - - Returns: The appropriate model configuration derived from ModelConfigBase. - """ - if fields is None: - fields = {} - - model_path = model_path.resolve() - - format_type = ModelFormat.Diffusers if model_path.is_dir() else ModelFormat.Checkpoint - model_info = None - model_type = ModelType(fields["type"]) if "type" in fields and fields["type"] else None - if not model_type: - if format_type is ModelFormat.Diffusers: - model_type = cls.get_model_type_from_folder(model_path) - else: - model_type = cls.get_model_type_from_checkpoint(model_path) - format_type = ModelFormat.ONNX if model_type == ModelType.ONNX else format_type - - probe_class = cls.PROBES[format_type].get(model_type) - if not probe_class: - raise InvalidModelConfigException(f"Unhandled combination of {format_type} and {model_type}") - - probe = probe_class(model_path) - - fields["source_type"] = fields.get("source_type") or ModelSourceType.Path - fields["source"] = fields.get("source") or model_path.as_posix() - fields["key"] = fields.get("key", uuid_string()) - fields["path"] = model_path.as_posix() - fields["type"] = fields.get("type") or model_type - fields["base"] = fields.get("base") or probe.get_base_type() - variant_func = cls.TYPE2VARIANT.get(fields["type"], None) - fields["variant"] = ( - fields.get("variant") or (variant_func and variant_func(model_path.as_posix())) or probe.get_variant_type() - ) - fields["prediction_type"] = fields.get("prediction_type") or probe.get_scheduler_prediction_type() - fields["image_encoder_model_id"] = fields.get("image_encoder_model_id") or probe.get_image_encoder_model_id() - fields["name"] = fields.get("name") or cls.get_model_name(model_path) - fields["description"] = ( - fields.get("description") or f"{fields['base'].value} {model_type.value} model {fields['name']}" - ) - fields["format"] = ModelFormat(fields.get("format")) if "format" in fields else probe.get_format() - fields["hash"] = fields.get("hash") or ModelHash(algorithm=hash_algo).hash(model_path) - fields["file_size"] = fields.get("file_size") or ModelOnDisk(model_path).size() - - fields["default_settings"] = fields.get("default_settings") - - if not fields["default_settings"]: - if fields["type"] in {ModelType.ControlNet, ModelType.T2IAdapter, ModelType.ControlLoRa}: - fields["default_settings"] = get_default_settings_control_adapters(fields["name"]) - if fields["type"] in {ModelType.LoRA}: - fields["default_settings"] = get_default_settings_lora() - elif fields["type"] is ModelType.Main: - fields["default_settings"] = get_default_settings_main(fields["base"]) - - if format_type == ModelFormat.Diffusers and isinstance(probe, FolderProbeBase): - fields["repo_variant"] = fields.get("repo_variant") or probe.get_repo_variant() - - # additional fields needed for main and controlnet models - if fields["type"] in [ModelType.Main, ModelType.ControlNet, ModelType.VAE] and fields["format"] in [ - ModelFormat.Checkpoint, - ModelFormat.BnbQuantizednf4b, - ModelFormat.GGUFQuantized, - ]: - ckpt_config_path = cls._get_checkpoint_config_path( - model_path, - model_type=fields["type"], - base_type=fields["base"], - variant_type=fields["variant"], - prediction_type=fields["prediction_type"], - ) - fields["config_path"] = str(ckpt_config_path) - - # additional fields needed for main non-checkpoint models - elif fields["type"] == ModelType.Main and fields["format"] in [ - ModelFormat.ONNX, - ModelFormat.Olive, - ModelFormat.Diffusers, - ]: - fields["upcast_attention"] = fields.get("upcast_attention") or ( - fields["base"] == BaseModelType.StableDiffusion2 - and fields["prediction_type"] == SchedulerPredictionType.VPrediction - ) - - get_submodels = getattr(probe, "get_submodels", None) - if fields["base"] == BaseModelType.StableDiffusion3 and callable(get_submodels): - fields["submodels"] = get_submodels() - - model_info = ModelConfigFactory.make_config(fields) # , key=fields.get("key", None)) - return model_info - - @classmethod - def get_model_name(cls, model_path: Path) -> str: - if model_path.suffix in {".safetensors", ".bin", ".pt", ".ckpt"}: - return model_path.stem - else: - return model_path.name - - @classmethod - def get_model_type_from_checkpoint(cls, model_path: Path, checkpoint: Optional[CkptType] = None) -> ModelType: - if model_path.suffix not in (".bin", ".pt", ".ckpt", ".safetensors", ".pth", ".gguf"): - raise InvalidModelConfigException(f"{model_path}: unrecognized suffix") - - if model_path.name == "learned_embeds.bin": - return ModelType.TextualInversion - - ckpt = checkpoint if checkpoint else read_checkpoint_meta(model_path, scan=True) - ckpt = ckpt.get("state_dict", ckpt) - - if isinstance(ckpt, dict) and is_state_dict_likely_flux_control(ckpt): - return ModelType.ControlLoRa - - if isinstance(ckpt, dict) and is_state_dict_likely_flux_redux(ckpt): - return ModelType.FluxRedux - - for key in [str(k) for k in ckpt.keys()]: - if key.startswith( - ( - "cond_stage_model.", - "first_stage_model.", - "model.diffusion_model.", - # Some FLUX checkpoint files contain transformer keys prefixed with "model.diffusion_model". - # This prefix is typically used to distinguish between multiple models bundled in a single file. - "model.diffusion_model.double_blocks.", - ) - ): - # Keys starting with double_blocks are associated with Flux models - return ModelType.Main - # FLUX models in the official BFL format contain keys with the "double_blocks." prefix, but we must be - # careful to avoid false positives on XLabs FLUX IP-Adapter models. - elif key.startswith("double_blocks.") and "ip_adapter" not in key: - return ModelType.Main - elif key.startswith(("encoder.conv_in", "decoder.conv_in")): - return ModelType.VAE - elif key.startswith(("lora_te_", "lora_unet_", "lora_te1_", "lora_te2_", "lora_transformer_")): - return ModelType.LoRA - # "lora_A.weight" and "lora_B.weight" are associated with models in PEFT format. We don't support all PEFT - # LoRA models, but as of the time of writing, we support Diffusers FLUX PEFT LoRA models. - elif key.endswith(("to_k_lora.up.weight", "to_q_lora.down.weight", "lora_A.weight", "lora_B.weight")): - return ModelType.LoRA - elif key.startswith( - ( - "controlnet", - "control_model", - "input_blocks", - # XLabs FLUX ControlNet models have keys starting with "controlnet_blocks." - # For example: https://huggingface.co/XLabs-AI/flux-controlnet-collections/blob/86ab1e915a389d5857135c00e0d350e9e38a9048/flux-canny-controlnet_v2.safetensors - # TODO(ryand): This is very fragile. XLabs FLUX ControlNet models also contain keys starting with - # "double_blocks.", which we check for above. But, I'm afraid to modify this logic because it is so - # delicate. - "controlnet_blocks", - ) - ): - return ModelType.ControlNet - elif key.startswith( - ( - "image_proj.", - "ip_adapter.", - # XLabs FLUX IP-Adapter models have keys startinh with "ip_adapter_proj_model.". - "ip_adapter_proj_model.", - ) - ): - return ModelType.IPAdapter - elif key in {"emb_params", "string_to_param"}: - return ModelType.TextualInversion - - # diffusers-ti - if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()): - return ModelType.TextualInversion - - # Check if the model can be loaded as a SpandrelImageToImageModel. - # This check is intentionally performed last, as it can be expensive (it requires loading the model from disk). - try: - # It would be nice to avoid having to load the Spandrel model from disk here. A couple of options were - # explored to avoid this: - # 1. Call `SpandrelImageToImageModel.load_from_state_dict(ckpt)`, where `ckpt` is a state_dict on the meta - # device. Unfortunately, some Spandrel models perform operations during initialization that are not - # supported on meta tensors. - # 2. Spandrel has internal logic to determine a model's type from its state_dict before loading the model. - # This logic is not exposed in spandrel's public API. We could copy the logic here, but then we have to - # maintain it, and the risk of false positive detections is higher. - SpandrelImageToImageModel.load_from_file(model_path) - return ModelType.SpandrelImageToImage - except spandrel.UnsupportedModelError: - pass - except Exception as e: - logger.warning( - f"Encountered error while probing to determine if {model_path} is a Spandrel model. Ignoring. Error: {e}" - ) - - raise InvalidModelConfigException(f"Unable to determine model type for {model_path}") - - @classmethod - def get_model_type_from_folder(cls, folder_path: Path) -> ModelType: - """Get the model type of a hugging-face style folder.""" - class_name = None - error_hint = None - for suffix in ["bin", "safetensors"]: - if (folder_path / f"learned_embeds.{suffix}").exists(): - return ModelType.TextualInversion - if (folder_path / f"pytorch_lora_weights.{suffix}").exists(): - return ModelType.LoRA - if (folder_path / "unet/model.onnx").exists(): - return ModelType.ONNX - if (folder_path / "image_encoder.txt").exists(): - return ModelType.IPAdapter - - config_path = None - for p in [ - folder_path / "model_index.json", # pipeline - folder_path / "config.json", # most diffusers - folder_path / "text_encoder_2" / "config.json", # T5 text encoder - folder_path / "text_encoder" / "config.json", # T5 CLIP - ]: - if p.exists(): - config_path = p - break - - if config_path: - with open(config_path, "r") as file: - conf = json.load(file) - if "_class_name" in conf: - class_name = conf["_class_name"] - elif "architectures" in conf: - class_name = conf["architectures"][0] - else: - class_name = None - else: - error_hint = f"No model_index.json or config.json found in {folder_path}." - - if class_name and (type := cls.CLASS2TYPE.get(class_name)): - return type - else: - error_hint = f"class {class_name} is not one of the supported classes [{', '.join(cls.CLASS2TYPE.keys())}]" - - # give up - raise InvalidModelConfigException( - f"Unable to determine model type for {folder_path}" + (f"; {error_hint}" if error_hint else "") - ) - - @classmethod - def _get_checkpoint_config_path( - cls, - model_path: Path, - model_type: ModelType, - base_type: BaseModelType, - variant_type: ModelVariantType, - prediction_type: SchedulerPredictionType, - ) -> Path: - # look for a YAML file adjacent to the model file first - possible_conf = model_path.with_suffix(".yaml") - if possible_conf.exists(): - return possible_conf.absolute() - - if model_type is ModelType.Main: - if base_type == BaseModelType.Flux: - # TODO: Decide between dev/schnell - checkpoint = ModelProbe._scan_and_load_checkpoint(model_path) - state_dict = checkpoint.get("state_dict") or checkpoint - - # HACK: For FLUX, config_file is used as a key into invokeai.backend.flux.util.params during model - # loading. When FLUX support was first added, it was decided that this was the easiest way to support - # the various FLUX formats rather than adding new model types/formats. Be careful when modifying this in - # the future. - if ( - "guidance_in.out_layer.weight" in state_dict - or "model.diffusion_model.guidance_in.out_layer.weight" in state_dict - ): - if variant_type == ModelVariantType.Normal: - config_file = "flux-dev" - elif variant_type == ModelVariantType.Inpaint: - config_file = "flux-dev-fill" - else: - raise ValueError(f"Unexpected FLUX variant type: {variant_type}") - else: - config_file = "flux-schnell" - else: - config_file = LEGACY_CONFIGS[base_type][variant_type] - if isinstance(config_file, dict): # need another tier for sd-2.x models - config_file = config_file[prediction_type] - config_file = f"stable-diffusion/{config_file}" - elif model_type is ModelType.ControlNet: - config_file = ( - "controlnet/cldm_v15.yaml" - if base_type is BaseModelType.StableDiffusion1 - else "controlnet/cldm_v21.yaml" - ) - elif model_type is ModelType.VAE: - config_file = ( - # For flux, this is a key in invokeai.backend.flux.util.ae_params - # Due to model type and format being the descriminator for model configs this - # is used rather than attempting to support flux with separate model types and format - # If changed in the future, please fix me - "flux" - if base_type is BaseModelType.Flux - else "stable-diffusion/v1-inference.yaml" - if base_type is BaseModelType.StableDiffusion1 - else "stable-diffusion/sd_xl_base.yaml" - if base_type is BaseModelType.StableDiffusionXL - else "stable-diffusion/v2-inference.yaml" - ) - else: - raise InvalidModelConfigException( - f"{model_path}: Unrecognized combination of model_type={model_type}, base_type={base_type}" - ) - return Path(config_file) - - @classmethod - def _scan_and_load_checkpoint(cls, model_path: Path) -> CkptType: - with SilenceWarnings(): - if model_path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin")): - cls._scan_model(model_path.name, model_path) - model = torch.load(model_path, map_location="cpu") - assert isinstance(model, dict) - return model - elif model_path.suffix.endswith(".gguf"): - return gguf_sd_loader(model_path, compute_dtype=torch.float32) - else: - return safetensors.torch.load_file(model_path) - - @classmethod - def _scan_model(cls, model_name: str, checkpoint: Path) -> None: - """ - Apply picklescanner to the indicated checkpoint and issue a warning - and option to exit if an infected file is identified. - """ - # scan model - scan_result = pscan.scan_file_path(checkpoint) - if scan_result.infected_files != 0: - if get_config().unsafe_disable_picklescan: - logger.warning( - f"The model {model_name} is potentially infected by malware, but picklescan is disabled. " - "Proceeding with caution." - ) - else: - raise RuntimeError(f"The model {model_name} is potentially infected by malware. Aborting import.") - if scan_result.scan_err: - if get_config().unsafe_disable_picklescan: - logger.warning( - f"Error scanning the model at {model_name} for malware, but picklescan is disabled. " - "Proceeding with caution." - ) - else: - raise RuntimeError(f"Error scanning the model at {model_name} for malware. Aborting import.") - - -# Probing utilities -MODEL_NAME_TO_PREPROCESSOR = { - "canny": "canny_image_processor", - "mlsd": "mlsd_image_processor", - "depth": "depth_anything_image_processor", - "bae": "normalbae_image_processor", - "normal": "normalbae_image_processor", - "sketch": "pidi_image_processor", - "scribble": "lineart_image_processor", - "lineart anime": "lineart_anime_image_processor", - "lineart_anime": "lineart_anime_image_processor", - "lineart": "lineart_image_processor", - "soft": "hed_image_processor", - "softedge": "hed_image_processor", - "hed": "hed_image_processor", - "shuffle": "content_shuffle_image_processor", - "pose": "dw_openpose_image_processor", - "mediapipe": "mediapipe_face_processor", - "pidi": "pidi_image_processor", - "zoe": "zoe_depth_image_processor", - "color": "color_map_image_processor", -} - - -def get_default_settings_control_adapters(model_name: str) -> Optional[ControlAdapterDefaultSettings]: - for k, v in MODEL_NAME_TO_PREPROCESSOR.items(): - model_name_lower = model_name.lower() - if k in model_name_lower: - return ControlAdapterDefaultSettings(preprocessor=v) - return None - - -def get_default_settings_lora() -> LoraModelDefaultSettings: - return LoraModelDefaultSettings() - - -def get_default_settings_main(model_base: BaseModelType) -> Optional[MainModelDefaultSettings]: - if model_base is BaseModelType.StableDiffusion1 or model_base is BaseModelType.StableDiffusion2: - return MainModelDefaultSettings(width=512, height=512) - elif model_base is BaseModelType.StableDiffusionXL: - return MainModelDefaultSettings(width=1024, height=1024) - # We don't provide defaults for BaseModelType.StableDiffusionXLRefiner, as they are not standalone models. - return None - - -# ##################################################3 -# Checkpoint probing -# ##################################################3 - - -class CheckpointProbeBase(ProbeBase): - def __init__(self, model_path: Path): - super().__init__(model_path) - self.checkpoint = ModelProbe._scan_and_load_checkpoint(model_path) - - def get_format(self) -> ModelFormat: - state_dict = self.checkpoint.get("state_dict") or self.checkpoint - if ( - "double_blocks.0.img_attn.proj.weight.quant_state.bitsandbytes__nf4" in state_dict - or "model.diffusion_model.double_blocks.0.img_attn.proj.weight.quant_state.bitsandbytes__nf4" in state_dict - ): - return ModelFormat.BnbQuantizednf4b - elif any(isinstance(v, GGMLTensor) for v in state_dict.values()): - return ModelFormat.GGUFQuantized - return ModelFormat("checkpoint") - - def get_variant_type(self) -> ModelVariantType: - model_type = ModelProbe.get_model_type_from_checkpoint(self.model_path, self.checkpoint) - base_type = self.get_base_type() - if model_type != ModelType.Main: - return ModelVariantType.Normal - state_dict = self.checkpoint.get("state_dict") or self.checkpoint - - if base_type == BaseModelType.Flux: - in_channels = get_flux_in_channels_from_state_dict(state_dict) - - if in_channels is None: - # If we cannot find the in_channels, we assume that this is a normal variant. Log a warning. - logger.warning( - f"{self.model_path} does not have img_in.weight or model.diffusion_model.img_in.weight key. Assuming normal variant." - ) - return ModelVariantType.Normal - - # FLUX Model variant types are distinguished by input channels: - # - Unquantized Dev and Schnell have in_channels=64 - # - BNB-NF4 Dev and Schnell have in_channels=1 - # - FLUX Fill has in_channels=384 - # - Unsure of quantized FLUX Fill models - # - Unsure of GGUF-quantized models - if in_channels == 384: - # This is a FLUX Fill model. FLUX Fill needs special handling throughout the application. The variant - # type is used to determine whether to use the fill model or the base model. - return ModelVariantType.Inpaint - else: - # Fall back on "normal" variant type for all other FLUX models. - return ModelVariantType.Normal - - in_channels = state_dict["model.diffusion_model.input_blocks.0.0.weight"].shape[1] - if in_channels == 9: - return ModelVariantType.Inpaint - elif in_channels == 5: - return ModelVariantType.Depth - elif in_channels == 4: - return ModelVariantType.Normal - else: - raise InvalidModelConfigException( - f"Cannot determine variant type (in_channels={in_channels}) at {self.model_path}" - ) - - -class PipelineCheckpointProbe(CheckpointProbeBase): - def get_base_type(self) -> BaseModelType: - checkpoint = self.checkpoint - state_dict = self.checkpoint.get("state_dict") or checkpoint - if ( - "double_blocks.0.img_attn.norm.key_norm.scale" in state_dict - or "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale" in state_dict - ): - return BaseModelType.Flux - key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight" - if key_name in state_dict and state_dict[key_name].shape[-1] == 768: - return BaseModelType.StableDiffusion1 - if key_name in state_dict and state_dict[key_name].shape[-1] == 1024: - return BaseModelType.StableDiffusion2 - key_name = "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight" - if key_name in state_dict and state_dict[key_name].shape[-1] == 2048: - return BaseModelType.StableDiffusionXL - elif key_name in state_dict and state_dict[key_name].shape[-1] == 1280: - return BaseModelType.StableDiffusionXLRefiner - else: - raise InvalidModelConfigException("Cannot determine base type") - - def get_scheduler_prediction_type(self) -> SchedulerPredictionType: - """Return model prediction type.""" - type = self.get_base_type() - if type == BaseModelType.StableDiffusion2: - checkpoint = self.checkpoint - state_dict = self.checkpoint.get("state_dict") or checkpoint - key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight" - if key_name in state_dict and state_dict[key_name].shape[-1] == 1024: - if "global_step" in checkpoint: - if checkpoint["global_step"] == 220000: - return SchedulerPredictionType.Epsilon - elif checkpoint["global_step"] == 110000: - return SchedulerPredictionType.VPrediction - return SchedulerPredictionType.VPrediction # a guess for sd2 ckpts - - elif type == BaseModelType.StableDiffusion1: - return SchedulerPredictionType.Epsilon # a reasonable guess for sd1 ckpts - else: - return SchedulerPredictionType.Epsilon - - -class VaeCheckpointProbe(CheckpointProbeBase): - def get_base_type(self) -> BaseModelType: - # VAEs of all base types have the same structure, so we wimp out and - # guess using the name. - for regexp, basetype in [ - (r"xl", BaseModelType.StableDiffusionXL), - (r"sd2", BaseModelType.StableDiffusion2), - (r"vae", BaseModelType.StableDiffusion1), - (r"FLUX.1-schnell_ae", BaseModelType.Flux), - ]: - if re.search(regexp, self.model_path.name, re.IGNORECASE): - return basetype - raise InvalidModelConfigException("Cannot determine base type") - - -class LoRACheckpointProbe(CheckpointProbeBase): - """Class for LoRA checkpoints.""" - - def get_format(self) -> ModelFormat: - if is_state_dict_likely_in_flux_diffusers_format(self.checkpoint): - # TODO(ryand): This is an unusual case. In other places throughout the codebase, we treat - # ModelFormat.Diffusers as meaning that the model is in a directory. In this case, the model is a single - # file, but the weight keys are in the diffusers format. - return ModelFormat.Diffusers - return ModelFormat.LyCORIS - - def get_base_type(self) -> BaseModelType: - if ( - is_state_dict_likely_in_flux_kohya_format(self.checkpoint) - or is_state_dict_likely_in_flux_onetrainer_format(self.checkpoint) - or is_state_dict_likely_in_flux_diffusers_format(self.checkpoint) - or is_state_dict_likely_flux_control(self.checkpoint) - ): - return BaseModelType.Flux - - # If we've gotten here, we assume that the model is a Stable Diffusion model. - token_vector_length = lora_token_vector_length(self.checkpoint) - if token_vector_length == 768: - return BaseModelType.StableDiffusion1 - elif token_vector_length == 1024: - return BaseModelType.StableDiffusion2 - elif token_vector_length == 1280: - return BaseModelType.StableDiffusionXL # recognizes format at https://civitai.com/models/224641 - elif token_vector_length == 2048: - return BaseModelType.StableDiffusionXL - else: - raise InvalidModelConfigException(f"Unknown LoRA type: {self.model_path}") - - -class TextualInversionCheckpointProbe(CheckpointProbeBase): - """Class for probing embeddings.""" - - def get_format(self) -> ModelFormat: - return ModelFormat.EmbeddingFile - - def get_base_type(self) -> BaseModelType: - checkpoint = self.checkpoint - if "string_to_token" in checkpoint: - token_dim = list(checkpoint["string_to_param"].values())[0].shape[-1] - elif "emb_params" in checkpoint: - token_dim = checkpoint["emb_params"].shape[-1] - elif "clip_g" in checkpoint: - token_dim = checkpoint["clip_g"].shape[-1] - else: - token_dim = list(checkpoint.values())[0].shape[0] - if token_dim == 768: - return BaseModelType.StableDiffusion1 - elif token_dim == 1024: - return BaseModelType.StableDiffusion2 - elif token_dim == 1280: - return BaseModelType.StableDiffusionXL - else: - raise InvalidModelConfigException(f"{self.model_path}: Could not determine base type") - - -class ControlNetCheckpointProbe(CheckpointProbeBase): - """Class for probing controlnets.""" - - def get_base_type(self) -> BaseModelType: - checkpoint = self.checkpoint - if is_state_dict_xlabs_controlnet(checkpoint) or is_state_dict_instantx_controlnet(checkpoint): - # TODO(ryand): Should I distinguish between XLabs, InstantX and other ControlNet models by implementing - # get_format()? - return BaseModelType.Flux - - for key_name in ( - "control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight", - "controlnet_mid_block.bias", - "input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight", - "down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_k.weight", - ): - if key_name not in checkpoint: - continue - width = checkpoint[key_name].shape[-1] - if width == 768: - return BaseModelType.StableDiffusion1 - elif width == 1024: - return BaseModelType.StableDiffusion2 - elif width == 2048: - return BaseModelType.StableDiffusionXL - elif width == 1280: - return BaseModelType.StableDiffusionXL - raise InvalidModelConfigException(f"{self.model_path}: Unable to determine base type") - - -class IPAdapterCheckpointProbe(CheckpointProbeBase): - """Class for probing IP Adapters""" - - def get_base_type(self) -> BaseModelType: - checkpoint = self.checkpoint - - if is_state_dict_xlabs_ip_adapter(checkpoint): - return BaseModelType.Flux - - for key in checkpoint.keys(): - if not key.startswith(("image_proj.", "ip_adapter.")): - continue - cross_attention_dim = checkpoint["ip_adapter.1.to_k_ip.weight"].shape[-1] - if cross_attention_dim == 768: - return BaseModelType.StableDiffusion1 - elif cross_attention_dim == 1024: - return BaseModelType.StableDiffusion2 - elif cross_attention_dim == 2048: - return BaseModelType.StableDiffusionXL - else: - raise InvalidModelConfigException( - f"IP-Adapter had unexpected cross-attention dimension: {cross_attention_dim}." - ) - raise InvalidModelConfigException(f"{self.model_path}: Unable to determine base type") - - -class CLIPVisionCheckpointProbe(CheckpointProbeBase): - def get_base_type(self) -> BaseModelType: - raise NotImplementedError() - - -class T2IAdapterCheckpointProbe(CheckpointProbeBase): - def get_base_type(self) -> BaseModelType: - raise NotImplementedError() - - -class SpandrelImageToImageCheckpointProbe(CheckpointProbeBase): - def get_base_type(self) -> BaseModelType: - return BaseModelType.Any - - -class SigLIPCheckpointProbe(CheckpointProbeBase): - def get_base_type(self) -> BaseModelType: - raise NotImplementedError() - - -class FluxReduxCheckpointProbe(CheckpointProbeBase): - def get_base_type(self) -> BaseModelType: - return BaseModelType.Flux - - -class LlavaOnevisionCheckpointProbe(CheckpointProbeBase): - def get_base_type(self) -> BaseModelType: - raise NotImplementedError() - - -######################################################## -# classes for probing folders -####################################################### -class FolderProbeBase(ProbeBase): - def get_variant_type(self) -> ModelVariantType: - return ModelVariantType.Normal - - def get_format(self) -> ModelFormat: - return ModelFormat("diffusers") - - def get_repo_variant(self) -> ModelRepoVariant: - # get all files ending in .bin or .safetensors - weight_files = list(self.model_path.glob("**/*.safetensors")) - weight_files.extend(list(self.model_path.glob("**/*.bin"))) - for x in weight_files: - if ".fp16" in x.suffixes: - return ModelRepoVariant.FP16 - if "openvino_model" in x.name: - return ModelRepoVariant.OpenVINO - if "flax_model" in x.name: - return ModelRepoVariant.Flax - if x.suffix == ".onnx": - return ModelRepoVariant.ONNX - return ModelRepoVariant.Default - - -class PipelineFolderProbe(FolderProbeBase): - def get_base_type(self) -> BaseModelType: - # Handle pipelines with a UNet (i.e SD 1.x, SD2, SDXL). - config_path = self.model_path / "unet" / "config.json" - if config_path.exists(): - with open(config_path) as file: - unet_conf = json.load(file) - if unet_conf["cross_attention_dim"] == 768: - return BaseModelType.StableDiffusion1 - elif unet_conf["cross_attention_dim"] == 1024: - return BaseModelType.StableDiffusion2 - elif unet_conf["cross_attention_dim"] == 1280: - return BaseModelType.StableDiffusionXLRefiner - elif unet_conf["cross_attention_dim"] == 2048: - return BaseModelType.StableDiffusionXL - else: - raise InvalidModelConfigException(f"Unknown base model for {self.model_path}") - - # Handle pipelines with a transformer (i.e. SD3). - config_path = self.model_path / "transformer" / "config.json" - if config_path.exists(): - with open(config_path) as file: - transformer_conf = json.load(file) - if transformer_conf["_class_name"] == "SD3Transformer2DModel": - return BaseModelType.StableDiffusion3 - elif transformer_conf["_class_name"] == "CogView4Transformer2DModel": - return BaseModelType.CogView4 - else: - raise InvalidModelConfigException(f"Unknown base model for {self.model_path}") - - raise InvalidModelConfigException(f"Unknown base model for {self.model_path}") - - def get_scheduler_prediction_type(self) -> SchedulerPredictionType: - with open(self.model_path / "scheduler" / "scheduler_config.json", "r") as file: - scheduler_conf = json.load(file) - if scheduler_conf.get("prediction_type", "epsilon") == "v_prediction": - return SchedulerPredictionType.VPrediction - elif scheduler_conf.get("prediction_type", "epsilon") == "epsilon": - return SchedulerPredictionType.Epsilon - else: - raise InvalidModelConfigException("Unknown scheduler prediction type: {scheduler_conf['prediction_type']}") - - def get_submodels(self) -> Dict[SubModelType, SubmodelDefinition]: - config = ConfigLoader.load_config(self.model_path, config_name="model_index.json") - submodels: Dict[SubModelType, SubmodelDefinition] = {} - for key, value in config.items(): - if key.startswith("_") or not (isinstance(value, list) and len(value) == 2): - continue - model_loader = str(value[1]) - if model_type := ModelProbe.CLASS2TYPE.get(model_loader): - variant_func = ModelProbe.TYPE2VARIANT.get(model_type, None) - submodels[SubModelType(key)] = SubmodelDefinition( - path_or_prefix=(self.model_path / key).resolve().as_posix(), - model_type=model_type, - variant=variant_func and variant_func((self.model_path / key).as_posix()), - ) - - return submodels - - def get_variant_type(self) -> ModelVariantType: - # This only works for pipelines! Any kind of - # exception results in our returning the - # "normal" variant type - try: - config_file = self.model_path / "unet" / "config.json" - with open(config_file, "r") as file: - conf = json.load(file) - - in_channels = conf["in_channels"] - if in_channels == 9: - return ModelVariantType.Inpaint - elif in_channels == 5: - return ModelVariantType.Depth - elif in_channels == 4: - return ModelVariantType.Normal - except Exception: - pass - return ModelVariantType.Normal - - -class VaeFolderProbe(FolderProbeBase): - def get_base_type(self) -> BaseModelType: - if self._config_looks_like_sdxl(): - return BaseModelType.StableDiffusionXL - elif self._name_looks_like_sdxl(): - # but SD and SDXL VAE are the same shape (3-channel RGB to 4-channel float scaled down - # by a factor of 8), we can't necessarily tell them apart by config hyperparameters. - return BaseModelType.StableDiffusionXL - else: - return BaseModelType.StableDiffusion1 - - def _config_looks_like_sdxl(self) -> bool: - # config values that distinguish Stability's SD 1.x VAE from their SDXL VAE. - config_file = self.model_path / "config.json" - if not config_file.exists(): - raise InvalidModelConfigException(f"Cannot determine base type for {self.model_path}") - with open(config_file, "r") as file: - config = json.load(file) - return config.get("scaling_factor", 0) == 0.13025 and config.get("sample_size") in [512, 1024] - - def _name_looks_like_sdxl(self) -> bool: - return bool(re.search(r"xl\b", self._guess_name(), re.IGNORECASE)) - - def _guess_name(self) -> str: - name = self.model_path.name - if name == "vae": - name = self.model_path.parent.name - return name - - -class TextualInversionFolderProbe(FolderProbeBase): - def get_format(self) -> ModelFormat: - return ModelFormat.EmbeddingFolder - - def get_base_type(self) -> BaseModelType: - path = self.model_path / "learned_embeds.bin" - if not path.exists(): - raise InvalidModelConfigException( - f"{self.model_path.as_posix()} does not contain expected 'learned_embeds.bin' file" - ) - return TextualInversionCheckpointProbe(path).get_base_type() - - -class T5EncoderFolderProbe(FolderProbeBase): - def get_base_type(self) -> BaseModelType: - return BaseModelType.Any - - def get_format(self) -> ModelFormat: - path = self.model_path / "text_encoder_2" - if (path / "model.safetensors.index.json").exists(): - return ModelFormat.T5Encoder - files = list(path.glob("*.safetensors")) - if len(files) == 0: - raise InvalidModelConfigException(f"{self.model_path.as_posix()}: no .safetensors files found") - - # shortcut: look for the quantization in the name - if any(x for x in files if "llm_int8" in x.as_posix()): - return ModelFormat.BnbQuantizedLlmInt8b - - # more reliable path: probe contents for a 'SCB' key - ckpt = read_checkpoint_meta(files[0], scan=True) - if any("SCB" in x for x in ckpt.keys()): - return ModelFormat.BnbQuantizedLlmInt8b - - raise InvalidModelConfigException(f"{self.model_path.as_posix()}: unknown model format") - - -class ONNXFolderProbe(PipelineFolderProbe): - def get_base_type(self) -> BaseModelType: - # Due to the way the installer is set up, the configuration file for safetensors - # will come along for the ride if both the onnx and safetensors forms - # share the same directory. We take advantage of this here. - if (self.model_path / "unet" / "config.json").exists(): - return super().get_base_type() - else: - logger.warning('Base type probing is not implemented for ONNX models. Assuming "sd-1"') - return BaseModelType.StableDiffusion1 - - def get_format(self) -> ModelFormat: - return ModelFormat("onnx") - - def get_variant_type(self) -> ModelVariantType: - return ModelVariantType.Normal - - -class ControlNetFolderProbe(FolderProbeBase): - def get_base_type(self) -> BaseModelType: - config_file = self.model_path / "config.json" - if not config_file.exists(): - raise InvalidModelConfigException(f"Cannot determine base type for {self.model_path}") - with open(config_file, "r") as file: - config = json.load(file) - - if config.get("_class_name", None) == "FluxControlNetModel": - return BaseModelType.Flux - - # no obvious way to distinguish between sd2-base and sd2-768 - dimension = config["cross_attention_dim"] - if dimension == 768: - return BaseModelType.StableDiffusion1 - if dimension == 1024: - return BaseModelType.StableDiffusion2 - if dimension == 2048: - return BaseModelType.StableDiffusionXL - raise InvalidModelConfigException(f"Unable to determine model base for {self.model_path}") - - -class LoRAFolderProbe(FolderProbeBase): - def get_base_type(self) -> BaseModelType: - model_file = None - for suffix in ["safetensors", "bin"]: - base_file = self.model_path / f"pytorch_lora_weights.{suffix}" - if base_file.exists(): - model_file = base_file - break - if not model_file: - raise InvalidModelConfigException("Unknown LoRA format encountered") - return LoRACheckpointProbe(model_file).get_base_type() - - -class IPAdapterFolderProbe(FolderProbeBase): - def get_format(self) -> ModelFormat: - return ModelFormat.InvokeAI - - def get_base_type(self) -> BaseModelType: - model_file = self.model_path / "ip_adapter.bin" - if not model_file.exists(): - raise InvalidModelConfigException("Unknown IP-Adapter model format.") - - state_dict = torch.load(model_file, map_location="cpu") - cross_attention_dim = state_dict["ip_adapter"]["1.to_k_ip.weight"].shape[-1] - if cross_attention_dim == 768: - return BaseModelType.StableDiffusion1 - elif cross_attention_dim == 1024: - return BaseModelType.StableDiffusion2 - elif cross_attention_dim == 2048: - return BaseModelType.StableDiffusionXL - else: - raise InvalidModelConfigException( - f"IP-Adapter had unexpected cross-attention dimension: {cross_attention_dim}." - ) - - def get_image_encoder_model_id(self) -> Optional[str]: - encoder_id_path = self.model_path / "image_encoder.txt" - if not encoder_id_path.exists(): - return None - with open(encoder_id_path, "r") as f: - image_encoder_model = f.readline().strip() - return image_encoder_model - - -class CLIPVisionFolderProbe(FolderProbeBase): - def get_base_type(self) -> BaseModelType: - return BaseModelType.Any - - -class CLIPEmbedFolderProbe(FolderProbeBase): - def get_base_type(self) -> BaseModelType: - return BaseModelType.Any - - -class SpandrelImageToImageFolderProbe(FolderProbeBase): - def get_base_type(self) -> BaseModelType: - raise NotImplementedError() - - -class SigLIPFolderProbe(FolderProbeBase): - def get_base_type(self) -> BaseModelType: - return BaseModelType.Any - - -class FluxReduxFolderProbe(FolderProbeBase): - def get_base_type(self) -> BaseModelType: - raise NotImplementedError() - - -class LlaveOnevisionFolderProbe(FolderProbeBase): - def get_base_type(self) -> BaseModelType: - return BaseModelType.Any - - -class T2IAdapterFolderProbe(FolderProbeBase): - def get_base_type(self) -> BaseModelType: - config_file = self.model_path / "config.json" - if not config_file.exists(): - raise InvalidModelConfigException(f"Cannot determine base type for {self.model_path}") - with open(config_file, "r") as file: - config = json.load(file) - - adapter_type = config.get("adapter_type", None) - if adapter_type == "full_adapter_xl": - return BaseModelType.StableDiffusionXL - elif adapter_type == "full_adapter" or "light_adapter": - # I haven't seen any T2I adapter models for SD2, so assume that this is an SD1 adapter. - return BaseModelType.StableDiffusion1 - else: - raise InvalidModelConfigException( - f"Unable to determine base model for '{self.model_path}' (adapter_type = {adapter_type})." - ) - - -# Register probe classes -ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe) -ModelProbe.register_probe("diffusers", ModelType.VAE, VaeFolderProbe) -ModelProbe.register_probe("diffusers", ModelType.LoRA, LoRAFolderProbe) -ModelProbe.register_probe("diffusers", ModelType.ControlLoRa, LoRAFolderProbe) -ModelProbe.register_probe("diffusers", ModelType.TextualInversion, TextualInversionFolderProbe) -ModelProbe.register_probe("diffusers", ModelType.T5Encoder, T5EncoderFolderProbe) -ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderProbe) -ModelProbe.register_probe("diffusers", ModelType.IPAdapter, IPAdapterFolderProbe) -ModelProbe.register_probe("diffusers", ModelType.CLIPEmbed, CLIPEmbedFolderProbe) -ModelProbe.register_probe("diffusers", ModelType.CLIPVision, CLIPVisionFolderProbe) -ModelProbe.register_probe("diffusers", ModelType.T2IAdapter, T2IAdapterFolderProbe) -ModelProbe.register_probe("diffusers", ModelType.SpandrelImageToImage, SpandrelImageToImageFolderProbe) -ModelProbe.register_probe("diffusers", ModelType.SigLIP, SigLIPFolderProbe) -ModelProbe.register_probe("diffusers", ModelType.FluxRedux, FluxReduxFolderProbe) -ModelProbe.register_probe("diffusers", ModelType.LlavaOnevision, LlaveOnevisionFolderProbe) - -ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe) -ModelProbe.register_probe("checkpoint", ModelType.VAE, VaeCheckpointProbe) -ModelProbe.register_probe("checkpoint", ModelType.LoRA, LoRACheckpointProbe) -ModelProbe.register_probe("checkpoint", ModelType.ControlLoRa, LoRACheckpointProbe) -ModelProbe.register_probe("checkpoint", ModelType.TextualInversion, TextualInversionCheckpointProbe) -ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpointProbe) -ModelProbe.register_probe("checkpoint", ModelType.IPAdapter, IPAdapterCheckpointProbe) -ModelProbe.register_probe("checkpoint", ModelType.CLIPVision, CLIPVisionCheckpointProbe) -ModelProbe.register_probe("checkpoint", ModelType.T2IAdapter, T2IAdapterCheckpointProbe) -ModelProbe.register_probe("checkpoint", ModelType.SpandrelImageToImage, SpandrelImageToImageCheckpointProbe) -ModelProbe.register_probe("checkpoint", ModelType.SigLIP, SigLIPCheckpointProbe) -ModelProbe.register_probe("checkpoint", ModelType.FluxRedux, FluxReduxCheckpointProbe) -ModelProbe.register_probe("checkpoint", ModelType.LlavaOnevision, LlavaOnevisionCheckpointProbe) - -ModelProbe.register_probe("onnx", ModelType.ONNX, ONNXFolderProbe) diff --git a/invokeai/backend/model_manager/load/__init__.py b/invokeai/backend/model_manager/load/__init__.py index eba7bd16a32..ec0abbfc731 100644 --- a/invokeai/backend/model_manager/load/__init__.py +++ b/invokeai/backend/model_manager/load/__init__.py @@ -3,18 +3,9 @@ Init file for the model loader. """ -from importlib import import_module -from pathlib import Path - from invokeai.backend.model_manager.load.load_base import LoadedModel, LoadedModelWithoutConfig, ModelLoaderBase from invokeai.backend.model_manager.load.load_default import ModelLoader from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache -from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry, ModelLoaderRegistryBase - -# This registers the subclasses that implement loaders of specific model types -loaders = [x.stem for x in Path(Path(__file__).parent, "model_loaders").glob("*.py") if x.stem != "__init__"] -for module in loaders: - import_module(f"{__package__}.model_loaders.{module}") __all__ = [ "LoadedModel", @@ -22,6 +13,4 @@ "ModelCache", "ModelLoaderBase", "ModelLoader", - "ModelLoaderRegistryBase", - "ModelLoaderRegistry", ] diff --git a/invokeai/backend/model_manager/load/load_base.py b/invokeai/backend/model_manager/load/load_base.py index 458fc0cfc0c..75191517c76 100644 --- a/invokeai/backend/model_manager/load/load_base.py +++ b/invokeai/backend/model_manager/load/load_base.py @@ -12,9 +12,7 @@ import torch from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.backend.model_manager.config import ( - AnyModelConfig, -) +from invokeai.backend.model_manager.config import AnyModelConfig from invokeai.backend.model_manager.load.model_cache.cache_record import CacheRecord from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache from invokeai.backend.model_manager.taxonomy import AnyModel, SubModelType diff --git a/invokeai/backend/model_manager/load/load_default.py b/invokeai/backend/model_manager/load/load_default.py index 3c26a956b76..139a7d2940b 100644 --- a/invokeai/backend/model_manager/load/load_default.py +++ b/invokeai/backend/model_manager/load/load_default.py @@ -6,7 +6,7 @@ from typing import Optional from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.backend.model_manager.config import AnyModelConfig, DiffusersConfigBase, InvalidModelConfigException +from invokeai.backend.model_manager.config import AnyModelConfig, Diffusers_Config_Base, InvalidModelConfigException from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoaderBase from invokeai.backend.model_manager.load.model_cache.cache_record import CacheRecord from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache, get_model_cache_key @@ -90,7 +90,7 @@ def get_size_fs( return calc_model_size_by_fs( model_path=model_path, subfolder=submodel_type.value if submodel_type else None, - variant=config.repo_variant if isinstance(config, DiffusersConfigBase) else None, + variant=config.repo_variant if isinstance(config, Diffusers_Config_Base) else None, ) # This needs to be implemented in the subclass diff --git a/invokeai/backend/model_manager/load/model_loader_registry.py b/invokeai/backend/model_manager/load/model_loader_registry.py deleted file mode 100644 index ecc4d1fe93b..00000000000 --- a/invokeai/backend/model_manager/load/model_loader_registry.py +++ /dev/null @@ -1,101 +0,0 @@ -# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Development team -""" -This module implements a system in which model loaders register the -type, base and format of models that they know how to load. - -Use like this: - - cls, model_config, submodel_type = ModelLoaderRegistry.get_implementation(model_config, submodel_type) # type: ignore - loaded_model = cls( - app_config=app_config, - logger=logger, - ram_cache=ram_cache, - convert_cache=convert_cache - ).load_model(model_config, submodel_type) - -""" - -from abc import ABC, abstractmethod -from typing import Callable, Dict, Optional, Tuple, Type, TypeVar - -from invokeai.backend.model_manager.config import ( - AnyModelConfig, - ModelConfigBase, -) -from invokeai.backend.model_manager.load import ModelLoaderBase -from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelFormat, ModelType, SubModelType - - -class ModelLoaderRegistryBase(ABC): - """This class allows model loaders to register their type, base and format.""" - - @classmethod - @abstractmethod - def register( - cls, type: ModelType, format: ModelFormat, base: BaseModelType = BaseModelType.Any - ) -> Callable[[Type[ModelLoaderBase]], Type[ModelLoaderBase]]: - """Define a decorator which registers the subclass of loader.""" - - @classmethod - @abstractmethod - def get_implementation( - cls, config: AnyModelConfig, submodel_type: Optional[SubModelType] - ) -> Tuple[Type[ModelLoaderBase], ModelConfigBase, Optional[SubModelType]]: - """ - Get subclass of ModelLoaderBase registered to handle base and type. - - Parameters: - :param config: Model configuration record, as returned by ModelRecordService - :param submodel_type: Submodel to fetch (main models only) - :return: tuple(loader_class, model_config, submodel_type) - - Note that the returned model config may be different from one what passed - in, in the event that a submodel type is provided. - """ - - -TModelLoader = TypeVar("TModelLoader", bound=ModelLoaderBase) - - -class ModelLoaderRegistry(ModelLoaderRegistryBase): - """ - This class allows model loaders to register their type, base and format. - """ - - _registry: Dict[str, Type[ModelLoaderBase]] = {} - - @classmethod - def register( - cls, type: ModelType, format: ModelFormat, base: BaseModelType = BaseModelType.Any - ) -> Callable[[Type[TModelLoader]], Type[TModelLoader]]: - """Define a decorator which registers the subclass of loader.""" - - def decorator(subclass: Type[TModelLoader]) -> Type[TModelLoader]: - key = cls._to_registry_key(base, type, format) - if key in cls._registry: - raise Exception( - f"{subclass.__name__} is trying to register as a loader for {base}/{type}/{format}, but this type of model has already been registered by {cls._registry[key].__name__}" - ) - cls._registry[key] = subclass - return subclass - - return decorator - - @classmethod - def get_implementation( - cls, config: AnyModelConfig, submodel_type: Optional[SubModelType] - ) -> Tuple[Type[ModelLoaderBase], ModelConfigBase, Optional[SubModelType]]: - """Get subclass of ModelLoaderBase registered to handle base and type.""" - - key1 = cls._to_registry_key(config.base, config.type, config.format) # for a specific base type - key2 = cls._to_registry_key(BaseModelType.Any, config.type, config.format) # with wildcard Any - implementation = cls._registry.get(key1) or cls._registry.get(key2) - if not implementation: - raise NotImplementedError( - f"No subclass of LoadedModel is registered for base={config.base}, type={config.type}, format={config.format}" - ) - return implementation, config, submodel_type - - @staticmethod - def _to_registry_key(base: BaseModelType, type: ModelType, format: ModelFormat) -> str: - return "-".join([base.value, type.value, format.value]) diff --git a/invokeai/backend/model_manager/load/model_loaders/__init__.py b/invokeai/backend/model_manager/load/model_loaders/__init__.py deleted file mode 100644 index 962cba54811..00000000000 --- a/invokeai/backend/model_manager/load/model_loaders/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -""" -Init file for model_loaders. -""" diff --git a/invokeai/backend/model_manager/load/model_loaders/clip_vision.py b/invokeai/backend/model_manager/load/model_loaders/clip_vision.py deleted file mode 100644 index 29d7bc691cf..00000000000 --- a/invokeai/backend/model_manager/load/model_loaders/clip_vision.py +++ /dev/null @@ -1,37 +0,0 @@ -from pathlib import Path -from typing import Optional - -from transformers import CLIPVisionModelWithProjection - -from invokeai.backend.model_manager.config import ( - AnyModelConfig, - DiffusersConfigBase, -) -from invokeai.backend.model_manager.load.load_default import ModelLoader -from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry -from invokeai.backend.model_manager.taxonomy import AnyModel, BaseModelType, ModelFormat, ModelType, SubModelType - - -@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.CLIPVision, format=ModelFormat.Diffusers) -class ClipVisionLoader(ModelLoader): - """Class to load CLIPVision models.""" - - def _load_model( - self, - config: AnyModelConfig, - submodel_type: Optional[SubModelType] = None, - ) -> AnyModel: - if not isinstance(config, DiffusersConfigBase): - raise ValueError("Only DiffusersConfigBase models are currently supported here.") - - if submodel_type is not None: - raise Exception("There are no submodels in CLIP Vision models.") - - model_path = Path(config.path) - - model = CLIPVisionModelWithProjection.from_pretrained( - model_path, torch_dtype=self._torch_dtype, local_files_only=True - ) - assert isinstance(model, CLIPVisionModelWithProjection) - - return model diff --git a/invokeai/backend/model_manager/load/model_loaders/cogview4.py b/invokeai/backend/model_manager/load/model_loaders/cogview4.py deleted file mode 100644 index e7669a33c42..00000000000 --- a/invokeai/backend/model_manager/load/model_loaders/cogview4.py +++ /dev/null @@ -1,60 +0,0 @@ -from pathlib import Path -from typing import Optional - -import torch - -from invokeai.backend.model_manager.config import ( - AnyModelConfig, - CheckpointConfigBase, - DiffusersConfigBase, -) -from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry -from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader -from invokeai.backend.model_manager.taxonomy import ( - AnyModel, - BaseModelType, - ModelFormat, - ModelType, - SubModelType, -) - - -@ModelLoaderRegistry.register(base=BaseModelType.CogView4, type=ModelType.Main, format=ModelFormat.Diffusers) -class CogView4DiffusersModel(GenericDiffusersLoader): - """Class to load CogView4 main models.""" - - def _load_model( - self, - config: AnyModelConfig, - submodel_type: Optional[SubModelType] = None, - ) -> AnyModel: - if isinstance(config, CheckpointConfigBase): - raise NotImplementedError("CheckpointConfigBase is not implemented for CogView4 models.") - - if submodel_type is None: - raise Exception("A submodel type must be provided when loading main pipelines.") - - model_path = Path(config.path) - load_class = self.get_hf_load_class(model_path, submodel_type) - repo_variant = config.repo_variant if isinstance(config, DiffusersConfigBase) else None - variant = repo_variant.value if repo_variant else None - model_path = model_path / submodel_type.value - - # We force bfloat16 for CogView4 models. It produces black images with float16. I haven't tracked down - # specifically which model(s) is/are responsible. - dtype = torch.bfloat16 - try: - result: AnyModel = load_class.from_pretrained( - model_path, - torch_dtype=dtype, - variant=variant, - ) - except OSError as e: - if variant and "no file named" in str( - e - ): # try without the variant, just in case user's preferences changed - result = load_class.from_pretrained(model_path, torch_dtype=dtype) - else: - raise e - - return result diff --git a/invokeai/backend/model_manager/load/model_loaders/controlnet.py b/invokeai/backend/model_manager/load/model_loaders/controlnet.py deleted file mode 100644 index 5bf93db3816..00000000000 --- a/invokeai/backend/model_manager/load/model_loaders/controlnet.py +++ /dev/null @@ -1,55 +0,0 @@ -# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team -"""Class for ControlNet model loading in InvokeAI.""" - -from typing import Optional - -from diffusers import ControlNetModel - -from invokeai.backend.model_manager.config import ( - AnyModelConfig, - ControlNetCheckpointConfig, -) -from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry -from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader -from invokeai.backend.model_manager.taxonomy import ( - AnyModel, - BaseModelType, - ModelFormat, - ModelType, - SubModelType, -) - - -@ModelLoaderRegistry.register( - base=BaseModelType.StableDiffusion1, type=ModelType.ControlNet, format=ModelFormat.Diffusers -) -@ModelLoaderRegistry.register( - base=BaseModelType.StableDiffusion1, type=ModelType.ControlNet, format=ModelFormat.Checkpoint -) -@ModelLoaderRegistry.register( - base=BaseModelType.StableDiffusion2, type=ModelType.ControlNet, format=ModelFormat.Diffusers -) -@ModelLoaderRegistry.register( - base=BaseModelType.StableDiffusion2, type=ModelType.ControlNet, format=ModelFormat.Checkpoint -) -@ModelLoaderRegistry.register( - base=BaseModelType.StableDiffusionXL, type=ModelType.ControlNet, format=ModelFormat.Diffusers -) -@ModelLoaderRegistry.register( - base=BaseModelType.StableDiffusionXL, type=ModelType.ControlNet, format=ModelFormat.Checkpoint -) -class ControlNetLoader(GenericDiffusersLoader): - """Class to load ControlNet models.""" - - def _load_model( - self, - config: AnyModelConfig, - submodel_type: Optional[SubModelType] = None, - ) -> AnyModel: - if isinstance(config, ControlNetCheckpointConfig): - return ControlNetModel.from_single_file( - config.path, - torch_dtype=self._torch_dtype, - ) - else: - return super()._load_model(config, submodel_type) diff --git a/invokeai/backend/model_manager/load/model_loaders/flux.py b/invokeai/backend/model_manager/load/model_loaders/flux.py deleted file mode 100644 index 6ea7b539252..00000000000 --- a/invokeai/backend/model_manager/load/model_loaders/flux.py +++ /dev/null @@ -1,424 +0,0 @@ -# Copyright (c) 2024, Brandon W. Rising and the InvokeAI Development Team -"""Class for Flux model loading in InvokeAI.""" - -from pathlib import Path -from typing import Optional - -import accelerate -import torch -from safetensors.torch import load_file -from transformers import ( - AutoConfig, - AutoModelForTextEncoding, - CLIPTextModel, - CLIPTokenizer, - T5EncoderModel, - T5TokenizerFast, -) - -from invokeai.app.services.config.config_default import get_config -from invokeai.backend.flux.controlnet.instantx_controlnet_flux import InstantXControlNetFlux -from invokeai.backend.flux.controlnet.state_dict_utils import ( - convert_diffusers_instantx_state_dict_to_bfl_format, - infer_flux_params_from_state_dict, - infer_instantx_num_control_modes_from_state_dict, - is_state_dict_instantx_controlnet, - is_state_dict_xlabs_controlnet, -) -from invokeai.backend.flux.controlnet.xlabs_controlnet_flux import XLabsControlNetFlux -from invokeai.backend.flux.ip_adapter.state_dict_utils import infer_xlabs_ip_adapter_params_from_state_dict -from invokeai.backend.flux.ip_adapter.xlabs_ip_adapter_flux import ( - XlabsIpAdapterFlux, -) -from invokeai.backend.flux.model import Flux -from invokeai.backend.flux.modules.autoencoder import AutoEncoder -from invokeai.backend.flux.redux.flux_redux_model import FluxReduxModel -from invokeai.backend.flux.util import ae_params, params -from invokeai.backend.model_manager.config import ( - AnyModelConfig, - CheckpointConfigBase, - CLIPEmbedDiffusersConfig, - ControlNetCheckpointConfig, - ControlNetDiffusersConfig, - FluxReduxConfig, - IPAdapterCheckpointConfig, - MainBnbQuantized4bCheckpointConfig, - MainCheckpointConfig, - MainGGUFCheckpointConfig, - T5EncoderBnbQuantizedLlmInt8bConfig, - T5EncoderConfig, - VAECheckpointConfig, -) -from invokeai.backend.model_manager.load.load_default import ModelLoader -from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry -from invokeai.backend.model_manager.taxonomy import ( - AnyModel, - BaseModelType, - ModelFormat, - ModelType, - SubModelType, -) -from invokeai.backend.model_manager.util.model_util import ( - convert_bundle_to_flux_transformer_checkpoint, -) -from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader -from invokeai.backend.quantization.gguf.utils import TORCH_COMPATIBLE_QTYPES -from invokeai.backend.util.silence_warnings import SilenceWarnings - -try: - from invokeai.backend.quantization.bnb_llm_int8 import quantize_model_llm_int8 - from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4 - - bnb_available = True -except ImportError: - bnb_available = False - -app_config = get_config() - - -@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.VAE, format=ModelFormat.Checkpoint) -class FluxVAELoader(ModelLoader): - """Class to load VAE models.""" - - def _load_model( - self, - config: AnyModelConfig, - submodel_type: Optional[SubModelType] = None, - ) -> AnyModel: - if not isinstance(config, VAECheckpointConfig): - raise ValueError("Only VAECheckpointConfig models are currently supported here.") - model_path = Path(config.path) - - with accelerate.init_empty_weights(): - model = AutoEncoder(ae_params[config.config_path]) - sd = load_file(model_path) - model.load_state_dict(sd, assign=True) - # VAE is broken in float16, which mps defaults to - if self._torch_dtype == torch.float16: - try: - vae_dtype = torch.tensor([1.0], dtype=torch.bfloat16, device=self._torch_device).dtype - except TypeError: - vae_dtype = torch.float32 - else: - vae_dtype = self._torch_dtype - model.to(vae_dtype) - - return model - - -@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.CLIPEmbed, format=ModelFormat.Diffusers) -class ClipCheckpointModel(ModelLoader): - """Class to load main models.""" - - def _load_model( - self, - config: AnyModelConfig, - submodel_type: Optional[SubModelType] = None, - ) -> AnyModel: - if not isinstance(config, CLIPEmbedDiffusersConfig): - raise ValueError("Only CLIPEmbedDiffusersConfig models are currently supported here.") - - match submodel_type: - case SubModelType.Tokenizer: - return CLIPTokenizer.from_pretrained(Path(config.path) / "tokenizer") - case SubModelType.TextEncoder: - return CLIPTextModel.from_pretrained(Path(config.path) / "text_encoder") - - raise ValueError( - f"Only Tokenizer and TextEncoder submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}" - ) - - -@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T5Encoder, format=ModelFormat.BnbQuantizedLlmInt8b) -class BnbQuantizedLlmInt8bCheckpointModel(ModelLoader): - """Class to load main models.""" - - def _load_model( - self, - config: AnyModelConfig, - submodel_type: Optional[SubModelType] = None, - ) -> AnyModel: - if not isinstance(config, T5EncoderBnbQuantizedLlmInt8bConfig): - raise ValueError("Only T5EncoderBnbQuantizedLlmInt8bConfig models are currently supported here.") - if not bnb_available: - raise ImportError( - "The bnb modules are not available. Please install bitsandbytes if available on your platform." - ) - match submodel_type: - case SubModelType.Tokenizer2 | SubModelType.Tokenizer3: - return T5TokenizerFast.from_pretrained(Path(config.path) / "tokenizer_2", max_length=512) - case SubModelType.TextEncoder2 | SubModelType.TextEncoder3: - te2_model_path = Path(config.path) / "text_encoder_2" - model_config = AutoConfig.from_pretrained(te2_model_path) - with accelerate.init_empty_weights(): - model = AutoModelForTextEncoding.from_config(model_config) - model = quantize_model_llm_int8(model, modules_to_not_convert=set()) - - state_dict_path = te2_model_path / "bnb_llm_int8_model.safetensors" - state_dict = load_file(state_dict_path) - self._load_state_dict_into_t5(model, state_dict) - - return model - - raise ValueError( - f"Only Tokenizer and TextEncoder submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}" - ) - - @classmethod - def _load_state_dict_into_t5(cls, model: T5EncoderModel, state_dict: dict[str, torch.Tensor]): - # There is a shared reference to a single weight tensor in the model. - # Both "encoder.embed_tokens.weight" and "shared.weight" refer to the same tensor, so only the latter should - # be present in the state_dict. - missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False, assign=True) - assert len(unexpected_keys) == 0 - assert set(missing_keys) == {"encoder.embed_tokens.weight"} - # Assert that the layers we expect to be shared are actually shared. - assert model.encoder.embed_tokens.weight is model.shared.weight - - -@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T5Encoder, format=ModelFormat.T5Encoder) -class T5EncoderCheckpointModel(ModelLoader): - """Class to load main models.""" - - def _load_model( - self, - config: AnyModelConfig, - submodel_type: Optional[SubModelType] = None, - ) -> AnyModel: - if not isinstance(config, T5EncoderConfig): - raise ValueError("Only T5EncoderConfig models are currently supported here.") - - match submodel_type: - case SubModelType.Tokenizer2 | SubModelType.Tokenizer3: - return T5TokenizerFast.from_pretrained(Path(config.path) / "tokenizer_2", max_length=512) - case SubModelType.TextEncoder2 | SubModelType.TextEncoder3: - return T5EncoderModel.from_pretrained( - Path(config.path) / "text_encoder_2", torch_dtype="auto", low_cpu_mem_usage=True - ) - - raise ValueError( - f"Only Tokenizer and TextEncoder submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}" - ) - - -@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.Main, format=ModelFormat.Checkpoint) -class FluxCheckpointModel(ModelLoader): - """Class to load main models.""" - - def _load_model( - self, - config: AnyModelConfig, - submodel_type: Optional[SubModelType] = None, - ) -> AnyModel: - if not isinstance(config, CheckpointConfigBase): - raise ValueError("Only CheckpointConfigBase models are currently supported here.") - - match submodel_type: - case SubModelType.Transformer: - return self._load_from_singlefile(config) - - raise ValueError( - f"Only Transformer submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}" - ) - - def _load_from_singlefile( - self, - config: AnyModelConfig, - ) -> AnyModel: - assert isinstance(config, MainCheckpointConfig) - model_path = Path(config.path) - - with accelerate.init_empty_weights(): - model = Flux(params[config.config_path]) - - sd = load_file(model_path) - if "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale" in sd: - sd = convert_bundle_to_flux_transformer_checkpoint(sd) - new_sd_size = sum([ten.nelement() * torch.bfloat16.itemsize for ten in sd.values()]) - self._ram_cache.make_room(new_sd_size) - for k in sd.keys(): - # We need to cast to bfloat16 due to it being the only currently supported dtype for inference - sd[k] = sd[k].to(torch.bfloat16) - model.load_state_dict(sd, assign=True) - return model - - -@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.Main, format=ModelFormat.GGUFQuantized) -class FluxGGUFCheckpointModel(ModelLoader): - """Class to load GGUF main models.""" - - def _load_model( - self, - config: AnyModelConfig, - submodel_type: Optional[SubModelType] = None, - ) -> AnyModel: - if not isinstance(config, CheckpointConfigBase): - raise ValueError("Only CheckpointConfigBase models are currently supported here.") - - match submodel_type: - case SubModelType.Transformer: - return self._load_from_singlefile(config) - - raise ValueError( - f"Only Transformer submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}" - ) - - def _load_from_singlefile( - self, - config: AnyModelConfig, - ) -> AnyModel: - assert isinstance(config, MainGGUFCheckpointConfig) - model_path = Path(config.path) - - with accelerate.init_empty_weights(): - model = Flux(params[config.config_path]) - - # HACK(ryand): We shouldn't be hard-coding the compute_dtype here. - sd = gguf_sd_loader(model_path, compute_dtype=torch.bfloat16) - - # HACK(ryand): There are some broken GGUF models in circulation that have the wrong shape for img_in.weight. - # We override the shape here to fix the issue. - # Example model with this issue (Q4_K_M): https://civitai.com/models/705823/ggufk-flux-unchained-km-quants - img_in_weight = sd.get("img_in.weight", None) - if img_in_weight is not None and img_in_weight._ggml_quantization_type in TORCH_COMPATIBLE_QTYPES: - expected_img_in_weight_shape = model.img_in.weight.shape - img_in_weight.quantized_data = img_in_weight.quantized_data.view(expected_img_in_weight_shape) - img_in_weight.tensor_shape = expected_img_in_weight_shape - - model.load_state_dict(sd, assign=True) - return model - - -@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.Main, format=ModelFormat.BnbQuantizednf4b) -class FluxBnbQuantizednf4bCheckpointModel(ModelLoader): - """Class to load main models.""" - - def _load_model( - self, - config: AnyModelConfig, - submodel_type: Optional[SubModelType] = None, - ) -> AnyModel: - if not isinstance(config, CheckpointConfigBase): - raise ValueError("Only CheckpointConfigBase models are currently supported here.") - - match submodel_type: - case SubModelType.Transformer: - return self._load_from_singlefile(config) - - raise ValueError( - f"Only Transformer submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}" - ) - - def _load_from_singlefile( - self, - config: AnyModelConfig, - ) -> AnyModel: - assert isinstance(config, MainBnbQuantized4bCheckpointConfig) - if not bnb_available: - raise ImportError( - "The bnb modules are not available. Please install bitsandbytes if available on your platform." - ) - model_path = Path(config.path) - - with SilenceWarnings(): - with accelerate.init_empty_weights(): - model = Flux(params[config.config_path]) - model = quantize_model_nf4(model, modules_to_not_convert=set(), compute_dtype=torch.bfloat16) - sd = load_file(model_path) - if "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale" in sd: - sd = convert_bundle_to_flux_transformer_checkpoint(sd) - model.load_state_dict(sd, assign=True) - return model - - -@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.ControlNet, format=ModelFormat.Checkpoint) -@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.ControlNet, format=ModelFormat.Diffusers) -class FluxControlnetModel(ModelLoader): - """Class to load FLUX ControlNet models.""" - - def _load_model( - self, - config: AnyModelConfig, - submodel_type: Optional[SubModelType] = None, - ) -> AnyModel: - if isinstance(config, ControlNetCheckpointConfig): - model_path = Path(config.path) - elif isinstance(config, ControlNetDiffusersConfig): - # If this is a diffusers directory, we simply ignore the config file and load from the weight file. - model_path = Path(config.path) / "diffusion_pytorch_model.safetensors" - else: - raise ValueError(f"Unexpected ControlNet model config type: {type(config)}") - - sd = load_file(model_path) - - # Detect the FLUX ControlNet model type from the state dict. - if is_state_dict_xlabs_controlnet(sd): - return self._load_xlabs_controlnet(sd) - elif is_state_dict_instantx_controlnet(sd): - return self._load_instantx_controlnet(sd) - else: - raise ValueError("Do not recognize the state dict as an XLabs or InstantX ControlNet model.") - - def _load_xlabs_controlnet(self, sd: dict[str, torch.Tensor]) -> AnyModel: - with accelerate.init_empty_weights(): - # HACK(ryand): Is it safe to assume dev here? - model = XLabsControlNetFlux(params["flux-dev"]) - - model.load_state_dict(sd, assign=True) - return model - - def _load_instantx_controlnet(self, sd: dict[str, torch.Tensor]) -> AnyModel: - sd = convert_diffusers_instantx_state_dict_to_bfl_format(sd) - flux_params = infer_flux_params_from_state_dict(sd) - num_control_modes = infer_instantx_num_control_modes_from_state_dict(sd) - - with accelerate.init_empty_weights(): - model = InstantXControlNetFlux(flux_params, num_control_modes) - - model.load_state_dict(sd, assign=True) - return model - - -@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.IPAdapter, format=ModelFormat.Checkpoint) -class FluxIpAdapterModel(ModelLoader): - """Class to load FLUX IP-Adapter models.""" - - def _load_model( - self, - config: AnyModelConfig, - submodel_type: Optional[SubModelType] = None, - ) -> AnyModel: - if not isinstance(config, IPAdapterCheckpointConfig): - raise ValueError(f"Unexpected model config type: {type(config)}.") - - sd = load_file(Path(config.path)) - - params = infer_xlabs_ip_adapter_params_from_state_dict(sd) - - with accelerate.init_empty_weights(): - model = XlabsIpAdapterFlux(params=params) - - model.load_xlabs_state_dict(sd, assign=True) - return model - - -@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.FluxRedux, format=ModelFormat.Checkpoint) -class FluxReduxModelLoader(ModelLoader): - """Class to load FLUX Redux models.""" - - def _load_model( - self, - config: AnyModelConfig, - submodel_type: Optional[SubModelType] = None, - ) -> AnyModel: - if not isinstance(config, FluxReduxConfig): - raise ValueError(f"Unexpected model config type: {type(config)}.") - - sd = load_file(Path(config.path)) - - with accelerate.init_empty_weights(): - model = FluxReduxModel() - - model.load_state_dict(sd, assign=True) - model.to(dtype=torch.bfloat16) - return model diff --git a/invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py b/invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py deleted file mode 100644 index 8a690583d5d..00000000000 --- a/invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py +++ /dev/null @@ -1,102 +0,0 @@ -# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team -"""Class for simple diffusers model loading in InvokeAI.""" - -import sys -from pathlib import Path -from typing import Any, Optional - -from diffusers.configuration_utils import ConfigMixin -from diffusers.models.modeling_utils import ModelMixin - -from invokeai.backend.model_manager.config import AnyModelConfig, DiffusersConfigBase, InvalidModelConfigException -from invokeai.backend.model_manager.load.load_default import ModelLoader -from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry -from invokeai.backend.model_manager.taxonomy import ( - AnyModel, - BaseModelType, - ModelFormat, - ModelType, - SubModelType, -) - - -@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.T2IAdapter, format=ModelFormat.Diffusers) -class GenericDiffusersLoader(ModelLoader): - """Class to load simple diffusers models.""" - - def _load_model( - self, - config: AnyModelConfig, - submodel_type: Optional[SubModelType] = None, - ) -> AnyModel: - model_path = Path(config.path) - model_class = self.get_hf_load_class(model_path) - if submodel_type is not None: - raise Exception(f"There are no submodels in models of type {model_class}") - repo_variant = config.repo_variant if isinstance(config, DiffusersConfigBase) else None - variant = repo_variant.value if repo_variant else None - try: - result: AnyModel = model_class.from_pretrained(model_path, torch_dtype=self._torch_dtype, variant=variant) - except OSError as e: - if variant and "no file named" in str( - e - ): # try without the variant, just in case user's preferences changed - result = model_class.from_pretrained(model_path, torch_dtype=self._torch_dtype) - else: - raise e - return result - - # TO DO: Add exception handling - def get_hf_load_class(self, model_path: Path, submodel_type: Optional[SubModelType] = None) -> ModelMixin: - """Given the model path and submodel, returns the diffusers ModelMixin subclass needed to load.""" - result = None - if submodel_type: - try: - config = self._load_diffusers_config(model_path, config_name="model_index.json") - module, class_name = config[submodel_type.value] - result = self._hf_definition_to_type(module=module, class_name=class_name) - except KeyError as e: - raise InvalidModelConfigException( - f'The "{submodel_type}" submodel is not available for this model.' - ) from e - else: - try: - config = self._load_diffusers_config(model_path, config_name="config.json") - if class_name := config.get("_class_name"): - result = self._hf_definition_to_type(module="diffusers", class_name=class_name) - elif class_name := config.get("architectures"): - result = self._hf_definition_to_type(module="transformers", class_name=class_name[0]) - else: - raise InvalidModelConfigException("Unable to decipher Load Class based on given config.json") - except KeyError as e: - raise InvalidModelConfigException("An expected config.json file is missing from this model.") from e - assert result is not None - return result - - # TO DO: Add exception handling - def _hf_definition_to_type(self, module: str, class_name: str) -> ModelMixin: # fix with correct type - if module in [ - "diffusers", - "transformers", - "invokeai.backend.quantization.fast_quantized_transformers_model", - "invokeai.backend.quantization.fast_quantized_diffusion_model", - ]: - res_type = sys.modules[module] - else: - res_type = sys.modules["diffusers"].pipelines - result: ModelMixin = getattr(res_type, class_name) - return result - - def _load_diffusers_config(self, model_path: Path, config_name: str = "config.json") -> dict[str, Any]: - return ConfigLoader.load_config(model_path, config_name=config_name) - - -class ConfigLoader(ConfigMixin): - """Subclass of ConfigMixin for loading diffusers configuration files.""" - - @classmethod - def load_config(cls, *args: Any, **kwargs: Any) -> dict[str, Any]: # pyright: ignore [reportIncompatibleMethodOverride] - """Load a diffusrs ConfigMixin configuration.""" - cls.config_name = kwargs.pop("config_name") - # TODO(psyche): the types on this diffusers method are not correct - return super().load_config(*args, **kwargs) # type: ignore diff --git a/invokeai/backend/model_manager/load/model_loaders/ip_adapter.py b/invokeai/backend/model_manager/load/model_loaders/ip_adapter.py deleted file mode 100644 index d103bc5dbcb..00000000000 --- a/invokeai/backend/model_manager/load/model_loaders/ip_adapter.py +++ /dev/null @@ -1,34 +0,0 @@ -# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team -"""Class for IP Adapter model loading in InvokeAI.""" - -from pathlib import Path -from typing import Optional - -import torch - -from invokeai.backend.ip_adapter.ip_adapter import build_ip_adapter -from invokeai.backend.model_manager.config import AnyModelConfig -from invokeai.backend.model_manager.load import ModelLoader, ModelLoaderRegistry -from invokeai.backend.model_manager.taxonomy import AnyModel, BaseModelType, ModelFormat, ModelType, SubModelType -from invokeai.backend.raw_model import RawModel - - -@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.IPAdapter, format=ModelFormat.InvokeAI) -@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.IPAdapter, format=ModelFormat.Checkpoint) -class IPAdapterInvokeAILoader(ModelLoader): - """Class to load IP Adapter diffusers models.""" - - def _load_model( - self, - config: AnyModelConfig, - submodel_type: Optional[SubModelType] = None, - ) -> AnyModel: - if submodel_type is not None: - raise ValueError("There are no submodels in an IP-Adapter model.") - model_path = Path(config.path) - model: RawModel = build_ip_adapter( - ip_adapter_ckpt_path=model_path, - device=torch.device("cpu"), - dtype=self._torch_dtype, - ) - return model diff --git a/invokeai/backend/model_manager/load/model_loaders/llava_onevision.py b/invokeai/backend/model_manager/load/model_loaders/llava_onevision.py deleted file mode 100644 index b508137f814..00000000000 --- a/invokeai/backend/model_manager/load/model_loaders/llava_onevision.py +++ /dev/null @@ -1,31 +0,0 @@ -from pathlib import Path -from typing import Optional - -from transformers import LlavaOnevisionForConditionalGeneration - -from invokeai.backend.model_manager.config import ( - AnyModelConfig, -) -from invokeai.backend.model_manager.load.load_default import ModelLoader -from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry -from invokeai.backend.model_manager.taxonomy import AnyModel, BaseModelType, ModelFormat, ModelType, SubModelType - - -@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.LlavaOnevision, format=ModelFormat.Diffusers) -class LlavaOnevisionModelLoader(ModelLoader): - """Class for loading LLaVA Onevision VLLM models.""" - - def _load_model( - self, - config: AnyModelConfig, - submodel_type: Optional[SubModelType] = None, - ) -> AnyModel: - if submodel_type is not None: - raise ValueError("Unexpected submodel requested for LLaVA OneVision model.") - - model_path = Path(config.path) - model = LlavaOnevisionForConditionalGeneration.from_pretrained( - model_path, local_files_only=True, torch_dtype=self._torch_dtype - ) - assert isinstance(model, LlavaOnevisionForConditionalGeneration) - return model diff --git a/invokeai/backend/model_manager/load/model_loaders/lora.py b/invokeai/backend/model_manager/load/model_loaders/lora.py deleted file mode 100644 index 98f54224fad..00000000000 --- a/invokeai/backend/model_manager/load/model_loaders/lora.py +++ /dev/null @@ -1,142 +0,0 @@ -# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team -"""Class for LoRA model loading in InvokeAI.""" - -from logging import Logger -from pathlib import Path -from typing import Optional - -import torch -from safetensors.torch import load_file - -from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.backend.model_manager.config import AnyModelConfig -from invokeai.backend.model_manager.load.load_default import ModelLoader -from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache -from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry -from invokeai.backend.model_manager.omi.omi import convert_from_omi -from invokeai.backend.model_manager.taxonomy import ( - AnyModel, - BaseModelType, - ModelFormat, - ModelType, - SubModelType, -) -from invokeai.backend.patches.lora_conversions.flux_aitoolkit_lora_conversion_utils import ( - is_state_dict_likely_in_flux_aitoolkit_format, - lora_model_from_flux_aitoolkit_state_dict, -) -from invokeai.backend.patches.lora_conversions.flux_control_lora_utils import ( - is_state_dict_likely_flux_control, - lora_model_from_flux_control_state_dict, -) -from invokeai.backend.patches.lora_conversions.flux_diffusers_lora_conversion_utils import ( - lora_model_from_flux_diffusers_state_dict, -) -from invokeai.backend.patches.lora_conversions.flux_kohya_lora_conversion_utils import ( - is_state_dict_likely_in_flux_kohya_format, - lora_model_from_flux_kohya_state_dict, -) -from invokeai.backend.patches.lora_conversions.flux_onetrainer_lora_conversion_utils import ( - is_state_dict_likely_in_flux_onetrainer_format, - lora_model_from_flux_onetrainer_state_dict, -) -from invokeai.backend.patches.lora_conversions.sd_lora_conversion_utils import lora_model_from_sd_state_dict -from invokeai.backend.patches.lora_conversions.sdxl_lora_conversion_utils import convert_sdxl_keys_to_diffusers_format - - -@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.LoRA, format=ModelFormat.OMI) -@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusionXL, type=ModelType.LoRA, format=ModelFormat.OMI) -@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.LoRA, format=ModelFormat.Diffusers) -@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.LoRA, format=ModelFormat.LyCORIS) -@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.ControlLoRa, format=ModelFormat.LyCORIS) -@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.ControlLoRa, format=ModelFormat.Diffusers) -class LoRALoader(ModelLoader): - """Class to load LoRA models.""" - - # We cheat a little bit to get access to the model base - def __init__( - self, - app_config: InvokeAIAppConfig, - logger: Logger, - ram_cache: ModelCache, - ): - """Initialize the loader.""" - super().__init__(app_config, logger, ram_cache) - self._model_base: Optional[BaseModelType] = None - - def _load_model( - self, - config: AnyModelConfig, - submodel_type: Optional[SubModelType] = None, - ) -> AnyModel: - if submodel_type is not None: - raise ValueError("There are no submodels in a LoRA model.") - model_path = Path(config.path) - assert self._model_base is not None - - # Load the state dict from the model file. - if model_path.suffix == ".safetensors": - state_dict = load_file(model_path.absolute().as_posix(), device="cpu") - else: - state_dict = torch.load(model_path, map_location="cpu") - - # Strip 'bundle_emb' keys - these are unused and currently cause downstream errors. - # To revisit later to determine if they're needed/useful. - state_dict = {k: v for k, v in state_dict.items() if not k.startswith("bundle_emb")} - - # At the time of writing, we support the OMI standard for base models Flux and SDXL - if config.format == ModelFormat.OMI and self._model_base in [ - BaseModelType.StableDiffusionXL, - BaseModelType.Flux, - ]: - state_dict = convert_from_omi(state_dict, config.base) # type: ignore - - # Apply state_dict key conversions, if necessary. - if self._model_base == BaseModelType.StableDiffusionXL: - state_dict = convert_sdxl_keys_to_diffusers_format(state_dict) - model = lora_model_from_sd_state_dict(state_dict=state_dict) - elif self._model_base == BaseModelType.Flux: - if config.format in [ModelFormat.Diffusers, ModelFormat.OMI]: - # HACK(ryand): We set alpha=None for diffusers PEFT format models. These models are typically - # distributed as a single file without the associated metadata containing the alpha value. We chose - # alpha=None, because this is treated as alpha=rank internally in `LoRALayerBase.scale()`. alpha=rank - # is a popular choice. For example, in the diffusers training scripts: - # https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth_lora_flux.py#L1194 - model = lora_model_from_flux_diffusers_state_dict(state_dict=state_dict, alpha=None) - elif config.format == ModelFormat.LyCORIS: - if is_state_dict_likely_in_flux_kohya_format(state_dict=state_dict): - model = lora_model_from_flux_kohya_state_dict(state_dict=state_dict) - elif is_state_dict_likely_in_flux_onetrainer_format(state_dict=state_dict): - model = lora_model_from_flux_onetrainer_state_dict(state_dict=state_dict) - elif is_state_dict_likely_flux_control(state_dict=state_dict): - model = lora_model_from_flux_control_state_dict(state_dict=state_dict) - elif is_state_dict_likely_in_flux_aitoolkit_format(state_dict=state_dict): - model = lora_model_from_flux_aitoolkit_state_dict(state_dict=state_dict) - else: - raise ValueError("LoRA model is in unsupported FLUX format") - else: - raise ValueError(f"LoRA model is in unsupported FLUX format: {config.format}") - elif self._model_base in [BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2]: - # Currently, we don't apply any conversions for SD1 and SD2 LoRA models. - model = lora_model_from_sd_state_dict(state_dict=state_dict) - else: - raise ValueError(f"Unsupported LoRA base model: {self._model_base}") - - model.to(dtype=self._torch_dtype) - return model - - def _get_model_path(self, config: AnyModelConfig) -> Path: - # cheating a little - we remember this variable for using in the subsequent call to _load_model() - self._model_base = config.base - - model_base_path = self._app_config.models_path - model_path = model_base_path / config.path - - if config.format == ModelFormat.Diffusers: - for ext in ["safetensors", "bin"]: # return path to the safetensors file inside the folder - path = model_base_path / config.path / f"pytorch_lora_weights.{ext}" - if path.exists(): - model_path = path - break - - return model_path.resolve() diff --git a/invokeai/backend/model_manager/load/model_loaders/onnx.py b/invokeai/backend/model_manager/load/model_loaders/onnx.py deleted file mode 100644 index 3078d622b4e..00000000000 --- a/invokeai/backend/model_manager/load/model_loaders/onnx.py +++ /dev/null @@ -1,42 +0,0 @@ -# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team -"""Class for Onnx model loading in InvokeAI.""" - -# This should work the same as Stable Diffusion pipelines -from pathlib import Path -from typing import Optional - -from invokeai.backend.model_manager.config import AnyModelConfig -from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry -from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader -from invokeai.backend.model_manager.taxonomy import ( - AnyModel, - BaseModelType, - ModelFormat, - ModelType, - SubModelType, -) - - -@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ONNX, format=ModelFormat.ONNX) -@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ONNX, format=ModelFormat.Olive) -class OnnyxDiffusersModel(GenericDiffusersLoader): - """Class to load onnx models.""" - - def _load_model( - self, - config: AnyModelConfig, - submodel_type: Optional[SubModelType] = None, - ) -> AnyModel: - if not submodel_type is not None: - raise Exception("A submodel type must be provided when loading onnx pipelines.") - model_path = Path(config.path) - load_class = self.get_hf_load_class(model_path, submodel_type) - repo_variant = getattr(config, "repo_variant", None) - variant = repo_variant.value if repo_variant else None - model_path = model_path / submodel_type.value - result: AnyModel = load_class.from_pretrained( - model_path, - torch_dtype=self._torch_dtype, - variant=variant, - ) - return result diff --git a/invokeai/backend/model_manager/load/model_loaders/sig_lip.py b/invokeai/backend/model_manager/load/model_loaders/sig_lip.py deleted file mode 100644 index bdf38887a3a..00000000000 --- a/invokeai/backend/model_manager/load/model_loaders/sig_lip.py +++ /dev/null @@ -1,28 +0,0 @@ -from pathlib import Path -from typing import Optional - -from transformers import SiglipVisionModel - -from invokeai.backend.model_manager.config import ( - AnyModelConfig, -) -from invokeai.backend.model_manager.load.load_default import ModelLoader -from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry -from invokeai.backend.model_manager.taxonomy import AnyModel, BaseModelType, ModelFormat, ModelType, SubModelType - - -@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.SigLIP, format=ModelFormat.Diffusers) -class SigLIPModelLoader(ModelLoader): - """Class for loading SigLIP models.""" - - def _load_model( - self, - config: AnyModelConfig, - submodel_type: Optional[SubModelType] = None, - ) -> AnyModel: - if submodel_type is not None: - raise ValueError("Unexpected submodel requested for LLaVA OneVision model.") - - model_path = Path(config.path) - model = SiglipVisionModel.from_pretrained(model_path, local_files_only=True, torch_dtype=self._torch_dtype) - return model diff --git a/invokeai/backend/model_manager/load/model_loaders/spandrel_image_to_image.py b/invokeai/backend/model_manager/load/model_loaders/spandrel_image_to_image.py deleted file mode 100644 index 44cb0277fc4..00000000000 --- a/invokeai/backend/model_manager/load/model_loaders/spandrel_image_to_image.py +++ /dev/null @@ -1,41 +0,0 @@ -from pathlib import Path -from typing import Optional - -import torch - -from invokeai.backend.model_manager.config import ( - AnyModelConfig, -) -from invokeai.backend.model_manager.load.load_default import ModelLoader -from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry -from invokeai.backend.model_manager.taxonomy import AnyModel, BaseModelType, ModelFormat, ModelType, SubModelType -from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel - - -@ModelLoaderRegistry.register( - base=BaseModelType.Any, type=ModelType.SpandrelImageToImage, format=ModelFormat.Checkpoint -) -class SpandrelImageToImageModelLoader(ModelLoader): - """Class for loading Spandrel Image-to-Image models (i.e. models wrapped by spandrel.ImageModelDescriptor).""" - - def _load_model( - self, - config: AnyModelConfig, - submodel_type: Optional[SubModelType] = None, - ) -> AnyModel: - if submodel_type is not None: - raise ValueError("Unexpected submodel requested for Spandrel model.") - - model_path = Path(config.path) - model = SpandrelImageToImageModel.load_from_file(model_path) - - torch_dtype = self._torch_dtype - if not model.supports_dtype(torch_dtype): - self._logger.warning( - f"The configured dtype ('{self._torch_dtype}') is not supported by the {model.get_model_type_name()} " - "model. Falling back to 'float32'." - ) - torch_dtype = torch.float32 - model.to(dtype=torch_dtype) - - return model diff --git a/invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py b/invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py deleted file mode 100644 index aa692478cad..00000000000 --- a/invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py +++ /dev/null @@ -1,137 +0,0 @@ -# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team -"""Class for StableDiffusion model loading in InvokeAI.""" - -from pathlib import Path -from typing import Optional - -from diffusers import ( - StableDiffusionInpaintPipeline, - StableDiffusionPipeline, - StableDiffusionXLInpaintPipeline, - StableDiffusionXLPipeline, -) - -from invokeai.backend.model_manager.config import ( - AnyModelConfig, - CheckpointConfigBase, - DiffusersConfigBase, - MainCheckpointConfig, -) -from invokeai.backend.model_manager.load.model_cache.model_cache import get_model_cache_key -from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry -from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader -from invokeai.backend.model_manager.taxonomy import ( - AnyModel, - BaseModelType, - ModelFormat, - ModelType, - ModelVariantType, - SubModelType, -) -from invokeai.backend.util.silence_warnings import SilenceWarnings - -VARIANT_TO_IN_CHANNEL_MAP = { - ModelVariantType.Normal: 4, - ModelVariantType.Depth: 5, - ModelVariantType.Inpaint: 9, -} - - -@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.Main, format=ModelFormat.Diffusers) -@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.Main, format=ModelFormat.Diffusers) -@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusionXL, type=ModelType.Main, format=ModelFormat.Diffusers) -@ModelLoaderRegistry.register( - base=BaseModelType.StableDiffusionXLRefiner, type=ModelType.Main, format=ModelFormat.Diffusers -) -@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion3, type=ModelType.Main, format=ModelFormat.Diffusers) -@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.Main, format=ModelFormat.Checkpoint) -@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.Main, format=ModelFormat.Checkpoint) -@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusionXL, type=ModelType.Main, format=ModelFormat.Checkpoint) -@ModelLoaderRegistry.register( - base=BaseModelType.StableDiffusionXLRefiner, type=ModelType.Main, format=ModelFormat.Checkpoint -) -class StableDiffusionDiffusersModel(GenericDiffusersLoader): - """Class to load main models.""" - - def _load_model( - self, - config: AnyModelConfig, - submodel_type: Optional[SubModelType] = None, - ) -> AnyModel: - if isinstance(config, CheckpointConfigBase): - return self._load_from_singlefile(config, submodel_type) - - if submodel_type is None: - raise Exception("A submodel type must be provided when loading main pipelines.") - - model_path = Path(config.path) - load_class = self.get_hf_load_class(model_path, submodel_type) - repo_variant = config.repo_variant if isinstance(config, DiffusersConfigBase) else None - variant = repo_variant.value if repo_variant else None - model_path = model_path / submodel_type.value - try: - result: AnyModel = load_class.from_pretrained( - model_path, - torch_dtype=self._torch_dtype, - variant=variant, - ) - except OSError as e: - if variant and "no file named" in str( - e - ): # try without the variant, just in case user's preferences changed - result = load_class.from_pretrained(model_path, torch_dtype=self._torch_dtype) - else: - raise e - - return result - - def _load_from_singlefile( - self, - config: AnyModelConfig, - submodel_type: Optional[SubModelType] = None, - ) -> AnyModel: - load_classes = { - BaseModelType.StableDiffusion1: { - ModelVariantType.Normal: StableDiffusionPipeline, - ModelVariantType.Inpaint: StableDiffusionInpaintPipeline, - }, - BaseModelType.StableDiffusion2: { - ModelVariantType.Normal: StableDiffusionPipeline, - ModelVariantType.Inpaint: StableDiffusionInpaintPipeline, - }, - BaseModelType.StableDiffusionXL: { - ModelVariantType.Normal: StableDiffusionXLPipeline, - ModelVariantType.Inpaint: StableDiffusionXLInpaintPipeline, - }, - BaseModelType.StableDiffusionXLRefiner: { - ModelVariantType.Normal: StableDiffusionXLPipeline, - }, - } - assert isinstance(config, MainCheckpointConfig) - try: - load_class = load_classes[config.base][config.variant] - except KeyError as e: - raise Exception(f"No diffusers pipeline known for base={config.base}, variant={config.variant}") from e - - # Without SilenceWarnings we get log messages like this: - # site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`. - # warnings.warn( - # Some weights of the model checkpoint were not used when initializing CLIPTextModel: - # ['text_model.embeddings.position_ids'] - # Some weights of the model checkpoint were not used when initializing CLIPTextModelWithProjection: - # ['text_model.embeddings.position_ids'] - - with SilenceWarnings(): - pipeline = load_class.from_single_file(config.path, torch_dtype=self._torch_dtype) - - if not submodel_type: - return pipeline - - # Proactively load the various submodels into the RAM cache so that we don't have to re-load - # the entire pipeline every time a new submodel is needed. - for subtype in SubModelType: - if subtype == submodel_type: - continue - if submodel := getattr(pipeline, subtype.value, None): - self._ram_cache.put(get_model_cache_key(config.key, subtype), model=submodel) - return getattr(pipeline, submodel_type.value) diff --git a/invokeai/backend/model_manager/load/model_loaders/textual_inversion.py b/invokeai/backend/model_manager/load/model_loaders/textual_inversion.py deleted file mode 100644 index 60ae4ea08b7..00000000000 --- a/invokeai/backend/model_manager/load/model_loaders/textual_inversion.py +++ /dev/null @@ -1,52 +0,0 @@ -# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team -"""Class for TI model loading in InvokeAI.""" - -from pathlib import Path -from typing import Optional - -from invokeai.backend.model_manager.config import AnyModelConfig -from invokeai.backend.model_manager.load.load_default import ModelLoader -from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry -from invokeai.backend.model_manager.taxonomy import ( - AnyModel, - BaseModelType, - ModelFormat, - ModelType, - SubModelType, -) -from invokeai.backend.textual_inversion import TextualInversionModelRaw - - -@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.TextualInversion, format=ModelFormat.EmbeddingFile) -@ModelLoaderRegistry.register( - base=BaseModelType.Any, type=ModelType.TextualInversion, format=ModelFormat.EmbeddingFolder -) -class TextualInversionLoader(ModelLoader): - """Class to load TI models.""" - - def _load_model( - self, - config: AnyModelConfig, - submodel_type: Optional[SubModelType] = None, - ) -> AnyModel: - if submodel_type is not None: - raise ValueError("There are no submodels in a TI model.") - model = TextualInversionModelRaw.from_checkpoint( - file_path=config.path, - dtype=self._torch_dtype, - ) - return model - - # override - def _get_model_path(self, config: AnyModelConfig) -> Path: - model_path = self._app_config.models_path / config.path - - if config.format == ModelFormat.EmbeddingFolder: - path = model_path / "learned_embeds.bin" - else: - path = model_path - - if not path.exists(): - raise OSError(f"The embedding file at {path} was not found") - - return path diff --git a/invokeai/backend/model_manager/load/model_loaders/vae.py b/invokeai/backend/model_manager/load/model_loaders/vae.py deleted file mode 100644 index 365fa0a547c..00000000000 --- a/invokeai/backend/model_manager/load/model_loaders/vae.py +++ /dev/null @@ -1,36 +0,0 @@ -# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team -"""Class for VAE model loading in InvokeAI.""" - -from typing import Optional - -from diffusers import AutoencoderKL - -from invokeai.backend.model_manager.config import AnyModelConfig, VAECheckpointConfig -from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry -from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader -from invokeai.backend.model_manager.taxonomy import ( - AnyModel, - BaseModelType, - ModelFormat, - ModelType, - SubModelType, -) - - -@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.VAE, format=ModelFormat.Diffusers) -@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.VAE, format=ModelFormat.Checkpoint) -class VAELoader(GenericDiffusersLoader): - """Class to load VAE models.""" - - def _load_model( - self, - config: AnyModelConfig, - submodel_type: Optional[SubModelType] = None, - ) -> AnyModel: - if isinstance(config, VAECheckpointConfig): - return AutoencoderKL.from_single_file( - config.path, - torch_dtype=self._torch_dtype, - ) - else: - return super()._load_model(config, submodel_type) diff --git a/invokeai/backend/model_manager/merge.py b/invokeai/backend/model_manager/merge.py deleted file mode 100644 index 03056b10f59..00000000000 --- a/invokeai/backend/model_manager/merge.py +++ /dev/null @@ -1,163 +0,0 @@ -""" -invokeai.backend.model_manager.merge exports: -merge_diffusion_models() -- combine multiple models by location and return a pipeline object -merge_diffusion_models_and_commit() -- combine multiple models by ModelManager ID and write to the models tables - -Copyright (c) 2023 Lincoln Stein and the InvokeAI Development Team -""" - -import warnings -from enum import Enum -from pathlib import Path -from typing import Any, List, Optional, Set - -import torch -from diffusers import AutoPipelineForText2Image -from diffusers.utils import logging as dlogging - -from invokeai.app.services.model_install import ModelInstallServiceBase -from invokeai.app.services.model_records.model_records_base import ModelRecordChanges -from invokeai.backend.model_manager import AnyModelConfig, BaseModelType, ModelType, ModelVariantType -from invokeai.backend.model_manager.config import MainDiffusersConfig -from invokeai.backend.util.devices import TorchDevice - - -class MergeInterpolationMethod(str, Enum): - WeightedSum = "weighted_sum" - Sigmoid = "sigmoid" - InvSigmoid = "inv_sigmoid" - AddDifference = "add_difference" - - -class ModelMerger(object): - """Wrapper class for model merge function.""" - - def __init__(self, installer: ModelInstallServiceBase): - """ - Initialize a ModelMerger object with the model installer. - """ - self._installer = installer - self._dtype = TorchDevice.choose_torch_dtype() - - def merge_diffusion_models( - self, - model_paths: List[Path], - alpha: float = 0.5, - interp: Optional[MergeInterpolationMethod] = None, - force: bool = False, - variant: Optional[str] = None, - **kwargs: Any, - ) -> Any: # pipe.merge is an untyped function. - """ - :param model_paths: up to three models, designated by their local paths or HuggingFace repo_ids - :param alpha: The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha - would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2 - :param interp: The interpolation method to use for the merging. Supports "sigmoid", "inv_sigmoid", "add_difference" and None. - Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported. - :param force: Whether to ignore mismatch in model_config.json for the current models. Defaults to False. - - **kwargs - the default DiffusionPipeline.get_config_dict kwargs: - cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map - """ - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - verbosity = dlogging.get_verbosity() - dlogging.set_verbosity_error() - dtype = torch.float16 if variant == "fp16" else self._dtype - - # Note that checkpoint_merger will not work with downloaded HuggingFace fp16 models - # until upstream https://github.com/huggingface/diffusers/pull/6670 is merged and released. - pipe = AutoPipelineForText2Image.from_pretrained( - model_paths[0], - custom_pipeline="checkpoint_merger", - torch_dtype=dtype, - variant=variant, - ) # type: ignore - merged_pipe = pipe.merge( - pretrained_model_name_or_path_list=model_paths, - alpha=alpha, - interp=interp.value if interp else None, # diffusers API treats None as "weighted sum" - force=force, - torch_dtype=dtype, - variant=variant, - **kwargs, - ) - dlogging.set_verbosity(verbosity) - return merged_pipe - - def merge_diffusion_models_and_save( - self, - model_keys: List[str], - merged_model_name: str, - alpha: float = 0.5, - force: bool = False, - interp: Optional[MergeInterpolationMethod] = None, - merge_dest_directory: Optional[Path] = None, - variant: Optional[str] = None, - **kwargs: Any, - ) -> AnyModelConfig: - """ - :param models: up to three models, designated by their registered InvokeAI model name - :param merged_model_name: name for new model - :param alpha: The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha - would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2 - :param interp: The interpolation method to use for the merging. Supports "weighted_average", "sigmoid", "inv_sigmoid", "add_difference" and None. - Passing None uses the default interpolation which is weighted sum interpolation. For merging three checkpoints, only "add_difference" is supported. Add_difference is A+(B-C). - :param force: Whether to ignore mismatch in model_config.json for the current models. Defaults to False. - :param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended) - **kwargs - the default DiffusionPipeline.get_config_dict kwargs: - cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map - """ - model_paths: List[Path] = [] - model_names: List[str] = [] - config = self._installer.app_config - store = self._installer.record_store - base_models: Set[BaseModelType] = set() - variant = None if self._installer.app_config.precision == "float32" else "fp16" - - assert len(model_keys) <= 2 or interp == MergeInterpolationMethod.AddDifference, ( - "When merging three models, only the 'add_difference' merge method is supported" - ) - - for key in model_keys: - info = store.get_model(key) - model_names.append(info.name) - assert isinstance(info, MainDiffusersConfig), ( - f"{info.name} ({info.key}) is not a diffusers model. It must be optimized before merging" - ) - assert info.variant == ModelVariantType("normal"), ( - f"{info.name} ({info.key}) is a {info.variant} model, which cannot currently be merged" - ) - - # tally base models used - base_models.add(info.base) - model_paths.extend([config.models_path / info.path]) - - assert len(base_models) == 1, f"All models to merge must have same base model, but found bases {base_models}" - base_model = base_models.pop() - - merge_method = None if interp == "weighted_sum" else MergeInterpolationMethod(interp) - merged_pipe = self.merge_diffusion_models(model_paths, alpha, merge_method, force, variant=variant, **kwargs) - dump_path = ( - Path(merge_dest_directory) - if merge_dest_directory - else config.models_path / base_model.value / ModelType.Main.value - ) - dump_path.mkdir(parents=True, exist_ok=True) - dump_path = dump_path / merged_model_name - - dtype = torch.float16 if variant == "fp16" else self._dtype - merged_pipe.save_pretrained(dump_path.as_posix(), safe_serialization=True, torch_dtype=dtype, variant=variant) - - # register model and get its unique key - key = self._installer.register_path(dump_path) - - # update model's config - model_config = self._installer.record_store.get_model(key) - model_config.name = merged_model_name - model_config.description = f"Merge of models {', '.join(model_names)}" - - self._installer.record_store.update_model( - key, ModelRecordChanges(name=model_config.name, description=model_config.description) - ) - return model_config diff --git a/invokeai/backend/model_manager/model_on_disk.py b/invokeai/backend/model_manager/model_on_disk.py index 502ca596a62..a86e94d3a4c 100644 --- a/invokeai/backend/model_manager/model_on_disk.py +++ b/invokeai/backend/model_manager/model_on_disk.py @@ -30,7 +30,8 @@ def __init__(self, path: Path, hash_algo: HASHING_ALGORITHMS = "blake3_single"): self.hash_algo = hash_algo # Having a cache helps users of ModelOnDisk (i.e. configs) to save state # This prevents redundant computations during matching and parsing - self.cache = {"_CACHED_STATE_DICTS": {}} + self._state_dict_cache: dict[Path, Any] = {} + self._metadata_cache: dict[Path, Any] = {} def hash(self) -> str: return ModelHash(algorithm=self.hash_algo).hash(self.path) @@ -47,13 +48,18 @@ def weight_files(self) -> set[Path]: return {f for f in self.path.rglob("*") if f.suffix in extensions} def metadata(self, path: Optional[Path] = None) -> dict[str, str]: + path = path or self.path + if path in self._metadata_cache: + return self._metadata_cache[path] try: with safe_open(self.path, framework="pt", device="cpu") as f: metadata = f.metadata() assert isinstance(metadata, dict) - return metadata except Exception: - return {} + metadata = {} + + self._metadata_cache[path] = metadata + return metadata def repo_variant(self) -> Optional[ModelRepoVariant]: if self.path.is_file(): @@ -73,10 +79,8 @@ def repo_variant(self) -> Optional[ModelRepoVariant]: return ModelRepoVariant.Default def load_state_dict(self, path: Optional[Path] = None) -> StateDict: - sd_cache = self.cache["_CACHED_STATE_DICTS"] - - if path in sd_cache: - return sd_cache[path] + if path in self._state_dict_cache: + return self._state_dict_cache[path] path = self.resolve_weight_file(path) @@ -111,7 +115,7 @@ def load_state_dict(self, path: Optional[Path] = None) -> StateDict: raise ValueError(f"Unrecognized model extension: {path.suffix}") state_dict = checkpoint.get("state_dict", checkpoint) - sd_cache[path] = state_dict + self._state_dict_cache[path] = state_dict return state_dict def resolve_weight_file(self, path: Optional[Path] = None) -> Path: diff --git a/invokeai/backend/model_manager/single_file_config_files.py b/invokeai/backend/model_manager/single_file_config_files.py new file mode 100644 index 00000000000..22fe646b550 --- /dev/null +++ b/invokeai/backend/model_manager/single_file_config_files.py @@ -0,0 +1,86 @@ +from dataclasses import dataclass + +from invokeai.backend.model_manager.taxonomy import ( + BaseModelType, + ModelType, + ModelVariantType, + SchedulerPredictionType, +) + + +@dataclass(frozen=True) +class LegacyConfigKey: + type: ModelType + base: BaseModelType + variant: ModelVariantType | None = None + pred: SchedulerPredictionType | None = None + + +LEGACY_CONFIG_MAP: dict[LegacyConfigKey, str] = { + LegacyConfigKey( + ModelType.Main, + BaseModelType.StableDiffusion1, + ModelVariantType.Normal, + SchedulerPredictionType.Epsilon, + ): "stable-diffusion/v1-inference.yaml", + LegacyConfigKey( + ModelType.Main, + BaseModelType.StableDiffusion1, + ModelVariantType.Normal, + SchedulerPredictionType.VPrediction, + ): "stable-diffusion/v1-inference-v.yaml", + LegacyConfigKey( + ModelType.Main, + BaseModelType.StableDiffusion1, + ModelVariantType.Inpaint, + ): "stable-diffusion/v1-inpainting-inference.yaml", + LegacyConfigKey( + ModelType.Main, + BaseModelType.StableDiffusion2, + ModelVariantType.Normal, + SchedulerPredictionType.Epsilon, + ): "stable-diffusion/v2-inference.yaml", + LegacyConfigKey( + ModelType.Main, + BaseModelType.StableDiffusion2, + ModelVariantType.Normal, + SchedulerPredictionType.VPrediction, + ): "stable-diffusion/v2-inference-v.yaml", + LegacyConfigKey( + ModelType.Main, + BaseModelType.StableDiffusion2, + ModelVariantType.Inpaint, + SchedulerPredictionType.Epsilon, + ): "stable-diffusion/v2-inpainting-inference.yaml", + LegacyConfigKey( + ModelType.Main, + BaseModelType.StableDiffusion2, + ModelVariantType.Inpaint, + SchedulerPredictionType.VPrediction, + ): "stable-diffusion/v2-inpainting-inference-v.yaml", + LegacyConfigKey( + ModelType.Main, + BaseModelType.StableDiffusion2, + ModelVariantType.Depth, + ): "stable-diffusion/v2-midas-inference.yaml", + LegacyConfigKey( + ModelType.Main, + BaseModelType.StableDiffusionXL, + ModelVariantType.Normal, + ): "stable-diffusion/sd_xl_base.yaml", + LegacyConfigKey( + ModelType.Main, + BaseModelType.StableDiffusionXL, + ModelVariantType.Inpaint, + ): "stable-diffusion/sd_xl_inpaint.yaml", + LegacyConfigKey( + ModelType.Main, + BaseModelType.StableDiffusionXLRefiner, + ModelVariantType.Normal, + ): "stable-diffusion/sd_xl_refiner.yaml", + LegacyConfigKey(ModelType.ControlNet, BaseModelType.StableDiffusion1): "controlnet/cldm_v15.yaml", + LegacyConfigKey(ModelType.ControlNet, BaseModelType.StableDiffusion2): "controlnet/cldm_v21.yaml", + LegacyConfigKey(ModelType.VAE, BaseModelType.StableDiffusion1): "stable-diffusion/v1-inference.yaml", + LegacyConfigKey(ModelType.VAE, BaseModelType.StableDiffusion2): "stable-diffusion/v2-inference.yaml", + LegacyConfigKey(ModelType.VAE, BaseModelType.StableDiffusionXL): "stable-diffusion/sd_xl_base.yaml", +} diff --git a/invokeai/backend/model_manager/taxonomy.py b/invokeai/backend/model_manager/taxonomy.py index 07f8c8f5def..99a31f438d1 100644 --- a/invokeai/backend/model_manager/taxonomy.py +++ b/invokeai/backend/model_manager/taxonomy.py @@ -1,39 +1,70 @@ from enum import Enum from typing import Dict, TypeAlias, Union -import diffusers import onnxruntime as ort import torch -from diffusers import ModelMixin +from diffusers.models.modeling_utils import ModelMixin +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from pydantic import TypeAdapter from invokeai.backend.raw_model import RawModel # ModelMixin is the base class for all diffusers and transformers models # RawModel is the InvokeAI wrapper class for ip_adapters, loras, textual_inversion and onnx runtime -AnyModel = Union[ - ModelMixin, RawModel, torch.nn.Module, Dict[str, torch.Tensor], diffusers.DiffusionPipeline, ort.InferenceSession +AnyModel: TypeAlias = Union[ + ModelMixin, + RawModel, + torch.nn.Module, + Dict[str, torch.Tensor], + DiffusionPipeline, + ort.InferenceSession, ] +"""Type alias for any kind of runtime, in-memory model representation. For example, a torch module or diffusers pipeline.""" class BaseModelType(str, Enum): - """Base model type.""" + """An enumeration of base model architectures. For example, Stable Diffusion 1.x, Stable Diffusion 2.x, FLUX, etc. + + Every model config must have a base architecture type. + + Not all models are associated with a base architecture. For example, CLIP models are their own thing, not related + to any particular model architecture. To simplify internal APIs and make it easier to work with models, we use a + fallback/null value `BaseModelType.Any` for these models, instead of making the model base optional.""" Any = "any" + """`Any` is essentially a fallback/null value for models with no base architecture association. + For example, CLIP models are not related to Stable Diffusion, FLUX, or any other model arch.""" StableDiffusion1 = "sd-1" + """Indicates the model is associated with the Stable Diffusion 1.x model architecture, including 1.4 and 1.5.""" StableDiffusion2 = "sd-2" + """Indicates the model is associated with the Stable Diffusion 2.x model architecture, including 2.0 and 2.1.""" StableDiffusion3 = "sd-3" + """Indicates the model is associated with the Stable Diffusion 3.5 model architecture.""" StableDiffusionXL = "sdxl" + """Indicates the model is associated with the Stable Diffusion XL model architecture.""" StableDiffusionXLRefiner = "sdxl-refiner" + """Indicates the model is associated with the Stable Diffusion XL Refiner model architecture.""" Flux = "flux" + """Indicates the model is associated with FLUX.1 model architecture, including FLUX Dev, Schnell and Fill.""" CogView4 = "cogview4" + """Indicates the model is associated with CogView 4 model architecture.""" Imagen3 = "imagen3" + """Indicates the model is associated with Google Imagen 3 model architecture. This is an external API model.""" Imagen4 = "imagen4" + """Indicates the model is associated with Google Imagen 4 model architecture. This is an external API model.""" Gemini2_5 = "gemini-2.5" + """Indicates the model is associated with Google Gemini 2.5 Flash Image model architecture. This is an external API model.""" ChatGPT4o = "chatgpt-4o" + """Indicates the model is associated with OpenAI ChatGPT 4o Image model architecture. This is an external API model.""" FluxKontext = "flux-kontext" + """Indicates the model is associated with FLUX Kontext model architecture. This is an external API model; local FLUX + Kontext models use the base `Flux`.""" Veo3 = "veo3" + """Indicates the model is associated with Google Veo 3 video model architecture. This is an external API model.""" Runway = "runway" + """Indicates the model is associated with Runway video model architecture. This is an external API model.""" Unknown = "unknown" + """Indicates the model's base architecture is unknown.""" class ModelType(str, Enum): @@ -92,6 +123,12 @@ class ModelVariantType(str, Enum): Depth = "depth" +class FluxVariantType(str, Enum): + Schnell = "schnell" + Dev = "dev" + DevFill = "dev_fill" + + class ModelFormat(str, Enum): """Storage format of model.""" @@ -149,4 +186,7 @@ class FluxLoRAFormat(str, Enum): AIToolkit = "flux.aitoolkit" -AnyVariant: TypeAlias = Union[ModelVariantType, ClipVariantType, None] +AnyVariant: TypeAlias = Union[ModelVariantType, ClipVariantType, FluxVariantType] +variant_type_adapter = TypeAdapter[ModelVariantType | ClipVariantType | FluxVariantType]( + ModelVariantType | ClipVariantType | FluxVariantType +) diff --git a/invokeai/backend/model_manager/util/model_util.py b/invokeai/backend/model_manager/util/model_util.py index 4fa095b5999..c153129353b 100644 --- a/invokeai/backend/model_manager/util/model_util.py +++ b/invokeai/backend/model_manager/util/model_util.py @@ -83,14 +83,14 @@ def read_checkpoint_meta(path: Union[str, Path], scan: bool = True) -> Dict[str, return checkpoint -def lora_token_vector_length(checkpoint: Dict[str, torch.Tensor]) -> Optional[int]: +def lora_token_vector_length(checkpoint: dict[str | int, torch.Tensor]) -> Optional[int]: """ Given a checkpoint in memory, return the lora token vector length :param checkpoint: The checkpoint """ - def _get_shape_1(key: str, tensor: torch.Tensor, checkpoint: Dict[str, torch.Tensor]) -> Optional[int]: + def _get_shape_1(key: str, tensor: torch.Tensor, checkpoint: dict[str | int, torch.Tensor]) -> Optional[int]: lora_token_vector_length = None if "." not in key: @@ -136,6 +136,8 @@ def _get_shape_1(key: str, tensor: torch.Tensor, checkpoint: Dict[str, torch.Ten lora_te1_length = None lora_te2_length = None for key, tensor in checkpoint.items(): + if isinstance(key, int): + continue if key.startswith("lora_unet_") and ("_attn2_to_k." in key or "_attn2_to_v." in key): lora_token_vector_length = _get_shape_1(key, tensor, checkpoint) elif key.startswith("lora_unet_") and ( diff --git a/invokeai/backend/model_patcher.py b/invokeai/backend/model_patcher.py index a1d8bbed0a5..04f99495609 100644 --- a/invokeai/backend/model_patcher.py +++ b/invokeai/backend/model_patcher.py @@ -5,10 +5,10 @@ import pickle from contextlib import contextmanager -from typing import Any, Iterator, List, Optional, Tuple, Type, Union +from typing import Any, Generator, Iterator, List, Optional, Tuple, Type, Union import torch -from diffusers import UNet2DConditionModel +from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer from invokeai.app.shared.models import FreeUConfig @@ -146,7 +146,7 @@ def apply_clip_skip( cls, text_encoder: Union[CLIPTextModel, CLIPTextModelWithProjection], clip_skip: int, - ) -> None: + ) -> Generator[None, Any, Any]: skipped_layers = [] try: for _i in range(clip_skip): @@ -164,7 +164,7 @@ def apply_freeu( cls, unet: UNet2DConditionModel, freeu_config: Optional[FreeUConfig] = None, - ) -> None: + ) -> Generator[None, Any, Any]: did_apply_freeu = False try: assert hasattr(unet, "enable_freeu") # mypy doesn't pick up this attribute? diff --git a/invokeai/backend/patches/lora_conversions/flux_aitoolkit_lora_conversion_utils.py b/invokeai/backend/patches/lora_conversions/flux_aitoolkit_lora_conversion_utils.py index 6ca06a0355f..f3c202268a7 100644 --- a/invokeai/backend/patches/lora_conversions/flux_aitoolkit_lora_conversion_utils.py +++ b/invokeai/backend/patches/lora_conversions/flux_aitoolkit_lora_conversion_utils.py @@ -12,7 +12,10 @@ from invokeai.backend.util import InvokeAILogger -def is_state_dict_likely_in_flux_aitoolkit_format(state_dict: dict[str, Any], metadata: dict[str, Any] = None) -> bool: +def is_state_dict_likely_in_flux_aitoolkit_format( + state_dict: dict[str | int, Any], + metadata: dict[str, Any] | None = None, +) -> bool: if metadata: try: software = json.loads(metadata.get("software", "{}")) @@ -20,7 +23,7 @@ def is_state_dict_likely_in_flux_aitoolkit_format(state_dict: dict[str, Any], me return False return software.get("name") == "ai-toolkit" # metadata got lost somewhere - return any("diffusion_model" == k.split(".", 1)[0] for k in state_dict.keys()) + return any("diffusion_model" == k.split(".", 1)[0] for k in state_dict.keys() if isinstance(k, str)) @dataclass diff --git a/invokeai/backend/patches/lora_conversions/flux_control_lora_utils.py b/invokeai/backend/patches/lora_conversions/flux_control_lora_utils.py index fa9cc764628..1762a4d5f4c 100644 --- a/invokeai/backend/patches/lora_conversions/flux_control_lora_utils.py +++ b/invokeai/backend/patches/lora_conversions/flux_control_lora_utils.py @@ -18,14 +18,16 @@ FLUX_CONTROL_TRANSFORMER_KEY_REGEX = r"(\w+\.)+(lora_A\.weight|lora_B\.weight|lora_B\.bias|scale)" -def is_state_dict_likely_flux_control(state_dict: Dict[str, Any]) -> bool: +def is_state_dict_likely_flux_control(state_dict: dict[str | int, Any]) -> bool: """Checks if the provided state dict is likely in the FLUX Control LoRA format. This is intended to be a high-precision detector, but it is not guaranteed to have perfect precision. (A perfect-precision detector would require checking all keys against a whitelist and verifying tensor shapes.) """ - all_keys_match = all(re.match(FLUX_CONTROL_TRANSFORMER_KEY_REGEX, str(k)) for k in state_dict.keys()) + all_keys_match = all( + re.match(FLUX_CONTROL_TRANSFORMER_KEY_REGEX, k) for k in state_dict.keys() if isinstance(k, str) + ) # Check the shape of the img_in weight, because this layer shape is modified by FLUX control LoRAs. lora_a_weight = state_dict.get("img_in.lora_A.weight", None) diff --git a/invokeai/backend/patches/lora_conversions/flux_diffusers_lora_conversion_utils.py b/invokeai/backend/patches/lora_conversions/flux_diffusers_lora_conversion_utils.py index 188d118cc4d..f5b4bc66847 100644 --- a/invokeai/backend/patches/lora_conversions/flux_diffusers_lora_conversion_utils.py +++ b/invokeai/backend/patches/lora_conversions/flux_diffusers_lora_conversion_utils.py @@ -9,14 +9,16 @@ from invokeai.backend.patches.model_patch_raw import ModelPatchRaw -def is_state_dict_likely_in_flux_diffusers_format(state_dict: Dict[str, torch.Tensor]) -> bool: +def is_state_dict_likely_in_flux_diffusers_format(state_dict: dict[str | int, torch.Tensor]) -> bool: """Checks if the provided state dict is likely in the Diffusers FLUX LoRA format. This is intended to be a reasonably high-precision detector, but it is not guaranteed to have perfect precision. (A perfect-precision detector would require checking all keys against a whitelist and verifying tensor shapes.) """ # First, check that all keys end in "lora_A.weight" or "lora_B.weight" (i.e. are in PEFT format). - all_keys_in_peft_format = all(k.endswith(("lora_A.weight", "lora_B.weight")) for k in state_dict.keys()) + all_keys_in_peft_format = all( + k.endswith(("lora_A.weight", "lora_B.weight")) for k in state_dict.keys() if isinstance(k, str) + ) # Check if keys use transformer prefix transformer_prefix_keys = [ diff --git a/invokeai/backend/patches/lora_conversions/flux_kohya_lora_conversion_utils.py b/invokeai/backend/patches/lora_conversions/flux_kohya_lora_conversion_utils.py index 7b5f3468963..f5a6830c4f1 100644 --- a/invokeai/backend/patches/lora_conversions/flux_kohya_lora_conversion_utils.py +++ b/invokeai/backend/patches/lora_conversions/flux_kohya_lora_conversion_utils.py @@ -44,7 +44,7 @@ FLUX_KOHYA_T5_KEY_REGEX = r"lora_te2_encoder_block_(\d+)_layer_(\d+)_(DenseReluDense|SelfAttention)_(\w+)_?(\w+)?\.?.*" -def is_state_dict_likely_in_flux_kohya_format(state_dict: Dict[str, Any]) -> bool: +def is_state_dict_likely_in_flux_kohya_format(state_dict: dict[str | int, Any]) -> bool: """Checks if the provided state dict is likely in the Kohya FLUX LoRA format. This is intended to be a high-precision detector, but it is not guaranteed to have perfect precision. (A @@ -56,6 +56,7 @@ def is_state_dict_likely_in_flux_kohya_format(state_dict: Dict[str, Any]) -> boo or re.match(FLUX_KOHYA_CLIP_KEY_REGEX, k) or re.match(FLUX_KOHYA_T5_KEY_REGEX, k) for k in state_dict.keys() + if isinstance(k, str) ) diff --git a/invokeai/backend/patches/lora_conversions/flux_onetrainer_lora_conversion_utils.py b/invokeai/backend/patches/lora_conversions/flux_onetrainer_lora_conversion_utils.py index 0413f0ef49f..88aeee95e49 100644 --- a/invokeai/backend/patches/lora_conversions/flux_onetrainer_lora_conversion_utils.py +++ b/invokeai/backend/patches/lora_conversions/flux_onetrainer_lora_conversion_utils.py @@ -40,7 +40,7 @@ ) -def is_state_dict_likely_in_flux_onetrainer_format(state_dict: Dict[str, Any]) -> bool: +def is_state_dict_likely_in_flux_onetrainer_format(state_dict: dict[str | int, Any]) -> bool: """Checks if the provided state dict is likely in the OneTrainer FLUX LoRA format. This is intended to be a high-precision detector, but it is not guaranteed to have perfect precision. (A @@ -53,6 +53,7 @@ def is_state_dict_likely_in_flux_onetrainer_format(state_dict: Dict[str, Any]) - or re.match(FLUX_KOHYA_CLIP_KEY_REGEX, k) or re.match(FLUX_KOHYA_T5_KEY_REGEX, k) for k in state_dict.keys() + if isinstance(k, str) ) diff --git a/invokeai/backend/patches/lora_conversions/formats.py b/invokeai/backend/patches/lora_conversions/formats.py index 94f71e05ee6..4cde7c98f67 100644 --- a/invokeai/backend/patches/lora_conversions/formats.py +++ b/invokeai/backend/patches/lora_conversions/formats.py @@ -1,3 +1,5 @@ +from typing import Any + from invokeai.backend.model_manager.taxonomy import FluxLoRAFormat from invokeai.backend.patches.lora_conversions.flux_aitoolkit_lora_conversion_utils import ( is_state_dict_likely_in_flux_aitoolkit_format, @@ -14,7 +16,10 @@ ) -def flux_format_from_state_dict(state_dict: dict, metadata: dict | None = None) -> FluxLoRAFormat | None: +def flux_format_from_state_dict( + state_dict: dict[str | int, Any], + metadata: dict[str, Any] | None = None, +) -> FluxLoRAFormat | None: if is_state_dict_likely_in_flux_kohya_format(state_dict): return FluxLoRAFormat.Kohya elif is_state_dict_likely_in_flux_onetrainer_format(state_dict): diff --git a/invokeai/backend/quantization/scripts/load_flux_model_bnb_llm_int8.py b/invokeai/backend/quantization/scripts/load_flux_model_bnb_llm_int8.py index 045ebbbf2c4..8231e313fdc 100644 --- a/invokeai/backend/quantization/scripts/load_flux_model_bnb_llm_int8.py +++ b/invokeai/backend/quantization/scripts/load_flux_model_bnb_llm_int8.py @@ -4,7 +4,8 @@ from safetensors.torch import load_file, save_file from invokeai.backend.flux.model import Flux -from invokeai.backend.flux.util import params +from invokeai.backend.flux.util import get_flux_transformers_params +from invokeai.backend.model_manager.taxonomy import ModelVariantType from invokeai.backend.quantization.bnb_llm_int8 import quantize_model_llm_int8 from invokeai.backend.quantization.scripts.load_flux_model_bnb_nf4 import log_time @@ -22,7 +23,7 @@ def main(): with log_time("Initialize FLUX transformer on meta device"): # TODO(ryand): Determine if this is a schnell model or a dev model and load the appropriate config. - p = params["flux-schnell"] + p = get_flux_transformers_params(ModelVariantType.FluxSchnell) # Initialize the model on the "meta" device. with accelerate.init_empty_weights(): diff --git a/invokeai/backend/quantization/scripts/load_flux_model_bnb_nf4.py b/invokeai/backend/quantization/scripts/load_flux_model_bnb_nf4.py index c8802b9e49e..6a4ee3abf93 100644 --- a/invokeai/backend/quantization/scripts/load_flux_model_bnb_nf4.py +++ b/invokeai/backend/quantization/scripts/load_flux_model_bnb_nf4.py @@ -7,7 +7,8 @@ from safetensors.torch import load_file, save_file from invokeai.backend.flux.model import Flux -from invokeai.backend.flux.util import params +from invokeai.backend.flux.util import get_flux_transformers_params +from invokeai.backend.model_manager.taxonomy import ModelVariantType from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4 @@ -35,7 +36,7 @@ def main(): # inference_dtype = torch.bfloat16 with log_time("Initialize FLUX transformer on meta device"): # TODO(ryand): Determine if this is a schnell model or a dev model and load the appropriate config. - p = params["flux-schnell"] + p = get_flux_transformers_params(ModelVariantType.FluxSchnell) # Initialize the model on the "meta" device. with accelerate.init_empty_weights(): diff --git a/invokeai/backend/util/hotfixes.py b/invokeai/backend/util/hotfixes.py index 95f2c904ad8..7e258b87795 100644 --- a/invokeai/backend/util/hotfixes.py +++ b/invokeai/backend/util/hotfixes.py @@ -23,6 +23,7 @@ from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel from torch import nn +from invokeai.backend.model_manager.taxonomy import BaseModelType, SchedulerPredictionType from invokeai.backend.util.logging import InvokeAILogger # TODO: create PR to diffusers @@ -407,7 +408,8 @@ def from_unet( use_linear_projection=unet.config.use_linear_projection, class_embed_type=unet.config.class_embed_type, num_class_embeds=unet.config.num_class_embeds, - upcast_attention=unet.config.upcast_attention, + upcast_attention=unet.config.base is BaseModelType.StableDiffusion2 + and unet.config.prediction_type is SchedulerPredictionType.VPrediction, resnet_time_scale_shift=unet.config.resnet_time_scale_shift, projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim, controlnet_conditioning_channel_order=controlnet_conditioning_channel_order, diff --git a/invokeai/frontend/web/src/features/modelManagerV2/models.ts b/invokeai/frontend/web/src/features/modelManagerV2/models.ts index ec4ddf1a1d5..0b4096e010b 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/models.ts +++ b/invokeai/frontend/web/src/features/modelManagerV2/models.ts @@ -1,4 +1,4 @@ -import type { BaseModelType, ModelFormat, ModelType, ModelVariantType } from 'features/nodes/types/common'; +import type { AnyModelVariant, BaseModelType, ModelFormat, ModelType } from 'features/nodes/types/common'; import type { AnyModelConfig } from 'services/api/types'; import { isCLIPEmbedModelConfig, @@ -219,10 +219,15 @@ export const MODEL_BASE_TO_SHORT_NAME: Record = { unknown: 'Unknown', }; -export const MODEL_VARIANT_TO_LONG_NAME: Record = { +export const MODEL_VARIANT_TO_LONG_NAME: Record = { normal: 'Normal', inpaint: 'Inpaint', depth: 'Depth', + dev: 'FLUX Dev', + dev_fill: 'FLUX Dev - Fill', + schnell: 'FLUX Schnell', + large: 'CLIP L', + gigantic: 'CLIP G', }; export const MODEL_FORMAT_TO_LONG_NAME: Record = { diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelFormatBadge.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelFormatBadge.tsx index 87923f9f00e..e139639f1f0 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelFormatBadge.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManagerPanel/ModelFormatBadge.tsx @@ -1,12 +1,12 @@ import { Badge } from '@invoke-ai/ui-library'; +import type { ModelFormat } from 'features/nodes/types/common'; import { memo } from 'react'; -import type { AnyModelConfig } from 'services/api/types'; type Props = { - format: AnyModelConfig['format']; + format: ModelFormat; }; -const FORMAT_NAME_MAP: Record = { +const FORMAT_NAME_MAP: Record = { diffusers: 'diffusers', lycoris: 'lycoris', checkpoint: 'checkpoint', @@ -20,9 +20,11 @@ const FORMAT_NAME_MAP: Record = { api: 'api', omi: 'omi', unknown: 'unknown', + olive: 'olive', + onnx: 'onnx', }; -const FORMAT_COLOR_MAP: Record = { +const FORMAT_COLOR_MAP: Record = { diffusers: 'base', omi: 'base', lycoris: 'base', @@ -36,6 +38,8 @@ const FORMAT_COLOR_MAP: Record = { gguf_quantized: 'base', api: 'base', unknown: 'red', + olive: 'base', + onnx: 'base', }; const ModelFormatBadge = ({ format }: Props) => { diff --git a/invokeai/frontend/web/src/features/nodes/types/common.test-d.ts b/invokeai/frontend/web/src/features/nodes/types/common.test-d.ts index e3fa3772bb8..c223747b931 100644 --- a/invokeai/frontend/web/src/features/nodes/types/common.test-d.ts +++ b/invokeai/frontend/web/src/features/nodes/types/common.test-d.ts @@ -12,6 +12,7 @@ import type { T2IAdapterField, zBaseModelType, zClipVariantType, + zFluxVariantType, zModelFormat, zModelVariantType, zSubModelType, @@ -45,6 +46,7 @@ describe('Common types', () => { test('ModelIdentifier', () => assert, S['SubModelType']>>()); test('ClipVariantType', () => assert, S['ClipVariantType']>>()); test('ModelVariantType', () => assert, S['ModelVariantType']>>()); + test('FluxVariantType', () => assert, S['FluxVariantType']>>()); test('ModelFormat', () => assert, S['ModelFormat']>>()); // Misc types diff --git a/invokeai/frontend/web/src/features/nodes/types/common.ts b/invokeai/frontend/web/src/features/nodes/types/common.ts index 4b97c2145d8..c51defd79c5 100644 --- a/invokeai/frontend/web/src/features/nodes/types/common.ts +++ b/invokeai/frontend/web/src/features/nodes/types/common.ts @@ -148,7 +148,9 @@ export const zSubModelType = z.enum([ export const zClipVariantType = z.enum(['large', 'gigantic']); export const zModelVariantType = z.enum(['normal', 'inpaint', 'depth']); -export type ModelVariantType = z.infer; +export const zFluxVariantType = z.enum(['dev', 'dev_fill', 'schnell']); +export const zAnyModelVariant = z.union([zModelVariantType, zClipVariantType, zFluxVariantType]); +export type AnyModelVariant = z.infer; export const zModelFormat = z.enum([ 'omi', 'diffusers', diff --git a/invokeai/frontend/web/src/features/nodes/types/field.ts b/invokeai/frontend/web/src/features/nodes/types/field.ts index 320a4ac521c..5b8634daa2b 100644 --- a/invokeai/frontend/web/src/features/nodes/types/field.ts +++ b/invokeai/frontend/web/src/features/nodes/types/field.ts @@ -10,15 +10,14 @@ import { z } from 'zod'; import type { ImageField } from './common'; import { + zAnyModelVariant, zBaseModelType, zBoardField, - zClipVariantType, zColorField, zImageField, zModelFormat, zModelIdentifierField, zModelType, - zModelVariantType, zSchedulerField, } from './common'; @@ -73,7 +72,7 @@ const zFieldInputTemplateBase = zFieldTemplateBase.extend({ ui_choice_labels: z.record(z.string(), z.string()).nullish(), ui_model_base: z.array(zBaseModelType).nullish(), ui_model_type: z.array(zModelType).nullish(), - ui_model_variant: z.array(zModelVariantType.or(zClipVariantType)).nullish(), + ui_model_variant: z.array(zAnyModelVariant).nullish(), ui_model_format: z.array(zModelFormat).nullish(), }); const zFieldOutputTemplateBase = zFieldTemplateBase.extend({ diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/Graph.test.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/Graph.test.ts index 24ef7123576..5a766e8d399 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/Graph.test.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/Graph.test.ts @@ -673,6 +673,8 @@ describe('Graph', () => { variant: 'inpaint', format: 'diffusers', repo_variant: 'fp16', + submodels: null, + usage_info: null, }); expect(field).toEqual({ key: 'b00ee8df-523d-40d2-9578-597283b07cb2', diff --git a/invokeai/frontend/web/src/features/settingsAccordions/components/GenerationSettingsAccordion/MainModelPicker.tsx b/invokeai/frontend/web/src/features/settingsAccordions/components/GenerationSettingsAccordion/MainModelPicker.tsx index aa527f29342..3f31d8ec769 100644 --- a/invokeai/frontend/web/src/features/settingsAccordions/components/GenerationSettingsAccordion/MainModelPicker.tsx +++ b/invokeai/frontend/web/src/features/settingsAccordions/components/GenerationSettingsAccordion/MainModelPicker.tsx @@ -25,9 +25,7 @@ export const MainModelPicker = memo(() => { const isFluxDevSelected = useMemo( () => - selectedModelConfig && - isCheckpointMainModelConfig(selectedModelConfig) && - selectedModelConfig.config_path === 'flux-dev', + selectedModelConfig && isCheckpointMainModelConfig(selectedModelConfig) && selectedModelConfig.variant === 'dev', [selectedModelConfig] ); diff --git a/invokeai/frontend/web/src/features/ui/layouts/InitialStateMainModelPicker.tsx b/invokeai/frontend/web/src/features/ui/layouts/InitialStateMainModelPicker.tsx index 9807ae9e690..b0aca495183 100644 --- a/invokeai/frontend/web/src/features/ui/layouts/InitialStateMainModelPicker.tsx +++ b/invokeai/frontend/web/src/features/ui/layouts/InitialStateMainModelPicker.tsx @@ -24,9 +24,7 @@ export const InitialStateMainModelPicker = memo(() => { const isFluxDevSelected = useMemo( () => - selectedModelConfig && - isCheckpointMainModelConfig(selectedModelConfig) && - selectedModelConfig.config_path === 'flux-dev', + selectedModelConfig && isCheckpointMainModelConfig(selectedModelConfig) && selectedModelConfig.variant === 'dev', [selectedModelConfig] ); diff --git a/invokeai/frontend/web/src/services/api/schema.ts b/invokeai/frontend/web/src/services/api/schema.ts index 4d9f9ab2ec3..388263e6ece 100644 --- a/invokeai/frontend/web/src/services/api/schema.ts +++ b/invokeai/frontend/web/src/services/api/schema.ts @@ -2200,6 +2200,7 @@ export type components = { */ type: "alpha_mask_to_tensor"; }; + AnyModelConfig: components["schemas"]["MainDiffusersConfig"] | components["schemas"]["MainCheckpointConfig"] | components["schemas"]["MainBnbQuantized4bCheckpointConfig"] | components["schemas"]["MainGGUFCheckpointConfig"] | components["schemas"]["VAEDiffusersConfig"] | components["schemas"]["VAECheckpointConfig"] | components["schemas"]["ControlNetDiffusersConfig"] | components["schemas"]["ControlNetCheckpointConfig"] | components["schemas"]["LoRALyCORISConfig"] | components["schemas"]["LoRAOmiConfig"] | components["schemas"]["ControlLoRALyCORISConfig"] | components["schemas"]["ControlLoRADiffusersConfig"] | components["schemas"]["LoRADiffusersConfig"] | components["schemas"]["T5EncoderConfig"] | components["schemas"]["T5EncoderBnbQuantizedLlmInt8bConfig"] | components["schemas"]["TextualInversionFileConfig"] | components["schemas"]["TextualInversionFolderConfig"] | components["schemas"]["IPAdapterInvokeAIConfig"] | components["schemas"]["IPAdapterCheckpointConfig"] | components["schemas"]["T2IAdapterConfig"] | components["schemas"]["SpandrelImageToImageConfig"] | components["schemas"]["CLIPVisionDiffusersConfig"] | components["schemas"]["CLIPLEmbedDiffusersConfig"] | components["schemas"]["CLIPGEmbedDiffusersConfig"] | components["schemas"]["SigLIPConfig"] | components["schemas"]["FluxReduxConfig"] | components["schemas"]["LlavaOnevisionConfig"] | components["schemas"]["ApiModelConfig"] | components["schemas"]["VideoApiModelConfig"] | components["schemas"]["UnknownModelConfig"]; /** * ApiModelConfig * @description Model config for API-based models. @@ -2231,19 +2232,10 @@ export type components = { */ name: string; /** - * Type - * @default main - * @constant - */ - type: "main"; - /** - * Format - * @default api - * @constant + * Description + * @description Model description */ - format: "api"; - /** @description The base model. */ - base: components["schemas"]["BaseModelType"]; + description: string | null; /** * Source * @description The original source of the model (path, URL or repo_id). @@ -2251,45 +2243,54 @@ export type components = { source: string; /** @description The type of source */ source_type: components["schemas"]["ModelSourceType"]; - /** - * Description - * @description Model description - */ - description?: string | null; /** * Source Api Response * @description The original API response from the source, as stringified JSON. */ - source_api_response?: string | null; + source_api_response: string | null; /** * Cover Image * @description Url for image to preview model */ - cover_image?: string | null; + cover_image: string | null; /** * Submodels * @description Loadable submodels in this model */ - submodels?: { + submodels: { [key: string]: components["schemas"]["SubmodelDefinition"]; } | null; /** * Usage Info * @description Usage information for this model */ - usage_info?: string | null; + usage_info: string | null; + /** + * Type + * @default main + * @constant + */ + type: "main"; /** * Trigger Phrases * @description Set of trigger phrases for this model */ - trigger_phrases?: string[] | null; + trigger_phrases: string[] | null; /** @description Default settings for this model */ - default_settings?: components["schemas"]["MainModelDefaultSettings"] | null; + default_settings: components["schemas"]["MainModelDefaultSettings"] | null; + /** Variant */ + variant: components["schemas"]["ModelVariantType"] | components["schemas"]["FluxVariantType"]; /** - * Variant - * @default normal + * Format + * @default api + * @constant + */ + format: "api"; + /** + * Base + * @enum {string} */ - variant?: components["schemas"]["ModelVariantType"] | components["schemas"]["ClipVariantType"] | null; + base: "chatgpt-4o" | "gemini-2.5" | "imagen3" | "imagen4" | "flux-kontext"; }; /** * AppConfig @@ -2464,7 +2465,13 @@ export type components = { }; /** * BaseModelType - * @description Base model type. + * @description An enumeration of base model architectures. For example, Stable Diffusion 1.x, Stable Diffusion 2.x, FLUX, etc. + * + * Every model config must have a base architecture type. + * + * Not all models are associated with a base architecture. For example, CLIP models are their own thing, not related + * to any particular model architecture. To simplify internal APIs and make it easier to work with models, we use a + * fallback/null value `BaseModelType.Any` for these models, instead of making the model base optional. * @enum {string} */ BaseModelType: "any" | "sd-1" | "sd-2" | "sd-3" | "sdxl" | "sdxl-refiner" | "flux" | "cogview4" | "imagen3" | "imagen4" | "gemini-2.5" | "chatgpt-4o" | "flux-kontext" | "veo3" | "runway" | "unknown"; @@ -3505,19 +3512,10 @@ export type components = { */ name: string; /** - * Type - * @default clip_embed - * @constant - */ - type: "clip_embed"; - /** - * Format - * @default diffusers - * @constant + * Description + * @description Model description */ - format: "diffusers"; - /** @description The base model. */ - base: components["schemas"]["BaseModelType"]; + description: string | null; /** * Source * @description The original source of the model (path, URL or repo_id). @@ -3525,41 +3523,54 @@ export type components = { source: string; /** @description The type of source */ source_type: components["schemas"]["ModelSourceType"]; - /** - * Description - * @description Model description - */ - description?: string | null; /** * Source Api Response * @description The original API response from the source, as stringified JSON. */ - source_api_response?: string | null; + source_api_response: string | null; /** * Cover Image * @description Url for image to preview model */ - cover_image?: string | null; + cover_image: string | null; /** * Submodels * @description Loadable submodels in this model */ - submodels?: { + submodels: { [key: string]: components["schemas"]["SubmodelDefinition"]; } | null; /** * Usage Info * @description Usage information for this model */ - usage_info?: string | null; + usage_info: string | null; + /** + * Format + * @default diffusers + * @constant + */ + format: "diffusers"; /** @default */ - repo_variant?: components["schemas"]["ModelRepoVariant"] | null; + repo_variant: components["schemas"]["ModelRepoVariant"] | null; + /** + * Base + * @default any + * @constant + */ + base: "any"; + /** + * Type + * @default clip_embed + * @constant + */ + type: "clip_embed"; /** * Variant * @default gigantic * @constant */ - variant?: "gigantic"; + variant: "gigantic"; }; /** * CLIPLEmbedDiffusersConfig @@ -3592,19 +3603,10 @@ export type components = { */ name: string; /** - * Type - * @default clip_embed - * @constant - */ - type: "clip_embed"; - /** - * Format - * @default diffusers - * @constant + * Description + * @description Model description */ - format: "diffusers"; - /** @description The base model. */ - base: components["schemas"]["BaseModelType"]; + description: string | null; /** * Source * @description The original source of the model (path, URL or repo_id). @@ -3612,41 +3614,54 @@ export type components = { source: string; /** @description The type of source */ source_type: components["schemas"]["ModelSourceType"]; - /** - * Description - * @description Model description - */ - description?: string | null; /** * Source Api Response * @description The original API response from the source, as stringified JSON. */ - source_api_response?: string | null; + source_api_response: string | null; /** * Cover Image * @description Url for image to preview model */ - cover_image?: string | null; + cover_image: string | null; /** * Submodels * @description Loadable submodels in this model */ - submodels?: { + submodels: { [key: string]: components["schemas"]["SubmodelDefinition"]; } | null; /** * Usage Info * @description Usage information for this model */ - usage_info?: string | null; + usage_info: string | null; + /** + * Format + * @default diffusers + * @constant + */ + format: "diffusers"; /** @default */ - repo_variant?: components["schemas"]["ModelRepoVariant"] | null; + repo_variant: components["schemas"]["ModelRepoVariant"] | null; + /** + * Base + * @default any + * @constant + */ + base: "any"; + /** + * Type + * @default clip_embed + * @constant + */ + type: "clip_embed"; /** * Variant * @default large * @constant */ - variant?: "large"; + variant: "large"; }; /** * CLIPOutput @@ -3755,19 +3770,10 @@ export type components = { */ name: string; /** - * Type - * @default clip_vision - * @constant - */ - type: "clip_vision"; - /** - * Format - * @default diffusers - * @constant + * Description + * @description Model description */ - format: "diffusers"; - /** @description The base model. */ - base: components["schemas"]["BaseModelType"]; + description: string | null; /** * Source * @description The original source of the model (path, URL or repo_id). @@ -3775,35 +3781,48 @@ export type components = { source: string; /** @description The type of source */ source_type: components["schemas"]["ModelSourceType"]; - /** - * Description - * @description Model description - */ - description?: string | null; /** * Source Api Response * @description The original API response from the source, as stringified JSON. */ - source_api_response?: string | null; + source_api_response: string | null; /** * Cover Image * @description Url for image to preview model */ - cover_image?: string | null; + cover_image: string | null; /** * Submodels * @description Loadable submodels in this model */ - submodels?: { + submodels: { [key: string]: components["schemas"]["SubmodelDefinition"]; } | null; /** * Usage Info * @description Usage information for this model */ - usage_info?: string | null; + usage_info: string | null; + /** + * Format + * @default diffusers + * @constant + */ + format: "diffusers"; /** @default */ - repo_variant?: components["schemas"]["ModelRepoVariant"] | null; + repo_variant: components["schemas"]["ModelRepoVariant"] | null; + /** + * Base + * @default any + * @constant + */ + base: "any"; + /** + * Type + * @default clip_vision + * @constant + */ + type: "clip_vision"; }; /** * CV2 Infill @@ -5251,19 +5270,10 @@ export type components = { */ name: string; /** - * Type - * @default control_lora - * @constant - */ - type: "control_lora"; - /** - * Format - * @default diffusers - * @constant + * Description + * @description Model description */ - format: "diffusers"; - /** @description The base model. */ - base: components["schemas"]["BaseModelType"]; + description: string | null; /** * Source * @description The original source of the model (path, URL or repo_id). @@ -5271,40 +5281,48 @@ export type components = { source: string; /** @description The type of source */ source_type: components["schemas"]["ModelSourceType"]; - /** - * Description - * @description Model description - */ - description?: string | null; /** * Source Api Response * @description The original API response from the source, as stringified JSON. */ - source_api_response?: string | null; + source_api_response: string | null; /** * Cover Image * @description Url for image to preview model */ - cover_image?: string | null; + cover_image: string | null; /** * Submodels * @description Loadable submodels in this model */ - submodels?: { + submodels: { [key: string]: components["schemas"]["SubmodelDefinition"]; } | null; /** * Usage Info * @description Usage information for this model */ - usage_info?: string | null; - /** @description Default settings for this model */ - default_settings?: components["schemas"]["ControlAdapterDefaultSettings"] | null; + usage_info: string | null; + default_settings: components["schemas"]["ControlAdapterDefaultSettings"] | null; /** - * Trigger Phrases - * @description Set of trigger phrases for this model + * Base + * @constant */ - trigger_phrases?: string[] | null; + base: "flux"; + /** + * Type + * @default control_lora + * @constant + */ + type: "control_lora"; + /** + * Format + * @default diffusers + * @constant + */ + format: "diffusers"; + /** Trigger Phrases */ + trigger_phrases: string[] | null; }; /** ControlLoRAField */ ControlLoRAField: { @@ -5349,19 +5367,10 @@ export type components = { */ name: string; /** - * Type - * @default control_lora - * @constant - */ - type: "control_lora"; - /** - * Format - * @default lycoris - * @constant + * Description + * @description Model description */ - format: "lycoris"; - /** @description The base model. */ - base: components["schemas"]["BaseModelType"]; + description: string | null; /** * Source * @description The original source of the model (path, URL or repo_id). @@ -5369,40 +5378,48 @@ export type components = { source: string; /** @description The type of source */ source_type: components["schemas"]["ModelSourceType"]; - /** - * Description - * @description Model description - */ - description?: string | null; /** * Source Api Response * @description The original API response from the source, as stringified JSON. */ - source_api_response?: string | null; + source_api_response: string | null; /** * Cover Image * @description Url for image to preview model */ - cover_image?: string | null; + cover_image: string | null; /** * Submodels * @description Loadable submodels in this model */ - submodels?: { + submodels: { [key: string]: components["schemas"]["SubmodelDefinition"]; } | null; /** * Usage Info * @description Usage information for this model */ - usage_info?: string | null; - /** @description Default settings for this model */ - default_settings?: components["schemas"]["ControlAdapterDefaultSettings"] | null; + usage_info: string | null; + default_settings: components["schemas"]["ControlAdapterDefaultSettings"] | null; /** - * Trigger Phrases - * @description Set of trigger phrases for this model + * Base + * @constant */ - trigger_phrases?: string[] | null; + base: "flux"; + /** + * Type + * @default control_lora + * @constant + */ + type: "control_lora"; + /** + * Format + * @default lycoris + * @constant + */ + format: "lycoris"; + /** Trigger Phrases */ + trigger_phrases: string[] | null; }; /** * ControlNetCheckpointConfig @@ -5435,20 +5452,10 @@ export type components = { */ name: string; /** - * Type - * @default controlnet - * @constant - */ - type: "controlnet"; - /** - * Format - * @description Format of the provided checkpoint model - * @default checkpoint - * @enum {string} + * Description + * @description Model description */ - format: "checkpoint" | "bnb_quantized_nf4b" | "gguf_quantized"; - /** @description The base model. */ - base: components["schemas"]["BaseModelType"]; + description: string | null; /** * Source * @description The original source of the model (path, URL or repo_id). @@ -5456,58 +5463,69 @@ export type components = { source: string; /** @description The type of source */ source_type: components["schemas"]["ModelSourceType"]; - /** - * Description - * @description Model description - */ - description?: string | null; /** * Source Api Response * @description The original API response from the source, as stringified JSON. */ - source_api_response?: string | null; + source_api_response: string | null; /** * Cover Image * @description Url for image to preview model */ - cover_image?: string | null; + cover_image: string | null; /** * Submodels * @description Loadable submodels in this model */ - submodels?: { + submodels: { [key: string]: components["schemas"]["SubmodelDefinition"]; } | null; /** * Usage Info * @description Usage information for this model */ - usage_info?: string | null; - /** @description Default settings for this model */ - default_settings?: components["schemas"]["ControlAdapterDefaultSettings"] | null; + usage_info: string | null; + default_settings: components["schemas"]["ControlAdapterDefaultSettings"] | null; /** * Config Path - * @description path to the checkpoint model config file + * @description Path to the config for this model, if any. */ - config_path: string; + config_path: string | null; /** * Converted At * @description When this model was last converted to diffusers */ - converted_at?: number | null; - }; - /** - * ControlNetDiffusersConfig - * @description Model config for ControlNet models (diffusers version). - */ - ControlNetDiffusersConfig: { + converted_at: number | null; /** - * Key - * @description A unique key for this model. + * Base + * @enum {string} */ - key: string; + base: "sd-1" | "sd-2" | "sdxl" | "flux"; /** - * Hash + * Type + * @default controlnet + * @constant + */ + type: "controlnet"; + /** + * Format + * @default checkpoint + * @constant + */ + format: "checkpoint"; + }; + /** + * ControlNetDiffusersConfig + * @description Model config for ControlNet models (diffusers version). + */ + ControlNetDiffusersConfig: { + /** + * Key + * @description A unique key for this model. + */ + key: string; + /** + * Hash * @description The hash of the model file(s). */ hash: string; @@ -5527,19 +5545,10 @@ export type components = { */ name: string; /** - * Type - * @default controlnet - * @constant - */ - type: "controlnet"; - /** - * Format - * @default diffusers - * @constant + * Description + * @description Model description */ - format: "diffusers"; - /** @description The base model. */ - base: components["schemas"]["BaseModelType"]; + description: string | null; /** * Source * @description The original source of the model (path, URL or repo_id). @@ -5547,37 +5556,48 @@ export type components = { source: string; /** @description The type of source */ source_type: components["schemas"]["ModelSourceType"]; - /** - * Description - * @description Model description - */ - description?: string | null; /** * Source Api Response * @description The original API response from the source, as stringified JSON. */ - source_api_response?: string | null; + source_api_response: string | null; /** * Cover Image * @description Url for image to preview model */ - cover_image?: string | null; + cover_image: string | null; /** * Submodels * @description Loadable submodels in this model */ - submodels?: { + submodels: { [key: string]: components["schemas"]["SubmodelDefinition"]; } | null; /** * Usage Info * @description Usage information for this model */ - usage_info?: string | null; - /** @description Default settings for this model */ - default_settings?: components["schemas"]["ControlAdapterDefaultSettings"] | null; + usage_info: string | null; + default_settings: components["schemas"]["ControlAdapterDefaultSettings"] | null; + /** + * Format + * @default diffusers + * @constant + */ + format: "diffusers"; /** @default */ - repo_variant?: components["schemas"]["ModelRepoVariant"] | null; + repo_variant: components["schemas"]["ModelRepoVariant"] | null; + /** + * Base + * @enum {string} + */ + base: "sd-1" | "sd-2" | "sdxl" | "flux"; + /** + * Type + * @default controlnet + * @constant + */ + type: "controlnet"; }; /** * ControlNet - SD1.5, SD2, SDXL @@ -8887,19 +8907,10 @@ export type components = { */ name: string; /** - * Type - * @default flux_redux - * @constant - */ - type: "flux_redux"; - /** - * Format - * @default checkpoint - * @constant + * Description + * @description Model description */ - format: "checkpoint"; - /** @description The base model. */ - base: components["schemas"]["BaseModelType"]; + description: string | null; /** * Source * @description The original source of the model (path, URL or repo_id). @@ -8907,33 +8918,46 @@ export type components = { source: string; /** @description The type of source */ source_type: components["schemas"]["ModelSourceType"]; - /** - * Description - * @description Model description - */ - description?: string | null; /** * Source Api Response * @description The original API response from the source, as stringified JSON. */ - source_api_response?: string | null; + source_api_response: string | null; /** * Cover Image * @description Url for image to preview model */ - cover_image?: string | null; + cover_image: string | null; /** * Submodels * @description Loadable submodels in this model */ - submodels?: { + submodels: { [key: string]: components["schemas"]["SubmodelDefinition"]; } | null; /** * Usage Info * @description Usage information for this model */ - usage_info?: string | null; + usage_info: string | null; + /** + * Type + * @default flux_redux + * @constant + */ + type: "flux_redux"; + /** + * Format + * @default checkpoint + * @constant + */ + format: "checkpoint"; + /** + * Base + * @default flux + * @constant + */ + base: "flux"; }; /** * FLUX Redux @@ -9162,6 +9186,11 @@ export type components = { */ type: "flux_vae_encode"; }; + /** + * FluxVariantType + * @enum {string} + */ + FluxVariantType: "schnell" | "dev" | "dev_fill"; /** FoundModel */ FoundModel: { /** @@ -9677,19 +9706,10 @@ export type components = { */ name: string; /** - * Type - * @default ip_adapter - * @constant - */ - type: "ip_adapter"; - /** - * Format - * @default checkpoint - * @constant + * Description + * @description Model description */ - format: "checkpoint"; - /** @description The base model. */ - base: components["schemas"]["BaseModelType"]; + description: string | null; /** * Source * @description The original source of the model (path, URL or repo_id). @@ -9697,33 +9717,45 @@ export type components = { source: string; /** @description The type of source */ source_type: components["schemas"]["ModelSourceType"]; - /** - * Description - * @description Model description - */ - description?: string | null; /** * Source Api Response * @description The original API response from the source, as stringified JSON. */ - source_api_response?: string | null; + source_api_response: string | null; /** * Cover Image * @description Url for image to preview model */ - cover_image?: string | null; + cover_image: string | null; /** * Submodels * @description Loadable submodels in this model */ - submodels?: { + submodels: { [key: string]: components["schemas"]["SubmodelDefinition"]; } | null; /** * Usage Info * @description Usage information for this model */ - usage_info?: string | null; + usage_info: string | null; + /** + * Type + * @default ip_adapter + * @constant + */ + type: "ip_adapter"; + /** + * Base + * @enum {string} + */ + base: "sd-1" | "sd-2" | "sdxl" | "flux"; + /** + * Format + * @default checkpoint + * @constant + */ + format: "checkpoint"; }; /** IPAdapterField */ IPAdapterField: { @@ -9881,19 +9913,10 @@ export type components = { */ name: string; /** - * Type - * @default ip_adapter - * @constant - */ - type: "ip_adapter"; - /** - * Format - * @default invokeai - * @constant + * Description + * @description Model description */ - format: "invokeai"; - /** @description The base model. */ - base: components["schemas"]["BaseModelType"]; + description: string | null; /** * Source * @description The original source of the model (path, URL or repo_id). @@ -9901,33 +9924,45 @@ export type components = { source: string; /** @description The type of source */ source_type: components["schemas"]["ModelSourceType"]; - /** - * Description - * @description Model description - */ - description?: string | null; /** * Source Api Response * @description The original API response from the source, as stringified JSON. */ - source_api_response?: string | null; + source_api_response: string | null; /** * Cover Image * @description Url for image to preview model */ - cover_image?: string | null; + cover_image: string | null; /** * Submodels * @description Loadable submodels in this model */ - submodels?: { + submodels: { [key: string]: components["schemas"]["SubmodelDefinition"]; } | null; /** * Usage Info * @description Usage information for this model */ - usage_info?: string | null; + usage_info: string | null; + /** + * Type + * @default ip_adapter + * @constant + */ + type: "ip_adapter"; + /** + * Base + * @enum {string} + */ + base: "sd-1" | "sd-2" | "sdxl"; + /** + * Format + * @default invokeai + * @constant + */ + format: "invokeai"; /** Image Encoder Model Id */ image_encoder_model_id: string; }; @@ -14037,19 +14072,10 @@ export type components = { */ name: string; /** - * Type - * @default llava_onevision - * @constant - */ - type: "llava_onevision"; - /** - * Format - * @default diffusers - * @constant + * Description + * @description Model description */ - format: "diffusers"; - /** @description The base model. */ - base: components["schemas"]["BaseModelType"]; + description: string | null; /** * Source * @description The original source of the model (path, URL or repo_id). @@ -14057,35 +14083,54 @@ export type components = { source: string; /** @description The type of source */ source_type: components["schemas"]["ModelSourceType"]; - /** - * Description - * @description Model description - */ - description?: string | null; /** * Source Api Response * @description The original API response from the source, as stringified JSON. */ - source_api_response?: string | null; + source_api_response: string | null; /** * Cover Image * @description Url for image to preview model */ - cover_image?: string | null; + cover_image: string | null; /** * Submodels * @description Loadable submodels in this model */ - submodels?: { + submodels: { [key: string]: components["schemas"]["SubmodelDefinition"]; } | null; /** * Usage Info * @description Usage information for this model */ - usage_info?: string | null; + usage_info: string | null; + /** + * Format + * @default diffusers + * @constant + */ + format: "diffusers"; /** @default */ - repo_variant?: components["schemas"]["ModelRepoVariant"] | null; + repo_variant: components["schemas"]["ModelRepoVariant"] | null; + /** + * Type + * @default llava_onevision + * @constant + */ + type: "llava_onevision"; + /** + * Base + * @default any + * @constant + */ + base: "any"; + /** + * Variant + * @default normal + * @constant + */ + variant: "normal"; }; /** * LLaVA OneVision VLLM @@ -14212,19 +14257,10 @@ export type components = { */ name: string; /** - * Type - * @default lora - * @constant - */ - type: "lora"; - /** - * Format - * @default diffusers - * @constant + * Description + * @description Model description */ - format: "diffusers"; - /** @description The base model. */ - base: components["schemas"]["BaseModelType"]; + description: string | null; /** * Source * @description The original source of the model (path, URL or repo_id). @@ -14232,40 +14268,52 @@ export type components = { source: string; /** @description The type of source */ source_type: components["schemas"]["ModelSourceType"]; - /** - * Description - * @description Model description - */ - description?: string | null; /** * Source Api Response * @description The original API response from the source, as stringified JSON. */ - source_api_response?: string | null; + source_api_response: string | null; /** * Cover Image * @description Url for image to preview model */ - cover_image?: string | null; + cover_image: string | null; /** * Submodels * @description Loadable submodels in this model */ - submodels?: { + submodels: { [key: string]: components["schemas"]["SubmodelDefinition"]; } | null; /** * Usage Info * @description Usage information for this model */ - usage_info?: string | null; + usage_info: string | null; + /** + * Type + * @default lora + * @constant + */ + type: "lora"; /** * Trigger Phrases * @description Set of trigger phrases for this model */ - trigger_phrases?: string[] | null; + trigger_phrases: string[] | null; /** @description Default settings for this model */ - default_settings?: components["schemas"]["LoraModelDefaultSettings"] | null; + default_settings: components["schemas"]["LoraModelDefaultSettings"] | null; + /** + * Base + * @enum {string} + */ + base: "sd-1" | "sd-2" | "sdxl" | "flux"; + /** + * Format + * @default diffusers + * @constant + */ + format: "diffusers"; }; /** LoRAField */ LoRAField: { @@ -14385,19 +14433,10 @@ export type components = { */ name: string; /** - * Type - * @default lora - * @constant + * Description + * @description Model description */ - type: "lora"; - /** - * Format - * @default lycoris - * @constant - */ - format: "lycoris"; - /** @description The base model. */ - base: components["schemas"]["BaseModelType"]; + description: string | null; /** * Source * @description The original source of the model (path, URL or repo_id). @@ -14405,40 +14444,52 @@ export type components = { source: string; /** @description The type of source */ source_type: components["schemas"]["ModelSourceType"]; - /** - * Description - * @description Model description - */ - description?: string | null; /** * Source Api Response * @description The original API response from the source, as stringified JSON. */ - source_api_response?: string | null; + source_api_response: string | null; /** * Cover Image * @description Url for image to preview model */ - cover_image?: string | null; + cover_image: string | null; /** * Submodels * @description Loadable submodels in this model */ - submodels?: { + submodels: { [key: string]: components["schemas"]["SubmodelDefinition"]; } | null; /** * Usage Info * @description Usage information for this model */ - usage_info?: string | null; + usage_info: string | null; + /** + * Type + * @default lora + * @constant + */ + type: "lora"; /** * Trigger Phrases * @description Set of trigger phrases for this model */ - trigger_phrases?: string[] | null; + trigger_phrases: string[] | null; /** @description Default settings for this model */ - default_settings?: components["schemas"]["LoraModelDefaultSettings"] | null; + default_settings: components["schemas"]["LoraModelDefaultSettings"] | null; + /** + * Base + * @enum {string} + */ + base: "sd-1" | "sd-2" | "sdxl" | "flux"; + /** + * Format + * @default lycoris + * @constant + */ + format: "lycoris"; }; /** * LoRAMetadataField @@ -14481,19 +14532,10 @@ export type components = { */ name: string; /** - * Type - * @default lora - * @constant - */ - type: "lora"; - /** - * Format - * @default omi - * @constant + * Description + * @description Model description */ - format: "omi"; - /** @description The base model. */ - base: components["schemas"]["BaseModelType"]; + description: string | null; /** * Source * @description The original source of the model (path, URL or repo_id). @@ -14501,40 +14543,52 @@ export type components = { source: string; /** @description The type of source */ source_type: components["schemas"]["ModelSourceType"]; - /** - * Description - * @description Model description - */ - description?: string | null; /** * Source Api Response * @description The original API response from the source, as stringified JSON. */ - source_api_response?: string | null; + source_api_response: string | null; /** * Cover Image * @description Url for image to preview model */ - cover_image?: string | null; + cover_image: string | null; /** * Submodels * @description Loadable submodels in this model */ - submodels?: { + submodels: { [key: string]: components["schemas"]["SubmodelDefinition"]; } | null; /** * Usage Info * @description Usage information for this model */ - usage_info?: string | null; + usage_info: string | null; + /** + * Type + * @default lora + * @constant + */ + type: "lora"; /** * Trigger Phrases * @description Set of trigger phrases for this model */ - trigger_phrases?: string[] | null; + trigger_phrases: string[] | null; /** @description Default settings for this model */ - default_settings?: components["schemas"]["LoraModelDefaultSettings"] | null; + default_settings: components["schemas"]["LoraModelDefaultSettings"] | null; + /** + * Base + * @enum {string} + */ + base: "flux" | "sdxl"; + /** + * Format + * @default omi + * @constant + */ + format: "omi"; }; /** * Select LoRA @@ -14754,19 +14808,10 @@ export type components = { */ name: string; /** - * Type - * @default main - * @constant - */ - type: "main"; - /** - * Format - * @default bnb_quantized_nf4b - * @constant + * Description + * @description Model description */ - format: "bnb_quantized_nf4b"; - /** @description The base model. */ - base: components["schemas"]["BaseModelType"]; + description: string | null; /** * Source * @description The original source of the model (path, URL or repo_id). @@ -14774,62 +14819,72 @@ export type components = { source: string; /** @description The type of source */ source_type: components["schemas"]["ModelSourceType"]; - /** - * Description - * @description Model description - */ - description?: string | null; /** * Source Api Response * @description The original API response from the source, as stringified JSON. */ - source_api_response?: string | null; + source_api_response: string | null; /** * Cover Image * @description Url for image to preview model */ - cover_image?: string | null; + cover_image: string | null; /** * Submodels * @description Loadable submodels in this model */ - submodels?: { + submodels: { [key: string]: components["schemas"]["SubmodelDefinition"]; } | null; /** * Usage Info * @description Usage information for this model */ - usage_info?: string | null; + usage_info: string | null; + /** + * Type + * @default main + * @constant + */ + type: "main"; /** * Trigger Phrases * @description Set of trigger phrases for this model */ - trigger_phrases?: string[] | null; + trigger_phrases: string[] | null; /** @description Default settings for this model */ - default_settings?: components["schemas"]["MainModelDefaultSettings"] | null; - /** - * Variant - * @default normal - */ - variant?: components["schemas"]["ModelVariantType"] | components["schemas"]["ClipVariantType"] | null; + default_settings: components["schemas"]["MainModelDefaultSettings"] | null; + /** Variant */ + variant: components["schemas"]["ModelVariantType"] | components["schemas"]["FluxVariantType"]; /** * Config Path - * @description path to the checkpoint model config file + * @description Path to the config for this model, if any. */ - config_path: string; + config_path: string | null; /** * Converted At * @description When this model was last converted to diffusers */ - converted_at?: number | null; + converted_at: number | null; + /** + * Base + * @default flux + * @constant + */ + base: "flux"; + /** + * Format + * @default bnb_quantized_nf4b + * @constant + */ + format: "bnb_quantized_nf4b"; /** @default epsilon */ - prediction_type?: components["schemas"]["SchedulerPredictionType"]; + prediction_type: components["schemas"]["SchedulerPredictionType"]; /** * Upcast Attention * @default false */ - upcast_attention?: boolean; + upcast_attention: boolean; }; /** * MainCheckpointConfig @@ -14862,20 +14917,10 @@ export type components = { */ name: string; /** - * Type - * @default main - * @constant - */ - type: "main"; - /** - * Format - * @description Format of the provided checkpoint model - * @default checkpoint - * @enum {string} + * Description + * @description Model description */ - format: "checkpoint" | "bnb_quantized_nf4b" | "gguf_quantized"; - /** @description The base model. */ - base: components["schemas"]["BaseModelType"]; + description: string | null; /** * Source * @description The original source of the model (path, URL or repo_id). @@ -14883,62 +14928,71 @@ export type components = { source: string; /** @description The type of source */ source_type: components["schemas"]["ModelSourceType"]; - /** - * Description - * @description Model description - */ - description?: string | null; /** * Source Api Response * @description The original API response from the source, as stringified JSON. */ - source_api_response?: string | null; + source_api_response: string | null; /** * Cover Image * @description Url for image to preview model */ - cover_image?: string | null; + cover_image: string | null; /** * Submodels * @description Loadable submodels in this model */ - submodels?: { + submodels: { [key: string]: components["schemas"]["SubmodelDefinition"]; } | null; /** * Usage Info * @description Usage information for this model */ - usage_info?: string | null; + usage_info: string | null; + /** + * Type + * @default main + * @constant + */ + type: "main"; /** * Trigger Phrases * @description Set of trigger phrases for this model */ - trigger_phrases?: string[] | null; + trigger_phrases: string[] | null; /** @description Default settings for this model */ - default_settings?: components["schemas"]["MainModelDefaultSettings"] | null; - /** - * Variant - * @default normal - */ - variant?: components["schemas"]["ModelVariantType"] | components["schemas"]["ClipVariantType"] | null; + default_settings: components["schemas"]["MainModelDefaultSettings"] | null; + /** Variant */ + variant: components["schemas"]["ModelVariantType"] | components["schemas"]["FluxVariantType"]; /** * Config Path - * @description path to the checkpoint model config file + * @description Path to the config for this model, if any. */ - config_path: string; + config_path: string | null; /** * Converted At * @description When this model was last converted to diffusers */ - converted_at?: number | null; + converted_at: number | null; + /** + * Base + * @enum {string} + */ + base: "sd-1" | "sd-2" | "sd-3" | "sdxl" | "sdxl-refiner" | "flux" | "cogview4"; + /** + * Format + * @default checkpoint + * @constant + */ + format: "checkpoint"; /** @default epsilon */ - prediction_type?: components["schemas"]["SchedulerPredictionType"]; + prediction_type: components["schemas"]["SchedulerPredictionType"]; /** * Upcast Attention * @default false */ - upcast_attention?: boolean; + upcast_attention: boolean; }; /** * MainDiffusersConfig @@ -14971,19 +15025,10 @@ export type components = { */ name: string; /** - * Type - * @default main - * @constant - */ - type: "main"; - /** - * Format - * @default diffusers - * @constant + * Description + * @description Model description */ - format: "diffusers"; - /** @description The base model. */ - base: components["schemas"]["BaseModelType"]; + description: string | null; /** * Source * @description The original source of the model (path, URL or repo_id). @@ -14991,47 +15036,56 @@ export type components = { source: string; /** @description The type of source */ source_type: components["schemas"]["ModelSourceType"]; - /** - * Description - * @description Model description - */ - description?: string | null; /** * Source Api Response * @description The original API response from the source, as stringified JSON. */ - source_api_response?: string | null; + source_api_response: string | null; /** * Cover Image * @description Url for image to preview model */ - cover_image?: string | null; + cover_image: string | null; /** * Submodels * @description Loadable submodels in this model */ - submodels?: { + submodels: { [key: string]: components["schemas"]["SubmodelDefinition"]; } | null; /** * Usage Info * @description Usage information for this model */ - usage_info?: string | null; + usage_info: string | null; + /** + * Type + * @default main + * @constant + */ + type: "main"; /** * Trigger Phrases * @description Set of trigger phrases for this model */ - trigger_phrases?: string[] | null; + trigger_phrases: string[] | null; /** @description Default settings for this model */ - default_settings?: components["schemas"]["MainModelDefaultSettings"] | null; + default_settings: components["schemas"]["MainModelDefaultSettings"] | null; + /** Variant */ + variant: components["schemas"]["ModelVariantType"] | components["schemas"]["FluxVariantType"]; /** - * Variant - * @default normal + * Format + * @default diffusers + * @constant */ - variant?: components["schemas"]["ModelVariantType"] | components["schemas"]["ClipVariantType"] | null; + format: "diffusers"; /** @default */ - repo_variant?: components["schemas"]["ModelRepoVariant"] | null; + repo_variant: components["schemas"]["ModelRepoVariant"] | null; + /** + * Base + * @enum {string} + */ + base: "sd-1" | "sd-2" | "sd-3" | "sdxl" | "sdxl-refiner" | "flux" | "cogview4"; }; /** * MainGGUFCheckpointConfig @@ -15064,19 +15118,10 @@ export type components = { */ name: string; /** - * Type - * @default main - * @constant - */ - type: "main"; - /** - * Format - * @default gguf_quantized - * @constant + * Description + * @description Model description */ - format: "gguf_quantized"; - /** @description The base model. */ - base: components["schemas"]["BaseModelType"]; + description: string | null; /** * Source * @description The original source of the model (path, URL or repo_id). @@ -15084,62 +15129,72 @@ export type components = { source: string; /** @description The type of source */ source_type: components["schemas"]["ModelSourceType"]; - /** - * Description - * @description Model description - */ - description?: string | null; /** * Source Api Response * @description The original API response from the source, as stringified JSON. */ - source_api_response?: string | null; + source_api_response: string | null; /** * Cover Image * @description Url for image to preview model */ - cover_image?: string | null; + cover_image: string | null; /** * Submodels * @description Loadable submodels in this model */ - submodels?: { + submodels: { [key: string]: components["schemas"]["SubmodelDefinition"]; } | null; /** * Usage Info * @description Usage information for this model */ - usage_info?: string | null; + usage_info: string | null; + /** + * Type + * @default main + * @constant + */ + type: "main"; /** * Trigger Phrases * @description Set of trigger phrases for this model */ - trigger_phrases?: string[] | null; + trigger_phrases: string[] | null; /** @description Default settings for this model */ - default_settings?: components["schemas"]["MainModelDefaultSettings"] | null; - /** - * Variant - * @default normal - */ - variant?: components["schemas"]["ModelVariantType"] | components["schemas"]["ClipVariantType"] | null; + default_settings: components["schemas"]["MainModelDefaultSettings"] | null; + /** Variant */ + variant: components["schemas"]["ModelVariantType"] | components["schemas"]["FluxVariantType"]; /** * Config Path - * @description path to the checkpoint model config file + * @description Path to the config for this model, if any. */ - config_path: string; + config_path: string | null; /** * Converted At * @description When this model was last converted to diffusers */ - converted_at?: number | null; + converted_at: number | null; + /** + * Base + * @default flux + * @constant + */ + base: "flux"; + /** + * Format + * @default gguf_quantized + * @constant + */ + format: "gguf_quantized"; /** @default epsilon */ - prediction_type?: components["schemas"]["SchedulerPredictionType"]; + prediction_type: components["schemas"]["SchedulerPredictionType"]; /** * Upcast Attention * @default false */ - upcast_attention?: boolean; + upcast_attention: boolean; }; /** MainModelDefaultSettings */ MainModelDefaultSettings: { @@ -17428,7 +17483,7 @@ export type components = { * Variant * @description The variant of the model. */ - variant?: components["schemas"]["ModelVariantType"] | components["schemas"]["ClipVariantType"] | null; + variant?: components["schemas"]["ModelVariantType"] | components["schemas"]["ClipVariantType"] | components["schemas"]["FluxVariantType"] | null; /** @description The prediction type of the model. */ prediction_type?: components["schemas"]["SchedulerPredictionType"] | null; /** @@ -20105,19 +20160,10 @@ export type components = { */ name: string; /** - * Type - * @default siglip - * @constant - */ - type: "siglip"; - /** - * Format - * @default diffusers - * @constant + * Description + * @description Model description */ - format: "diffusers"; - /** @description The base model. */ - base: components["schemas"]["BaseModelType"]; + description: string | null; /** * Source * @description The original source of the model (path, URL or repo_id). @@ -20125,39 +20171,52 @@ export type components = { source: string; /** @description The type of source */ source_type: components["schemas"]["ModelSourceType"]; - /** - * Description - * @description Model description - */ - description?: string | null; /** * Source Api Response * @description The original API response from the source, as stringified JSON. */ - source_api_response?: string | null; + source_api_response: string | null; /** * Cover Image * @description Url for image to preview model */ - cover_image?: string | null; + cover_image: string | null; /** * Submodels * @description Loadable submodels in this model */ - submodels?: { + submodels: { [key: string]: components["schemas"]["SubmodelDefinition"]; } | null; /** * Usage Info * @description Usage information for this model */ - usage_info?: string | null; - /** @default */ - repo_variant?: components["schemas"]["ModelRepoVariant"] | null; - }; - /** - * Image-to-Image (Autoscale) - * @description Run any spandrel image-to-image model (https://github.com/chaiNNer-org/spandrel) until the target scale is reached. + usage_info: string | null; + /** + * Format + * @default diffusers + * @constant + */ + format: "diffusers"; + /** @default */ + repo_variant: components["schemas"]["ModelRepoVariant"] | null; + /** + * Type + * @default siglip + * @constant + */ + type: "siglip"; + /** + * Base + * @default any + * @constant + */ + base: "any"; + }; + /** + * Image-to-Image (Autoscale) + * @description Run any spandrel image-to-image model (https://github.com/chaiNNer-org/spandrel) until the target scale is reached. */ SpandrelImageToImageAutoscaleInvocation: { /** @@ -20254,19 +20313,10 @@ export type components = { */ name: string; /** - * Type - * @default spandrel_image_to_image - * @constant - */ - type: "spandrel_image_to_image"; - /** - * Format - * @default checkpoint - * @constant + * Description + * @description Model description */ - format: "checkpoint"; - /** @description The base model. */ - base: components["schemas"]["BaseModelType"]; + description: string | null; /** * Source * @description The original source of the model (path, URL or repo_id). @@ -20274,33 +20324,46 @@ export type components = { source: string; /** @description The type of source */ source_type: components["schemas"]["ModelSourceType"]; - /** - * Description - * @description Model description - */ - description?: string | null; /** * Source Api Response * @description The original API response from the source, as stringified JSON. */ - source_api_response?: string | null; + source_api_response: string | null; /** * Cover Image * @description Url for image to preview model */ - cover_image?: string | null; + cover_image: string | null; /** * Submodels * @description Loadable submodels in this model */ - submodels?: { + submodels: { [key: string]: components["schemas"]["SubmodelDefinition"]; } | null; /** * Usage Info * @description Usage information for this model */ - usage_info?: string | null; + usage_info: string | null; + /** + * Base + * @default any + * @constant + */ + base: "any"; + /** + * Type + * @default spandrel_image_to_image + * @constant + */ + type: "spandrel_image_to_image"; + /** + * Format + * @default checkpoint + * @constant + */ + format: "checkpoint"; }; /** * Image-to-Image @@ -20940,7 +21003,7 @@ export type components = { path_or_prefix: string; model_type: components["schemas"]["ModelType"]; /** Variant */ - variant?: components["schemas"]["ModelVariantType"] | components["schemas"]["ClipVariantType"] | null; + variant?: components["schemas"]["ModelVariantType"] | components["schemas"]["ClipVariantType"] | components["schemas"]["FluxVariantType"] | null; }; /** * Subtract Integers @@ -21014,19 +21077,10 @@ export type components = { */ name: string; /** - * Type - * @default t2i_adapter - * @constant - */ - type: "t2i_adapter"; - /** - * Format - * @default diffusers - * @constant + * Description + * @description Model description */ - format: "diffusers"; - /** @description The base model. */ - base: components["schemas"]["BaseModelType"]; + description: string | null; /** * Source * @description The original source of the model (path, URL or repo_id). @@ -21034,37 +21088,48 @@ export type components = { source: string; /** @description The type of source */ source_type: components["schemas"]["ModelSourceType"]; - /** - * Description - * @description Model description - */ - description?: string | null; /** * Source Api Response * @description The original API response from the source, as stringified JSON. */ - source_api_response?: string | null; + source_api_response: string | null; /** * Cover Image * @description Url for image to preview model */ - cover_image?: string | null; + cover_image: string | null; /** * Submodels * @description Loadable submodels in this model */ - submodels?: { + submodels: { [key: string]: components["schemas"]["SubmodelDefinition"]; } | null; /** * Usage Info * @description Usage information for this model */ - usage_info?: string | null; - /** @description Default settings for this model */ - default_settings?: components["schemas"]["ControlAdapterDefaultSettings"] | null; + usage_info: string | null; + default_settings: components["schemas"]["ControlAdapterDefaultSettings"] | null; + /** + * Format + * @default diffusers + * @constant + */ + format: "diffusers"; /** @default */ - repo_variant?: components["schemas"]["ModelRepoVariant"] | null; + repo_variant: components["schemas"]["ModelRepoVariant"] | null; + /** + * Base + * @enum {string} + */ + base: "sd-1" | "sd-2" | "sdxl"; + /** + * Type + * @default t2i_adapter + * @constant + */ + type: "t2i_adapter"; }; /** T2IAdapterField */ T2IAdapterField: { @@ -21242,19 +21307,10 @@ export type components = { */ name: string; /** - * Type - * @default t5_encoder - * @constant - */ - type: "t5_encoder"; - /** - * Format - * @default bnb_quantized_int8b - * @constant + * Description + * @description Model description */ - format: "bnb_quantized_int8b"; - /** @description The base model. */ - base: components["schemas"]["BaseModelType"]; + description: string | null; /** * Source * @description The original source of the model (path, URL or repo_id). @@ -21262,33 +21318,46 @@ export type components = { source: string; /** @description The type of source */ source_type: components["schemas"]["ModelSourceType"]; - /** - * Description - * @description Model description - */ - description?: string | null; /** * Source Api Response * @description The original API response from the source, as stringified JSON. */ - source_api_response?: string | null; + source_api_response: string | null; /** * Cover Image * @description Url for image to preview model */ - cover_image?: string | null; + cover_image: string | null; /** * Submodels * @description Loadable submodels in this model */ - submodels?: { + submodels: { [key: string]: components["schemas"]["SubmodelDefinition"]; } | null; /** * Usage Info * @description Usage information for this model */ - usage_info?: string | null; + usage_info: string | null; + /** + * Base + * @default any + * @constant + */ + base: "any"; + /** + * Type + * @default t5_encoder + * @constant + */ + type: "t5_encoder"; + /** + * Format + * @default bnb_quantized_int8b + * @constant + */ + format: "bnb_quantized_int8b"; }; /** T5EncoderConfig */ T5EncoderConfig: { @@ -21318,19 +21387,10 @@ export type components = { */ name: string; /** - * Type - * @default t5_encoder - * @constant - */ - type: "t5_encoder"; - /** - * Format - * @default t5_encoder - * @constant + * Description + * @description Model description */ - format: "t5_encoder"; - /** @description The base model. */ - base: components["schemas"]["BaseModelType"]; + description: string | null; /** * Source * @description The original source of the model (path, URL or repo_id). @@ -21338,33 +21398,46 @@ export type components = { source: string; /** @description The type of source */ source_type: components["schemas"]["ModelSourceType"]; - /** - * Description - * @description Model description - */ - description?: string | null; /** * Source Api Response * @description The original API response from the source, as stringified JSON. */ - source_api_response?: string | null; + source_api_response: string | null; /** * Cover Image * @description Url for image to preview model */ - cover_image?: string | null; + cover_image: string | null; /** * Submodels * @description Loadable submodels in this model */ - submodels?: { + submodels: { [key: string]: components["schemas"]["SubmodelDefinition"]; } | null; /** * Usage Info * @description Usage information for this model */ - usage_info?: string | null; + usage_info: string | null; + /** + * Base + * @default any + * @constant + */ + base: "any"; + /** + * Type + * @default t5_encoder + * @constant + */ + type: "t5_encoder"; + /** + * Format + * @default t5_encoder + * @constant + */ + format: "t5_encoder"; }; /** T5EncoderField */ T5EncoderField: { @@ -21431,19 +21504,10 @@ export type components = { */ name: string; /** - * Type - * @default embedding - * @constant - */ - type: "embedding"; - /** - * Format - * @default embedding_file - * @constant + * Description + * @description Model description */ - format: "embedding_file"; - /** @description The base model. */ - base: components["schemas"]["BaseModelType"]; + description: string | null; /** * Source * @description The original source of the model (path, URL or repo_id). @@ -21451,33 +21515,45 @@ export type components = { source: string; /** @description The type of source */ source_type: components["schemas"]["ModelSourceType"]; - /** - * Description - * @description Model description - */ - description?: string | null; /** * Source Api Response * @description The original API response from the source, as stringified JSON. */ - source_api_response?: string | null; + source_api_response: string | null; /** * Cover Image * @description Url for image to preview model */ - cover_image?: string | null; + cover_image: string | null; /** * Submodels * @description Loadable submodels in this model */ - submodels?: { + submodels: { [key: string]: components["schemas"]["SubmodelDefinition"]; } | null; /** * Usage Info * @description Usage information for this model */ - usage_info?: string | null; + usage_info: string | null; + /** + * Base + * @enum {string} + */ + base: "sd-1" | "sd-2" | "sdxl"; + /** + * Type + * @default embedding + * @constant + */ + type: "embedding"; + /** + * Format + * @default embedding_file + * @constant + */ + format: "embedding_file"; }; /** * TextualInversionFolderConfig @@ -21510,19 +21586,10 @@ export type components = { */ name: string; /** - * Type - * @default embedding - * @constant - */ - type: "embedding"; - /** - * Format - * @default embedding_folder - * @constant + * Description + * @description Model description */ - format: "embedding_folder"; - /** @description The base model. */ - base: components["schemas"]["BaseModelType"]; + description: string | null; /** * Source * @description The original source of the model (path, URL or repo_id). @@ -21530,33 +21597,45 @@ export type components = { source: string; /** @description The type of source */ source_type: components["schemas"]["ModelSourceType"]; - /** - * Description - * @description Model description - */ - description?: string | null; /** * Source Api Response * @description The original API response from the source, as stringified JSON. */ - source_api_response?: string | null; + source_api_response: string | null; /** * Cover Image * @description Url for image to preview model */ - cover_image?: string | null; + cover_image: string | null; /** * Submodels * @description Loadable submodels in this model */ - submodels?: { + submodels: { [key: string]: components["schemas"]["SubmodelDefinition"]; } | null; /** * Usage Info * @description Usage information for this model */ - usage_info?: string | null; + usage_info: string | null; + /** + * Base + * @enum {string} + */ + base: "sd-1" | "sd-2" | "sdxl"; + /** + * Type + * @default embedding + * @constant + */ + type: "embedding"; + /** + * Format + * @default embedding_folder + * @constant + */ + format: "embedding_folder"; }; /** Tile */ Tile: { @@ -21968,23 +22047,10 @@ export type components = { */ name: string; /** - * Type - * @default unknown - * @constant - */ - type: "unknown"; - /** - * Format - * @default unknown - * @constant - */ - format: "unknown"; - /** - * Base - * @default unknown - * @constant + * Description + * @description Model description */ - base: "unknown"; + description: string | null; /** * Source * @description The original source of the model (path, URL or repo_id). @@ -21992,33 +22058,46 @@ export type components = { source: string; /** @description The type of source */ source_type: components["schemas"]["ModelSourceType"]; - /** - * Description - * @description Model description - */ - description?: string | null; /** * Source Api Response * @description The original API response from the source, as stringified JSON. */ - source_api_response?: string | null; + source_api_response: string | null; /** * Cover Image * @description Url for image to preview model */ - cover_image?: string | null; + cover_image: string | null; /** * Submodels * @description Loadable submodels in this model */ - submodels?: { + submodels: { [key: string]: components["schemas"]["SubmodelDefinition"]; } | null; /** * Usage Info * @description Usage information for this model */ - usage_info?: string | null; + usage_info: string | null; + /** + * Base + * @default unknown + * @constant + */ + base: "unknown"; + /** + * Type + * @default unknown + * @constant + */ + type: "unknown"; + /** + * Format + * @default unknown + * @constant + */ + format: "unknown"; }; /** * Unsharp Mask @@ -22146,20 +22225,10 @@ export type components = { */ name: string; /** - * Type - * @default vae - * @constant - */ - type: "vae"; - /** - * Format - * @description Format of the provided checkpoint model - * @default checkpoint - * @enum {string} + * Description + * @description Model description */ - format: "checkpoint" | "bnb_quantized_nf4b" | "gguf_quantized"; - /** @description The base model. */ - base: components["schemas"]["BaseModelType"]; + description: string | null; /** * Source * @description The original source of the model (path, URL or repo_id). @@ -22167,43 +22236,55 @@ export type components = { source: string; /** @description The type of source */ source_type: components["schemas"]["ModelSourceType"]; - /** - * Description - * @description Model description - */ - description?: string | null; /** * Source Api Response * @description The original API response from the source, as stringified JSON. */ - source_api_response?: string | null; + source_api_response: string | null; /** * Cover Image * @description Url for image to preview model */ - cover_image?: string | null; + cover_image: string | null; /** * Submodels * @description Loadable submodels in this model */ - submodels?: { + submodels: { [key: string]: components["schemas"]["SubmodelDefinition"]; } | null; /** * Usage Info * @description Usage information for this model */ - usage_info?: string | null; + usage_info: string | null; /** * Config Path - * @description path to the checkpoint model config file + * @description Path to the config for this model, if any. */ - config_path: string; + config_path: string | null; /** * Converted At * @description When this model was last converted to diffusers */ - converted_at?: number | null; + converted_at: number | null; + /** + * Base + * @enum {string} + */ + base: "sd-1" | "sd-2" | "sdxl" | "flux"; + /** + * Type + * @default vae + * @constant + */ + type: "vae"; + /** + * Format + * @default checkpoint + * @constant + */ + format: "checkpoint"; }; /** * VAEDiffusersConfig @@ -22236,19 +22317,10 @@ export type components = { */ name: string; /** - * Type - * @default vae - * @constant - */ - type: "vae"; - /** - * Format - * @default diffusers - * @constant + * Description + * @description Model description */ - format: "diffusers"; - /** @description The base model. */ - base: components["schemas"]["BaseModelType"]; + description: string | null; /** * Source * @description The original source of the model (path, URL or repo_id). @@ -22256,33 +22328,47 @@ export type components = { source: string; /** @description The type of source */ source_type: components["schemas"]["ModelSourceType"]; - /** - * Description - * @description Model description - */ - description?: string | null; /** * Source Api Response * @description The original API response from the source, as stringified JSON. */ - source_api_response?: string | null; + source_api_response: string | null; /** * Cover Image * @description Url for image to preview model */ - cover_image?: string | null; + cover_image: string | null; /** * Submodels * @description Loadable submodels in this model */ - submodels?: { + submodels: { [key: string]: components["schemas"]["SubmodelDefinition"]; } | null; /** * Usage Info * @description Usage information for this model */ - usage_info?: string | null; + usage_info: string | null; + /** + * Format + * @default diffusers + * @constant + */ + format: "diffusers"; + /** @default */ + repo_variant: components["schemas"]["ModelRepoVariant"] | null; + /** + * Base + * @enum {string} + */ + base: "sd-1" | "sdxl"; + /** + * Type + * @default vae + * @constant + */ + type: "vae"; }; /** VAEField */ VAEField: { @@ -22404,19 +22490,10 @@ export type components = { */ name: string; /** - * Type - * @default video - * @constant - */ - type: "video"; - /** - * Format - * @default api - * @constant + * Description + * @description Model description */ - format: "api"; - /** @description The base model. */ - base: components["schemas"]["BaseModelType"]; + description: string | null; /** * Source * @description The original source of the model (path, URL or repo_id). @@ -22424,45 +22501,52 @@ export type components = { source: string; /** @description The type of source */ source_type: components["schemas"]["ModelSourceType"]; - /** - * Description - * @description Model description - */ - description?: string | null; /** * Source Api Response * @description The original API response from the source, as stringified JSON. */ - source_api_response?: string | null; + source_api_response: string | null; /** * Cover Image * @description Url for image to preview model */ - cover_image?: string | null; + cover_image: string | null; /** * Submodels * @description Loadable submodels in this model */ - submodels?: { + submodels: { [key: string]: components["schemas"]["SubmodelDefinition"]; } | null; /** * Usage Info * @description Usage information for this model */ - usage_info?: string | null; + usage_info: string | null; + /** + * Type + * @default video + * @constant + */ + type: "video"; + /** + * Base + * @enum {string} + */ + base: "veo3" | "runway"; + /** + * Format + * @default api + * @constant + */ + format: "api"; /** * Trigger Phrases * @description Set of trigger phrases for this model */ - trigger_phrases?: string[] | null; + trigger_phrases: string[] | null; /** @description Default settings for this model */ - default_settings?: components["schemas"]["MainModelDefaultSettings"] | null; - /** - * Variant - * @default normal - */ - variant?: components["schemas"]["ModelVariantType"] | components["schemas"]["ClipVariantType"] | null; + default_settings: components["schemas"]["MainModelDefaultSettings"] | null; }; /** * VideoDTO diff --git a/invokeai/frontend/web/src/services/api/types.ts b/invokeai/frontend/web/src/services/api/types.ts index b9798c04d99..7506b29356f 100644 --- a/invokeai/frontend/web/src/services/api/types.ts +++ b/invokeai/frontend/web/src/services/api/types.ts @@ -106,51 +106,34 @@ export const isVideoDTO = (dto: ImageDTO | VideoDTO): dto is VideoDTO => { }; // Model Configs -export type ControlLoRAModelConfig = S['ControlLoRALyCORISConfig'] | S['ControlLoRADiffusersConfig']; -export type LoRAModelConfig = S['LoRADiffusersConfig'] | S['LoRALyCORISConfig'] | S['LoRAOmiConfig']; -export type VAEModelConfig = S['VAECheckpointConfig'] | S['VAEDiffusersConfig']; -export type ControlNetModelConfig = S['ControlNetDiffusersConfig'] | S['ControlNetCheckpointConfig']; -export type IPAdapterModelConfig = S['IPAdapterInvokeAIConfig'] | S['IPAdapterCheckpointConfig']; -export type T2IAdapterModelConfig = S['T2IAdapterConfig']; -export type CLIPLEmbedModelConfig = S['CLIPLEmbedDiffusersConfig']; -export type CLIPGEmbedModelConfig = S['CLIPGEmbedDiffusersConfig']; -export type CLIPEmbedModelConfig = CLIPLEmbedModelConfig | CLIPGEmbedModelConfig; -type LlavaOnevisionConfig = S['LlavaOnevisionConfig']; -export type T5EncoderModelConfig = S['T5EncoderConfig']; -export type T5EncoderBnbQuantizedLlmInt8bModelConfig = S['T5EncoderBnbQuantizedLlmInt8bConfig']; -export type SpandrelImageToImageModelConfig = S['SpandrelImageToImageConfig']; -type TextualInversionModelConfig = S['TextualInversionFileConfig'] | S['TextualInversionFolderConfig']; -type DiffusersModelConfig = S['MainDiffusersConfig']; -export type CheckpointModelConfig = S['MainCheckpointConfig']; -type CLIPVisionDiffusersConfig = S['CLIPVisionDiffusersConfig']; -type SigLipModelConfig = S['SigLIPConfig']; -export type FLUXReduxModelConfig = S['FluxReduxConfig']; -type ApiModelConfig = S['ApiModelConfig']; -export type VideoApiModelConfig = S['VideoApiModelConfig']; -type UnknownModelConfig = S['UnknownModelConfig']; -export type MainModelConfig = DiffusersModelConfig | CheckpointModelConfig | ApiModelConfig; +export type AnyModelConfig = S['AnyModelConfig']; +export type MainModelConfig = Extract; +export type ControlLoRAModelConfig = Extract; +export type LoRAModelConfig = Extract; +export type VAEModelConfig = Extract; +export type ControlNetModelConfig = Extract; +export type IPAdapterModelConfig = Extract; +export type T2IAdapterModelConfig = Extract; +export type CLIPLEmbedModelConfig = Extract; +export type CLIPGEmbedModelConfig = Extract; +export type CLIPEmbedModelConfig = Extract; +type LlavaOnevisionConfig = Extract; +export type T5EncoderModelConfig = Extract; +export type T5EncoderBnbQuantizedLlmInt8bModelConfig = Extract< + S['AnyModelConfig'], + { type: 't5_encoder'; format: 'bnb_quantized_int8b' } +>; +export type SpandrelImageToImageModelConfig = Extract; +export type CheckpointModelConfig = Extract; +type CLIPVisionDiffusersConfig = Extract; +type SigLipModelConfig = Extract; +export type FLUXReduxModelConfig = Extract; +type ApiModelConfig = Extract; +export type VideoApiModelConfig = Extract; +type UnknownModelConfig = Extract; export type FLUXKontextModelConfig = MainModelConfig; export type ChatGPT4oModelConfig = ApiModelConfig; export type Gemini2_5ModelConfig = ApiModelConfig; -export type AnyModelConfig = - | ControlLoRAModelConfig - | LoRAModelConfig - | VAEModelConfig - | ControlNetModelConfig - | IPAdapterModelConfig - | T5EncoderModelConfig - | T5EncoderBnbQuantizedLlmInt8bModelConfig - | CLIPEmbedModelConfig - | T2IAdapterModelConfig - | SpandrelImageToImageModelConfig - | TextualInversionModelConfig - | MainModelConfig - | VideoApiModelConfig - | CLIPVisionDiffusersConfig - | SigLipModelConfig - | FLUXReduxModelConfig - | LlavaOnevisionConfig - | UnknownModelConfig; /** * Checks if a list of submodels contains any that match a given variant or type diff --git a/invokeai/frontend/web/src/services/events/setEventListeners.tsx b/invokeai/frontend/web/src/services/events/setEventListeners.tsx index 0cfa4cef26f..5804399f1f7 100644 --- a/invokeai/frontend/web/src/services/events/setEventListeners.tsx +++ b/invokeai/frontend/web/src/services/events/setEventListeners.tsx @@ -295,18 +295,7 @@ export const setEventListeners = ({ socket, store, setIsConnected }: SetEventLis const { id, config } = data; - if ( - config.type === 'unknown' || - config.base === 'unknown' || - /** - * Checking if type/base are 'unknown' technically narrows the config such that it's not possible for a config - * that passes to the `config.[type|base] === 'unknown'` checks. In the future, if we have more model config - * classes, this may change, so we will continue to check all three. Any one being 'unknown' is concerning - * enough to warrant a toast. - */ - /* @ts-expect-error See note above */ - config.format === 'unknown' - ) { + if (config.type === 'unknown') { toast({ id: 'UNKNOWN_MODEL', title: t('modelManager.unidentifiedModelTitle'), diff --git a/scripts/classify-model.py b/scripts/classify-model.py index 6411b4c7055..2ae253b72fe 100755 --- a/scripts/classify-model.py +++ b/scripts/classify-model.py @@ -7,7 +7,8 @@ from typing import get_args from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS -from invokeai.backend.model_manager import InvalidModelConfigException, ModelConfigBase, ModelProbe +from invokeai.backend.model_manager import InvalidModelConfigException, ModelProbe +from invokeai.backend.model_manager.config import ModelConfigFactory algos = ", ".join(set(get_args(HASHING_ALGORITHMS))) @@ -30,7 +31,10 @@ def classify_with_fallback(path: Path, hash_algo: HASHING_ALGORITHMS): try: return ModelProbe.probe(path, hash_algo=hash_algo) except InvalidModelConfigException: - return ModelConfigBase.classify(path, hash_algo) + return ModelConfigFactory.from_model_on_disk( + mod=path, + hash_algo=hash_algo, + ) for path in args.model_path: diff --git a/tests/app/services/model_records/test_model_records_sql.py b/tests/app/services/model_records/test_model_records_sql.py index c8b5698dd8b..41bbc2c024d 100644 --- a/tests/app/services/model_records/test_model_records_sql.py +++ b/tests/app/services/model_records/test_model_records_sql.py @@ -21,7 +21,7 @@ ControlAdapterDefaultSettings, MainDiffusersConfig, MainModelDefaultSettings, - TextualInversionFileConfig, + TI_File_Config, VAEDiffusersConfig, ) from invokeai.backend.model_manager.taxonomy import ModelSourceType @@ -40,8 +40,8 @@ def store( return ModelRecordServiceSQL(db, logger) -def example_ti_config(key: Optional[str] = None) -> TextualInversionFileConfig: - config = TextualInversionFileConfig( +def example_ti_config(key: Optional[str] = None) -> TI_File_Config: + config = TI_File_Config( source="test/source/", source_type=ModelSourceType.Path, path="/tmp/pokemon.bin", @@ -61,7 +61,7 @@ def test_type(store: ModelRecordServiceBase): config = example_ti_config("key1") store.add_model(config) config1 = store.get_model("key1") - assert isinstance(config1, TextualInversionFileConfig) + assert isinstance(config1, TI_File_Config) def test_raises_on_violating_uniqueness(store: ModelRecordServiceBase): diff --git a/tests/backend/patches/lora_conversions/test_flux_aitoolkit_lora_conversion_utils.py b/tests/backend/patches/lora_conversions/test_flux_aitoolkit_lora_conversion_utils.py index ed3e05a9b26..051ed210cd5 100644 --- a/tests/backend/patches/lora_conversions/test_flux_aitoolkit_lora_conversion_utils.py +++ b/tests/backend/patches/lora_conversions/test_flux_aitoolkit_lora_conversion_utils.py @@ -2,7 +2,8 @@ import pytest from invokeai.backend.flux.model import Flux -from invokeai.backend.flux.util import params +from invokeai.backend.flux.util import get_flux_transformers_params +from invokeai.backend.model_manager.taxonomy import ModelVariantType from invokeai.backend.patches.lora_conversions.flux_aitoolkit_lora_conversion_utils import ( _group_state_by_submodel, is_state_dict_likely_in_flux_aitoolkit_format, @@ -44,7 +45,7 @@ def test_flux_aitoolkit_transformer_state_dict_is_in_invoke_format(): # Initialize a FLUX model on the meta device. with accelerate.init_empty_weights(): - model = Flux(params["flux-schnell"]) + model = Flux(get_flux_transformers_params(ModelVariantType.FluxSchnell)) model_keys = set(model.state_dict().keys()) for converted_key_prefix in converted_key_prefixes: diff --git a/tests/backend/patches/lora_conversions/test_flux_kohya_lora_conversion_utils.py b/tests/backend/patches/lora_conversions/test_flux_kohya_lora_conversion_utils.py index 52b8ecc9c9c..eb8846f456b 100644 --- a/tests/backend/patches/lora_conversions/test_flux_kohya_lora_conversion_utils.py +++ b/tests/backend/patches/lora_conversions/test_flux_kohya_lora_conversion_utils.py @@ -3,7 +3,8 @@ import torch from invokeai.backend.flux.model import Flux -from invokeai.backend.flux.util import params +from invokeai.backend.flux.util import get_flux_transformers_params +from invokeai.backend.model_manager.taxonomy import ModelVariantType from invokeai.backend.patches.lora_conversions.flux_kohya_lora_conversion_utils import ( _convert_flux_transformer_kohya_state_dict_to_invoke_format, is_state_dict_likely_in_flux_kohya_format, @@ -63,7 +64,7 @@ def test_convert_flux_transformer_kohya_state_dict_to_invoke_format(): # Initialize a FLUX model on the meta device. with accelerate.init_empty_weights(): - model = Flux(params["flux-dev"]) + model = Flux(get_flux_transformers_params(ModelVariantType.FluxSchnell)) model_keys = set(model.state_dict().keys()) # Assert that the converted state dict matches the keys in the actual model. diff --git a/tests/test_model_probe.py b/tests/test_model_probe.py index 8ee4f8df1f5..03a7428382a 100644 --- a/tests/test_model_probe.py +++ b/tests/test_model_probe.py @@ -15,7 +15,7 @@ AnyModelConfig, InvalidModelConfigException, MainDiffusersConfig, - ModelConfigBase, + Config_Base, ModelConfigFactory, get_model_discriminator_value, ) @@ -109,13 +109,13 @@ def test_probe_sd1_diffusers_inpainting(datadir: Path): assert config.repo_variant is ModelRepoVariant.FP16 -class MinimalConfigExample(ModelConfigBase): +class MinimalConfigExample(Config_Base): type: ModelType = ModelType.Main format: ModelFormat = ModelFormat.Checkpoint fun_quote: str @classmethod - def matches(cls, mod: ModelOnDisk) -> bool: + def matches(cls, mod: ModelOnDisk, **overrides) -> bool: return mod.path.suffix == ".json" @classmethod @@ -132,7 +132,10 @@ def parse(cls, mod: ModelOnDisk) -> dict[str, Any]: def test_minimal_working_example(datadir: Path): model_path = datadir / "minimal_config_model.json" overrides = {"base": BaseModelType.StableDiffusion1} - config = ModelConfigBase.classify(model_path, **overrides) + config = ModelConfigFactory.from_model_on_disk( + mod=model_path, + overrides=overrides, + ) assert isinstance(config, MinimalConfigExample) assert config.base == BaseModelType.StableDiffusion1 @@ -160,7 +163,10 @@ def test_regression_against_model_probe(datadir: Path, override_model_loading): try: stripped_mod = StrippedModelOnDisk(path) - new_config = ModelConfigBase.classify(stripped_mod, hash=fake_hash, key=fake_key) + new_config = ModelConfigFactory.from_model_on_disk( + mod=stripped_mod, + overrides={"hash": fake_hash, "key": fake_key}, + ) except InvalidModelConfigException: pass @@ -169,10 +175,10 @@ def test_regression_against_model_probe(datadir: Path, override_model_loading): assert legacy_config.model_dump_json() == new_config.model_dump_json() elif legacy_config: - assert type(legacy_config) in ModelConfigBase.USING_LEGACY_PROBE + assert type(legacy_config) in Config_Base.USING_LEGACY_PROBE elif new_config: - assert type(new_config) in ModelConfigBase.USING_CLASSIFY_API + assert type(new_config) in Config_Base.USING_CLASSIFY_API else: raise ValueError(f"Both probe and classify failed to classify model at path {path}.") @@ -180,7 +186,7 @@ def test_regression_against_model_probe(datadir: Path, override_model_loading): config_type = type(legacy_config or new_config) configs_with_tests.add(config_type) - untested_configs = ModelConfigBase.all_config_classes() - configs_with_tests - {MinimalConfigExample} + untested_configs = Config_Base.all_config_classes() - configs_with_tests - {MinimalConfigExample} logger.warning(f"Function test_regression_against_model_probe missing test case for: {untested_configs}") @@ -200,7 +206,7 @@ def test_serialisation_roundtrip(): We need to ensure they are de-serialised into the original config with all relevant fields restored. """ excluded = {MinimalConfigExample} - for config_cls in ModelConfigBase.all_config_classes() - excluded: + for config_cls in Config_Base.all_config_classes() - excluded: trials_per_class = 50 configs_with_random_data = create_fake_configs(config_cls, trials_per_class) @@ -215,7 +221,7 @@ def test_serialisation_roundtrip(): def test_discriminator_tagging_for_config_instances(): """Verify that each ModelConfig instance is assigned the correct, unique Pydantic discriminator tag.""" excluded = {MinimalConfigExample} - config_classes = ModelConfigBase.all_config_classes() - excluded + config_classes = Config_Base.all_config_classes() - excluded tags = {c.get_tag() for c in config_classes} assert len(tags) == len(config_classes), "Each config should have its own unique tag" @@ -240,10 +246,10 @@ def test_inheritance_order(): It may be worth rethinking our config taxonomy in the future, but in the meantime this test can help prevent debugging effort. """ - for config_cls in ModelConfigBase.all_config_classes(): + for config_cls in Config_Base.all_config_classes(): excluded = {abc.ABC, pydantic.BaseModel, object} inheritance_list = [cls for cls in config_cls.mro() if cls not in excluded] - assert inheritance_list[-1] is ModelConfigBase + assert inheritance_list[-1] is Config_Base def test_any_model_config_includes_all_config_classes(): @@ -256,7 +262,7 @@ def test_any_model_config_includes_all_config_classes(): config_class, _ = get_args(annotated_pair) extracted.add(config_class) - expected = set(ModelConfigBase.all_config_classes()) - {MinimalConfigExample} + expected = set(Config_Base.all_config_classes()) - {MinimalConfigExample} assert extracted == expected @@ -264,7 +270,7 @@ def test_config_uniquely_matches_model(datadir: Path): model_paths = ModelSearch().search(datadir / "stripped_models") for path in model_paths: mod = StrippedModelOnDisk(path) - matches = {cls for cls in ModelConfigBase.USING_CLASSIFY_API if cls.matches(mod)} + matches = {cls for cls in Config_Base.USING_CLASSIFY_API if cls.matches(mod)} assert len(matches) <= 1, f"Model at path {path} matches multiple config classes: {matches}" if not matches: logger.warning(f"Model at path {path} does not match any config classes using classify API.")