@@ -366,9 +366,15 @@ def __init__(self, strength: float, start_percent=0.0, guarantee_steps=1):
366
366
self .start_t = 999999999.9
367
367
self .guarantee_steps = guarantee_steps
368
368
369
+ def get_effective_guarantee_steps (self , max_sigma : torch .Tensor ):
370
+ '''If keyframe starts before current sampling range (max_sigma), treat as 0.'''
371
+ if self .start_t > max_sigma :
372
+ return 0
373
+ return self .guarantee_steps
374
+
369
375
def clone (self ):
370
376
c = HookKeyframe (strength = self .strength ,
371
- start_percent = self .start_percent , guarantee_steps = self .guarantee_steps )
377
+ start_percent = self .start_percent , guarantee_steps = self .guarantee_steps )
372
378
c .start_t = self .start_t
373
379
return c
374
380
@@ -408,6 +414,12 @@ def _set_first_as_current(self):
408
414
else :
409
415
self ._current_keyframe = None
410
416
417
+ def has_guarantee_steps (self ):
418
+ for kf in self .keyframes :
419
+ if kf .guarantee_steps > 0 :
420
+ return True
421
+ return False
422
+
411
423
def has_index (self , index : int ):
412
424
return index >= 0 and index < len (self .keyframes )
413
425
@@ -425,15 +437,16 @@ def initialize_timesteps(self, model: 'BaseModel'):
425
437
for keyframe in self .keyframes :
426
438
keyframe .start_t = model .model_sampling .percent_to_sigma (keyframe .start_percent )
427
439
428
- def prepare_current_keyframe (self , curr_t : float ) -> bool :
440
+ def prepare_current_keyframe (self , curr_t : float , transformer_options : dict [ str , torch . Tensor ] ) -> bool :
429
441
if self .is_empty ():
430
442
return False
431
443
if curr_t == self ._curr_t :
432
444
return False
445
+ max_sigma = torch .max (transformer_options ["sigmas" ])
433
446
prev_index = self ._current_index
434
447
prev_strength = self ._current_strength
435
448
# if met guaranteed steps, look for next keyframe in case need to switch
436
- if self ._current_used_steps >= self ._current_keyframe .guarantee_steps :
449
+ if self ._current_used_steps >= self ._current_keyframe .get_effective_guarantee_steps ( max_sigma ) :
437
450
# if has next index, loop through and see if need to switch
438
451
if self .has_index (self ._current_index + 1 ):
439
452
for i in range (self ._current_index + 1 , len (self .keyframes )):
@@ -446,7 +459,7 @@ def prepare_current_keyframe(self, curr_t: float) -> bool:
446
459
self ._current_keyframe = eval_c
447
460
self ._current_used_steps = 0
448
461
# if guarantee_steps greater than zero, stop searching for other keyframes
449
- if self ._current_keyframe .guarantee_steps > 0 :
462
+ if self ._current_keyframe .get_effective_guarantee_steps ( max_sigma ) > 0 :
450
463
break
451
464
# if eval_c is outside the percent range, stop looking further
452
465
else : break
0 commit comments