@@ -253,6 +253,8 @@ def prepare_embedding_providers(pipe, clip_skip) -> list[EmbeddingsProvider]:
253
253
if getattr (pipe , "prior_pipe" , None ) is not None and getattr (pipe .prior_pipe , "tokenizer" , None ) is not None and getattr (pipe .prior_pipe , "text_encoder" , None ) is not None :
254
254
provider = EmbeddingsProvider (padding_attention_mask_value = 0 , tokenizer = pipe .prior_pipe .tokenizer , text_encoder = pipe .prior_pipe .text_encoder , truncate = False , returned_embeddings_type = embedding_type , device = device )
255
255
embeddings_providers .append (provider )
256
+ no_mask_provider = EmbeddingsProvider (padding_attention_mask_value = 1 , tokenizer = pipe .prior_pipe .tokenizer , text_encoder = pipe .prior_pipe .text_encoder , truncate = False , returned_embeddings_type = embedding_type , device = device )
257
+ embeddings_providers .append (no_mask_provider )
256
258
elif getattr (pipe , "tokenizer" , None ) is not None and getattr (pipe , "text_encoder" , None ) is not None :
257
259
provider = EmbeddingsProvider (tokenizer = pipe .tokenizer , text_encoder = pipe .text_encoder , truncate = False , returned_embeddings_type = embedding_type , device = device )
258
260
embeddings_providers .append (provider )
@@ -262,7 +264,7 @@ def prepare_embedding_providers(pipe, clip_skip) -> list[EmbeddingsProvider]:
262
264
return embeddings_providers
263
265
264
266
265
- def pad_to_same_length (pipe , embeds ):
267
+ def pad_to_same_length (pipe , embeds , empty_embedding_providers = None ):
266
268
if not hasattr (pipe , 'encode_prompt' ) and 'StableCascade' not in pipe .__class__ .__name__ :
267
269
return embeds
268
270
device = pipe .device if str (pipe .device ) != 'meta' else devices .device
@@ -271,8 +273,8 @@ def pad_to_same_length(pipe, embeds):
271
273
else :
272
274
try :
273
275
if 'StableCascade' in pipe .__class__ .__name__ :
274
- empty_embed = pipe . prior_pipe . encode_prompt ( device , 1 , 1 , False , prompt = "" )
275
- empty_embed = [torch . nn . functional . normalize ( empty_embed [ 0 ]) ]
276
+ empty_embed = empty_embedding_providers [ 0 ]. get_embeddings_for_weighted_prompt_fragments ( text_batch = [[ "" ]], fragment_weights_batch = [[ 1 ]], should_return_tokens = False , device = device )
277
+ empty_embed = [empty_embed ]
276
278
else :
277
279
empty_embed = pipe .encode_prompt ("" )
278
280
except TypeError : # SD1.5
@@ -331,6 +333,11 @@ def get_weighted_text_embeddings(pipe, prompt: str = "", neg_prompt: str = "", c
331
333
negative_weights .pop (0 )
332
334
333
335
embedding_providers = prepare_embedding_providers (pipe , clip_skip )
336
+ empty_embedding_providers = None
337
+ if 'StableCascade' in pipe .__class__ .__name__ :
338
+ empty_embedding_providers = [embedding_providers [1 ]]
339
+ embedding_providers = [embedding_providers [0 ]]
340
+
334
341
prompt_embeds = []
335
342
negative_prompt_embeds = []
336
343
pooled_prompt_embeds = []
@@ -400,7 +407,7 @@ def get_weighted_text_embeddings(pipe, prompt: str = "", neg_prompt: str = "", c
400
407
negative_pooled_prompt_embeds = None
401
408
debug (f'Prompt: positive={ prompt_embeds .shape if prompt_embeds is not None else None } pooled={ pooled_prompt_embeds .shape if pooled_prompt_embeds is not None else None } negative={ negative_prompt_embeds .shape if negative_prompt_embeds is not None else None } pooled={ negative_pooled_prompt_embeds .shape if negative_pooled_prompt_embeds is not None else None } ' )
402
409
if prompt_embeds .shape [1 ] != negative_prompt_embeds .shape [1 ]:
403
- [prompt_embeds , negative_prompt_embeds ] = pad_to_same_length (pipe , [prompt_embeds , negative_prompt_embeds ])
410
+ [prompt_embeds , negative_prompt_embeds ] = pad_to_same_length (pipe , [prompt_embeds , negative_prompt_embeds ], empty_embedding_providers = empty_embedding_providers )
404
411
if SD3 :
405
412
device = pipe .device if str (pipe .device ) != 'meta' else devices .device
406
413
t5_prompt_embed = pipe ._get_t5_prompt_embeds ( # pylint: disable=protected-access
0 commit comments