Skip to content

Commit

Permalink
Add callback priority
Browse files Browse the repository at this point in the history
  • Loading branch information
pamparamm committed Sep 1, 2024
1 parent 696f411 commit 2c03c8b
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 13 deletions.
4 changes: 2 additions & 2 deletions __init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,5 @@
cb_manager = CallbackManager()
cb_manager.hijack_samplers()

cb_manager.register_callback(VectorscopeCC.PARAMS_NAME, VectorscopeCC.callback)
cb_manager.register_callback(DiffusionCG.PARAMS_NAME, DiffusionCG.callback)
cb_manager.register_callback(VectorscopeCC.PARAMS_NAME, VectorscopeCC.callback, 210)
cb_manager.register_callback(DiffusionCG.PARAMS_NAME, DiffusionCG.callback, 211)
26 changes: 17 additions & 9 deletions callback_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

class CallbackManager:
def __init__(self):
self.callbacks: dict[str, Callable[[tuple], Callable]] = {}
self.callbacks: dict[str, tuple[Callable[[tuple], Callable], int]] = {}

def hijack_samplers(self):
if not hasattr(comfy.sample, "sample_original"):
Expand All @@ -16,25 +16,33 @@ def hijack_samplers(self):
comfy.sample.sample_custom_original = comfy.sample.sample_custom
comfy.sample.sample_custom = self.sample_wrapper(comfy.sample.sample_custom_original)

def register_callback(self, params_name: str, callback_func: Callable[[tuple], Callable]):
self.callbacks[params_name] = callback_func
def register_callback(self, params_name: str, callback_func: Callable[[tuple], Callable], priority: int):
self.callbacks[params_name] = callback_func, priority

def sample_wrapper(self, original_sample: Callable):
def sample(*args, **kwargs):
model = args[0]

original_callback = kwargs["callback"]
original_cb = kwargs["callback"]
original_cb_priority = 1000

callbacks = []
for params_name, cb in self.callbacks.items():

def add_cb(cb, priority):
if cb is not None:
callbacks.append((priority, cb))

for params_name, (cb_wrapper, priority) in self.callbacks.items():
params = model.model_options.get(params_name, None)
if params:
callbacks.append(cb(params))
callbacks.append(original_callback)
callbacks = [cb for cb in callbacks if cb is not None]
cb = cb_wrapper(params)
add_cb(cb, priority)
add_cb(original_cb, original_cb_priority)

callbacks.sort()

def callback(step: int, x0: torch.Tensor, x: torch.Tensor, total_steps: int):
for cb in callbacks:
for _, cb in callbacks:
cb(step, x0, x, total_steps)

kwargs["callback"] = callback
Expand Down
2 changes: 1 addition & 1 deletion cg_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,6 @@ def _callback(step: int, x0: Tensor, x: Tensor, total_steps: int):
latent[b][c] += -target[b][c].mean() * recenter_strength

if normalize:
latent[b][c] = normalize_tensor(target[b][c], dynamic_range)
latent[b][c] = normalize_tensor(latent[b][c], dynamic_range)

return _callback
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[project]
name = "comfyui-vectorscope-cc"
description = "ComfyUI port of Vectorscope CC and Diffusion Color Grading by Haoming02. Makes it possible to adjust Brightness/Contrast/Saturation/Hue during image generation."
version = "1.1.0"
version = "1.1.1"
license = { text = "MIT License" }

[project.urls]
Expand Down

0 comments on commit 2c03c8b

Please sign in to comment.