Skip to content

Commit 8cf1daa

Browse files
Fix SDXL area composition sometimes not using the right pooled output.
1 parent d2f3229 commit 8cf1daa

File tree

2 files changed

+14
-3
lines changed

2 files changed

+14
-3
lines changed

comfy/model_base.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,9 +126,15 @@ def blank_inpaint_image_like(latent_image):
126126
cond_concat.append(blank_inpaint_image_like(noise))
127127
data = torch.cat(cond_concat, dim=1)
128128
out['c_concat'] = comfy.conds.CONDNoiseShape(data)
129+
129130
adm = self.encode_adm(**kwargs)
130131
if adm is not None:
131132
out['y'] = comfy.conds.CONDRegular(adm)
133+
134+
cross_attn = kwargs.get("cross_attn", None)
135+
if cross_attn is not None:
136+
out['c_crossattn'] = comfy.conds.CONDCrossAttn(cross_attn)
137+
132138
return out
133139

134140
def load_model_weights(self, sd, unet_prefix=""):
@@ -322,6 +328,10 @@ def extra_conds(self, **kwargs):
322328

323329
out['c_concat'] = comfy.conds.CONDNoiseShape(latent_image)
324330

331+
cross_attn = kwargs.get("cross_attn", None)
332+
if cross_attn is not None:
333+
out['c_crossattn'] = comfy.conds.CONDCrossAttn(cross_attn)
334+
325335
if "time_conditioning" in kwargs:
326336
out["time_context"] = comfy.conds.CONDCrossAttn(kwargs["time_conditioning"])
327337

comfy/samplers.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -599,6 +599,10 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model
599599
calculate_start_end_timesteps(model, negative)
600600
calculate_start_end_timesteps(model, positive)
601601

602+
if hasattr(model, 'extra_conds'):
603+
positive = encode_model_conds(model.extra_conds, positive, noise, device, "positive", latent_image=latent_image, denoise_mask=denoise_mask)
604+
negative = encode_model_conds(model.extra_conds, negative, noise, device, "negative", latent_image=latent_image, denoise_mask=denoise_mask)
605+
602606
#make sure each cond area has an opposite one with the same area
603607
for c in positive:
604608
create_cond_with_same_area_if_none(negative, c)
@@ -613,9 +617,6 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model
613617
if latent_image is not None:
614618
latent_image = model.process_latent_in(latent_image)
615619

616-
if hasattr(model, 'extra_conds'):
617-
positive = encode_model_conds(model.extra_conds, positive, noise, device, "positive", latent_image=latent_image, denoise_mask=denoise_mask)
618-
negative = encode_model_conds(model.extra_conds, negative, noise, device, "negative", latent_image=latent_image, denoise_mask=denoise_mask)
619620

620621
extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": model_options, "seed":seed}
621622

0 commit comments

Comments
 (0)