Skip to content

Commit

Permalink
Don't zero out conditioning for SD1.5
Browse files Browse the repository at this point in the history
- it does not like it.
  • Loading branch information
Acly committed Mar 11, 2025
1 parent 6a80e3b commit 8959a73
Showing 1 changed file with 19 additions and 12 deletions.
31 changes: 19 additions & 12 deletions ai_diffusion/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,8 @@ def load_checkpoint_with_lora(w: ComfyWorkflow, checkpoint: CheckpointInput, mod
if arch.supports_attention_guidance and checkpoint.self_attention_guidance:
model = w.apply_self_attention_guidance(model)

return model, Clip(clip, arch), vae


def vae_decode(w: ComfyWorkflow, vae: Output, latent: Output, tiled: bool):
if tiled:
Expand Down Expand Up @@ -253,26 +255,31 @@ def from_input(i: ControlInput):
return Control(i.mode, ImageOutput(i.image), None, i.strength, i.range)


class Clip(NamedTuple):
model: Output
arch: Arch


class TextPrompt:
text: str
language: str
# Cached values to avoid re-encoding the same text for multiple regions and passes
_output: Output | None = None
_clip: Output | None = None # can be different due to Lora hooks
_clip: Clip | None = None # can be different due to Lora hooks

def __init__(self, text: str, language: str):
self.text = text
self.language = language

def encode(self, w: ComfyWorkflow, clip: Output, style_prompt: str | None = None):
def encode(self, w: ComfyWorkflow, clip: Clip, style_prompt: str | None = None):
text = self.text
if text != "" and style_prompt:
text = merge_prompt(text, style_prompt, self.language)
if self._output is None or self._clip != clip:
if text and self.language:
text = w.translate(text)
self._output = w.clip_text_encode(clip, text)
if text == "":
self._output = w.clip_text_encode(clip.model, text)
if text == "" and clip.arch is not Arch.sd15:
self._output = w.conditioning_zero_out(self._output)
self._clip = clip
return self._output
Expand All @@ -286,7 +293,7 @@ class Region:
control: list[Control] = field(default_factory=list)
loras: list[LoraInput] = field(default_factory=list)
is_background: bool = False
clip: Output | None = None
clip: Clip | None = None

@staticmethod
def from_input(i: RegionInput, index: int, language: str):
Expand All @@ -301,15 +308,15 @@ def from_input(i: RegionInput, index: int, language: str):
is_background=index == 0,
)

def patch_clip(self, w: ComfyWorkflow, clip: Output):
def patch_clip(self, w: ComfyWorkflow, clip: Clip):
if self.clip is None:
self.clip = clip
if len(self.loras) > 0:
hooks = w.create_hook_lora([(lora.name, lora.strength) for lora in self.loras])
self.clip = w.set_clip_hooks(clip, hooks)
self.clip = Clip(w.set_clip_hooks(clip.model, hooks), clip.arch)
return self.clip

def encode_prompt(self, w: ComfyWorkflow, clip: Output, style_prompt: str | None = None):
def encode_prompt(self, w: ComfyWorkflow, clip: Clip, style_prompt: str | None = None):
return self.positive.encode(w, self.patch_clip(w, clip), style_prompt)

def copy(self):
Expand Down Expand Up @@ -384,7 +391,7 @@ def downscale_all_control_images(cond: ConditioningInput, original: Extent, targ
def encode_text_prompt(
w: ComfyWorkflow,
cond: Conditioning,
clip: Output,
clip: Clip,
regions: Output | None,
):
if len(cond.regions) <= 1 or all(len(r.loras) == 0 for r in cond.regions):
Expand Down Expand Up @@ -413,7 +420,7 @@ def apply_attention_mask(
w: ComfyWorkflow,
model: Output,
cond: Conditioning,
clip: Output,
clip: Clip,
shape: Extent | ImageReshape = no_reshape,
):
if len(cond.regions) == 0:
Expand Down Expand Up @@ -643,7 +650,7 @@ def scale_refine_and_decode(
sampling: SamplingInput,
latent: Output,
model: Output,
clip: Output,
clip: Clip,
vae: Output,
models: ModelDict,
tiled_vae: bool,
Expand Down Expand Up @@ -1240,7 +1247,7 @@ def get_param(node: ComfyNode, expected_type: type | tuple[type, type] | None =
sampling = _sampling_from_style(style, 1.0, is_live)
model, clip, vae = load_checkpoint_with_lora(w, checkpoint_input, models)
outputs[node.output(0)] = model
outputs[node.output(1)] = clip
outputs[node.output(1)] = clip.model
outputs[node.output(2)] = vae
outputs[node.output(3)] = style.style_prompt
outputs[node.output(4)] = style.negative_prompt
Expand Down

0 comments on commit 8959a73

Please sign in to comment.