Skip to content

Commit 3507870

Browse files
authored
Add 'sigmas' to transformer_options so that downstream code can know about the full scope of current sampling run, fix Hook Keyframes' guarantee_steps=1 inconsistent behavior with sampling split across different Sampling nodes/sampling runs by referencing 'sigmas' (#6273)
1 parent 82ecb02 commit 3507870

File tree

3 files changed

+26
-10
lines changed

3 files changed

+26
-10
lines changed

comfy/hooks.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -366,9 +366,15 @@ def __init__(self, strength: float, start_percent=0.0, guarantee_steps=1):
366366
self.start_t = 999999999.9
367367
self.guarantee_steps = guarantee_steps
368368

369+
def get_effective_guarantee_steps(self, max_sigma: torch.Tensor):
370+
'''If keyframe starts before current sampling range (max_sigma), treat as 0.'''
371+
if self.start_t > max_sigma:
372+
return 0
373+
return self.guarantee_steps
374+
369375
def clone(self):
370376
c = HookKeyframe(strength=self.strength,
371-
start_percent=self.start_percent, guarantee_steps=self.guarantee_steps)
377+
start_percent=self.start_percent, guarantee_steps=self.guarantee_steps)
372378
c.start_t = self.start_t
373379
return c
374380

@@ -408,6 +414,12 @@ def _set_first_as_current(self):
408414
else:
409415
self._current_keyframe = None
410416

417+
def has_guarantee_steps(self):
418+
for kf in self.keyframes:
419+
if kf.guarantee_steps > 0:
420+
return True
421+
return False
422+
411423
def has_index(self, index: int):
412424
return index >= 0 and index < len(self.keyframes)
413425

@@ -425,15 +437,16 @@ def initialize_timesteps(self, model: 'BaseModel'):
425437
for keyframe in self.keyframes:
426438
keyframe.start_t = model.model_sampling.percent_to_sigma(keyframe.start_percent)
427439

428-
def prepare_current_keyframe(self, curr_t: float) -> bool:
440+
def prepare_current_keyframe(self, curr_t: float, transformer_options: dict[str, torch.Tensor]) -> bool:
429441
if self.is_empty():
430442
return False
431443
if curr_t == self._curr_t:
432444
return False
445+
max_sigma = torch.max(transformer_options["sigmas"])
433446
prev_index = self._current_index
434447
prev_strength = self._current_strength
435448
# if met guaranteed steps, look for next keyframe in case need to switch
436-
if self._current_used_steps >= self._current_keyframe.guarantee_steps:
449+
if self._current_used_steps >= self._current_keyframe.get_effective_guarantee_steps(max_sigma):
437450
# if has next index, loop through and see if need to switch
438451
if self.has_index(self._current_index+1):
439452
for i in range(self._current_index+1, len(self.keyframes)):
@@ -446,7 +459,7 @@ def prepare_current_keyframe(self, curr_t: float) -> bool:
446459
self._current_keyframe = eval_c
447460
self._current_used_steps = 0
448461
# if guarantee_steps greater than zero, stop searching for other keyframes
449-
if self._current_keyframe.guarantee_steps > 0:
462+
if self._current_keyframe.get_effective_guarantee_steps(max_sigma) > 0:
450463
break
451464
# if eval_c is outside the percent range, stop looking further
452465
else: break

comfy/model_patcher.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -919,11 +919,12 @@ def restore_hook_patches(self):
919919
def set_hook_mode(self, hook_mode: comfy.hooks.EnumHookMode):
920920
self.hook_mode = hook_mode
921921

922-
def prepare_hook_patches_current_keyframe(self, t: torch.Tensor, hook_group: comfy.hooks.HookGroup):
922+
def prepare_hook_patches_current_keyframe(self, t: torch.Tensor, hook_group: comfy.hooks.HookGroup, model_options: dict[str]):
923923
curr_t = t[0]
924924
reset_current_hooks = False
925+
transformer_options = model_options.get("transformer_options", {})
925926
for hook in hook_group.hooks:
926-
changed = hook.hook_keyframe.prepare_current_keyframe(curr_t=curr_t)
927+
changed = hook.hook_keyframe.prepare_current_keyframe(curr_t=curr_t, transformer_options=transformer_options)
927928
# if keyframe changed, remove any cached HookGroups that contain hook with the same hook_ref;
928929
# this will cause the weights to be recalculated when sampling
929930
if changed:

comfy/samplers.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def cond_cat(c_list):
144144

145145
return out
146146

147-
def finalize_default_conds(model: 'BaseModel', hooked_to_run: dict[comfy.hooks.HookGroup,list[tuple[tuple,int]]], default_conds: list[list[dict]], x_in, timestep):
147+
def finalize_default_conds(model: 'BaseModel', hooked_to_run: dict[comfy.hooks.HookGroup,list[tuple[tuple,int]]], default_conds: list[list[dict]], x_in, timestep, model_options):
148148
# need to figure out remaining unmasked area for conds
149149
default_mults = []
150150
for _ in default_conds:
@@ -183,7 +183,7 @@ def finalize_default_conds(model: 'BaseModel', hooked_to_run: dict[comfy.hooks.H
183183
# replace p's mult with calculated mult
184184
p = p._replace(mult=mult)
185185
if p.hooks is not None:
186-
model.current_patcher.prepare_hook_patches_current_keyframe(timestep, p.hooks)
186+
model.current_patcher.prepare_hook_patches_current_keyframe(timestep, p.hooks, model_options)
187187
hooked_to_run.setdefault(p.hooks, list())
188188
hooked_to_run[p.hooks] += [(p, i)]
189189

@@ -218,7 +218,7 @@ def _calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Te
218218
if p is None:
219219
continue
220220
if p.hooks is not None:
221-
model.current_patcher.prepare_hook_patches_current_keyframe(timestep, p.hooks)
221+
model.current_patcher.prepare_hook_patches_current_keyframe(timestep, p.hooks, model_options)
222222
hooked_to_run.setdefault(p.hooks, list())
223223
hooked_to_run[p.hooks] += [(p, i)]
224224
default_conds.append(default_c)
@@ -840,7 +840,9 @@ def inner_sample(self, noise, latent_image, device, sampler, sigmas, denoise_mas
840840

841841
self.conds = process_conds(self.inner_model, noise, self.conds, device, latent_image, denoise_mask, seed)
842842

843-
extra_args = {"model_options": comfy.model_patcher.create_model_options_clone(self.model_options), "seed": seed}
843+
extra_model_options = comfy.model_patcher.create_model_options_clone(self.model_options)
844+
extra_model_options.setdefault("transformer_options", {})["sigmas"] = sigmas
845+
extra_args = {"model_options": extra_model_options, "seed": seed}
844846

845847
executor = comfy.patcher_extension.WrapperExecutor.new_class_executor(
846848
sampler.sample,

0 commit comments

Comments
 (0)