Skip to content

Commit 6431296

Browse files
committed
Fix Cascade empty prompt encode
1 parent e1c4038 commit 6431296

File tree

1 file changed

+11
-4
lines changed

1 file changed

+11
-4
lines changed

modules/prompt_parser_diffusers.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,8 @@ def prepare_embedding_providers(pipe, clip_skip) -> list[EmbeddingsProvider]:
253253
if getattr(pipe, "prior_pipe", None) is not None and getattr(pipe.prior_pipe, "tokenizer", None) is not None and getattr(pipe.prior_pipe, "text_encoder", None) is not None:
254254
provider = EmbeddingsProvider(padding_attention_mask_value=0, tokenizer=pipe.prior_pipe.tokenizer, text_encoder=pipe.prior_pipe.text_encoder, truncate=False, returned_embeddings_type=embedding_type, device=device)
255255
embeddings_providers.append(provider)
256+
no_mask_provider = EmbeddingsProvider(padding_attention_mask_value=1, tokenizer=pipe.prior_pipe.tokenizer, text_encoder=pipe.prior_pipe.text_encoder, truncate=False, returned_embeddings_type=embedding_type, device=device)
257+
embeddings_providers.append(no_mask_provider)
256258
elif getattr(pipe, "tokenizer", None) is not None and getattr(pipe, "text_encoder", None) is not None:
257259
provider = EmbeddingsProvider(tokenizer=pipe.tokenizer, text_encoder=pipe.text_encoder, truncate=False, returned_embeddings_type=embedding_type, device=device)
258260
embeddings_providers.append(provider)
@@ -262,7 +264,7 @@ def prepare_embedding_providers(pipe, clip_skip) -> list[EmbeddingsProvider]:
262264
return embeddings_providers
263265

264266

265-
def pad_to_same_length(pipe, embeds):
267+
def pad_to_same_length(pipe, embeds, empty_embedding_providers=None):
266268
if not hasattr(pipe, 'encode_prompt') and 'StableCascade' not in pipe.__class__.__name__:
267269
return embeds
268270
device = pipe.device if str(pipe.device) != 'meta' else devices.device
@@ -271,8 +273,8 @@ def pad_to_same_length(pipe, embeds):
271273
else:
272274
try:
273275
if 'StableCascade' in pipe.__class__.__name__:
274-
empty_embed = pipe.prior_pipe.encode_prompt(device, 1, 1, False, prompt="")
275-
empty_embed = [torch.nn.functional.normalize(empty_embed[0])]
276+
empty_embed = empty_embedding_providers[0].get_embeddings_for_weighted_prompt_fragments(text_batch=[[""]], fragment_weights_batch=[[1]], should_return_tokens=False, device=device)
277+
empty_embed = [empty_embed]
276278
else:
277279
empty_embed = pipe.encode_prompt("")
278280
except TypeError: # SD1.5
@@ -331,6 +333,11 @@ def get_weighted_text_embeddings(pipe, prompt: str = "", neg_prompt: str = "", c
331333
negative_weights.pop(0)
332334

333335
embedding_providers = prepare_embedding_providers(pipe, clip_skip)
336+
empty_embedding_providers = None
337+
if 'StableCascade' in pipe.__class__.__name__:
338+
empty_embedding_providers = [embedding_providers[1]]
339+
embedding_providers = [embedding_providers[0]]
340+
334341
prompt_embeds = []
335342
negative_prompt_embeds = []
336343
pooled_prompt_embeds = []
@@ -400,7 +407,7 @@ def get_weighted_text_embeddings(pipe, prompt: str = "", neg_prompt: str = "", c
400407
negative_pooled_prompt_embeds = None
401408
debug(f'Prompt: positive={prompt_embeds.shape if prompt_embeds is not None else None} pooled={pooled_prompt_embeds.shape if pooled_prompt_embeds is not None else None} negative={negative_prompt_embeds.shape if negative_prompt_embeds is not None else None} pooled={negative_pooled_prompt_embeds.shape if negative_pooled_prompt_embeds is not None else None}')
402409
if prompt_embeds.shape[1] != negative_prompt_embeds.shape[1]:
403-
[prompt_embeds, negative_prompt_embeds] = pad_to_same_length(pipe, [prompt_embeds, negative_prompt_embeds])
410+
[prompt_embeds, negative_prompt_embeds] = pad_to_same_length(pipe, [prompt_embeds, negative_prompt_embeds], empty_embedding_providers=empty_embedding_providers)
404411
if SD3:
405412
device = pipe.device if str(pipe.device) != 'meta' else devices.device
406413
t5_prompt_embed = pipe._get_t5_prompt_embeds( # pylint: disable=protected-access

0 commit comments

Comments
 (0)