11import  copy 
22from  time  import  sleep 
3- from  typing  import  List ,  Optional , Literal , Dict 
3+ from  typing  import  Dict ,  List , Literal , Optional 
44
55from  pydantic  import  BaseModel , Field 
66
1212    invocation_output ,
1313)
1414from  invokeai .app .invocations .fields  import  FieldDescriptions , Input , InputField , OutputField , UIType 
15+ from  invokeai .app .services .model_records  import  ModelRecordChanges 
1516from  invokeai .app .services .shared .invocation_context  import  InvocationContext 
1617from  invokeai .app .shared .models  import  FreeUConfig 
17- from  invokeai .app .services .model_records  import  ModelRecordChanges 
18- from  invokeai .backend .model_manager .config  import  AnyModelConfig , BaseModelType , ModelType , SubModelType , ModelFormat 
18+ from  invokeai .backend .model_manager .config  import  AnyModelConfig , BaseModelType , ModelFormat , ModelType , SubModelType 
1919
2020
2121class  ModelIdentifierField (BaseModel ):
@@ -132,31 +132,22 @@ def invoke(self, context: InvocationContext) -> ModelIdentifierOutput:
132132
133133        return  ModelIdentifierOutput (model = self .model )
134134
135- T5_ENCODER_OPTIONS  =  Literal ["base" , "16b_quantized" , "8b_quantized" ]
135+ 
136+ T5_ENCODER_OPTIONS  =  Literal ["base" , "8b_quantized" ]
136137T5_ENCODER_MAP : Dict [str , Dict [str , str ]] =  {
137138    "base" : {
138-         "text_encoder_repo" : "black-forest-labs/FLUX.1-schnell::text_encoder_2" ,
139-         "tokenizer_repo" : "black-forest-labs/FLUX.1-schnell::tokenizer_2" ,
140-         "text_encoder_name" : "FLUX.1-schnell_text_encoder_2" ,
141-         "tokenizer_name" : "FLUX.1-schnell_tokenizer_2" ,
139+         "repo" : "invokeai/flux_dev::t5_xxl_encoder/base" ,
140+         "name" : "t5_base_encoder" ,
142141        "format" : ModelFormat .T5Encoder ,
143142    },
144143    "8b_quantized" : {
145-         "text_encoder_repo" : "hf_repo1" ,
146-         "tokenizer_repo" : "hf_repo1" ,
147-         "text_encoder_name" : "hf_repo1" ,
148-         "tokenizer_name" : "hf_repo1" ,
149-         "format" : ModelFormat .T5Encoder8b ,
150-     },
151-     "4b_quantized" : {
152-         "text_encoder_repo" : "hf_repo2" ,
153-         "tokenizer_repo" : "hf_repo2" ,
154-         "text_encoder_name" : "hf_repo2" ,
155-         "tokenizer_name" : "hf_repo2" ,
156-         "format" : ModelFormat .T5Encoder8b ,
144+         "repo" : "invokeai/flux_dev::t5_xxl_encoder/8b_quantized" ,
145+         "name" : "t5_8b_quantized_encoder" ,
146+         "format" : ModelFormat .T5Encoder ,
157147    },
158148}
159149
150+ 
160151@invocation_output ("flux_model_loader_output" ) 
161152class  FluxModelLoaderOutput (BaseInvocationOutput ):
162153    """Flux base model loader output""" 
@@ -176,7 +167,7 @@ class FluxModelLoaderInvocation(BaseInvocation):
176167        ui_type = UIType .FluxMainModel ,
177168        input = Input .Direct ,
178169    )
179-      
170+ 
180171    t5_encoder : T5_ENCODER_OPTIONS  =  InputField (description = "The T5 Encoder model to use." )
181172
182173    def  invoke (self , context : InvocationContext ) ->  FluxModelLoaderOutput :
@@ -189,7 +180,15 @@ def invoke(self, context: InvocationContext) -> FluxModelLoaderOutput:
189180        tokenizer2  =  self ._get_model (context , SubModelType .Tokenizer2 )
190181        clip_encoder  =  self ._get_model (context , SubModelType .TextEncoder )
191182        t5_encoder  =  self ._get_model (context , SubModelType .TextEncoder2 )
192-         vae  =  self ._install_model (context , SubModelType .VAE , "FLUX.1-schnell_ae" , "black-forest-labs/FLUX.1-schnell::ae.safetensors" , ModelFormat .Checkpoint , ModelType .VAE , BaseModelType .Flux )
183+         vae  =  self ._install_model (
184+             context ,
185+             SubModelType .VAE ,
186+             "FLUX.1-schnell_ae" ,
187+             "black-forest-labs/FLUX.1-schnell::ae.safetensors" ,
188+             ModelFormat .Checkpoint ,
189+             ModelType .VAE ,
190+             BaseModelType .Flux ,
191+         )
193192
194193        return  FluxModelLoaderOutput (
195194            transformer = TransformerField (transformer = transformer ),
@@ -198,33 +197,59 @@ def invoke(self, context: InvocationContext) -> FluxModelLoaderOutput:
198197            vae = VAEField (vae = vae ),
199198        )
200199
201-     def  _get_model (self , context : InvocationContext , submodel :SubModelType ) ->  ModelIdentifierField :
202-         match ( submodel ) :
200+     def  _get_model (self , context : InvocationContext , submodel :  SubModelType ) ->  ModelIdentifierField :
201+         match   submodel :
203202            case  SubModelType .Transformer :
204203                return  self .model .model_copy (update = {"submodel_type" : SubModelType .Transformer })
205204            case  submodel  if  submodel  in  [SubModelType .Tokenizer , SubModelType .TextEncoder ]:
206-                 return  self ._install_model (context , submodel , "clip-vit-large-patch14" , "openai/clip-vit-large-patch14" , ModelFormat .Diffusers , ModelType .CLIPEmbed , BaseModelType .Any )
207-             case  SubModelType .TextEncoder2 :
208-                 return  self ._install_model (context , submodel , T5_ENCODER_MAP [self .t5_encoder ]["text_encoder_name" ], T5_ENCODER_MAP [self .t5_encoder ]["text_encoder_repo" ], ModelFormat (T5_ENCODER_MAP [self .t5_encoder ]["format" ]), ModelType .T5Encoder , BaseModelType .Any )
209-             case  SubModelType .Tokenizer2 :
210-                 return  self ._install_model (context , submodel , T5_ENCODER_MAP [self .t5_encoder ]["tokenizer_name" ], T5_ENCODER_MAP [self .t5_encoder ]["tokenizer_repo" ], ModelFormat (T5_ENCODER_MAP [self .t5_encoder ]["format" ]), ModelType .T5Encoder , BaseModelType .Any )
205+                 return  self ._install_model (
206+                     context ,
207+                     submodel ,
208+                     "clip-vit-large-patch14" ,
209+                     "openai/clip-vit-large-patch14" ,
210+                     ModelFormat .Diffusers ,
211+                     ModelType .CLIPEmbed ,
212+                     BaseModelType .Any ,
213+                 )
214+             case  submodel  if  submodel  in  [SubModelType .Tokenizer2 , SubModelType .TextEncoder2 ]:
215+                 return  self ._install_model (
216+                     context ,
217+                     submodel ,
218+                     T5_ENCODER_MAP [self .t5_encoder ]["name" ],
219+                     T5_ENCODER_MAP [self .t5_encoder ]["repo" ],
220+                     ModelFormat (T5_ENCODER_MAP [self .t5_encoder ]["format" ]),
221+                     ModelType .T5Encoder ,
222+                     BaseModelType .Any ,
223+                 )
211224            case  _:
212-                 raise  Exception (f"{ submodel .value }  )  
213- 
214-     def  _install_model (self , context : InvocationContext , submodel :SubModelType , name : str , repo_id : str , format : ModelFormat , type : ModelType , base : BaseModelType ):
215-         if  (models  :=  context .models .search_by_attrs (name = name , base = base , type = type )):
225+                 raise  Exception (f"{ submodel .value }  )
226+ 
227+     def  _install_model (
228+         self ,
229+         context : InvocationContext ,
230+         submodel : SubModelType ,
231+         name : str ,
232+         repo_id : str ,
233+         format : ModelFormat ,
234+         type : ModelType ,
235+         base : BaseModelType ,
236+     ):
237+         if  models  :=  context .models .search_by_attrs (name = name , base = base , type = type ):
216238            if  len (models ) !=  1 :
217239                raise  Exception (f"Multiple models detected for selected model with name { name }  )
218240            return  ModelIdentifierField .from_config (models [0 ]).model_copy (update = {"submodel_type" : submodel })
219241        else :
220242            model_path  =  context .models .download_and_cache_model (repo_id )
221-             config  =  ModelRecordChanges (name   =   name , base   =   base , type = type , format = format )
243+             config  =  ModelRecordChanges (name = name , base = base , type = type , format = format )
222244            model_install_job  =  context .models .import_local_model (model_path = model_path , config = config )
223245            while  not  model_install_job .in_terminal_state :
224246                sleep (0.01 )
225247            if  not  model_install_job .config_out :
226248                raise  Exception (f"Failed to install { name }  )
227-             return  ModelIdentifierField .from_config (model_install_job .config_out ).model_copy (update = {"submodel_type" : submodel })
249+             return  ModelIdentifierField .from_config (model_install_job .config_out ).model_copy (
250+                 update = {"submodel_type" : submodel }
251+             )
252+ 
228253
229254@invocation ( 
230255    "main_model_loader" , 
0 commit comments