From 94e2ceb1be8f148e48f98461f68a1e3ada349ed3 Mon Sep 17 00:00:00 2001 From: Brandon Rising Date: Tue, 27 Feb 2024 12:26:51 -0500 Subject: [PATCH 1/5] Allow TIs to be either a key or a name in the prompt during our transition to using keys --- invokeai/app/invocations/compel.py | 72 +++++++++++++++++++----------- 1 file changed, 47 insertions(+), 25 deletions(-) diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index 50f53225137..ce9b1948ebd 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -3,7 +3,7 @@ import torch from compel import Compel, ReturnedEmbeddingsType from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment -from transformers import CLIPTokenizer +from transformers import CLIPTokenizer, CLIPTextModel import invokeai.backend.util.logging as logger from invokeai.app.invocations.fields import ( @@ -18,7 +18,7 @@ from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.util.ti_utils import extract_ti_triggers_from_prompt from invokeai.backend.lora import LoRAModelRaw -from invokeai.backend.model_manager import ModelType +from invokeai.backend.model_manager.config import ModelType from invokeai.backend.model_patcher import ModelPatcher from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( BasicConditioningInfo, @@ -70,7 +70,11 @@ class CompelInvocation(BaseInvocation): @torch.no_grad() def invoke(self, context: InvocationContext) -> ConditioningOutput: tokenizer_info = context.models.load(**self.clip.tokenizer.model_dump()) + tokenizer_model = tokenizer_info.model + assert isinstance(tokenizer_model, CLIPTokenizer) text_encoder_info = context.models.load(**self.clip.text_encoder.model_dump()) + text_encoder_model = text_encoder_info.model + assert isinstance(text_encoder_model, CLIPTextModel) def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: for lora in self.clip.loras: @@ -82,21 +86,29 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: # loras = [(context.models.get(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras] - ti_list = [] + ti_list: List[Tuple[str, TextualInversionModelRaw]] = [] for trigger in extract_ti_triggers_from_prompt(self.prompt): - name = trigger[1:-1] + name_or_key = trigger[1:-1] try: - loaded_model = context.models.load(key=name).model - assert isinstance(loaded_model, TextualInversionModelRaw) - ti_list.append((name, loaded_model)) + loaded_model = context.models.load(key=name_or_key) + model = loaded_model.model + assert isinstance(model, TextualInversionModelRaw) + ti_list.append((name_or_key, model)) except UnknownModelException: - # print(e) - # import traceback - # print(traceback.format_exc()) - print(f'Warn: trigger: "{trigger}" not found') + try: + loaded_model = context.models.load_by_attrs( + model_name=name_or_key, base_model=text_encoder_info.config.base, model_type=ModelType.TextualInversion + ) + model = loaded_model.model + assert isinstance(model, TextualInversionModelRaw) + ti_list.append((name_or_key, model)) + except UnknownModelException: + logger.warning(f'trigger: "{trigger}" not found') + except ValueError: + logger.warning(f'trigger: "{trigger}" more than one similarly-named textual inversion models') with ( - ModelPatcher.apply_ti(tokenizer_info.model, text_encoder_info.model, ti_list) as ( + ModelPatcher.apply_ti(tokenizer_model, text_encoder_model, ti_list) as ( tokenizer, ti_manager, ), @@ -106,6 +118,7 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: # Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers. ModelPatcher.apply_clip_skip(text_encoder_info.model, self.clip.skipped_layers), ): + assert isinstance(text_encoder, CLIPTextModel) compel = Compel( tokenizer=tokenizer, text_encoder=text_encoder, @@ -155,7 +168,11 @@ def run_clip_compel( zero_on_empty: bool, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[ExtraConditioningInfo]]: tokenizer_info = context.models.load(**clip_field.tokenizer.model_dump()) + tokenizer_model = tokenizer_info.model + assert isinstance(tokenizer_model, CLIPTokenizer) text_encoder_info = context.models.load(**clip_field.text_encoder.model_dump()) + text_encoder_model = text_encoder_info.model + assert isinstance(text_encoder_model, CLIPTextModel) # return zero on empty if prompt == "" and zero_on_empty: @@ -189,25 +206,29 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: # loras = [(context.models.get(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras] - ti_list = [] + ti_list: List[Tuple[str, TextualInversionModelRaw]] = [] for trigger in extract_ti_triggers_from_prompt(prompt): - name = trigger[1:-1] + name_or_key = trigger[1:-1] try: - ti_model = context.models.load_by_attrs( - model_name=name, base_model=text_encoder_info.config.base, model_type=ModelType.TextualInversion - ).model - assert isinstance(ti_model, TextualInversionModelRaw) - ti_list.append((name, ti_model)) + loaded_model = context.models.load(key=name_or_key) + model = loaded_model.model + assert isinstance(model, TextualInversionModelRaw) + ti_list.append((name_or_key, model)) except UnknownModelException: - # print(e) - # import traceback - # print(traceback.format_exc()) - logger.warning(f'trigger: "{trigger}" not found') + try: + loaded_model = context.models.load_by_attrs( + model_name=name_or_key, base_model=text_encoder_info.config.base, model_type=ModelType.TextualInversion + ) + model = loaded_model.model + assert isinstance(model, TextualInversionModelRaw) + ti_list.append((name_or_key, model)) + except UnknownModelException: + logger.warning(f'trigger: "{trigger}" not found') except ValueError: logger.warning(f'trigger: "{trigger}" more than one similarly-named textual inversion models') with ( - ModelPatcher.apply_ti(tokenizer_info.model, text_encoder_info.model, ti_list) as ( + ModelPatcher.apply_ti(tokenizer_model, text_encoder_model, ti_list) as ( tokenizer, ti_manager, ), @@ -215,8 +236,9 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: # Apply the LoRA after text_encoder has been moved to its target device for faster patching. ModelPatcher.apply_lora(text_encoder, _lora_loader(), lora_prefix), # Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers. - ModelPatcher.apply_clip_skip(text_encoder_info.model, clip_field.skipped_layers), + ModelPatcher.apply_clip_skip(text_encoder_model, clip_field.skipped_layers), ): + assert isinstance(text_encoder, CLIPTextModel) compel = Compel( tokenizer=tokenizer, text_encoder=text_encoder, From 902731996fcb8b25705da23487dd1b356e2fe486 Mon Sep 17 00:00:00 2001 From: Brandon Rising Date: Tue, 27 Feb 2024 12:36:04 -0500 Subject: [PATCH 2/5] Fix one last reference to the uncasted model --- invokeai/app/invocations/compel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index ce9b1948ebd..bc0c7955dc0 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -116,7 +116,7 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: # Apply the LoRA after text_encoder has been moved to its target device for faster patching. ModelPatcher.apply_lora_text_encoder(text_encoder, _lora_loader()), # Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers. - ModelPatcher.apply_clip_skip(text_encoder_info.model, self.clip.skipped_layers), + ModelPatcher.apply_clip_skip(text_encoder_model, self.clip.skipped_layers), ): assert isinstance(text_encoder, CLIPTextModel) compel = Compel( From cb8be593fcb941651e255bd7131a6f50aca4d6bb Mon Sep 17 00:00:00 2001 From: Brandon Rising Date: Tue, 27 Feb 2024 15:20:14 -0500 Subject: [PATCH 3/5] Extract TI loading logic into util, disallow it from ever failing a generation --- invokeai/app/invocations/compel.py | 44 ++---------------------------- invokeai/app/util/ti_utils.py | 42 ++++++++++++++++++++++++++-- 2 files changed, 42 insertions(+), 44 deletions(-) diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index bc0c7955dc0..ca6ca644bba 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -16,7 +16,7 @@ from invokeai.app.invocations.primitives import ConditioningOutput from invokeai.app.services.model_records import UnknownModelException from invokeai.app.services.shared.invocation_context import InvocationContext -from invokeai.app.util.ti_utils import extract_ti_triggers_from_prompt +from invokeai.app.util.ti_utils import generate_ti_list from invokeai.backend.lora import LoRAModelRaw from invokeai.backend.model_manager.config import ModelType from invokeai.backend.model_patcher import ModelPatcher @@ -86,26 +86,7 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: # loras = [(context.models.get(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras] - ti_list: List[Tuple[str, TextualInversionModelRaw]] = [] - for trigger in extract_ti_triggers_from_prompt(self.prompt): - name_or_key = trigger[1:-1] - try: - loaded_model = context.models.load(key=name_or_key) - model = loaded_model.model - assert isinstance(model, TextualInversionModelRaw) - ti_list.append((name_or_key, model)) - except UnknownModelException: - try: - loaded_model = context.models.load_by_attrs( - model_name=name_or_key, base_model=text_encoder_info.config.base, model_type=ModelType.TextualInversion - ) - model = loaded_model.model - assert isinstance(model, TextualInversionModelRaw) - ti_list.append((name_or_key, model)) - except UnknownModelException: - logger.warning(f'trigger: "{trigger}" not found') - except ValueError: - logger.warning(f'trigger: "{trigger}" more than one similarly-named textual inversion models') + ti_list = generate_ti_list(self.prompt, text_encoder_info.config.base, context) with ( ModelPatcher.apply_ti(tokenizer_model, text_encoder_model, ti_list) as ( @@ -206,26 +187,7 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: # loras = [(context.models.get(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras] - ti_list: List[Tuple[str, TextualInversionModelRaw]] = [] - for trigger in extract_ti_triggers_from_prompt(prompt): - name_or_key = trigger[1:-1] - try: - loaded_model = context.models.load(key=name_or_key) - model = loaded_model.model - assert isinstance(model, TextualInversionModelRaw) - ti_list.append((name_or_key, model)) - except UnknownModelException: - try: - loaded_model = context.models.load_by_attrs( - model_name=name_or_key, base_model=text_encoder_info.config.base, model_type=ModelType.TextualInversion - ) - model = loaded_model.model - assert isinstance(model, TextualInversionModelRaw) - ti_list.append((name_or_key, model)) - except UnknownModelException: - logger.warning(f'trigger: "{trigger}" not found') - except ValueError: - logger.warning(f'trigger: "{trigger}" more than one similarly-named textual inversion models') + ti_list = generate_ti_list(prompt, text_encoder_info.config.base, context) with ( ModelPatcher.apply_ti(tokenizer_model, text_encoder_model, ti_list) as ( diff --git a/invokeai/app/util/ti_utils.py b/invokeai/app/util/ti_utils.py index a66a832b42a..c2645e07021 100644 --- a/invokeai/app/util/ti_utils.py +++ b/invokeai/app/util/ti_utils.py @@ -1,8 +1,44 @@ import re +from typing import List, Tuple +from invokeai.backend.model_manager.config import BaseModelType, ModelType +from invokeai.backend.textual_inversion import TextualInversionModelRaw +from invokeai.app.services.shared.invocation_context import InvocationContext +from invokeai.app.services.model_records import UnknownModelException +import invokeai.backend.util.logging as logger -def extract_ti_triggers_from_prompt(prompt: str) -> list[str]: - ti_triggers = [] + +def extract_ti_triggers_from_prompt(prompt: str) -> List[str]: + ti_triggers: List[str] = [] for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", prompt): - ti_triggers.append(trigger) + ti_triggers.append(str(trigger)) return ti_triggers + +def generate_ti_list(prompt: str, base: BaseModelType, context: InvocationContext) -> List[Tuple[str, TextualInversionModelRaw]]: + ti_list: List[Tuple[str, TextualInversionModelRaw]] = [] + for trigger in extract_ti_triggers_from_prompt(prompt): + name_or_key = trigger[1:-1] + try: + loaded_model = context.models.load(key=name_or_key) + model = loaded_model.model + assert isinstance(model, TextualInversionModelRaw) + assert loaded_model.config.base == base + ti_list.append((name_or_key, model)) + except UnknownModelException: + try: + loaded_model = context.models.load_by_attrs( + model_name=name_or_key, base_model=base, model_type=ModelType.TextualInversion + ) + model = loaded_model.model + assert isinstance(model, TextualInversionModelRaw) + assert loaded_model.config.base == base + ti_list.append((name_or_key, model)) + except UnknownModelException: + pass + except ValueError: + logger.warning(f'trigger: "{trigger}" more than one similarly-named textual inversion models') + except AssertionError: + logger.warning(f'trigger: "{trigger}" not a valid textual inversion model for this graph') + except Exception: + logger.warning(f'Failed to load TI model for trigger: "{trigger}"') + return ti_list \ No newline at end of file From b65f259d8c8d8c449dc22fe30065cd4dc3ac8a06 Mon Sep 17 00:00:00 2001 From: Brandon Rising Date: Tue, 27 Feb 2024 15:20:55 -0500 Subject: [PATCH 4/5] Ruff check --- invokeai/app/invocations/compel.py | 6 +----- invokeai/app/util/ti_utils.py | 8 ++++---- 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index ca6ca644bba..771c811eea0 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -3,9 +3,8 @@ import torch from compel import Compel, ReturnedEmbeddingsType from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment -from transformers import CLIPTokenizer, CLIPTextModel +from transformers import CLIPTextModel, CLIPTokenizer -import invokeai.backend.util.logging as logger from invokeai.app.invocations.fields import ( FieldDescriptions, Input, @@ -14,11 +13,9 @@ UIComponent, ) from invokeai.app.invocations.primitives import ConditioningOutput -from invokeai.app.services.model_records import UnknownModelException from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.util.ti_utils import generate_ti_list from invokeai.backend.lora import LoRAModelRaw -from invokeai.backend.model_manager.config import ModelType from invokeai.backend.model_patcher import ModelPatcher from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( BasicConditioningInfo, @@ -26,7 +23,6 @@ ExtraConditioningInfo, SDXLConditioningInfo, ) -from invokeai.backend.textual_inversion import TextualInversionModelRaw from invokeai.backend.util.devices import torch_dtype from .baseinvocation import ( diff --git a/invokeai/app/util/ti_utils.py b/invokeai/app/util/ti_utils.py index c2645e07021..0d803408fd1 100644 --- a/invokeai/app/util/ti_utils.py +++ b/invokeai/app/util/ti_utils.py @@ -1,11 +1,11 @@ import re from typing import List, Tuple +import invokeai.backend.util.logging as logger +from invokeai.app.services.model_records import UnknownModelException +from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.model_manager.config import BaseModelType, ModelType from invokeai.backend.textual_inversion import TextualInversionModelRaw -from invokeai.app.services.shared.invocation_context import InvocationContext -from invokeai.app.services.model_records import UnknownModelException -import invokeai.backend.util.logging as logger def extract_ti_triggers_from_prompt(prompt: str) -> List[str]: @@ -41,4 +41,4 @@ def generate_ti_list(prompt: str, base: BaseModelType, context: InvocationContex logger.warning(f'trigger: "{trigger}" not a valid textual inversion model for this graph') except Exception: logger.warning(f'Failed to load TI model for trigger: "{trigger}"') - return ti_list \ No newline at end of file + return ti_list From 605dc5a903124a908a3c23b7f9bf25f58c70c56f Mon Sep 17 00:00:00 2001 From: Brandon Rising Date: Tue, 27 Feb 2024 15:24:23 -0500 Subject: [PATCH 5/5] Ruff format --- invokeai/app/util/ti_utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/invokeai/app/util/ti_utils.py b/invokeai/app/util/ti_utils.py index 0d803408fd1..b5c884c9b7a 100644 --- a/invokeai/app/util/ti_utils.py +++ b/invokeai/app/util/ti_utils.py @@ -14,7 +14,10 @@ def extract_ti_triggers_from_prompt(prompt: str) -> List[str]: ti_triggers.append(str(trigger)) return ti_triggers -def generate_ti_list(prompt: str, base: BaseModelType, context: InvocationContext) -> List[Tuple[str, TextualInversionModelRaw]]: + +def generate_ti_list( + prompt: str, base: BaseModelType, context: InvocationContext +) -> List[Tuple[str, TextualInversionModelRaw]]: ti_list: List[Tuple[str, TextualInversionModelRaw]] = [] for trigger in extract_ti_triggers_from_prompt(prompt): name_or_key = trigger[1:-1]