From 8d80584f6a2797268b9b57ec84d6c76e8a27891c Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 12 Nov 2023 01:25:33 -0500 Subject: [PATCH 01/84] Remove useless argument from uni_pc sampler. --- comfy/extra_samplers/uni_pc.py | 2 +- comfy/samplers.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/comfy/extra_samplers/uni_pc.py b/comfy/extra_samplers/uni_pc.py index 1a7a8392902..08bf0fc9e67 100644 --- a/comfy/extra_samplers/uni_pc.py +++ b/comfy/extra_samplers/uni_pc.py @@ -858,7 +858,7 @@ def predict_eps_sigma(model, input, sigma_in, **kwargs): return (input - model(input, sigma_in, **kwargs)) / sigma -def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, extra_args=None, callback=None, disable=False, noise_mask=None, variant='bh1'): +def sample_unipc(model, noise, image, sigmas, max_denoise, extra_args=None, callback=None, disable=False, noise_mask=None, variant='bh1'): timesteps = sigmas.clone() if sigmas[-1] == 0: timesteps = sigmas[:] diff --git a/comfy/samplers.py b/comfy/samplers.py index a839ee9e2a2..b8836a29d26 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -511,11 +511,11 @@ def max_denoise(self, model_wrap, sigmas): class UNIPC(Sampler): def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False): - return uni_pc.sample_unipc(model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=self.max_denoise(model_wrap, sigmas), extra_args=extra_args, noise_mask=denoise_mask, callback=callback, disable=disable_pbar) + return uni_pc.sample_unipc(model_wrap, noise, latent_image, sigmas, max_denoise=self.max_denoise(model_wrap, sigmas), extra_args=extra_args, noise_mask=denoise_mask, callback=callback, disable=disable_pbar) class UNIPCBH2(Sampler): def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False): - return uni_pc.sample_unipc(model_wrap, noise, latent_image, sigmas, sampling_function=sampling_function, max_denoise=self.max_denoise(model_wrap, sigmas), extra_args=extra_args, noise_mask=denoise_mask, callback=callback, variant='bh2', disable=disable_pbar) + return uni_pc.sample_unipc(model_wrap, noise, latent_image, sigmas, max_denoise=self.max_denoise(model_wrap, sigmas), extra_args=extra_args, noise_mask=denoise_mask, callback=callback, variant='bh2', disable=disable_pbar) KSAMPLER_NAMES = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral", "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu", From 2c9dba8dc08eb35d29dab691c1f2808f6c9191ea Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 12 Nov 2023 03:45:10 -0500 Subject: [PATCH 02/84] sampling_function now has the model object as the argument. --- comfy/samplers.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/comfy/samplers.py b/comfy/samplers.py index b8836a29d26..a2c784a4a48 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -11,7 +11,7 @@ #The main sampling function shared by all the samplers #Returns denoised -def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None): +def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None): def get_area_and_mult(conds, x_in, timestep_in): area = (x_in.shape[2], x_in.shape[3], 0, 0) strength = 1.0 @@ -134,7 +134,7 @@ def cond_cat(c_list): return out - def calc_cond_uncond_batch(model_function, cond, uncond, x_in, timestep, max_total_area, model_options): + def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, max_total_area, model_options): out_cond = torch.zeros_like(x_in) out_count = torch.ones_like(x_in) * 1e-37 @@ -221,9 +221,9 @@ def calc_cond_uncond_batch(model_function, cond, uncond, x_in, timestep, max_tot c['transformer_options'] = transformer_options if 'model_function_wrapper' in model_options: - output = model_options['model_function_wrapper'](model_function, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks) + output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks) else: - output = model_function(input_x, timestep_, **c).chunk(batch_chunks) + output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks) del input_x for o in range(batch_chunks): @@ -246,7 +246,7 @@ def calc_cond_uncond_batch(model_function, cond, uncond, x_in, timestep, max_tot if math.isclose(cond_scale, 1.0): uncond = None - cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, timestep, max_total_area, model_options) + cond, uncond = calc_cond_uncond_batch(model, cond, uncond, x, timestep, max_total_area, model_options) if "sampler_cfg_function" in model_options: args = {"cond": x - cond, "uncond": x - uncond, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep} return x - model_options["sampler_cfg_function"](args) @@ -258,7 +258,7 @@ def __init__(self, model): super().__init__() self.inner_model = model def apply_model(self, x, timestep, cond, uncond, cond_scale, model_options={}, seed=None): - out = sampling_function(self.inner_model.apply_model, x, timestep, uncond, cond, cond_scale, model_options=model_options, seed=seed) + out = sampling_function(self.inner_model, x, timestep, uncond, cond, cond_scale, model_options=model_options, seed=seed) return out def forward(self, *args, **kwargs): return self.apply_model(*args, **kwargs) From dd4ba68b6e93a562d9499eff34e50dbbbc8714e7 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 12 Nov 2023 04:02:16 -0500 Subject: [PATCH 03/84] Allow different models to estimate memory usage differently. --- comfy/model_base.py | 10 ++++++++++ comfy/model_management.py | 21 --------------------- comfy/model_patcher.py | 3 +++ comfy/sample.py | 2 +- comfy/samplers.py | 9 +++++---- 5 files changed, 19 insertions(+), 26 deletions(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index 7ba253470f4..f6de0b258d1 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -157,6 +157,16 @@ def state_dict_for_saving(self, clip_state_dict, vae_state_dict): def set_inpaint(self): self.inpaint_model = True + def memory_required(self, input_shape): + area = input_shape[0] * input_shape[2] * input_shape[3] + if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention(): + #TODO: this needs to be tweaked + return (area / 20) * (1024 * 1024) + else: + #TODO: this formula might be too aggressive since I tweaked the sub-quad and split algorithms to use less memory. + return (((area * 0.6) / 0.9) + 1024) * (1024 * 1024) + + def unclip_adm(unclip_conditioning, device, noise_augmentor, noise_augment_merge=0.0): adm_inputs = [] weights = [] diff --git a/comfy/model_management.py b/comfy/model_management.py index 53582fc736d..799e52ba239 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -579,27 +579,6 @@ def get_free_memory(dev=None, torch_free_too=False): else: return mem_free_total -def batch_area_memory(area): - if xformers_enabled() or pytorch_attention_flash_attention(): - #TODO: these formulas are copied from maximum_batch_area below - return (area / 20) * (1024 * 1024) - else: - return (((area * 0.6) / 0.9) + 1024) * (1024 * 1024) - -def maximum_batch_area(): - global vram_state - if vram_state == VRAMState.NO_VRAM: - return 0 - - memory_free = get_free_memory() / (1024 * 1024) - if xformers_enabled() or pytorch_attention_flash_attention(): - #TODO: this needs to be tweaked - area = 20 * memory_free - else: - #TODO: this formula is because AMD sucks and has memory management issues which might be fixed in the future - area = ((memory_free - 1024) * 0.9) / (0.6) - return int(max(area, 0)) - def cpu_mode(): global cpu_state return cpu_state == CPUState.CPU diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 9dc09791add..1c36855de90 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -52,6 +52,9 @@ def is_clone(self, other): return True return False + def memory_required(self, input_shape): + return self.model.memory_required(input_shape=input_shape) + def set_model_sampler_cfg_function(self, sampler_cfg_function): if len(inspect.signature(sampler_cfg_function).parameters) == 3: self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way diff --git a/comfy/sample.py b/comfy/sample.py index b3fcd1658a5..4bfdb8ce55d 100644 --- a/comfy/sample.py +++ b/comfy/sample.py @@ -83,7 +83,7 @@ def prepare_sampling(model, noise_shape, positive, negative, noise_mask): real_model = None models, inference_memory = get_additional_models(positive, negative, model.model_dtype()) - comfy.model_management.load_models_gpu([model] + models, comfy.model_management.batch_area_memory(noise_shape[0] * noise_shape[2] * noise_shape[3]) + inference_memory) + comfy.model_management.load_models_gpu([model] + models, model.memory_required(noise_shape) + inference_memory) real_model = model.model return real_model, positive, negative, noise_mask, models diff --git a/comfy/samplers.py b/comfy/samplers.py index a2c784a4a48..5340dd019b4 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -134,7 +134,7 @@ def cond_cat(c_list): return out - def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, max_total_area, model_options): + def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options): out_cond = torch.zeros_like(x_in) out_count = torch.ones_like(x_in) * 1e-37 @@ -170,9 +170,11 @@ def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, max_total_area, to_batch_temp.reverse() to_batch = to_batch_temp[:1] + free_memory = model_management.get_free_memory(x_in.device) for i in range(1, len(to_batch_temp) + 1): batch_amount = to_batch_temp[:len(to_batch_temp)//i] - if (len(batch_amount) * first_shape[0] * first_shape[2] * first_shape[3] < max_total_area): + input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:] + if model.memory_required(input_shape) < free_memory: to_batch = batch_amount break @@ -242,11 +244,10 @@ def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, max_total_area, return out_cond, out_uncond - max_total_area = model_management.maximum_batch_area() if math.isclose(cond_scale, 1.0): uncond = None - cond, uncond = calc_cond_uncond_batch(model, cond, uncond, x, timestep, max_total_area, model_options) + cond, uncond = calc_cond_uncond_batch(model, cond, uncond, x, timestep, model_options) if "sampler_cfg_function" in model_options: args = {"cond": x - cond, "uncond": x - uncond, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep} return x - model_options["sampler_cfg_function"](args) From 4781819a85847a8cf180a41d0ee4cdf99979e5be Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 12 Nov 2023 04:26:16 -0500 Subject: [PATCH 04/84] Make memory estimation aware of model dtype. --- comfy/model_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index f6de0b258d1..37bf24bb8c6 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -161,7 +161,7 @@ def memory_required(self, input_shape): area = input_shape[0] * input_shape[2] * input_shape[3] if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention(): #TODO: this needs to be tweaked - return (area / 20) * (1024 * 1024) + return (area / (comfy.model_management.dtype_size(self.get_dtype()) * 10)) * (1024 * 1024) else: #TODO: this formula might be too aggressive since I tweaked the sub-quad and split algorithms to use less memory. return (((area * 0.6) / 0.9) + 1024) * (1024 * 1024) From 4aeef781a3caecc694e3336ca9339e8e171ba4d4 Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Sun, 12 Nov 2023 19:49:23 +0000 Subject: [PATCH 05/84] Support number/text ids when importing API JSON (#1952) * support numeric/text ids --- web/scripts/app.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/scripts/app.js b/web/scripts/app.js index 61b88d44b85..d22b98c315f 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -1863,7 +1863,7 @@ export class ComfyApp { for (const id of ids) { const data = apiData[id]; const node = LiteGraph.createNode(data.class_type); - node.id = id; + node.id = isNaN(+id) ? id : +id; graph.add(node); } From f12ec55983fb13b3bcad33b05ff8043b22b36181 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 13 Nov 2023 00:42:34 -0500 Subject: [PATCH 06/84] Allow boolean widgets to have no options dict. --- web/scripts/widgets.js | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/web/scripts/widgets.js b/web/scripts/widgets.js index 2b674776937..36bc7ff7fd7 100644 --- a/web/scripts/widgets.js +++ b/web/scripts/widgets.js @@ -305,14 +305,23 @@ export const ComfyWidgets = { }; }, BOOLEAN(node, inputName, inputData) { - let defaultVal = inputData[1]["default"]; + let defaultVal = false; + let options = {}; + if (inputData[1]) { + if (inputData[1].default) + defaultVal = inputData[1].default; + if (inputData[1].label_on) + options["on"] = inputData[1].label_on; + if (inputData[1].label_off) + options["off"] = inputData[1].label_off; + } return { widget: node.addWidget( "toggle", inputName, defaultVal, () => {}, - {"on": inputData[1].label_on, "off": inputData[1].label_off} + options, ) }; }, From 7339479b10a622729222ae7d9a5e06db340a1b99 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 13 Nov 2023 12:27:44 -0500 Subject: [PATCH 07/84] Disable xformers when it can't load properly. --- comfy/model_management.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/comfy/model_management.py b/comfy/model_management.py index 799e52ba239..be4301aa4e3 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -133,6 +133,10 @@ def get_total_memory(dev=None, torch_total_too=False): import xformers import xformers.ops XFORMERS_IS_AVAILABLE = True + try: + XFORMERS_IS_AVAILABLE = xformers._has_cpp_library + except: + pass try: XFORMERS_VERSION = xformers.version.__version__ print("xformers version:", XFORMERS_VERSION) From eb0407e80657ab603a1251a653ad8b2e9e89c83c Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 13 Nov 2023 16:26:28 -0500 Subject: [PATCH 08/84] Update litegraph to latest. --- web/lib/litegraph.core.js | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/web/lib/litegraph.core.js b/web/lib/litegraph.core.js index e906590f5ef..0ca2038429e 100644 --- a/web/lib/litegraph.core.js +++ b/web/lib/litegraph.core.js @@ -2533,7 +2533,7 @@ var w = this.widgets[i]; if(!w) continue; - if(w.options && w.options.property && this.properties[ w.options.property ]) + if(w.options && w.options.property && (this.properties[ w.options.property ] != undefined)) w.value = JSON.parse( JSON.stringify( this.properties[ w.options.property ] ) ); } if (info.widgets_values) { @@ -4928,9 +4928,7 @@ LGraphNode.prototype.executeAction = function(action) this.title = o.title; this._bounding.set(o.bounding); this.color = o.color; - if (o.font_size) { - this.font_size = o.font_size; - } + this.font_size = o.font_size; }; LGraphGroup.prototype.serialize = function() { @@ -5714,10 +5712,10 @@ LGraphNode.prototype.executeAction = function(action) * @method enableWebGL **/ LGraphCanvas.prototype.enableWebGL = function() { - if (typeof GL === undefined) { + if (typeof GL === "undefined") { throw "litegl.js must be included to use a WebGL canvas"; } - if (typeof enableWebGLCanvas === undefined) { + if (typeof enableWebGLCanvas === "undefined") { throw "webglCanvas.js must be included to use this feature"; } @@ -7110,15 +7108,16 @@ LGraphNode.prototype.executeAction = function(action) } }; - LGraphCanvas.prototype.copyToClipboard = function() { + LGraphCanvas.prototype.copyToClipboard = function(nodes) { var clipboard_info = { nodes: [], links: [] }; var index = 0; var selected_nodes_array = []; - for (var i in this.selected_nodes) { - var node = this.selected_nodes[i]; + if (!nodes) nodes = this.selected_nodes; + for (var i in nodes) { + var node = nodes[i]; if (node.clonable === false) continue; node._relative_id = index; @@ -11702,7 +11701,7 @@ LGraphNode.prototype.executeAction = function(action) default: iS = 0; // try with first if no name set } - if (typeof options.node_from.outputs[iS] !== undefined){ + if (typeof options.node_from.outputs[iS] !== "undefined"){ if (iS!==false && iS>-1){ options.node_from.connectByType( iS, node, options.node_from.outputs[iS].type ); } @@ -11730,7 +11729,7 @@ LGraphNode.prototype.executeAction = function(action) default: iS = 0; // try with first if no name set } - if (typeof options.node_to.inputs[iS] !== undefined){ + if (typeof options.node_to.inputs[iS] !== "undefined"){ if (iS!==false && iS>-1){ // try connection options.node_to.connectByTypeOutput(iS,node,options.node_to.inputs[iS].type); From 61112c81b99d0e43c2d6031aae036eed8a39fdbb Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 13 Nov 2023 21:45:08 -0500 Subject: [PATCH 09/84] Add a node to flip the sigmas for unsampling. --- comfy_extras/nodes_custom_sampler.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/comfy_extras/nodes_custom_sampler.py b/comfy_extras/nodes_custom_sampler.py index 154ecd0d234..ff7407f4192 100644 --- a/comfy_extras/nodes_custom_sampler.py +++ b/comfy_extras/nodes_custom_sampler.py @@ -118,6 +118,24 @@ def get_sigmas(self, sigmas, step): sigmas2 = sigmas[step:] return (sigmas1, sigmas2) +class FlipSigmas: + @classmethod + def INPUT_TYPES(s): + return {"required": + {"sigmas": ("SIGMAS", ), + } + } + RETURN_TYPES = ("SIGMAS",) + CATEGORY = "sampling/custom_sampling/sigmas" + + FUNCTION = "get_sigmas" + + def get_sigmas(self, sigmas): + sigmas = sigmas.flip(0) + if sigmas[0] == 0: + sigmas[0] = 0.0001 + return (sigmas,) + class KSamplerSelect: @classmethod def INPUT_TYPES(s): @@ -243,4 +261,5 @@ def sample(self, model, add_noise, noise_seed, cfg, positive, negative, sampler, "SamplerDPMPP_SDE": SamplerDPMPP_SDE, "BasicScheduler": BasicScheduler, "SplitSigmas": SplitSigmas, + "FlipSigmas": FlipSigmas, } From 8509bd58b436eb56e1e251c627416b457626252a Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 13 Nov 2023 21:45:23 -0500 Subject: [PATCH 10/84] Reorganize custom_sampling nodes. --- comfy_extras/nodes_custom_sampler.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/comfy_extras/nodes_custom_sampler.py b/comfy_extras/nodes_custom_sampler.py index ff7407f4192..f0576946a58 100644 --- a/comfy_extras/nodes_custom_sampler.py +++ b/comfy_extras/nodes_custom_sampler.py @@ -16,7 +16,7 @@ def INPUT_TYPES(s): } } RETURN_TYPES = ("SIGMAS",) - CATEGORY = "sampling/custom_sampling" + CATEGORY = "sampling/custom_sampling/schedulers" FUNCTION = "get_sigmas" @@ -36,7 +36,7 @@ def INPUT_TYPES(s): } } RETURN_TYPES = ("SIGMAS",) - CATEGORY = "sampling/custom_sampling" + CATEGORY = "sampling/custom_sampling/schedulers" FUNCTION = "get_sigmas" @@ -54,7 +54,7 @@ def INPUT_TYPES(s): } } RETURN_TYPES = ("SIGMAS",) - CATEGORY = "sampling/custom_sampling" + CATEGORY = "sampling/custom_sampling/schedulers" FUNCTION = "get_sigmas" @@ -73,7 +73,7 @@ def INPUT_TYPES(s): } } RETURN_TYPES = ("SIGMAS",) - CATEGORY = "sampling/custom_sampling" + CATEGORY = "sampling/custom_sampling/schedulers" FUNCTION = "get_sigmas" @@ -92,7 +92,7 @@ def INPUT_TYPES(s): } } RETURN_TYPES = ("SIGMAS",) - CATEGORY = "sampling/custom_sampling" + CATEGORY = "sampling/custom_sampling/schedulers" FUNCTION = "get_sigmas" @@ -109,7 +109,7 @@ def INPUT_TYPES(s): } } RETURN_TYPES = ("SIGMAS","SIGMAS") - CATEGORY = "sampling/custom_sampling" + CATEGORY = "sampling/custom_sampling/sigmas" FUNCTION = "get_sigmas" @@ -144,7 +144,7 @@ def INPUT_TYPES(s): } } RETURN_TYPES = ("SAMPLER",) - CATEGORY = "sampling/custom_sampling" + CATEGORY = "sampling/custom_sampling/samplers" FUNCTION = "get_sampler" @@ -163,7 +163,7 @@ def INPUT_TYPES(s): } } RETURN_TYPES = ("SAMPLER",) - CATEGORY = "sampling/custom_sampling" + CATEGORY = "sampling/custom_sampling/samplers" FUNCTION = "get_sampler" @@ -187,7 +187,7 @@ def INPUT_TYPES(s): } } RETURN_TYPES = ("SAMPLER",) - CATEGORY = "sampling/custom_sampling" + CATEGORY = "sampling/custom_sampling/samplers" FUNCTION = "get_sampler" @@ -252,6 +252,7 @@ def sample(self, model, add_noise, noise_seed, cfg, positive, negative, sampler, NODE_CLASS_MAPPINGS = { "SamplerCustom": SamplerCustom, + "BasicScheduler": BasicScheduler, "KarrasScheduler": KarrasScheduler, "ExponentialScheduler": ExponentialScheduler, "PolyexponentialScheduler": PolyexponentialScheduler, @@ -259,7 +260,6 @@ def sample(self, model, add_noise, noise_seed, cfg, positive, negative, sampler, "KSamplerSelect": KSamplerSelect, "SamplerDPMPP_2M_SDE": SamplerDPMPP_2M_SDE, "SamplerDPMPP_SDE": SamplerDPMPP_SDE, - "BasicScheduler": BasicScheduler, "SplitSigmas": SplitSigmas, "FlipSigmas": FlipSigmas, } From 94cc718e9c42cb4de337293b66dd42fb594b9cae Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 14 Nov 2023 00:08:12 -0500 Subject: [PATCH 11/84] Add a way to add patches to the input block. --- comfy/ldm/modules/diffusionmodules/openaimodel.py | 5 +++++ comfy/model_patcher.py | 3 +++ 2 files changed, 8 insertions(+) diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index 49c1e8cbb5a..cac0dfb6598 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -624,6 +624,11 @@ def forward(self, x, timesteps=None, context=None, y=None, control=None, transfo transformer_options["block"] = ("input", id) h = forward_timestep_embed(module, h, emb, context, transformer_options) h = apply_control(h, control, 'input') + if "input_block_patch" in transformer_patches: + patch = transformer_patches["input_block_patch"] + for p in patch: + h = p(h, transformer_options) + hs.append(h) transformer_options["block"] = ("middle", 0) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 1c36855de90..02368433126 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -96,6 +96,9 @@ def set_model_attn1_output_patch(self, patch): def set_model_attn2_output_patch(self, patch): self.set_model_patch(patch, "attn2_output_patch") + def set_model_input_block_patch(self, patch): + self.set_model_patch(patch, "input_block_patch") + def set_model_output_block_patch(self, patch): self.set_model_patch(patch, "output_block_patch") From f2e49b1d575b3da4367ba4d60b95187f270d42c9 Mon Sep 17 00:00:00 2001 From: Jianqi Pan Date: Tue, 14 Nov 2023 14:32:05 +0900 Subject: [PATCH 12/84] fix: adaptation to older versions of pytroch --- comfy/sd1_clip.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 7db7ee0f449..af621b2dcb5 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -175,7 +175,7 @@ def forward(self, tokens): else: precision_scope = lambda a, b: contextlib.nullcontext(a) - with precision_scope(model_management.get_autocast_device(device), torch.float32): + with precision_scope(model_management.get_autocast_device(device), dtype=torch.float32): attention_mask = None if self.enable_attention_masks: attention_mask = torch.zeros_like(tokens) From 420beeeb05ef59e887f8731f615f8a9ec6eb0a4c Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 14 Nov 2023 00:39:34 -0500 Subject: [PATCH 13/84] Clean up and refactor sampler code. This should make it much easier to write custom nodes with kdiffusion type samplers. --- comfy/samplers.py | 85 +++++++++++++++++----------- comfy_extras/nodes_custom_sampler.py | 6 +- 2 files changed, 54 insertions(+), 37 deletions(-) diff --git a/comfy/samplers.py b/comfy/samplers.py index 5340dd019b4..65c44791d02 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -522,42 +522,59 @@ def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=N "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu", "dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm"] -def ksampler(sampler_name, extra_options={}, inpaint_options={}): - class KSAMPLER(Sampler): - def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False): - extra_args["denoise_mask"] = denoise_mask - model_k = KSamplerX0Inpaint(model_wrap) - model_k.latent_image = latent_image - if inpaint_options.get("random", False): #TODO: Should this be the default? - generator = torch.manual_seed(extra_args.get("seed", 41) + 1) - model_k.noise = torch.randn(noise.shape, generator=generator, device="cpu").to(noise.dtype).to(noise.device) - else: - model_k.noise = noise +class KSAMPLER(Sampler): + def __init__(self, sampler_function, extra_options={}, inpaint_options={}): + self.sampler_function = sampler_function + self.extra_options = extra_options + self.inpaint_options = inpaint_options - if self.max_denoise(model_wrap, sigmas): - noise = noise * torch.sqrt(1.0 + sigmas[0] ** 2.0) - else: - noise = noise * sigmas[0] + def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False): + extra_args["denoise_mask"] = denoise_mask + model_k = KSamplerX0Inpaint(model_wrap) + model_k.latent_image = latent_image + if self.inpaint_options.get("random", False): #TODO: Should this be the default? + generator = torch.manual_seed(extra_args.get("seed", 41) + 1) + model_k.noise = torch.randn(noise.shape, generator=generator, device="cpu").to(noise.dtype).to(noise.device) + else: + model_k.noise = noise - k_callback = None - total_steps = len(sigmas) - 1 - if callback is not None: - k_callback = lambda x: callback(x["i"], x["denoised"], x["x"], total_steps) + if self.max_denoise(model_wrap, sigmas): + noise = noise * torch.sqrt(1.0 + sigmas[0] ** 2.0) + else: + noise = noise * sigmas[0] + + k_callback = None + total_steps = len(sigmas) - 1 + if callback is not None: + k_callback = lambda x: callback(x["i"], x["denoised"], x["x"], total_steps) + + if latent_image is not None: + noise += latent_image + samples = self.sampler_function(model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar, **self.extra_options) + return samples + + +def ksampler(sampler_name, extra_options={}, inpaint_options={}): + if sampler_name == "dpm_fast": + def dpm_fast_function(model, noise, sigmas, extra_args, callback, disable): sigma_min = sigmas[-1] if sigma_min == 0: sigma_min = sigmas[-2] + total_steps = len(sigmas) - 1 + return k_diffusion_sampling.sample_dpm_fast(model, noise, sigma_min, sigmas[0], total_steps, extra_args=extra_args, callback=callback, disable=disable) + sampler_function = dpm_fast_function + elif sampler_name == "dpm_adaptive": + def dpm_adaptive_function(model, noise, sigmas, extra_args, callback, disable): + sigma_min = sigmas[-1] + if sigma_min == 0: + sigma_min = sigmas[-2] + return k_diffusion_sampling.sample_dpm_adaptive(model, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=callback, disable=disable) + sampler_function = dpm_adaptive_function + else: + sampler_function = getattr(k_diffusion_sampling, "sample_{}".format(sampler_name)) - if latent_image is not None: - noise += latent_image - if sampler_name == "dpm_fast": - samples = k_diffusion_sampling.sample_dpm_fast(model_k, noise, sigma_min, sigmas[0], total_steps, extra_args=extra_args, callback=k_callback, disable=disable_pbar) - elif sampler_name == "dpm_adaptive": - samples = k_diffusion_sampling.sample_dpm_adaptive(model_k, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=k_callback, disable=disable_pbar) - else: - samples = getattr(k_diffusion_sampling, "sample_{}".format(sampler_name))(model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar, **extra_options) - return samples - return KSAMPLER + return KSAMPLER(sampler_function, extra_options, inpaint_options) def wrap_model(model): model_denoise = CFGNoisePredictor(model) @@ -618,11 +635,11 @@ def calculate_sigmas_scheduler(model, scheduler_name, steps): print("error invalid scheduler", self.scheduler) return sigmas -def sampler_class(name): +def sampler_object(name): if name == "uni_pc": - sampler = UNIPC + sampler = UNIPC() elif name == "uni_pc_bh2": - sampler = UNIPCBH2 + sampler = UNIPCBH2() elif name == "ddim": sampler = ksampler("euler", inpaint_options={"random": True}) else: @@ -687,6 +704,6 @@ def sample(self, noise, positive, negative, cfg, latent_image=None, start_step=N else: return torch.zeros_like(noise) - sampler = sampler_class(self.sampler) + sampler = sampler_object(self.sampler) - return sample(self.model, noise, positive, negative, cfg, self.device, sampler(), sigmas, self.model_options, latent_image=latent_image, denoise_mask=denoise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed) + return sample(self.model, noise, positive, negative, cfg, self.device, sampler, sigmas, self.model_options, latent_image=latent_image, denoise_mask=denoise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed) diff --git a/comfy_extras/nodes_custom_sampler.py b/comfy_extras/nodes_custom_sampler.py index f0576946a58..d3c1d4a23ee 100644 --- a/comfy_extras/nodes_custom_sampler.py +++ b/comfy_extras/nodes_custom_sampler.py @@ -149,7 +149,7 @@ def INPUT_TYPES(s): FUNCTION = "get_sampler" def get_sampler(self, sampler_name): - sampler = comfy.samplers.sampler_class(sampler_name)() + sampler = comfy.samplers.sampler_object(sampler_name) return (sampler, ) class SamplerDPMPP_2M_SDE: @@ -172,7 +172,7 @@ def get_sampler(self, solver_type, eta, s_noise, noise_device): sampler_name = "dpmpp_2m_sde" else: sampler_name = "dpmpp_2m_sde_gpu" - sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "solver_type": solver_type})() + sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "solver_type": solver_type}) return (sampler, ) @@ -196,7 +196,7 @@ def get_sampler(self, eta, s_noise, r, noise_device): sampler_name = "dpmpp_sde" else: sampler_name = "dpmpp_sde_gpu" - sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "r": r})() + sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "r": r}) return (sampler, ) class SamplerCustom: From c962884a5c987e95d6928565ddb44220b769808e Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 14 Nov 2023 11:38:36 -0500 Subject: [PATCH 14/84] Make bislerp work on GPU. --- comfy/utils.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/comfy/utils.py b/comfy/utils.py index 4b484d07ac9..1985012e0f1 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -307,13 +307,13 @@ def slerp(b1, b2, r): res[dot < 1e-5 - 1] = (b1 * (1.0-r) + b2 * r)[dot < 1e-5 - 1] return res - def generate_bilinear_data(length_old, length_new): - coords_1 = torch.arange(length_old).reshape((1,1,1,-1)).to(torch.float32) + def generate_bilinear_data(length_old, length_new, device): + coords_1 = torch.arange(length_old, dtype=torch.float32, device=device).reshape((1,1,1,-1)) coords_1 = torch.nn.functional.interpolate(coords_1, size=(1, length_new), mode="bilinear") ratios = coords_1 - coords_1.floor() coords_1 = coords_1.to(torch.int64) - coords_2 = torch.arange(length_old).reshape((1,1,1,-1)).to(torch.float32) + 1 + coords_2 = torch.arange(length_old, dtype=torch.float32, device=device).reshape((1,1,1,-1)) + 1 coords_2[:,:,:,-1] -= 1 coords_2 = torch.nn.functional.interpolate(coords_2, size=(1, length_new), mode="bilinear") coords_2 = coords_2.to(torch.int64) @@ -323,7 +323,7 @@ def generate_bilinear_data(length_old, length_new): h_new, w_new = (height, width) #linear w - ratios, coords_1, coords_2 = generate_bilinear_data(w, w_new) + ratios, coords_1, coords_2 = generate_bilinear_data(w, w_new, samples.device) coords_1 = coords_1.expand((n, c, h, -1)) coords_2 = coords_2.expand((n, c, h, -1)) ratios = ratios.expand((n, 1, h, -1)) @@ -336,7 +336,7 @@ def generate_bilinear_data(length_old, length_new): result = result.reshape(n, h, w_new, c).movedim(-1, 1) #linear h - ratios, coords_1, coords_2 = generate_bilinear_data(h, h_new) + ratios, coords_1, coords_2 = generate_bilinear_data(h, h_new, samples.device) coords_1 = coords_1.reshape((1,1,-1,1)).expand((n, c, -1, w_new)) coords_2 = coords_2.reshape((1,1,-1,1)).expand((n, c, -1, w_new)) ratios = ratios.reshape((1,1,-1,1)).expand((n, 1, -1, w_new)) From 728613bb3e9a42a3e05abf19b1b893eb6ef35081 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 14 Nov 2023 14:41:31 -0500 Subject: [PATCH 15/84] Fix last pr. --- comfy/sd1_clip.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index af621b2dcb5..58acb97fce7 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -173,7 +173,7 @@ def forward(self, tokens): if getattr(self.transformer, self.inner_name).final_layer_norm.weight.dtype != torch.float32: precision_scope = torch.autocast else: - precision_scope = lambda a, b: contextlib.nullcontext(a) + precision_scope = lambda a, dtype: contextlib.nullcontext(a) with precision_scope(model_management.get_autocast_device(device), dtype=torch.float32): attention_mask = None From 7b87c825a3e95b362b101a608bbae2bbf13e1850 Mon Sep 17 00:00:00 2001 From: 42lux Date: Wed, 15 Nov 2023 02:37:35 +0100 Subject: [PATCH 16/84] Added Colorschemes. Arc, North and Github. --- web/extensions/core/colorPalette.js | 207 ++++++++++++++++++++++++++++ 1 file changed, 207 insertions(+) diff --git a/web/extensions/core/colorPalette.js b/web/extensions/core/colorPalette.js index 3695b08e27f..b8d83613d4b 100644 --- a/web/extensions/core/colorPalette.js +++ b/web/extensions/core/colorPalette.js @@ -174,6 +174,213 @@ const colorPalettes = { "tr-odd-bg-color": "#073642", } }, + }, + "arc": { + "id": "arc", + "name": "Arc", + "colors": { + "node_slot": { + "BOOLEAN": "", + "CLIP": "#eacb8b", + "CLIP_VISION": "#A8DADC", + "CLIP_VISION_OUTPUT": "#ad7452", + "CONDITIONING": "#cf876f", + "CONTROL_NET": "#00d78d", + "CONTROL_NET_WEIGHTS": "", + "FLOAT": "", + "GLIGEN": "", + "IMAGE": "#80a1c0", + "IMAGEUPLOAD": "", + "INT": "", + "LATENT": "#b38ead", + "LATENT_KEYFRAME": "", + "MASK": "#a3bd8d", + "MODEL": "#8978a7", + "SAMPLER": "", + "SIGMAS": "", + "STRING": "", + "STYLE_MODEL": "#C2FFAE", + "T2I_ADAPTER_WEIGHTS": "", + "TAESD": "#DCC274", + "TIMESTEP_KEYFRAME": "", + "UPSCALE_MODEL": "", + "VAE": "#be616b" + }, + "litegraph_base": { + "BACKGROUND_IMAGE": "", + "CLEAR_BACKGROUND_COLOR": "#2b2f38", + "NODE_TITLE_COLOR": "#b2b7bd", + "NODE_SELECTED_TITLE_COLOR": "#FFF", + "NODE_TEXT_SIZE": 14, + "NODE_TEXT_COLOR": "#AAA", + "NODE_SUBTEXT_SIZE": 12, + "NODE_DEFAULT_COLOR": "#2b2f38", + "NODE_DEFAULT_BGCOLOR": "#242730", + "NODE_DEFAULT_BOXCOLOR": "#6e7581", + "NODE_DEFAULT_SHAPE": "box", + "NODE_BOX_OUTLINE_COLOR": "#FFF", + "DEFAULT_SHADOW_COLOR": "rgba(0,0,0,0.5)", + "DEFAULT_GROUP_FONT": 22, + "WIDGET_BGCOLOR": "#2b2f38", + "WIDGET_OUTLINE_COLOR": "#6e7581", + "WIDGET_TEXT_COLOR": "#DDD", + "WIDGET_SECONDARY_TEXT_COLOR": "#b2b7bd", + "LINK_COLOR": "#9A9", + "EVENT_LINK_COLOR": "#A86", + "CONNECTING_LINK_COLOR": "#AFA" + }, + "comfy_base": { + "fg-color": "#fff", + "bg-color": "#2b2f38", + "comfy-menu-bg": "#242730", + "comfy-input-bg": "#2b2f38", + "input-text": "#ddd", + "descrip-text": "#b2b7bd", + "drag-text": "#ccc", + "error-text": "#ff4444", + "border-color": "#6e7581", + "tr-even-bg-color": "#2b2f38", + "tr-odd-bg-color": "#242730" + } + }, + }, + "nord": { + "id": "nord", + "name": "Nord", + "colors": { + "node_slot": { + "BOOLEAN": "", + "CLIP": "#eacb8b", + "CLIP_VISION": "#A8DADC", + "CLIP_VISION_OUTPUT": "#ad7452", + "CONDITIONING": "#cf876f", + "CONTROL_NET": "#00d78d", + "CONTROL_NET_WEIGHTS": "", + "FLOAT": "", + "GLIGEN": "", + "IMAGE": "#80a1c0", + "IMAGEUPLOAD": "", + "INT": "", + "LATENT": "#b38ead", + "LATENT_KEYFRAME": "", + "MASK": "#a3bd8d", + "MODEL": "#8978a7", + "SAMPLER": "", + "SIGMAS": "", + "STRING": "", + "STYLE_MODEL": "#C2FFAE", + "T2I_ADAPTER_WEIGHTS": "", + "TAESD": "#DCC274", + "TIMESTEP_KEYFRAME": "", + "UPSCALE_MODEL": "", + "VAE": "#be616b" + }, + "litegraph_base": { + "BACKGROUND_IMAGE": "", + "CLEAR_BACKGROUND_COLOR": "#212732", + "NODE_TITLE_COLOR": "#999", + "NODE_SELECTED_TITLE_COLOR": "#e5eaf0", + "NODE_TEXT_SIZE": 14, + "NODE_TEXT_COLOR": "#bcc2c8", + "NODE_SUBTEXT_SIZE": 12, + "NODE_DEFAULT_COLOR": "#2e3440", + "NODE_DEFAULT_BGCOLOR": "#161b22", + "NODE_DEFAULT_BOXCOLOR": "#545d70", + "NODE_DEFAULT_SHAPE": "box", + "NODE_BOX_OUTLINE_COLOR": "#e5eaf0", + "DEFAULT_SHADOW_COLOR": "rgba(0,0,0,0.5)", + "DEFAULT_GROUP_FONT": 24, + "WIDGET_BGCOLOR": "#2e3440", + "WIDGET_OUTLINE_COLOR": "#545d70", + "WIDGET_TEXT_COLOR": "#bcc2c8", + "WIDGET_SECONDARY_TEXT_COLOR": "#999", + "LINK_COLOR": "#9A9", + "EVENT_LINK_COLOR": "#A86", + "CONNECTING_LINK_COLOR": "#AFA" + }, + "comfy_base": { + "fg-color": "#e5eaf0", + "bg-color": "#2e3440", + "comfy-menu-bg": "#161b22", + "comfy-input-bg": "#2e3440", + "input-text": "#bcc2c8", + "descrip-text": "#999", + "drag-text": "#ccc", + "error-text": "#ff4444", + "border-color": "#545d70", + "tr-even-bg-color": "#2e3440", + "tr-odd-bg-color": "#161b22" + } + }, + }, + "github": { + "id": "github", + "name": "Github", + "colors": { + "node_slot": { + "BOOLEAN": "", + "CLIP": "#eacb8b", + "CLIP_VISION": "#A8DADC", + "CLIP_VISION_OUTPUT": "#ad7452", + "CONDITIONING": "#cf876f", + "CONTROL_NET": "#00d78d", + "CONTROL_NET_WEIGHTS": "", + "FLOAT": "", + "GLIGEN": "", + "IMAGE": "#80a1c0", + "IMAGEUPLOAD": "", + "INT": "", + "LATENT": "#b38ead", + "LATENT_KEYFRAME": "", + "MASK": "#a3bd8d", + "MODEL": "#8978a7", + "SAMPLER": "", + "SIGMAS": "", + "STRING": "", + "STYLE_MODEL": "#C2FFAE", + "T2I_ADAPTER_WEIGHTS": "", + "TAESD": "#DCC274", + "TIMESTEP_KEYFRAME": "", + "UPSCALE_MODEL": "", + "VAE": "#be616b" + }, + "litegraph_base": { + "BACKGROUND_IMAGE": "", + "CLEAR_BACKGROUND_COLOR": "#040506", + "NODE_TITLE_COLOR": "#999", + "NODE_SELECTED_TITLE_COLOR": "#e5eaf0", + "NODE_TEXT_SIZE": 14, + "NODE_TEXT_COLOR": "#bcc2c8", + "NODE_SUBTEXT_SIZE": 12, + "NODE_DEFAULT_COLOR": "#161b22", + "NODE_DEFAULT_BGCOLOR": "#13171d", + "NODE_DEFAULT_BOXCOLOR": "#30363d", + "NODE_DEFAULT_SHAPE": "box", + "NODE_BOX_OUTLINE_COLOR": "#e5eaf0", + "DEFAULT_SHADOW_COLOR": "rgba(0,0,0,0.5)", + "DEFAULT_GROUP_FONT": 24, + "WIDGET_BGCOLOR": "#161b22", + "WIDGET_OUTLINE_COLOR": "#30363d", + "WIDGET_TEXT_COLOR": "#bcc2c8", + "WIDGET_SECONDARY_TEXT_COLOR": "#999", + "LINK_COLOR": "#9A9", + "EVENT_LINK_COLOR": "#A86", + "CONNECTING_LINK_COLOR": "#AFA" + }, + "comfy_base": { + "fg-color": "#e5eaf0", + "bg-color": "#161b22", + "comfy-menu-bg": "#13171d", + "comfy-input-bg": "#161b22", + "input-text": "#bcc2c8", + "descrip-text": "#999", + "drag-text": "#ccc", + "error-text": "#ff4444", + "border-color": "#30363d", + "tr-even-bg-color": "#161b22", + "tr-odd-bg-color": "#13171d" + } + }, } }; From 57eea0efbb07a48d4810b477b29d44ba5425a742 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 14 Nov 2023 23:45:36 -0500 Subject: [PATCH 17/84] heunpp2 sampler. --- comfy/k_diffusion/sampling.py | 58 +++++++++++++++++++++++++++++++++++ comfy/samplers.py | 2 +- 2 files changed, 59 insertions(+), 1 deletion(-) diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index dd6f7bbe598..761c2e0ef7c 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -750,3 +750,61 @@ def sample_lcm(model, x, sigmas, extra_args=None, callback=None, disable=None, n if sigmas[i + 1] > 0: x += sigmas[i + 1] * noise_sampler(sigmas[i], sigmas[i + 1]) return x + + + +@torch.no_grad() +def sample_heunpp2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.): + # From MIT licensed: https://github.com/Carzit/sd-webui-samplers-scheduler/ + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + s_end = sigmas[-1] + for i in trange(len(sigmas) - 1, disable=disable): + gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0. + eps = torch.randn_like(x) * s_noise + sigma_hat = sigmas[i] * (gamma + 1) + if gamma > 0: + x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5 + denoised = model(x, sigma_hat * s_in, **extra_args) + d = to_d(x, sigma_hat, denoised) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised}) + dt = sigmas[i + 1] - sigma_hat + if sigmas[i + 1] == s_end: + # Euler method + x = x + d * dt + elif sigmas[i + 2] == s_end: + + # Heun's method + x_2 = x + d * dt + denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args) + d_2 = to_d(x_2, sigmas[i + 1], denoised_2) + + w = 2 * sigmas[0] + w2 = sigmas[i+1]/w + w1 = 1 - w2 + + d_prime = d * w1 + d_2 * w2 + + + x = x + d_prime * dt + + else: + # Heun++ + x_2 = x + d * dt + denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args) + d_2 = to_d(x_2, sigmas[i + 1], denoised_2) + dt_2 = sigmas[i + 2] - sigmas[i + 1] + + x_3 = x_2 + d_2 * dt_2 + denoised_3 = model(x_3, sigmas[i + 2] * s_in, **extra_args) + d_3 = to_d(x_3, sigmas[i + 2], denoised_3) + + w = 3 * sigmas[0] + w2 = sigmas[i + 1] / w + w3 = sigmas[i + 2] / w + w1 = 1 - w2 - w3 + + d_prime = w1 * d + w2 * d_2 + w3 * d_3 + x = x + d_prime * dt + return x diff --git a/comfy/samplers.py b/comfy/samplers.py index 65c44791d02..d8037d8ea8b 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -518,7 +518,7 @@ class UNIPCBH2(Sampler): def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False): return uni_pc.sample_unipc(model_wrap, noise, latent_image, sigmas, max_denoise=self.max_denoise(model_wrap, sigmas), extra_args=extra_args, noise_mask=denoise_mask, callback=callback, variant='bh2', disable=disable_pbar) -KSAMPLER_NAMES = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral", +KSAMPLER_NAMES = ["euler", "euler_ancestral", "heun", "heunpp2","dpm_2", "dpm_2_ancestral", "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu", "dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm"] From 7114cfec0eefe713340257c85a2b342e98fdcfb2 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 15 Nov 2023 15:55:02 -0500 Subject: [PATCH 18/84] Always clone graph data when loading to fix some load issues. --- web/scripts/app.js | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/web/scripts/app.js b/web/scripts/app.js index d22b98c315f..4507527f686 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -1489,16 +1489,18 @@ export class ComfyApp { let reset_invalid_values = false; if (!graphData) { - if (typeof structuredClone === "undefined") - { - graphData = JSON.parse(JSON.stringify(defaultGraph)); - }else - { - graphData = structuredClone(defaultGraph); - } + graphData = defaultGraph; reset_invalid_values = true; } + if (typeof structuredClone === "undefined") + { + graphData = JSON.parse(JSON.stringify(graphData)); + }else + { + graphData = structuredClone(graphData); + } + const missingNodeTypes = []; for (let n of graphData.nodes) { // Patch T2IAdapterLoader to ControlNetLoader since they are the same node now From dcec1047e6bb04880551a64cdb8f31dbde920ea0 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 16 Nov 2023 04:07:35 -0500 Subject: [PATCH 19/84] Invert the start and end percentages in the code. This doesn't affect how percentages behave in the frontend but breaks things if you relied on them in the backend. percent_to_sigma goes from 0 to 1.0 instead of 1.0 to 0 for less confusion. Make percent 0 return an extremely large sigma and percent 1.0 return a zero one to fix imprecision. --- comfy/controlnet.py | 4 ++-- comfy/model_sampling.py | 5 +++++ comfy/samplers.py | 2 ++ comfy_extras/nodes_model_advanced.py | 5 +++++ nodes.py | 6 +++--- 5 files changed, 17 insertions(+), 5 deletions(-) diff --git a/comfy/controlnet.py b/comfy/controlnet.py index 09868158287..433381df6ec 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -33,7 +33,7 @@ def __init__(self, device=None): self.cond_hint_original = None self.cond_hint = None self.strength = 1.0 - self.timestep_percent_range = (1.0, 0.0) + self.timestep_percent_range = (0.0, 1.0) self.timestep_range = None if device is None: @@ -42,7 +42,7 @@ def __init__(self, device=None): self.previous_controlnet = None self.global_average_pooling = False - def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(1.0, 0.0)): + def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0)): self.cond_hint_original = cond_hint self.strength = strength self.timestep_percent_range = timestep_percent_range diff --git a/comfy/model_sampling.py b/comfy/model_sampling.py index a2935d47d18..d5b1642ef3a 100644 --- a/comfy/model_sampling.py +++ b/comfy/model_sampling.py @@ -76,5 +76,10 @@ def sigma(self, timestep): return log_sigma.exp() def percent_to_sigma(self, percent): + if percent <= 0.0: + return torch.tensor(999999999.9) + if percent >= 1.0: + return torch.tensor(0.0) + percent = 1.0 - percent return self.sigma(torch.tensor(percent * 999.0)) diff --git a/comfy/samplers.py b/comfy/samplers.py index d8037d8ea8b..1d012a514a7 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -220,6 +220,8 @@ def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options): transformer_options["patches"] = patches transformer_options["cond_or_uncond"] = cond_or_uncond[:] + transformer_options["sigmas"] = timestep + c['transformer_options'] = transformer_options if 'model_function_wrapper' in model_options: diff --git a/comfy_extras/nodes_model_advanced.py b/comfy_extras/nodes_model_advanced.py index 399123eaa2e..c8c4b4a1e70 100644 --- a/comfy_extras/nodes_model_advanced.py +++ b/comfy_extras/nodes_model_advanced.py @@ -66,6 +66,11 @@ def sigma(self, timestep): return log_sigma.exp() def percent_to_sigma(self, percent): + if percent <= 0.0: + return torch.tensor(999999999.9) + if percent >= 1.0: + return torch.tensor(0.0) + percent = 1.0 - percent return self.sigma(torch.tensor(percent * 999.0)) diff --git a/nodes.py b/nodes.py index 2bbfd8fe874..e8cfb5e6ac2 100644 --- a/nodes.py +++ b/nodes.py @@ -248,8 +248,8 @@ def set_range(self, conditioning, start, end): c = [] for t in conditioning: d = t[1].copy() - d['start_percent'] = 1.0 - start - d['end_percent'] = 1.0 - end + d['start_percent'] = start + d['end_percent'] = end n = [t[0], d] c.append(n) return (c, ) @@ -685,7 +685,7 @@ def apply_controlnet(self, positive, negative, control_net, image, strength, sta if prev_cnet in cnets: c_net = cnets[prev_cnet] else: - c_net = control_net.copy().set_cond_hint(control_hint, strength, (1.0 - start_percent, 1.0 - end_percent)) + c_net = control_net.copy().set_cond_hint(control_hint, strength, (start_percent, end_percent)) c_net.set_previous_controlnet(prev_cnet) cnets[prev_cnet] = c_net From 7ea6bb038cf488224269565bf0e0bcc400f0a7e2 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 16 Nov 2023 12:57:12 -0500 Subject: [PATCH 20/84] Print warning when controlnet can't be applied instead of crashing. --- comfy/ldm/modules/diffusionmodules/openaimodel.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index cac0dfb6598..504b79ede66 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -255,7 +255,10 @@ def apply_control(h, control, name): if control is not None and name in control and len(control[name]) > 0: ctrl = control[name].pop() if ctrl is not None: - h += ctrl + try: + h += ctrl + except: + print("warning control could not be applied", h.shape, ctrl.shape) return h class UNetModel(nn.Module): From bd07ad1861949007139de7dd5c6bcdb77426919c Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 16 Nov 2023 13:23:25 -0500 Subject: [PATCH 21/84] Add PatchModelAddDownscale (Kohya Deep Shrink) node. By adding a downscale to the unet in the first timesteps this node lets you generate images at higher resolutions with less consistency issues. --- comfy_extras/nodes_model_downscale.py | 45 +++++++++++++++++++++++++++ nodes.py | 1 + 2 files changed, 46 insertions(+) create mode 100644 comfy_extras/nodes_model_downscale.py diff --git a/comfy_extras/nodes_model_downscale.py b/comfy_extras/nodes_model_downscale.py new file mode 100644 index 00000000000..f1b2d3ff2c5 --- /dev/null +++ b/comfy_extras/nodes_model_downscale.py @@ -0,0 +1,45 @@ +import torch + +class PatchModelAddDownscale: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "block_number": ("INT", {"default": 3, "min": 1, "max": 32, "step": 1}), + "downscale_factor": ("FLOAT", {"default": 2.0, "min": 0.1, "max": 9.0, "step": 0.001}), + "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), + "end_percent": ("FLOAT", {"default": 0.35, "min": 0.0, "max": 1.0, "step": 0.001}), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + + CATEGORY = "_for_testing" + + def patch(self, model, block_number, downscale_factor, start_percent, end_percent): + sigma_start = model.model.model_sampling.percent_to_sigma(start_percent).item() + sigma_end = model.model.model_sampling.percent_to_sigma(end_percent).item() + + def input_block_patch(h, transformer_options): + if transformer_options["block"][1] == block_number: + sigma = transformer_options["sigmas"][0].item() + if sigma <= sigma_start and sigma >= sigma_end: + h = torch.nn.functional.interpolate(h, scale_factor=(1.0 / downscale_factor), mode="bicubic", align_corners=False) + return h + + def output_block_patch(h, hsp, transformer_options): + if h.shape[2] != hsp.shape[2]: + h = torch.nn.functional.interpolate(h, size=(hsp.shape[2], hsp.shape[3]), mode="bicubic", align_corners=False) + return h, hsp + + m = model.clone() + m.set_model_input_block_patch(input_block_patch) + m.set_model_output_block_patch(output_block_patch) + return (m, ) + +NODE_CLASS_MAPPINGS = { + "PatchModelAddDownscale": PatchModelAddDownscale, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + # Sampling + "PatchModelAddDownscale": "PatchModelAddDownscale (Kohya Deep Shrink)", +} diff --git a/nodes.py b/nodes.py index e8cfb5e6ac2..f9d2d7f6c8b 100644 --- a/nodes.py +++ b/nodes.py @@ -1799,6 +1799,7 @@ def init_custom_nodes(): "nodes_custom_sampler.py", "nodes_hypertile.py", "nodes_model_advanced.py", + "nodes_model_downscale.py", ] for node_file in extras_files: From 9f00a18095e5f8ef114525bc19db035756501959 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 16 Nov 2023 14:59:54 -0500 Subject: [PATCH 22/84] Fix potential issues. --- comfy/model_patcher.py | 2 +- comfy/utils.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 02368433126..7f5ed45fee2 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -37,7 +37,7 @@ def model_size(self): return size def clone(self): - n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device) + n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device, weight_inplace_update=self.weight_inplace_update) n.patches = {} for k in self.patches: n.patches[k] = self.patches[k][:] diff --git a/comfy/utils.py b/comfy/utils.py index 1985012e0f1..f4c0ab41928 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -258,7 +258,7 @@ def set_attr(obj, attr, value): for name in attrs[:-1]: obj = getattr(obj, name) prev = getattr(obj, attrs[-1]) - setattr(obj, attrs[-1], torch.nn.Parameter(value)) + setattr(obj, attrs[-1], torch.nn.Parameter(value, requires_grad=False)) del prev def copy_to_param(obj, attr, value): From 7e3fe3ad28fad4dede2893d77093a086344b81b6 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 16 Nov 2023 15:26:28 -0500 Subject: [PATCH 23/84] Make deep shrink behave like it should. --- comfy/ldm/modules/diffusionmodules/openaimodel.py | 4 ++++ comfy/model_patcher.py | 3 +++ comfy_extras/nodes_model_downscale.py | 8 ++++++-- 3 files changed, 13 insertions(+), 2 deletions(-) diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index 504b79ede66..10eb68d73b5 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -633,6 +633,10 @@ def forward(self, x, timesteps=None, context=None, y=None, control=None, transfo h = p(h, transformer_options) hs.append(h) + if "input_block_patch_after_skip" in transformer_patches: + patch = transformer_patches["input_block_patch_after_skip"] + for p in patch: + h = p(h, transformer_options) transformer_options["block"] = ("middle", 0) h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 7f5ed45fee2..a3cffc3be9d 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -99,6 +99,9 @@ def set_model_attn2_output_patch(self, patch): def set_model_input_block_patch(self, patch): self.set_model_patch(patch, "input_block_patch") + def set_model_input_block_patch_after_skip(self, patch): + self.set_model_patch(patch, "input_block_patch_after_skip") + def set_model_output_block_patch(self, patch): self.set_model_patch(patch, "output_block_patch") diff --git a/comfy_extras/nodes_model_downscale.py b/comfy_extras/nodes_model_downscale.py index f1b2d3ff2c5..8850d094891 100644 --- a/comfy_extras/nodes_model_downscale.py +++ b/comfy_extras/nodes_model_downscale.py @@ -8,13 +8,14 @@ def INPUT_TYPES(s): "downscale_factor": ("FLOAT", {"default": 2.0, "min": 0.1, "max": 9.0, "step": 0.001}), "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), "end_percent": ("FLOAT", {"default": 0.35, "min": 0.0, "max": 1.0, "step": 0.001}), + "downscale_after_skip": ("BOOLEAN", {"default": True}), }} RETURN_TYPES = ("MODEL",) FUNCTION = "patch" CATEGORY = "_for_testing" - def patch(self, model, block_number, downscale_factor, start_percent, end_percent): + def patch(self, model, block_number, downscale_factor, start_percent, end_percent, downscale_after_skip): sigma_start = model.model.model_sampling.percent_to_sigma(start_percent).item() sigma_end = model.model.model_sampling.percent_to_sigma(end_percent).item() @@ -31,7 +32,10 @@ def output_block_patch(h, hsp, transformer_options): return h, hsp m = model.clone() - m.set_model_input_block_patch(input_block_patch) + if downscale_after_skip: + m.set_model_input_block_patch_after_skip(input_block_patch) + else: + m.set_model_input_block_patch(input_block_patch) m.set_model_output_block_patch(output_block_patch) return (m, ) From 107e78b1cb079f652408bece8b0045927dc9f1fd Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 16 Nov 2023 23:12:55 -0500 Subject: [PATCH 24/84] Add support for loading SSD1B diffusers unet version. Improve diffusers model detection. --- comfy/model_detection.py | 76 +++++++++++++++++++++++----------------- 1 file changed, 44 insertions(+), 32 deletions(-) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 4f4e0b3b7f0..d65d91e7cb5 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -186,17 +186,24 @@ def convert_config(unet_config): def unet_config_from_diffusers_unet(state_dict, dtype): match = {} - attention_resolutions = [] + transformer_depth = [] attn_res = 1 - for i in range(5): - k = "down_blocks.{}.attentions.1.transformer_blocks.0.attn2.to_k.weight".format(i) - if k in state_dict: - match["context_dim"] = state_dict[k].shape[1] - attention_resolutions.append(attn_res) + down_blocks = count_blocks(state_dict, "down_blocks.{}") + for i in range(down_blocks): + attn_blocks = count_blocks(state_dict, "down_blocks.{}.attentions.".format(i) + '{}') + for ab in range(attn_blocks): + transformer_count = count_blocks(state_dict, "down_blocks.{}.attentions.{}.transformer_blocks.".format(i, ab) + '{}') + transformer_depth.append(transformer_count) + if transformer_count > 0: + match["context_dim"] = state_dict["down_blocks.{}.attentions.{}.transformer_blocks.0.attn2.to_k.weight".format(i, ab)].shape[1] + attn_res *= 2 + if attn_blocks == 0: + transformer_depth.append(0) + transformer_depth.append(0) - match["attention_resolutions"] = attention_resolutions + match["transformer_depth"] = transformer_depth match["model_channels"] = state_dict["conv_in.weight"].shape[0] match["in_channels"] = state_dict["conv_in.weight"].shape[1] @@ -208,50 +215,55 @@ def unet_config_from_diffusers_unet(state_dict, dtype): SDXL = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, 'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, - 'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 2, 10], 'channel_mult': [1, 2, 4], - 'transformer_depth_middle': 10, 'use_linear_in_transformer': True, 'context_dim': 2048, "num_head_channels": 64} + 'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 2, 2, 10, 10], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 10, + 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 2, 2, 2, 10, 10, 10]} SDXL_refiner = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, 'num_classes': 'sequential', 'adm_in_channels': 2560, 'dtype': dtype, 'in_channels': 4, 'model_channels': 384, - 'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 4, 4, 0], 'channel_mult': [1, 2, 4, 4], - 'transformer_depth_middle': 4, 'use_linear_in_transformer': True, 'context_dim': 1280, "num_head_channels": 64} + 'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [0, 0, 4, 4, 4, 4, 0, 0], 'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 4, + 'use_linear_in_transformer': True, 'context_dim': 1280, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 4, 4, 4, 4, 4, 4, 0, 0, 0]} SD21 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, - 'adm_in_channels': None, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2, - 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4], - 'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024, "num_head_channels": 64} + 'adm_in_channels': None, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': [2, 2, 2, 2], + 'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0], 'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1, 'use_linear_in_transformer': True, + 'context_dim': 1024, 'num_head_channels': 64, 'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0]} SD21_uncliph = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, 'num_classes': 'sequential', 'adm_in_channels': 2048, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, - 'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4], - 'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024, "num_head_channels": 64} + 'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0], 'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1, + 'use_linear_in_transformer': True, 'context_dim': 1024, 'num_head_channels': 64, 'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0]} SD21_unclipl = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, 'num_classes': 'sequential', 'adm_in_channels': 1536, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, - 'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4], - 'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024} + 'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0], 'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1, + 'use_linear_in_transformer': True, 'context_dim': 1024, 'num_head_channels': 64, 'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0]} - SD15 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, - 'adm_in_channels': None, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2, - 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4], - 'transformer_depth_middle': 1, 'use_linear_in_transformer': False, 'context_dim': 768, "num_heads": 8} + SD15 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, 'adm_in_channels': None, + 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0], + 'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1, 'use_linear_in_transformer': False, 'context_dim': 768, 'num_heads': 8, + 'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0]} SDXL_mid_cnet = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, - 'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, - 'num_res_blocks': 2, 'attention_resolutions': [4], 'transformer_depth': [0, 0, 1], 'channel_mult': [1, 2, 4], - 'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 2048, "num_head_channels": 64} + 'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, + 'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 0, 0, 1, 1], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 1, + 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 0, 0, 0, 1, 1, 1]} SDXL_small_cnet = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, - 'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, - 'num_res_blocks': 2, 'attention_resolutions': [], 'transformer_depth': [0, 0, 0], 'channel_mult': [1, 2, 4], - 'transformer_depth_middle': 0, 'use_linear_in_transformer': True, "num_head_channels": 64, 'context_dim': 1} + 'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, + 'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 0, 0, 0, 0], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 0, + 'use_linear_in_transformer': True, 'num_head_channels': 64, 'context_dim': 1, 'transformer_depth_output': [0, 0, 0, 0, 0, 0, 0, 0, 0]} SDXL_diffusers_inpaint = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, - 'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 9, 'model_channels': 320, - 'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 2, 10], 'channel_mult': [1, 2, 4], - 'transformer_depth_middle': 10, 'use_linear_in_transformer': True, 'context_dim': 2048, "num_head_channels": 64} + 'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 9, 'model_channels': 320, + 'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 2, 2, 10, 10], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 10, + 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 2, 2, 2, 10, 10, 10]} + + SSD_1B = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, + 'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, + 'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 2, 2, 4, 4], 'transformer_depth_output': [0, 0, 0, 1, 1, 2, 10, 4, 4], + 'channel_mult': [1, 2, 4], 'transformer_depth_middle': -1, 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64} - supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint] + supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B] for unet_config in supported_models: matches = True From 0cf4e8693945d68000e37fe291f877eff9ef0aaa Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 17 Nov 2023 02:56:59 -0500 Subject: [PATCH 25/84] Add some command line arguments to store text encoder weights in fp8. Pytorch supports two variants of fp8: --fp8_e4m3fn-text-enc (the one that seems to give better results) --fp8_e5m2-text-enc --- comfy/cli_args.py | 7 +++++++ comfy/model_management.py | 15 +++++++++++++++ comfy/sd.py | 5 +---- 3 files changed, 23 insertions(+), 4 deletions(-) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index e79b89c0f0d..72fce10872f 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -62,6 +62,13 @@ def __call__(self, parser, namespace, values, option_string=None): fpvae_group.add_argument("--fp32-vae", action="store_true", help="Run the VAE in full precision fp32.") fpvae_group.add_argument("--bf16-vae", action="store_true", help="Run the VAE in bf16.") +fpte_group = parser.add_mutually_exclusive_group() +fpte_group.add_argument("--fp8_e4m3fn-text-enc", action="store_true", help="Store text encoder weights in fp8 (e4m3fn variant).") +fpte_group.add_argument("--fp8_e5m2-text-enc", action="store_true", help="Store text encoder weights in fp8 (e5m2 variant).") +fpte_group.add_argument("--fp16-text-enc", action="store_true", help="Store text encoder weights in fp16.") +fpte_group.add_argument("--fp32-text-enc", action="store_true", help="Store text encoder weights in fp32.") + + parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.") parser.add_argument("--disable-ipex-optimize", action="store_true", help="Disables ipex.optimize when loading models with Intel GPUs.") diff --git a/comfy/model_management.py b/comfy/model_management.py index be4301aa4e3..d4acd8950ca 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -482,6 +482,21 @@ def text_encoder_device(): else: return torch.device("cpu") +def text_encoder_dtype(device=None): + if args.fp8_e4m3fn_text_enc: + return torch.float8_e4m3fn + elif args.fp8_e5m2_text_enc: + return torch.float8_e5m2 + elif args.fp16_text_enc: + return torch.float16 + elif args.fp32_text_enc: + return torch.float32 + + if should_use_fp16(device, prioritize_performance=False): + return torch.float16 + else: + return torch.float32 + def vae_device(): return get_torch_device() diff --git a/comfy/sd.py b/comfy/sd.py index 65d94f46ecc..c3cc8e72080 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -95,10 +95,7 @@ def __init__(self, target=None, embedding_directory=None, no_init=False): load_device = model_management.text_encoder_device() offload_device = model_management.text_encoder_offload_device() params['device'] = offload_device - if model_management.should_use_fp16(load_device, prioritize_performance=False): - params['dtype'] = torch.float16 - else: - params['dtype'] = torch.float32 + params['dtype'] = model_management.text_encoder_dtype(load_device) self.cond_stage_model = clip(**(params)) From 8a451234b3090db488fbee9740a5f6be2f989253 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 18 Nov 2023 04:44:17 -0500 Subject: [PATCH 26/84] Add ImageCrop node. --- comfy_extras/nodes_images.py | 29 +++++++++++++++++++++++++++++ nodes.py | 1 + 2 files changed, 30 insertions(+) create mode 100644 comfy_extras/nodes_images.py diff --git a/comfy_extras/nodes_images.py b/comfy_extras/nodes_images.py new file mode 100644 index 00000000000..2b8e93001af --- /dev/null +++ b/comfy_extras/nodes_images.py @@ -0,0 +1,29 @@ +import nodes +MAX_RESOLUTION = nodes.MAX_RESOLUTION + +class ImageCrop: + @classmethod + def INPUT_TYPES(s): + return {"required": { "image": ("IMAGE",), + "width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}), + "height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}), + "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), + "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), + }} + RETURN_TYPES = ("IMAGE",) + FUNCTION = "crop" + + CATEGORY = "image/transform" + + def crop(self, image, width, height, x, y): + x = min(x, image.shape[2] - 1) + y = min(y, image.shape[1] - 1) + to_x = width + x + to_y = height + y + img = image[:,y:to_y, x:to_x, :] + return (img,) + + +NODE_CLASS_MAPPINGS = { + "ImageCrop": ImageCrop, +} diff --git a/nodes.py b/nodes.py index f9d2d7f6c8b..2adc5e07371 100644 --- a/nodes.py +++ b/nodes.py @@ -1800,6 +1800,7 @@ def init_custom_nodes(): "nodes_hypertile.py", "nodes_model_advanced.py", "nodes_model_downscale.py", + "nodes_images.py", ] for node_file in extras_files: From d9d8702d8dd2337c64610633f5df2dcd402379a8 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 18 Nov 2023 23:20:29 -0500 Subject: [PATCH 27/84] percent_to_sigma now returns a float instead of a tensor. --- comfy/model_sampling.py | 6 +++--- comfy_extras/nodes_model_advanced.py | 6 +++--- comfy_extras/nodes_model_downscale.py | 4 ++-- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/comfy/model_sampling.py b/comfy/model_sampling.py index d5b1642ef3a..37a3ac725c6 100644 --- a/comfy/model_sampling.py +++ b/comfy/model_sampling.py @@ -77,9 +77,9 @@ def sigma(self, timestep): def percent_to_sigma(self, percent): if percent <= 0.0: - return torch.tensor(999999999.9) + return 999999999.9 if percent >= 1.0: - return torch.tensor(0.0) + return 0.0 percent = 1.0 - percent - return self.sigma(torch.tensor(percent * 999.0)) + return self.sigma(torch.tensor(percent * 999.0)).item() diff --git a/comfy_extras/nodes_model_advanced.py b/comfy_extras/nodes_model_advanced.py index c8c4b4a1e70..0f4ddd9c340 100644 --- a/comfy_extras/nodes_model_advanced.py +++ b/comfy_extras/nodes_model_advanced.py @@ -67,11 +67,11 @@ def sigma(self, timestep): def percent_to_sigma(self, percent): if percent <= 0.0: - return torch.tensor(999999999.9) + return 999999999.9 if percent >= 1.0: - return torch.tensor(0.0) + return 0.0 percent = 1.0 - percent - return self.sigma(torch.tensor(percent * 999.0)) + return self.sigma(torch.tensor(percent * 999.0)).item() def rescale_zero_terminal_snr_sigmas(sigmas): diff --git a/comfy_extras/nodes_model_downscale.py b/comfy_extras/nodes_model_downscale.py index 8850d094891..f65ef05e18b 100644 --- a/comfy_extras/nodes_model_downscale.py +++ b/comfy_extras/nodes_model_downscale.py @@ -16,8 +16,8 @@ def INPUT_TYPES(s): CATEGORY = "_for_testing" def patch(self, model, block_number, downscale_factor, start_percent, end_percent, downscale_after_skip): - sigma_start = model.model.model_sampling.percent_to_sigma(start_percent).item() - sigma_end = model.model.model_sampling.percent_to_sigma(end_percent).item() + sigma_start = model.model.model_sampling.percent_to_sigma(start_percent) + sigma_end = model.model.model_sampling.percent_to_sigma(end_percent) def input_block_patch(h, transformer_options): if transformer_options["block"][1] == block_number: From dba4f3b4fce575994ed718ac31888620e8d6e733 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 19 Nov 2023 06:09:01 -0500 Subject: [PATCH 28/84] Add a RepeatImageBatch node. --- comfy_extras/nodes_images.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/comfy_extras/nodes_images.py b/comfy_extras/nodes_images.py index 2b8e93001af..8cb322327b0 100644 --- a/comfy_extras/nodes_images.py +++ b/comfy_extras/nodes_images.py @@ -23,7 +23,22 @@ def crop(self, image, width, height, x, y): img = image[:,y:to_y, x:to_x, :] return (img,) +class RepeatImageBatch: + @classmethod + def INPUT_TYPES(s): + return {"required": { "image": ("IMAGE",), + "amount": ("INT", {"default": 1, "min": 1, "max": 64}), + }} + RETURN_TYPES = ("IMAGE",) + FUNCTION = "repeat" + + CATEGORY = "image/batch" + + def repeat(self, image, amount): + s = image.repeat((amount, 1,1,1)) + return (s,) NODE_CLASS_MAPPINGS = { "ImageCrop": ImageCrop, + "RepeatImageBatch": RepeatImageBatch, } From 31c5ea7b2c79f36d3ebc729acf946ba47b4e5785 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 20 Nov 2023 03:55:51 -0500 Subject: [PATCH 29/84] Add LatentInterpolate to interpolate between latents. --- comfy_extras/nodes_latent.py | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/comfy_extras/nodes_latent.py b/comfy_extras/nodes_latent.py index 001de39fceb..cedf39d6346 100644 --- a/comfy_extras/nodes_latent.py +++ b/comfy_extras/nodes_latent.py @@ -1,4 +1,5 @@ import comfy.utils +import torch def reshape_latent_to(target_shape, latent): if latent.shape[1:] != target_shape[1:]: @@ -67,8 +68,43 @@ def op(self, samples, multiplier): samples_out["samples"] = s1 * multiplier return (samples_out,) +class LatentInterpolate: + @classmethod + def INPUT_TYPES(s): + return {"required": { "samples1": ("LATENT",), + "samples2": ("LATENT",), + "ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), + }} + + RETURN_TYPES = ("LATENT",) + FUNCTION = "op" + + CATEGORY = "latent/advanced" + + def op(self, samples1, samples2, ratio): + samples_out = samples1.copy() + + s1 = samples1["samples"] + s2 = samples2["samples"] + + s2 = reshape_latent_to(s1.shape, s2) + + m1 = torch.linalg.vector_norm(s1, dim=(1)) + m2 = torch.linalg.vector_norm(s2, dim=(1)) + + s1 = torch.nan_to_num(s1 / m1) + s2 = torch.nan_to_num(s2 / m2) + + t = (s1 * ratio + s2 * (1.0 - ratio)) + mt = torch.linalg.vector_norm(t, dim=(1)) + st = torch.nan_to_num(t / mt) + + samples_out["samples"] = st * (m1 * ratio + m2 * (1.0 - ratio)) + return (samples_out,) + NODE_CLASS_MAPPINGS = { "LatentAdd": LatentAdd, "LatentSubtract": LatentSubtract, "LatentMultiply": LatentMultiply, + "LatentInterpolate": LatentInterpolate, } From a03dde190ede39675736e746c3045ecfc4baa79b Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 20 Nov 2023 16:38:39 -0500 Subject: [PATCH 30/84] Cap maximum history size at 10000. Delete oldest entry when reached. --- execution.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/execution.py b/execution.py index 918c2bc5cc3..9a2ca5b9d04 100644 --- a/execution.py +++ b/execution.py @@ -681,6 +681,7 @@ def validate_prompt(prompt): return (True, None, list(good_outputs), node_errors) +MAXIMUM_HISTORY_SIZE = 10000 class PromptQueue: def __init__(self, server): @@ -713,6 +714,8 @@ def get(self): def task_done(self, item_id, outputs): with self.mutex: prompt = self.currently_running.pop(item_id) + if len(self.history) > MAXIMUM_HISTORY_SIZE: + self.history.pop(next(iter(self.history))) self.history[prompt[1]] = { "prompt": prompt, "outputs": {} } for o in outputs: self.history[prompt[1]]["outputs"][o] = outputs[o] From 2dd5b4dd78fc0a30f3d5baa0b99a6b10f002d917 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 20 Nov 2023 16:51:41 -0500 Subject: [PATCH 31/84] Only show last 200 elements in the UI history tab. --- execution.py | 14 ++++++++++++-- server.py | 5 ++++- web/scripts/api.js | 2 +- 3 files changed, 17 insertions(+), 4 deletions(-) diff --git a/execution.py b/execution.py index 9a2ca5b9d04..bca48a785c2 100644 --- a/execution.py +++ b/execution.py @@ -750,10 +750,20 @@ def delete_queue_item(self, function): return True return False - def get_history(self, prompt_id=None): + def get_history(self, prompt_id=None, max_items=None, offset=-1): with self.mutex: if prompt_id is None: - return copy.deepcopy(self.history) + out = {} + i = 0 + if offset < 0 and max_items is not None: + offset = len(self.history) - max_items + for k in self.history: + if i >= offset: + out[k] = self.history[k] + if max_items is not None and len(out) >= max_items: + break + i += 1 + return out elif prompt_id in self.history: return {prompt_id: copy.deepcopy(self.history[prompt_id])} else: diff --git a/server.py b/server.py index 11bd2a0fb44..1a8e92b8f96 100644 --- a/server.py +++ b/server.py @@ -431,7 +431,10 @@ async def get_object_info_node(request): @routes.get("/history") async def get_history(request): - return web.json_response(self.prompt_queue.get_history()) + max_items = request.rel_url.query.get("max_items", None) + if max_items is not None: + max_items = int(max_items) + return web.json_response(self.prompt_queue.get_history(max_items=max_items)) @routes.get("/history/{prompt_id}") async def get_history(request): diff --git a/web/scripts/api.js b/web/scripts/api.js index b1d245d73ff..de56b23108b 100644 --- a/web/scripts/api.js +++ b/web/scripts/api.js @@ -256,7 +256,7 @@ class ComfyApi extends EventTarget { */ async getHistory() { try { - const res = await this.fetchApi("/history"); + const res = await this.fetchApi("/history?max_items=200"); return { History: Object.values(await res.json()) }; } catch (error) { console.error(error); From ce67dcbcdabe2edf1497e37ecf1b6f976a3ecdf6 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 20 Nov 2023 22:27:36 -0500 Subject: [PATCH 32/84] Make it easy for models to process the unet state dict on load. --- comfy/model_base.py | 1 + comfy/supported_models_base.py | 3 +++ 2 files changed, 4 insertions(+) diff --git a/comfy/model_base.py b/comfy/model_base.py index 37bf24bb8c6..772e2693493 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -121,6 +121,7 @@ def load_model_weights(self, sd, unet_prefix=""): if k.startswith(unet_prefix): to_load[k[len(unet_prefix):]] = sd.pop(k) + to_load = self.model_config.process_unet_state_dict(to_load) m, u = self.diffusion_model.load_state_dict(to_load, strict=False) if len(m) > 0: print("unet missing:", m) diff --git a/comfy/supported_models_base.py b/comfy/supported_models_base.py index 88a1d7fde49..6dfae034303 100644 --- a/comfy/supported_models_base.py +++ b/comfy/supported_models_base.py @@ -53,6 +53,9 @@ def get_model(self, state_dict, prefix="", device=None): def process_clip_state_dict(self, state_dict): return state_dict + def process_unet_state_dict(self, state_dict): + return state_dict + def process_clip_state_dict_for_saving(self, state_dict): replace_prefix = {"": "cond_stage_model."} return utils.state_dict_prefix_replace(state_dict, replace_prefix) From 6ff06fa7960524749d8e584100a0e50594485f29 Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Tue, 21 Nov 2023 06:33:58 +0000 Subject: [PATCH 33/84] Animated image output support (#2008) * Refactor multiline widget into generic DOM widget * wip webp preview * webp support * fix check * fix sizing * show image when zoomed out * Swap webp checkto generic animated image flag * remove duplicate * Fix falsy check --- web/scripts/app.js | 78 +++++---- web/scripts/domWidget.js | 312 +++++++++++++++++++++++++++++++++ web/scripts/ui/imagePreview.js | 97 ++++++++++ web/scripts/widgets.js | 168 ++---------------- web/style.css | 15 ++ 5 files changed, 483 insertions(+), 187 deletions(-) create mode 100644 web/scripts/domWidget.js create mode 100644 web/scripts/ui/imagePreview.js diff --git a/web/scripts/app.js b/web/scripts/app.js index 4507527f686..601e486e6e4 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -4,7 +4,10 @@ import { ComfyUI, $el } from "./ui.js"; import { api } from "./api.js"; import { defaultGraph } from "./defaultGraph.js"; import { getPngMetadata, getWebpMetadata, importA1111, getLatentMetadata } from "./pnginfo.js"; +import { addDomClippingSetting } from "./domWidget.js"; +import { createImageHost, calculateImageGrid } from "./ui/imagePreview.js" +export const ANIM_PREVIEW_WIDGET = "$$comfy_animation_preview" function sanitizeNodeName(string) { let entityMap = { @@ -405,7 +408,9 @@ export class ComfyApp { return shiftY; } - node.prototype.setSizeForImage = function () { + node.prototype.setSizeForImage = function (force) { + if(!force && this.animatedImages) return; + if (this.inputHeight) { this.setSize(this.size); return; @@ -422,13 +427,20 @@ export class ComfyApp { let imagesChanged = false const output = app.nodeOutputs[this.id + ""]; - if (output && output.images) { + if (output?.images) { + this.animatedImages = output?.animated?.find(Boolean); if (this.images !== output.images) { this.images = output.images; imagesChanged = true; - imgURLs = imgURLs.concat(output.images.map(params => { - return api.apiURL("/view?" + new URLSearchParams(params).toString() + app.getPreviewFormatParam()); - })) + imgURLs = imgURLs.concat( + output.images.map((params) => { + return api.apiURL( + "/view?" + + new URLSearchParams(params).toString() + + (this.animatedImages ? "" : app.getPreviewFormatParam()) + ); + }) + ); } } @@ -507,7 +519,34 @@ export class ComfyApp { return true; } - if (this.imgs && this.imgs.length) { + if (this.imgs?.length) { + const widgetIdx = this.widgets?.findIndex((w) => w.name === ANIM_PREVIEW_WIDGET); + + if(this.animatedImages) { + // Instead of using the canvas we'll use a IMG + if(widgetIdx > -1) { + // Replace content + const widget = this.widgets[widgetIdx]; + widget.options.host.updateImages(this.imgs); + } else { + const host = createImageHost(this); + this.setSizeForImage(true); + const widget = this.addDOMWidget(ANIM_PREVIEW_WIDGET, "img", host.el, { + host, + getHeight: host.getHeight, + onDraw: host.onDraw, + hideOnZoom: false + }); + widget.serializeValue = () => undefined; + widget.options.host.updateImages(this.imgs); + } + return; + } + + if (widgetIdx > -1) { + this.widgets.splice(widgetIdx, 1); + } + const canvas = app.graph.list_of_graphcanvas[0]; const mouse = canvas.graph_mouse; if (!canvas.pointer_is_down && this.pointerDown) { @@ -547,31 +586,7 @@ export class ComfyApp { } else { cell_padding = 0; - let best = 0; - let w = this.imgs[0].naturalWidth; - let h = this.imgs[0].naturalHeight; - - // compact style - for (let c = 1; c <= numImages; c++) { - const rows = Math.ceil(numImages / c); - const cW = dw / c; - const cH = dh / rows; - const scaleX = cW / w; - const scaleY = cH / h; - - const scale = Math.min(scaleX, scaleY, 1); - const imageW = w * scale; - const imageH = h * scale; - const area = imageW * imageH * numImages; - - if (area > best) { - best = area; - cellWidth = imageW; - cellHeight = imageH; - cols = c; - shiftX = c * ((cW - imageW) / 2); - } - } + ({ cellWidth, cellHeight, cols, shiftX } = calculateImageGrid(this.imgs, dw, dh)); } let anyHovered = false; @@ -1272,6 +1287,7 @@ export class ComfyApp { canvasEl.tabIndex = "1"; document.body.prepend(canvasEl); + addDomClippingSetting(); this.#addProcessMouseHandler(); this.#addProcessKeyHandler(); this.#addConfigureHandler(); diff --git a/web/scripts/domWidget.js b/web/scripts/domWidget.js new file mode 100644 index 00000000000..16f4e192eea --- /dev/null +++ b/web/scripts/domWidget.js @@ -0,0 +1,312 @@ +import { app, ANIM_PREVIEW_WIDGET } from "./app.js"; + +const SIZE = Symbol(); + +function intersect(a, b) { + const x = Math.max(a.x, b.x); + const num1 = Math.min(a.x + a.width, b.x + b.width); + const y = Math.max(a.y, b.y); + const num2 = Math.min(a.y + a.height, b.y + b.height); + if (num1 >= x && num2 >= y) return [x, y, num1 - x, num2 - y]; + else return null; +} + +function getClipPath(node, element, elRect) { + const selectedNode = Object.values(app.canvas.selected_nodes)[0]; + if (selectedNode && selectedNode !== node) { + const MARGIN = 7; + const scale = app.canvas.ds.scale; + + const intersection = intersect( + { x: elRect.x / scale, y: elRect.y / scale, width: elRect.width / scale, height: elRect.height / scale }, + { + x: selectedNode.pos[0] + app.canvas.ds.offset[0] - MARGIN, + y: selectedNode.pos[1] + app.canvas.ds.offset[1] - LiteGraph.NODE_TITLE_HEIGHT - MARGIN, + width: selectedNode.size[0] + MARGIN + MARGIN, + height: selectedNode.size[1] + LiteGraph.NODE_TITLE_HEIGHT + MARGIN + MARGIN, + } + ); + + if (!intersection) { + return ""; + } + + const widgetRect = element.getBoundingClientRect(); + const clipX = intersection[0] - widgetRect.x / scale + "px"; + const clipY = intersection[1] - widgetRect.y / scale + "px"; + const clipWidth = intersection[2] + "px"; + const clipHeight = intersection[3] + "px"; + const path = `polygon(0% 0%, 0% 100%, ${clipX} 100%, ${clipX} ${clipY}, calc(${clipX} + ${clipWidth}) ${clipY}, calc(${clipX} + ${clipWidth}) calc(${clipY} + ${clipHeight}), ${clipX} calc(${clipY} + ${clipHeight}), ${clipX} 100%, 100% 100%, 100% 0%)`; + return path; + } + return ""; +} + +function computeSize(size) { + if (this.widgets?.[0].last_y == null) return; + + let y = this.widgets[0].last_y; + let freeSpace = size[1] - y; + + let widgetHeight = 0; + let dom = []; + for (const w of this.widgets) { + if (w.type === "converted-widget") { + // Ignore + delete w.computedHeight; + } else if (w.computeSize) { + widgetHeight += w.computeSize()[1] + 4; + } else if (w.element) { + // Extract DOM widget size info + const styles = getComputedStyle(w.element); + let minHeight = w.options.getMinHeight?.() ?? parseInt(styles.getPropertyValue("--comfy-widget-min-height")); + let maxHeight = w.options.getMaxHeight?.() ?? parseInt(styles.getPropertyValue("--comfy-widget-max-height")); + + let prefHeight = w.options.getHeight?.() ?? styles.getPropertyValue("--comfy-widget-height"); + if (prefHeight.endsWith?.("%")) { + prefHeight = size[1] * (parseFloat(prefHeight.substring(0, prefHeight.length - 1)) / 100); + } else { + prefHeight = parseInt(prefHeight); + if (isNaN(minHeight)) { + minHeight = prefHeight; + } + } + if (isNaN(minHeight)) { + minHeight = 50; + } + if (!isNaN(maxHeight)) { + if (!isNaN(prefHeight)) { + prefHeight = Math.min(prefHeight, maxHeight); + } else { + prefHeight = maxHeight; + } + } + dom.push({ + minHeight, + prefHeight, + w, + }); + } else { + widgetHeight += LiteGraph.NODE_WIDGET_HEIGHT + 4; + } + } + + freeSpace -= widgetHeight; + + // Calculate sizes with all widgets at their min height + const prefGrow = []; // Nodes that want to grow to their prefd size + const canGrow = []; // Nodes that can grow to auto size + let growBy = 0; + for (const d of dom) { + freeSpace -= d.minHeight; + if (isNaN(d.prefHeight)) { + canGrow.push(d); + d.w.computedHeight = d.minHeight; + } else { + const diff = d.prefHeight - d.minHeight; + if (diff > 0) { + prefGrow.push(d); + growBy += diff; + d.diff = diff; + } else { + d.w.computedHeight = d.minHeight; + } + } + } + + if (this.imgs && !this.widgets.find((w) => w.name === ANIM_PREVIEW_WIDGET)) { + // Allocate space for image + freeSpace -= 220; + } + + if (freeSpace < 0) { + // Not enough space for all widgets so we need to grow + size[1] -= freeSpace; + this.graph.setDirtyCanvas(true); + } else { + // Share the space between each + const growDiff = freeSpace - growBy; + if (growDiff > 0) { + // All pref sizes can be fulfilled + freeSpace = growDiff; + for (const d of prefGrow) { + d.w.computedHeight = d.prefHeight; + } + } else { + // We need to grow evenly + const shared = -growDiff / prefGrow.length; + for (const d of prefGrow) { + d.w.computedHeight = d.prefHeight - shared; + } + freeSpace = 0; + } + + if (freeSpace > 0 && canGrow.length) { + // Grow any that are auto height + const shared = freeSpace / canGrow.length; + for (const d of canGrow) { + d.w.computedHeight += shared; + } + } + } + + // Position each of the widgets + for (const w of this.widgets) { + w.y = y; + if (w.computedHeight) { + y += w.computedHeight; + } else if (w.computeSize) { + y += w.computeSize()[1] + 4; + } else { + y += LiteGraph.NODE_WIDGET_HEIGHT + 4; + } + } +} + +// Override the compute visible nodes function to allow us to hide/show DOM elements when the node goes offscreen +const elementWidgets = new Set(); +const computeVisibleNodes = LGraphCanvas.prototype.computeVisibleNodes; +LGraphCanvas.prototype.computeVisibleNodes = function () { + const visibleNodes = computeVisibleNodes.apply(this, arguments); + for (const node of app.graph._nodes) { + if (elementWidgets.has(node)) { + const hidden = visibleNodes.indexOf(node) === -1; + for (const w of node.widgets) { + if (w.element) { + w.element.hidden = hidden; + if (hidden) { + w.options.onHide?.(w); + } + } + } + } + } + + return visibleNodes; +}; + +let enableDomClipping = true; + +export function addDomClippingSetting() { + app.ui.settings.addSetting({ + id: "Comfy.DOMClippingEnabled", + name: "Enable DOM element clipping (enabling may reduce performance)", + type: "boolean", + defaultValue: enableDomClipping, + onChange(value) { + console.log("enableDomClipping", enableDomClipping); + enableDomClipping = !!value; + }, + }); +} + +LGraphNode.prototype.addDOMWidget = function (name, type, element, options) { + options = { hideOnZoom: true, selectOn: ["focus", "click"], ...options }; + + if (!element.parentElement) { + document.body.append(element); + } + + let mouseDownHandler; + if (element.blur) { + mouseDownHandler = (event) => { + if (!element.contains(event.target)) { + element.blur(); + } + }; + document.addEventListener("mousedown", mouseDownHandler); + } + + const widget = { + type, + name, + get value() { + return options.getValue?.() ?? undefined; + }, + set value(v) { + options.setValue?.(v); + widget.callback?.(widget.value); + }, + draw: function (ctx, node, widgetWidth, y, widgetHeight) { + if (widget.computedHeight == null) { + computeSize.call(node, node.size); + } + + const hidden = + (!!options.hideOnZoom && app.canvas.ds.scale < 0.5) || + widget.computedHeight <= 0 || + widget.type === "converted-widget"; + element.hidden = hidden; + element.style.display = hidden ? "none" : null; + if (hidden) { + widget.options.onHide?.(widget); + return; + } + + const margin = 10; + const elRect = ctx.canvas.getBoundingClientRect(); + const transform = new DOMMatrix() + .scaleSelf(elRect.width / ctx.canvas.width, elRect.height / ctx.canvas.height) + .multiplySelf(ctx.getTransform()) + .translateSelf(margin, margin + y); + + const scale = new DOMMatrix().scaleSelf(transform.a, transform.d); + + Object.assign(element.style, { + transformOrigin: "0 0", + transform: scale, + left: `${transform.a + transform.e}px`, + top: `${transform.d + transform.f}px`, + width: `${widgetWidth - margin * 2}px`, + height: `${(widget.computedHeight ?? 50) - margin * 2}px`, + position: "absolute", + zIndex: app.graph._nodes.indexOf(node), + }); + + if (enableDomClipping) { + element.style.clipPath = getClipPath(node, element, elRect); + element.style.willChange = "clip-path"; + } + + this.options.onDraw?.(widget); + }, + element, + options, + onRemove() { + if (mouseDownHandler) { + document.removeEventListener("mousedown", mouseDownHandler); + } + element.remove(); + }, + }; + + for (const evt of options.selectOn) { + element.addEventListener(evt, () => { + app.canvas.selectNode(this); + app.canvas.bringToFront(this); + }); + } + + this.addCustomWidget(widget); + elementWidgets.add(this); + + const onRemoved = this.onRemoved; + this.onRemoved = function () { + element.remove(); + elementWidgets.delete(this); + onRemoved?.apply(this, arguments); + }; + + if (!this[SIZE]) { + this[SIZE] = true; + const onResize = this.onResize; + this.onResize = function (size) { + options.beforeResize?.call(widget, this); + computeSize.call(this, size); + onResize?.apply(this, arguments); + options.afterResize?.call(widget, this); + }; + } + + return widget; +}; diff --git a/web/scripts/ui/imagePreview.js b/web/scripts/ui/imagePreview.js new file mode 100644 index 00000000000..2a7f66b8f3b --- /dev/null +++ b/web/scripts/ui/imagePreview.js @@ -0,0 +1,97 @@ +import { $el } from "../ui.js"; + +export function calculateImageGrid(imgs, dw, dh) { + let best = 0; + let w = imgs[0].naturalWidth; + let h = imgs[0].naturalHeight; + const numImages = imgs.length; + + let cellWidth, cellHeight, cols, rows, shiftX; + // compact style + for (let c = 1; c <= numImages; c++) { + const r = Math.ceil(numImages / c); + const cW = dw / c; + const cH = dh / r; + const scaleX = cW / w; + const scaleY = cH / h; + + const scale = Math.min(scaleX, scaleY, 1); + const imageW = w * scale; + const imageH = h * scale; + const area = imageW * imageH * numImages; + + if (area > best) { + best = area; + cellWidth = imageW; + cellHeight = imageH; + cols = c; + rows = r; + shiftX = c * ((cW - imageW) / 2); + } + } + + return { cellWidth, cellHeight, cols, rows, shiftX }; +} + +export function createImageHost(node) { + const el = $el("div.comfy-img-preview"); + let currentImgs; + let first = true; + + function updateSize() { + let w = null; + let h = null; + + if (currentImgs) { + let elH = el.clientHeight; + if (first) { + first = false; + // On first run, if we are small then grow a bit + if (elH < 190) { + elH = 190; + } + el.style.setProperty("--comfy-widget-min-height", elH); + } else { + el.style.setProperty("--comfy-widget-min-height", null); + } + + const nw = node.size[0]; + ({ cellWidth: w, cellHeight: h } = calculateImageGrid(currentImgs, nw - 20, elH)); + w += "px"; + h += "px"; + + el.style.setProperty("--comfy-img-preview-width", w); + el.style.setProperty("--comfy-img-preview-height", h); + } + } + return { + el, + updateImages(imgs) { + if (imgs !== currentImgs) { + if (currentImgs == null) { + requestAnimationFrame(() => { + updateSize(); + }); + } + el.replaceChildren(...imgs); + currentImgs = imgs; + node.onResize(node.size); + node.graph.setDirtyCanvas(true, true); + } + }, + getHeight() { + updateSize(); + }, + onDraw() { + // Element from point uses a hittest find elements so we need to toggle pointer events + el.style.pointerEvents = "all"; + const over = document.elementFromPoint(app.canvas.mouse[0], app.canvas.mouse[1]); + el.style.pointerEvents = "none"; + + if(!over) return; + // Set the overIndex so Open Image etc work + const idx = currentImgs.indexOf(over); + node.overIndex = idx; + }, + }; +} diff --git a/web/scripts/widgets.js b/web/scripts/widgets.js index 36bc7ff7fd7..ccddc0bc44b 100644 --- a/web/scripts/widgets.js +++ b/web/scripts/widgets.js @@ -1,4 +1,5 @@ import { api } from "./api.js" +import "./domWidget.js"; function getNumberDefaults(inputData, defaultStep, precision, enable_rounding) { let defaultVal = inputData[1]["default"]; @@ -97,166 +98,21 @@ function seedWidget(node, inputName, inputData, app) { seed.widget.linkedWidgets = [seedControl]; return seed; } - -const MultilineSymbol = Symbol(); -const MultilineResizeSymbol = Symbol(); - function addMultilineWidget(node, name, opts, app) { - const MIN_SIZE = 50; - - function computeSize(size) { - if (node.widgets[0].last_y == null) return; - - let y = node.widgets[0].last_y; - let freeSpace = size[1] - y; - - // Compute the height of all non customtext widgets - let widgetHeight = 0; - const multi = []; - for (let i = 0; i < node.widgets.length; i++) { - const w = node.widgets[i]; - if (w.type === "customtext") { - multi.push(w); - } else { - if (w.computeSize) { - widgetHeight += w.computeSize()[1] + 4; - } else { - widgetHeight += LiteGraph.NODE_WIDGET_HEIGHT + 4; - } - } - } - - // See how large each text input can be - freeSpace -= widgetHeight; - freeSpace /= multi.length + (!!node.imgs?.length); - - if (freeSpace < MIN_SIZE) { - // There isnt enough space for all the widgets, increase the size of the node - freeSpace = MIN_SIZE; - node.size[1] = y + widgetHeight + freeSpace * (multi.length + (!!node.imgs?.length)); - node.graph.setDirtyCanvas(true); - } - - // Position each of the widgets - for (const w of node.widgets) { - w.y = y; - if (w.type === "customtext") { - y += freeSpace; - w.computedHeight = freeSpace - multi.length*4; - } else if (w.computeSize) { - y += w.computeSize()[1] + 4; - } else { - y += LiteGraph.NODE_WIDGET_HEIGHT + 4; - } - } - - node.inputHeight = freeSpace; - } - - const widget = { - type: "customtext", - name, - get value() { - return this.inputEl.value; + const inputEl = document.createElement("textarea"); + inputEl.className = "comfy-multiline-input"; + inputEl.value = opts.defaultVal; + inputEl.placeholder = opts.placeholder || ""; + + const widget = node.addDOMWidget(name, "customtext", inputEl, { + getValue() { + return inputEl.value; }, - set value(x) { - this.inputEl.value = x; - }, - draw: function (ctx, _, widgetWidth, y, widgetHeight) { - if (!this.parent.inputHeight) { - // If we are initially offscreen when created we wont have received a resize event - // Calculate it here instead - computeSize(node.size); - } - const visible = app.canvas.ds.scale > 0.5 && this.type === "customtext"; - const margin = 10; - const elRect = ctx.canvas.getBoundingClientRect(); - const transform = new DOMMatrix() - .scaleSelf(elRect.width / ctx.canvas.width, elRect.height / ctx.canvas.height) - .multiplySelf(ctx.getTransform()) - .translateSelf(margin, margin + y); - - const scale = new DOMMatrix().scaleSelf(transform.a, transform.d) - Object.assign(this.inputEl.style, { - transformOrigin: "0 0", - transform: scale, - left: `${transform.a + transform.e}px`, - top: `${transform.d + transform.f}px`, - width: `${widgetWidth - (margin * 2)}px`, - height: `${this.parent.inputHeight - (margin * 2)}px`, - position: "absolute", - background: (!node.color)?'':node.color, - color: (!node.color)?'':'white', - zIndex: app.graph._nodes.indexOf(node), - }); - this.inputEl.hidden = !visible; + setValue(v) { + inputEl.value = v; }, - }; - widget.inputEl = document.createElement("textarea"); - widget.inputEl.className = "comfy-multiline-input"; - widget.inputEl.value = opts.defaultVal; - widget.inputEl.placeholder = opts.placeholder || ""; - document.addEventListener("mousedown", function (event) { - if (!widget.inputEl.contains(event.target)) { - widget.inputEl.blur(); - } }); - widget.parent = node; - document.body.appendChild(widget.inputEl); - - node.addCustomWidget(widget); - - app.canvas.onDrawBackground = function () { - // Draw node isnt fired once the node is off the screen - // if it goes off screen quickly, the input may not be removed - // this shifts it off screen so it can be moved back if the node is visible. - for (let n in app.graph._nodes) { - n = graph._nodes[n]; - for (let w in n.widgets) { - let wid = n.widgets[w]; - if (Object.hasOwn(wid, "inputEl")) { - wid.inputEl.style.left = -8000 + "px"; - wid.inputEl.style.position = "absolute"; - } - } - } - }; - - node.onRemoved = function () { - // When removing this node we need to remove the input from the DOM - for (let y in this.widgets) { - if (this.widgets[y].inputEl) { - this.widgets[y].inputEl.remove(); - } - } - }; - - widget.onRemove = () => { - widget.inputEl?.remove(); - - // Restore original size handler if we are the last - if (!--node[MultilineSymbol]) { - node.onResize = node[MultilineResizeSymbol]; - delete node[MultilineSymbol]; - delete node[MultilineResizeSymbol]; - } - }; - - if (node[MultilineSymbol]) { - node[MultilineSymbol]++; - } else { - node[MultilineSymbol] = 1; - const onResize = (node[MultilineResizeSymbol] = node.onResize); - - node.onResize = function (size) { - computeSize(size); - - // Call original resizer handler - if (onResize) { - onResize.apply(this, arguments); - } - }; - } + widget.inputEl = inputEl; return { minWidth: 400, minHeight: 200, widget }; } diff --git a/web/style.css b/web/style.css index 692fa31d672..378fe0a48b9 100644 --- a/web/style.css +++ b/web/style.css @@ -409,6 +409,21 @@ dialog::backdrop { width: calc(100% - 10px); } +.comfy-img-preview { + pointer-events: none; + overflow: hidden; + display: flex; + flex-wrap: wrap; + align-content: flex-start; + justify-content: center; +} + +.comfy-img-preview img { + object-fit: contain; + width: var(--comfy-img-preview-width); + height: var(--comfy-img-preview-height); +} + /* Search box */ .litegraph.litesearchbox { From 89e31abc46df00d10d48b8a4e36256fefd5973ed Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Tue, 21 Nov 2023 17:54:01 +0000 Subject: [PATCH 34/84] Fix clipping of collapsed nodes --- web/scripts/domWidget.js | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/web/scripts/domWidget.js b/web/scripts/domWidget.js index 16f4e192eea..2f73e573e13 100644 --- a/web/scripts/domWidget.js +++ b/web/scripts/domWidget.js @@ -17,13 +17,14 @@ function getClipPath(node, element, elRect) { const MARGIN = 7; const scale = app.canvas.ds.scale; + const bounding = selectedNode.getBounding(); const intersection = intersect( { x: elRect.x / scale, y: elRect.y / scale, width: elRect.width / scale, height: elRect.height / scale }, { x: selectedNode.pos[0] + app.canvas.ds.offset[0] - MARGIN, y: selectedNode.pos[1] + app.canvas.ds.offset[1] - LiteGraph.NODE_TITLE_HEIGHT - MARGIN, - width: selectedNode.size[0] + MARGIN + MARGIN, - height: selectedNode.size[1] + LiteGraph.NODE_TITLE_HEIGHT + MARGIN + MARGIN, + width: bounding[2] + MARGIN + MARGIN, + height: bounding[3] + MARGIN + MARGIN, } ); From cd4fc77d5f83867cdfb806f0c96c65ce8a84322c Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 21 Nov 2023 12:54:19 -0500 Subject: [PATCH 35/84] Add taesd and taesdxl to VAELoader node. They will show up if both the taesd_encoder and taesd_decoder or taesdxl model files are present in the models/vae_approx directory. --- comfy/sd.py | 17 ++++++++++---- comfy/taesd/taesd.py | 19 +++++++++++---- latent_preview.py | 5 +--- nodes.py | 55 +++++++++++++++++++++++++++++++++++++++++--- 4 files changed, 79 insertions(+), 17 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index c3cc8e72080..0f83cc5814d 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -23,6 +23,7 @@ import comfy.lora import comfy.t2i_adapter.adapter import comfy.supported_models_base +import comfy.taesd.taesd def load_model_weights(model, sd): m, u = model.load_state_dict(sd, strict=False) @@ -154,10 +155,16 @@ def __init__(self, sd=None, device=None, config=None): if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format sd = diffusers_convert.convert_vae_state_dict(sd) + self.memory_used_encode = lambda shape: (2078 * shape[2] * shape[3]) * 1.7 #These are for AutoencoderKL and need tweaking + self.memory_used_decode = lambda shape: (2562 * shape[2] * shape[3] * 64) * 1.7 + if config is None: - #default SD1.x/SD2.x VAE parameters - ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0} - self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=4) + if "taesd_decoder.1.weight" in sd: + self.first_stage_model = comfy.taesd.taesd.TAESD() + else: + #default SD1.x/SD2.x VAE parameters + ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0} + self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=4) else: self.first_stage_model = AutoencoderKL(**(config['params'])) self.first_stage_model = self.first_stage_model.eval() @@ -206,7 +213,7 @@ def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64): def decode(self, samples_in): self.first_stage_model = self.first_stage_model.to(self.device) try: - memory_used = (2562 * samples_in.shape[2] * samples_in.shape[3] * 64) * 1.7 + memory_used = self.memory_used_decode(samples_in.shape) model_management.free_memory(memory_used, self.device) free_memory = model_management.get_free_memory(self.device) batch_number = int(free_memory / memory_used) @@ -234,7 +241,7 @@ def encode(self, pixel_samples): self.first_stage_model = self.first_stage_model.to(self.device) pixel_samples = pixel_samples.movedim(-1,1) try: - memory_used = (2078 * pixel_samples.shape[2] * pixel_samples.shape[3]) * 1.7 #NOTE: this constant along with the one in the decode above are estimated from the mem usage for the VAE and could change. + memory_used = self.memory_used_encode(pixel_samples.shape) #NOTE: this constant along with the one in the decode above are estimated from the mem usage for the VAE and could change. model_management.free_memory(memory_used, self.device) free_memory = model_management.get_free_memory(self.device) batch_number = int(free_memory / memory_used) diff --git a/comfy/taesd/taesd.py b/comfy/taesd/taesd.py index 8df1f160915..46f3097a2a1 100644 --- a/comfy/taesd/taesd.py +++ b/comfy/taesd/taesd.py @@ -46,15 +46,16 @@ class TAESD(nn.Module): latent_magnitude = 3 latent_shift = 0.5 - def __init__(self, encoder_path="taesd_encoder.pth", decoder_path="taesd_decoder.pth"): + def __init__(self, encoder_path=None, decoder_path=None): """Initialize pretrained TAESD on the given device from the given checkpoints.""" super().__init__() - self.encoder = Encoder() - self.decoder = Decoder() + self.taesd_encoder = Encoder() + self.taesd_decoder = Decoder() + self.vae_scale = torch.nn.Parameter(torch.tensor(1.0)) if encoder_path is not None: - self.encoder.load_state_dict(comfy.utils.load_torch_file(encoder_path, safe_load=True)) + self.taesd_encoder.load_state_dict(comfy.utils.load_torch_file(encoder_path, safe_load=True)) if decoder_path is not None: - self.decoder.load_state_dict(comfy.utils.load_torch_file(decoder_path, safe_load=True)) + self.taesd_decoder.load_state_dict(comfy.utils.load_torch_file(decoder_path, safe_load=True)) @staticmethod def scale_latents(x): @@ -65,3 +66,11 @@ def scale_latents(x): def unscale_latents(x): """[0, 1] -> raw latents""" return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude) + + def decode(self, x): + x_sample = self.taesd_decoder(x * self.vae_scale) + x_sample = x_sample.sub(0.5).mul(2) + return x_sample + + def encode(self, x): + return self.taesd_encoder(x * 0.5 + 0.5) / self.vae_scale diff --git a/latent_preview.py b/latent_preview.py index 6e758a1a9d1..61754751efe 100644 --- a/latent_preview.py +++ b/latent_preview.py @@ -22,10 +22,7 @@ def __init__(self, taesd): self.taesd = taesd def decode_latent_to_preview(self, x0): - x_sample = self.taesd.decoder(x0[:1])[0].detach() - # x_sample = self.taesd.unscale_latents(x_sample).div(4).add(0.5) # returns value in [-2, 2] - x_sample = x_sample.sub(0.5).mul(2) - + x_sample = self.taesd.decode(x0[:1])[0].detach() x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0) x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2) x_sample = x_sample.astype(np.uint8) diff --git a/nodes.py b/nodes.py index 2adc5e07371..2de468da709 100644 --- a/nodes.py +++ b/nodes.py @@ -573,9 +573,55 @@ def load_lora(self, model, clip, lora_name, strength_model, strength_clip): return (model_lora, clip_lora) class VAELoader: + @staticmethod + def vae_list(): + vaes = folder_paths.get_filename_list("vae") + approx_vaes = folder_paths.get_filename_list("vae_approx") + sdxl_taesd_enc = False + sdxl_taesd_dec = False + sd1_taesd_enc = False + sd1_taesd_dec = False + + for v in approx_vaes: + if v.startswith("taesd_decoder."): + sd1_taesd_dec = True + elif v.startswith("taesd_encoder."): + sd1_taesd_enc = True + elif v.startswith("taesdxl_decoder."): + sdxl_taesd_dec = True + elif v.startswith("taesdxl_encoder."): + sdxl_taesd_enc = True + if sd1_taesd_dec and sd1_taesd_enc: + vaes.append("taesd") + if sdxl_taesd_dec and sdxl_taesd_enc: + vaes.append("taesdxl") + return vaes + + @staticmethod + def load_taesd(name): + sd = {} + approx_vaes = folder_paths.get_filename_list("vae_approx") + + encoder = next(filter(lambda a: a.startswith("{}_encoder.".format(name)), approx_vaes)) + decoder = next(filter(lambda a: a.startswith("{}_decoder.".format(name)), approx_vaes)) + + enc = comfy.utils.load_torch_file(folder_paths.get_full_path("vae_approx", encoder)) + for k in enc: + sd["taesd_encoder.{}".format(k)] = enc[k] + + dec = comfy.utils.load_torch_file(folder_paths.get_full_path("vae_approx", decoder)) + for k in dec: + sd["taesd_decoder.{}".format(k)] = dec[k] + + if name == "taesd": + sd["vae_scale"] = torch.tensor(0.18215) + elif name == "taesdxl": + sd["vae_scale"] = torch.tensor(0.13025) + return sd + @classmethod def INPUT_TYPES(s): - return {"required": { "vae_name": (folder_paths.get_filename_list("vae"), )}} + return {"required": { "vae_name": (s.vae_list(), )}} RETURN_TYPES = ("VAE",) FUNCTION = "load_vae" @@ -583,8 +629,11 @@ def INPUT_TYPES(s): #TODO: scale factor? def load_vae(self, vae_name): - vae_path = folder_paths.get_full_path("vae", vae_name) - sd = comfy.utils.load_torch_file(vae_path) + if vae_name in ["taesd", "taesdxl"]: + sd = self.load_taesd(vae_name) + else: + vae_path = folder_paths.get_full_path("vae", vae_name) + sd = comfy.utils.load_torch_file(vae_path) vae = comfy.sd.VAE(sd=sd) return (vae,) From 6a491ebe2729c675322491e255a72d5ac0ef5bf6 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 21 Nov 2023 16:29:18 -0500 Subject: [PATCH 36/84] Allow model config to preprocess the vae state dict on load. --- comfy/sd.py | 1 + comfy/supported_models_base.py | 3 +++ 2 files changed, 4 insertions(+) diff --git a/comfy/sd.py b/comfy/sd.py index 0f83cc5814d..c006a0362c7 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -448,6 +448,7 @@ class WeightsLoader(torch.nn.Module): if output_vae: vae_sd = comfy.utils.state_dict_prefix_replace(sd, {"first_stage_model.": ""}, filter_keys=True) + vae_sd = model_config.process_vae_state_dict(vae_sd) vae = VAE(sd=vae_sd) if output_clip: diff --git a/comfy/supported_models_base.py b/comfy/supported_models_base.py index 6dfae034303..b073eb4fc58 100644 --- a/comfy/supported_models_base.py +++ b/comfy/supported_models_base.py @@ -56,6 +56,9 @@ def process_clip_state_dict(self, state_dict): def process_unet_state_dict(self, state_dict): return state_dict + def process_vae_state_dict(self, state_dict): + return state_dict + def process_clip_state_dict_for_saving(self, state_dict): replace_prefix = {"": "cond_stage_model."} return utils.state_dict_prefix_replace(state_dict, replace_prefix) From 72741105a687c67137eb5d7a38840b8373d82e61 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 21 Nov 2023 17:18:49 -0500 Subject: [PATCH 37/84] Remove useless code. --- .../modules/diffusionmodules/openaimodel.py | 31 +++++++------------ 1 file changed, 11 insertions(+), 20 deletions(-) diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index 10eb68d73b5..e8f35a540fa 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -28,25 +28,6 @@ def forward(self, x, emb): Apply the module to `x` given `emb` timestep embeddings. """ - -class TimestepEmbedSequential(nn.Sequential, TimestepBlock): - """ - A sequential module that passes timestep embeddings to the children that - support it as an extra input. - """ - - def forward(self, x, emb, context=None, transformer_options={}, output_shape=None): - for layer in self: - if isinstance(layer, TimestepBlock): - x = layer(x, emb) - elif isinstance(layer, SpatialTransformer): - x = layer(x, context, transformer_options) - elif isinstance(layer, Upsample): - x = layer(x, output_shape=output_shape) - else: - x = layer(x) - return x - #This is needed because accelerate makes a copy of transformer_options which breaks "current_index" def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, output_shape=None): for layer in ts: @@ -54,13 +35,23 @@ def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, out x = layer(x, emb) elif isinstance(layer, SpatialTransformer): x = layer(x, context, transformer_options) - transformer_options["current_index"] += 1 + if "current_index" in transformer_options: + transformer_options["current_index"] += 1 elif isinstance(layer, Upsample): x = layer(x, output_shape=output_shape) else: x = layer(x) return x +class TimestepEmbedSequential(nn.Sequential, TimestepBlock): + """ + A sequential module that passes timestep embeddings to the children that + support it as an extra input. + """ + + def forward(self, *args, **kwargs): + return forward_timestep_embed(self, *args, **kwargs) + class Upsample(nn.Module): """ An upsampling layer with an optional convolution. From c3ae99a749fa1e9a6dbb96c69c65c6fcf2507af3 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 22 Nov 2023 03:23:16 -0500 Subject: [PATCH 38/84] Allow controlling downscale and upscale methods in PatchModelAddDownscale. --- comfy/utils.py | 6 ++++-- comfy_extras/nodes_model_downscale.py | 10 +++++++--- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/comfy/utils.py b/comfy/utils.py index f4c0ab41928..294bbb425ff 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -318,7 +318,9 @@ def generate_bilinear_data(length_old, length_new, device): coords_2 = torch.nn.functional.interpolate(coords_2, size=(1, length_new), mode="bilinear") coords_2 = coords_2.to(torch.int64) return ratios, coords_1, coords_2 - + + orig_dtype = samples.dtype + samples = samples.float() n,c,h,w = samples.shape h_new, w_new = (height, width) @@ -347,7 +349,7 @@ def generate_bilinear_data(length_old, length_new, device): result = slerp(pass_1, pass_2, ratios) result = result.reshape(n, h_new, w_new, c).movedim(-1, 1) - return result + return result.to(orig_dtype) def lanczos(samples, width, height): images = [Image.fromarray(np.clip(255. * image.movedim(0, -1).cpu().numpy(), 0, 255).astype(np.uint8)) for image in samples] diff --git a/comfy_extras/nodes_model_downscale.py b/comfy_extras/nodes_model_downscale.py index f65ef05e18b..48bcc689273 100644 --- a/comfy_extras/nodes_model_downscale.py +++ b/comfy_extras/nodes_model_downscale.py @@ -1,6 +1,8 @@ import torch +import comfy.utils class PatchModelAddDownscale: + upscale_methods = ["bicubic", "nearest-exact", "bilinear", "area", "bislerp"] @classmethod def INPUT_TYPES(s): return {"required": { "model": ("MODEL",), @@ -9,13 +11,15 @@ def INPUT_TYPES(s): "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), "end_percent": ("FLOAT", {"default": 0.35, "min": 0.0, "max": 1.0, "step": 0.001}), "downscale_after_skip": ("BOOLEAN", {"default": True}), + "downscale_method": (s.upscale_methods,), + "upscale_method": (s.upscale_methods,), }} RETURN_TYPES = ("MODEL",) FUNCTION = "patch" CATEGORY = "_for_testing" - def patch(self, model, block_number, downscale_factor, start_percent, end_percent, downscale_after_skip): + def patch(self, model, block_number, downscale_factor, start_percent, end_percent, downscale_after_skip, downscale_method, upscale_method): sigma_start = model.model.model_sampling.percent_to_sigma(start_percent) sigma_end = model.model.model_sampling.percent_to_sigma(end_percent) @@ -23,12 +27,12 @@ def input_block_patch(h, transformer_options): if transformer_options["block"][1] == block_number: sigma = transformer_options["sigmas"][0].item() if sigma <= sigma_start and sigma >= sigma_end: - h = torch.nn.functional.interpolate(h, scale_factor=(1.0 / downscale_factor), mode="bicubic", align_corners=False) + h = comfy.utils.common_upscale(h, round(h.shape[-1] * (1.0 / downscale_factor)), round(h.shape[-2] * (1.0 / downscale_factor)), downscale_method, "disabled") return h def output_block_patch(h, hsp, transformer_options): if h.shape[2] != hsp.shape[2]: - h = torch.nn.functional.interpolate(h, size=(hsp.shape[2], hsp.shape[3]), mode="bicubic", align_corners=False) + h = comfy.utils.common_upscale(h, hsp.shape[-1], hsp.shape[-2], upscale_method, "disabled") return h, hsp m = model.clone() From ab7d4f784892c275e888d71aa80a3a2ed59d9b83 Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Wed, 22 Nov 2023 13:53:30 +0000 Subject: [PATCH 39/84] Handle collapsing to hide element --- web/scripts/domWidget.js | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/web/scripts/domWidget.js b/web/scripts/domWidget.js index 16f4e192eea..0f8a2eb0179 100644 --- a/web/scripts/domWidget.js +++ b/web/scripts/domWidget.js @@ -233,6 +233,7 @@ LGraphNode.prototype.addDOMWidget = function (name, type, element, options) { } const hidden = + node.flags?.collapsed || (!!options.hideOnZoom && app.canvas.ds.scale < 0.5) || widget.computedHeight <= 0 || widget.type === "converted-widget"; @@ -290,6 +291,15 @@ LGraphNode.prototype.addDOMWidget = function (name, type, element, options) { this.addCustomWidget(widget); elementWidgets.add(this); + const collapse = this.collapse; + this.collapse = function() { + collapse.apply(this, arguments); + if(this.flags?.collapsed) { + element.hidden = true; + element.style.display = "none"; + } + } + const onRemoved = this.onRemoved; this.onRemoved = function () { element.remove(); From 70d2ea0faa28e1727f7535466ac5378e786b32cb Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Wed, 22 Nov 2023 17:52:20 +0000 Subject: [PATCH 40/84] Control filter list (#2009) * Add control_filter_list to filter items after queue * fix regex * backwards compatibility * formatting * revert * Add and fix test --- tests-ui/tests/widgetInputs.test.js | 96 ++++++++++++++++++++++++++--- web/extensions/core/widgetInputs.js | 8 ++- web/scripts/widgets.js | 56 ++++++++++++++--- 3 files changed, 141 insertions(+), 19 deletions(-) diff --git a/tests-ui/tests/widgetInputs.test.js b/tests-ui/tests/widgetInputs.test.js index 022e5492667..e1873105acc 100644 --- a/tests-ui/tests/widgetInputs.test.js +++ b/tests-ui/tests/widgetInputs.test.js @@ -14,10 +14,10 @@ const lg = require("../utils/litegraph"); * @param { InstanceType } graph * @param { InstanceType } input * @param { string } widgetType - * @param { boolean } hasControlWidget + * @param { number } controlWidgetCount * @returns */ -async function connectPrimitiveAndReload(ez, graph, input, widgetType, hasControlWidget) { +async function connectPrimitiveAndReload(ez, graph, input, widgetType, controlWidgetCount = 0) { // Connect to primitive and ensure its still connected after let primitive = ez.PrimitiveNode(); primitive.outputs[0].connectTo(input); @@ -33,13 +33,17 @@ async function connectPrimitiveAndReload(ez, graph, input, widgetType, hasContro expect(valueWidget.widget.type).toBe(widgetType); // Check if control_after_generate should be added - if (hasControlWidget) { + if (controlWidgetCount) { const controlWidget = primitive.widgets.control_after_generate; expect(controlWidget.widget.type).toBe("combo"); + if(widgetType === "combo") { + const filterWidget = primitive.widgets.control_filter_list; + expect(filterWidget.widget.type).toBe("string"); + } } // Ensure we dont have other widgets - expect(primitive.node.widgets).toHaveLength(1 + +!!hasControlWidget); + expect(primitive.node.widgets).toHaveLength(1 + controlWidgetCount); }); return primitive; @@ -55,8 +59,8 @@ describe("widget inputs", () => { }); [ - { name: "int", type: "INT", widget: "number", control: true }, - { name: "float", type: "FLOAT", widget: "number", control: true }, + { name: "int", type: "INT", widget: "number", control: 1 }, + { name: "float", type: "FLOAT", widget: "number", control: 1 }, { name: "text", type: "STRING" }, { name: "customtext", @@ -64,7 +68,7 @@ describe("widget inputs", () => { opt: { multiline: true }, }, { name: "toggle", type: "BOOLEAN" }, - { name: "combo", type: ["a", "b", "c"], control: true }, + { name: "combo", type: ["a", "b", "c"], control: 2 }, ].forEach((c) => { test(`widget conversion + primitive works on ${c.name}`, async () => { const { ez, graph } = await start({ @@ -106,7 +110,7 @@ describe("widget inputs", () => { n.widgets.ckpt_name.convertToInput(); expect(n.inputs.length).toEqual(inputCount + 1); - const primitive = await connectPrimitiveAndReload(ez, graph, n.inputs.ckpt_name, "combo", true); + const primitive = await connectPrimitiveAndReload(ez, graph, n.inputs.ckpt_name, "combo", 2); // Disconnect & reconnect primitive.outputs[0].connections[0].disconnect(); @@ -226,7 +230,7 @@ describe("widget inputs", () => { // Reload and ensure it still only has 1 converted widget if (!assertNotNullOrUndefined(input)) return; - await connectPrimitiveAndReload(ez, graph, input, "number", true); + await connectPrimitiveAndReload(ez, graph, input, "number", 1); n = graph.find(n); expect(n.widgets).toHaveLength(1); w = n.widgets.example; @@ -258,7 +262,7 @@ describe("widget inputs", () => { // Reload and ensure it still only has 1 converted widget if (assertNotNullOrUndefined(input)) { - await connectPrimitiveAndReload(ez, graph, input, "number", true); + await connectPrimitiveAndReload(ez, graph, input, "number", 1); n = graph.find(n); expect(n.widgets).toHaveLength(1); expect(n.widgets.example.isConvertedToInput).toBeTruthy(); @@ -316,4 +320,76 @@ describe("widget inputs", () => { n1.outputs[0].connectTo(n2.inputs[0]); expect(() => n1.outputs[0].connectTo(n3.inputs[0])).toThrow(); }); + + test("combo primitive can filter list when control_after_generate called", async () => { + const { ez } = await start({ + mockNodeDefs: { + ...makeNodeDef("TestNode1", { example: [["A", "B", "C", "D", "AA", "BB", "CC", "DD", "AAA", "BBB"], {}] }), + }, + }); + + const n1 = ez.TestNode1(); + n1.widgets.example.convertToInput(); + const p = ez.PrimitiveNode() + p.outputs[0].connectTo(n1.inputs[0]); + + const value = p.widgets.value; + const control = p.widgets.control_after_generate.widget; + const filter = p.widgets.control_filter_list; + + expect(p.widgets.length).toBe(3); + control.value = "increment"; + expect(value.value).toBe("A"); + + // Manually trigger after queue when set to increment + control["afterQueued"](); + expect(value.value).toBe("B"); + + // Filter to items containing D + filter.value = "D"; + control["afterQueued"](); + expect(value.value).toBe("D"); + control["afterQueued"](); + expect(value.value).toBe("DD"); + + // Check decrement + value.value = "BBB"; + control.value = "decrement"; + filter.value = "B"; + control["afterQueued"](); + expect(value.value).toBe("BB"); + control["afterQueued"](); + expect(value.value).toBe("B"); + + // Check regex works + value.value = "BBB"; + filter.value = "/[AB]|^C$/"; + control["afterQueued"](); + expect(value.value).toBe("AAA"); + control["afterQueued"](); + expect(value.value).toBe("BB"); + control["afterQueued"](); + expect(value.value).toBe("AA"); + control["afterQueued"](); + expect(value.value).toBe("C"); + control["afterQueued"](); + expect(value.value).toBe("B"); + control["afterQueued"](); + expect(value.value).toBe("A"); + + // Check random + control.value = "randomize"; + filter.value = "/D/"; + for(let i = 0; i < 100; i++) { + control["afterQueued"](); + expect(value.value === "D" || value.value === "DD").toBeTruthy(); + } + + // Ensure it doesnt apply when fixed + control.value = "fixed"; + value.value = "B"; + filter.value = "C"; + control["afterQueued"](); + expect(value.value).toBe("B"); + }); }); diff --git a/web/extensions/core/widgetInputs.js b/web/extensions/core/widgetInputs.js index bad3ac3a74c..5c8fbc9b2d3 100644 --- a/web/extensions/core/widgetInputs.js +++ b/web/extensions/core/widgetInputs.js @@ -1,4 +1,4 @@ -import { ComfyWidgets, addValueControlWidget } from "../../scripts/widgets.js"; +import { ComfyWidgets, addValueControlWidgets } from "../../scripts/widgets.js"; import { app } from "../../scripts/app.js"; const CONVERTED_TYPE = "converted-widget"; @@ -467,7 +467,11 @@ app.registerExtension({ if (!control_value) { control_value = "fixed"; } - addValueControlWidget(this, widget, control_value); + addValueControlWidgets(this, widget, control_value); + let filter = this.widgets_values?.[2]; + if(filter && this.widgets.length === 3) { + this.widgets[2].value = filter; + } } // When our value changes, update other widgets to reflect our changes diff --git a/web/scripts/widgets.js b/web/scripts/widgets.js index ccddc0bc44b..fbc1d0fc324 100644 --- a/web/scripts/widgets.js +++ b/web/scripts/widgets.js @@ -24,17 +24,58 @@ function getNumberDefaults(inputData, defaultStep, precision, enable_rounding) { } export function addValueControlWidget(node, targetWidget, defaultValue = "randomize", values) { - const valueControl = node.addWidget("combo", "control_after_generate", defaultValue, function (v) { }, { + const widgets = addValueControlWidgets(node, targetWidget, defaultValue, values, { + addFilterList: false, + }); + return widgets[0]; +} + +export function addValueControlWidgets(node, targetWidget, defaultValue = "randomize", values, options) { + if (!options) options = {}; + + const widgets = []; + const valueControl = node.addWidget("combo", "control_after_generate", defaultValue, function (v) { }, { values: ["fixed", "increment", "decrement", "randomize"], serialize: false, // Don't include this in prompt. }); - valueControl.afterQueued = () => { + widgets.push(valueControl); + + const isCombo = targetWidget.type === "combo"; + let comboFilter; + if (isCombo && options.addFilterList !== false) { + comboFilter = node.addWidget("string", "control_filter_list", "", function (v) {}, { + serialize: false, // Don't include this in prompt. + }); + widgets.push(comboFilter); + } + valueControl.afterQueued = () => { var v = valueControl.value; - if (targetWidget.type == "combo" && v !== "fixed") { - let current_index = targetWidget.options.values.indexOf(targetWidget.value); - let current_length = targetWidget.options.values.length; + if (isCombo && v !== "fixed") { + let values = targetWidget.options.values; + const filter = comboFilter?.value; + if (filter) { + let check; + if (filter.startsWith("/") && filter.endsWith("/")) { + try { + const regex = new RegExp(filter.substring(1, filter.length - 1)); + check = (item) => regex.test(item); + } catch (error) { + console.error("Error constructing RegExp filter for node " + node.id, filter, error); + } + } + if (!check) { + const lower = filter.toLocaleLowerCase(); + check = (item) => item.toLocaleLowerCase().includes(lower); + } + values = values.filter(item => check(item)); + if (!values.length && targetWidget.options.values.length) { + console.warn("Filter for node " + node.id + " has filtered out all items", filter); + } + } + let current_index = values.indexOf(targetWidget.value); + let current_length = values.length; switch (v) { case "increment": @@ -51,7 +92,7 @@ export function addValueControlWidget(node, targetWidget, defaultValue = "random current_index = Math.max(0, current_index); current_index = Math.min(current_length - 1, current_index); if (current_index >= 0) { - let value = targetWidget.options.values[current_index]; + let value = values[current_index]; targetWidget.value = value; targetWidget.callback(value); } @@ -88,7 +129,8 @@ export function addValueControlWidget(node, targetWidget, defaultValue = "random targetWidget.callback(targetWidget.value); } } - return valueControl; + + return widgets; }; function seedWidget(node, inputName, inputData, app) { From 32447f0c392be6a6b64fbac09fd7e7f33eb451f8 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 22 Nov 2023 17:23:37 -0500 Subject: [PATCH 41/84] Add sampling_settings so models can specify specific sampling settings. --- comfy/model_sampling.py | 2 +- comfy/supported_models_base.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy/model_sampling.py b/comfy/model_sampling.py index 37a3ac725c6..9e2a1c1afa6 100644 --- a/comfy/model_sampling.py +++ b/comfy/model_sampling.py @@ -24,7 +24,7 @@ def __init__(self, model_config=None): super().__init__() beta_schedule = "linear" if model_config is not None: - beta_schedule = model_config.beta_schedule + beta_schedule = model_config.sampling_settings.get("beta_schedule", beta_schedule) self._register_schedule(given_betas=None, beta_schedule=beta_schedule, timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3) self.sigma_data = 1.0 diff --git a/comfy/supported_models_base.py b/comfy/supported_models_base.py index b073eb4fc58..3412cfea030 100644 --- a/comfy/supported_models_base.py +++ b/comfy/supported_models_base.py @@ -19,7 +19,7 @@ class BASE: clip_prefix = [] clip_vision_prefix = None noise_aug_config = None - beta_schedule = "linear" + sampling_settings = {} latent_format = latent_formats.LatentFormat @classmethod From 410bf0777197c7005fe13aa4f6717d6dc63e2b22 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 22 Nov 2023 18:16:02 -0500 Subject: [PATCH 42/84] Make VAE memory estimation take dtype into account. --- comfy/sd.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index c006a0362c7..a8df3bdd449 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -155,8 +155,8 @@ def __init__(self, sd=None, device=None, config=None): if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format sd = diffusers_convert.convert_vae_state_dict(sd) - self.memory_used_encode = lambda shape: (2078 * shape[2] * shape[3]) * 1.7 #These are for AutoencoderKL and need tweaking - self.memory_used_decode = lambda shape: (2562 * shape[2] * shape[3] * 64) * 1.7 + self.memory_used_encode = lambda shape, dtype: (1767 * shape[2] * shape[3]) * model_management.dtype_size(dtype) #These are for AutoencoderKL and need tweaking (should be lower) + self.memory_used_decode = lambda shape, dtype: (2178 * shape[2] * shape[3] * 64) * model_management.dtype_size(dtype) if config is None: if "taesd_decoder.1.weight" in sd: @@ -213,7 +213,7 @@ def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64): def decode(self, samples_in): self.first_stage_model = self.first_stage_model.to(self.device) try: - memory_used = self.memory_used_decode(samples_in.shape) + memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype) model_management.free_memory(memory_used, self.device) free_memory = model_management.get_free_memory(self.device) batch_number = int(free_memory / memory_used) @@ -241,7 +241,7 @@ def encode(self, pixel_samples): self.first_stage_model = self.first_stage_model.to(self.device) pixel_samples = pixel_samples.movedim(-1,1) try: - memory_used = self.memory_used_encode(pixel_samples.shape) #NOTE: this constant along with the one in the decode above are estimated from the mem usage for the VAE and could change. + memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype) model_management.free_memory(memory_used, self.device) free_memory = model_management.get_free_memory(self.device) batch_number = int(free_memory / memory_used) From d03d8aa2e348c6ba3333150eb18aa76f5180a7f0 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 23 Nov 2023 01:09:15 -0500 Subject: [PATCH 43/84] Fix loading groups. --- web/lib/litegraph.core.js | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/web/lib/litegraph.core.js b/web/lib/litegraph.core.js index 0ca2038429e..f571edb30b8 100644 --- a/web/lib/litegraph.core.js +++ b/web/lib/litegraph.core.js @@ -4928,7 +4928,9 @@ LGraphNode.prototype.executeAction = function(action) this.title = o.title; this._bounding.set(o.bounding); this.color = o.color; - this.font_size = o.font_size; + if (o.font_size) { + this.font_size = o.font_size; + } }; LGraphGroup.prototype.serialize = function() { From 87031a1945278abe6b8a8058dfe6f38a5138655c Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 23 Nov 2023 11:59:11 -0500 Subject: [PATCH 44/84] Update readme with link to LCM example page. --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index d622c907209..f87c0404f74 100644 --- a/README.md +++ b/README.md @@ -30,6 +30,7 @@ This ui will let you design and execute advanced stable diffusion pipelines usin - [unCLIP Models](https://comfyanonymous.github.io/ComfyUI_examples/unclip/) - [GLIGEN](https://comfyanonymous.github.io/ComfyUI_examples/gligen/) - [Model Merging](https://comfyanonymous.github.io/ComfyUI_examples/model_merging/) +- [LCM models and Loras](https://comfyanonymous.github.io/ComfyUI_examples/lcm/) - Latent previews with [TAESD](#how-to-show-high-quality-previews) - Starts up very fast. - Works fully offline: will never download anything. From a657f96c5cd9d72725352d6b00def82d9ce5d556 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 23 Nov 2023 13:55:29 -0500 Subject: [PATCH 45/84] Add a node to save animated webp. --- comfy_extras/nodes_images.py | 76 ++++++++++++++++++++++++++++++++++++ web/scripts/pnginfo.js | 4 +- 2 files changed, 79 insertions(+), 1 deletion(-) diff --git a/comfy_extras/nodes_images.py b/comfy_extras/nodes_images.py index 8cb322327b0..18c579190e4 100644 --- a/comfy_extras/nodes_images.py +++ b/comfy_extras/nodes_images.py @@ -1,4 +1,12 @@ import nodes +import folder_paths +from comfy.cli_args import args + +from PIL import Image +import numpy as np +import json +import os + MAX_RESOLUTION = nodes.MAX_RESOLUTION class ImageCrop: @@ -38,7 +46,75 @@ def repeat(self, image, amount): s = image.repeat((amount, 1,1,1)) return (s,) +class SaveAnimatedWEBP: + def __init__(self): + self.output_dir = folder_paths.get_output_directory() + self.type = "output" + self.prefix_append = "" + + methods = {"default": 4, "fastest": 0, "slowest": 6} + @classmethod + def INPUT_TYPES(s): + return {"required": + {"images": ("IMAGE", ), + "filename_prefix": ("STRING", {"default": "ComfyUI"}), + "fps": ("FLOAT", {"default": 6.0, "min": 0.01, "max": 1000.0, "step": 0.01}), + "lossless": ("BOOLEAN", {"default": True}), + "quality": ("INT", {"default": 80, "min": 0, "max": 100}), + "method": (list(s.methods.keys()),), + # "num_frames": ("INT", {"default": 0, "min": 0, "max": 8192}), + }, + "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, + } + + RETURN_TYPES = () + FUNCTION = "save_images" + + OUTPUT_NODE = True + + CATEGORY = "_for_testing" + + def save_images(self, images, fps, filename_prefix, lossless, quality, method, num_frames=0, prompt=None, extra_pnginfo=None): + method = self.methods.get(method, "aoeu") + filename_prefix += self.prefix_append + full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0]) + results = list() + pil_images = [] + for image in images: + i = 255. * image.cpu().numpy() + img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8)) + pil_images.append(img) + + metadata = None + if not args.disable_metadata: + metadata = pil_images[0].getexif() + if prompt is not None: + metadata[0x0110] = "prompt:{}".format(json.dumps(prompt)) + if extra_pnginfo is not None: + inital_exif = 0x010f + for x in extra_pnginfo: + metadata[inital_exif] = "{}:{}".format(x, json.dumps(extra_pnginfo[x])) + inital_exif -= 1 + + if num_frames == 0: + num_frames = len(pil_images) + + c = len(pil_images) + for i in range(0, c, num_frames): + file = f"{filename}_{counter:05}_.webp" + pil_images[i].save(os.path.join(full_output_folder, file), save_all=True, duration=int(1000.0/fps), append_images=pil_images[i + 1:i + num_frames], exif=metadata, lossless=lossless, quality=quality, method=method) + results.append({ + "filename": file, + "subfolder": subfolder, + "type": self.type + }) + counter += 1 + + animated = num_frames != 1 + return { "ui": { "images": results, "animated": (animated,) } } + NODE_CLASS_MAPPINGS = { "ImageCrop": ImageCrop, "RepeatImageBatch": RepeatImageBatch, + "SaveAnimatedWEBP": SaveAnimatedWEBP, } diff --git a/web/scripts/pnginfo.js b/web/scripts/pnginfo.js index 491caed79f5..f8cbe7a3cd9 100644 --- a/web/scripts/pnginfo.js +++ b/web/scripts/pnginfo.js @@ -50,7 +50,6 @@ export function getPngMetadata(file) { function parseExifData(exifData) { // Check for the correct TIFF header (0x4949 for little-endian or 0x4D4D for big-endian) const isLittleEndian = new Uint16Array(exifData.slice(0, 2))[0] === 0x4949; - console.log(exifData); // Function to read 16-bit and 32-bit integers from binary data function readInt(offset, isLittleEndian, length) { @@ -126,6 +125,9 @@ export function getWebpMetadata(file) { const chunk_length = dataView.getUint32(offset + 4, true); const chunk_type = String.fromCharCode(...webp.slice(offset, offset + 4)); if (chunk_type === "EXIF") { + if (String.fromCharCode(...webp.slice(offset + 8, offset + 8 + 6)) == "Exif\0\0") { + offset += 6; + } let data = parseExifData(webp.slice(offset + 8, offset + 8 + chunk_length)); for (var key in data) { var value = data[key]; From 4d2437e68165cf12989dafe1ef0a26c3a0abc7f5 Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Thu, 23 Nov 2023 19:43:55 +0000 Subject: [PATCH 46/84] Call widget onRemove to remove element --- web/scripts/app.js | 1 + 1 file changed, 1 insertion(+) diff --git a/web/scripts/app.js b/web/scripts/app.js index 601e486e6e4..180416ef971 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -544,6 +544,7 @@ export class ComfyApp { } if (widgetIdx > -1) { + this.widgets[widgetIdx].onRemove?.(); this.widgets.splice(widgetIdx, 1); } From 022033a0e75901c7c357ab96e1c804fd5da05770 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 23 Nov 2023 15:06:35 -0500 Subject: [PATCH 47/84] Fix SaveAnimatedWEBP not working when metadata is disabled. --- comfy_extras/nodes_images.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/comfy_extras/nodes_images.py b/comfy_extras/nodes_images.py index 18c579190e4..8c6ae538711 100644 --- a/comfy_extras/nodes_images.py +++ b/comfy_extras/nodes_images.py @@ -75,7 +75,7 @@ def INPUT_TYPES(s): CATEGORY = "_for_testing" def save_images(self, images, fps, filename_prefix, lossless, quality, method, num_frames=0, prompt=None, extra_pnginfo=None): - method = self.methods.get(method, "aoeu") + method = self.methods.get(method) filename_prefix += self.prefix_append full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0]) results = list() @@ -85,9 +85,8 @@ def save_images(self, images, fps, filename_prefix, lossless, quality, method, n img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8)) pil_images.append(img) - metadata = None + metadata = pil_images[0].getexif() if not args.disable_metadata: - metadata = pil_images[0].getexif() if prompt is not None: metadata[0x0110] = "prompt:{}".format(json.dumps(prompt)) if extra_pnginfo is not None: From 1964bf1e78dda9c6c7cf1b561068b835639aa166 Mon Sep 17 00:00:00 2001 From: Enrico Fasoli Date: Thu, 23 Nov 2023 22:24:58 +0100 Subject: [PATCH 48/84] fix: folder handling issues --- folder_paths.py | 5 ++++- nodes.py | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/folder_paths.py b/folder_paths.py index 4a38deec06f..7046255e422 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -38,7 +38,10 @@ filename_list_cache = {} if not os.path.exists(input_directory): - os.makedirs(input_directory) + try: + os.makedirs(input_directory) + except: + print("Failed to create input directory") def set_output_directory(output_dir): global output_directory diff --git a/nodes.py b/nodes.py index 2de468da709..27b8b1c1b80 100644 --- a/nodes.py +++ b/nodes.py @@ -1808,7 +1808,7 @@ def load_custom_nodes(): node_paths = folder_paths.get_folder_paths("custom_nodes") node_import_times = [] for custom_node_path in node_paths: - possible_modules = os.listdir(custom_node_path) + possible_modules = os.listdir(os.path.realpath(custom_node_path)) if "__pycache__" in possible_modules: possible_modules.remove("__pycache__") From 871cc20e13e9ef2629e3b5faa6af64207e86d6d2 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 23 Nov 2023 19:41:33 -0500 Subject: [PATCH 49/84] Support SVD img2vid model. --- comfy/cldm/cldm.py | 1 + comfy/ldm/modules/attention.py | 271 ++++++++++++-- .../modules/diffusionmodules/openaimodel.py | 348 +++++++++++++++--- comfy/ldm/modules/diffusionmodules/util.py | 69 +++- comfy/ldm/modules/temporal_ae.py | 244 ++++++++++++ comfy/model_base.py | 56 ++- comfy/model_detection.py | 18 +- comfy/model_sampling.py | 46 ++- comfy/sd.py | 10 +- comfy/supported_models.py | 36 +- comfy_extras/nodes_model_advanced.py | 31 ++ 11 files changed, 1030 insertions(+), 100 deletions(-) create mode 100644 comfy/ldm/modules/temporal_ae.py diff --git a/comfy/cldm/cldm.py b/comfy/cldm/cldm.py index 9a63202ab07..76a525b378a 100644 --- a/comfy/cldm/cldm.py +++ b/comfy/cldm/cldm.py @@ -54,6 +54,7 @@ def __init__( transformer_depth_output=None, device=None, operations=comfy.ops, + **kwargs, ): super().__init__() assert use_spatial_transformer == True, "use_spatial_transformer has to be true" diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 016795a5974..947e2008cbd 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -5,8 +5,10 @@ from torch import nn, einsum from einops import rearrange, repeat from typing import Optional, Any +from functools import partial -from .diffusionmodules.util import checkpoint + +from .diffusionmodules.util import checkpoint, AlphaBlender, timestep_embedding from .sub_quadratic_attention import efficient_dot_product_attention from comfy import model_management @@ -370,21 +372,45 @@ def forward(self, x, context=None, value=None, mask=None): class BasicTransformerBlock(nn.Module): - def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, - disable_self_attn=False, dtype=None, device=None, operations=comfy.ops): + def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, ff_in=False, inner_dim=None, + disable_self_attn=False, disable_temporal_crossattention=False, switch_temporal_ca_to_sa=False, dtype=None, device=None, operations=comfy.ops): super().__init__() + + self.ff_in = ff_in or inner_dim is not None + if inner_dim is None: + inner_dim = dim + + self.is_res = inner_dim == dim + + if self.ff_in: + self.norm_in = nn.LayerNorm(dim, dtype=dtype, device=device) + self.ff_in = FeedForward(dim, dim_out=inner_dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device, operations=operations) + self.disable_self_attn = disable_self_attn - self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout, + self.attn1 = CrossAttention(query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout, context_dim=context_dim if self.disable_self_attn else None, dtype=dtype, device=device, operations=operations) # is a self-attention if not self.disable_self_attn - self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device, operations=operations) - self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, - heads=n_heads, dim_head=d_head, dropout=dropout, dtype=dtype, device=device, operations=operations) # is self-attn if context is none - self.norm1 = nn.LayerNorm(dim, dtype=dtype, device=device) - self.norm2 = nn.LayerNorm(dim, dtype=dtype, device=device) - self.norm3 = nn.LayerNorm(dim, dtype=dtype, device=device) + self.ff = FeedForward(inner_dim, dim_out=dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device, operations=operations) + + if disable_temporal_crossattention: + if switch_temporal_ca_to_sa: + raise ValueError + else: + self.attn2 = None + else: + context_dim_attn2 = None + if not switch_temporal_ca_to_sa: + context_dim_attn2 = context_dim + + self.attn2 = CrossAttention(query_dim=inner_dim, context_dim=context_dim_attn2, + heads=n_heads, dim_head=d_head, dropout=dropout, dtype=dtype, device=device, operations=operations) # is self-attn if context is none + self.norm2 = nn.LayerNorm(inner_dim, dtype=dtype, device=device) + + self.norm1 = nn.LayerNorm(inner_dim, dtype=dtype, device=device) + self.norm3 = nn.LayerNorm(inner_dim, dtype=dtype, device=device) self.checkpoint = checkpoint self.n_heads = n_heads self.d_head = d_head + self.switch_temporal_ca_to_sa = switch_temporal_ca_to_sa def forward(self, x, context=None, transformer_options={}): return checkpoint(self._forward, (x, context, transformer_options), self.parameters(), self.checkpoint) @@ -418,6 +444,12 @@ def _forward(self, x, context=None, transformer_options={}): else: transformer_patches_replace = {} + if self.ff_in: + x_skip = x + x = self.ff_in(self.norm_in(x)) + if self.is_res: + x += x_skip + n = self.norm1(x) if self.disable_self_attn: context_attn1 = context @@ -465,31 +497,34 @@ def _forward(self, x, context=None, transformer_options={}): for p in patch: x = p(x, extra_options) - n = self.norm2(x) - - context_attn2 = context - value_attn2 = None - if "attn2_patch" in transformer_patches: - patch = transformer_patches["attn2_patch"] - value_attn2 = context_attn2 - for p in patch: - n, context_attn2, value_attn2 = p(n, context_attn2, value_attn2, extra_options) - - attn2_replace_patch = transformer_patches_replace.get("attn2", {}) - block_attn2 = transformer_block - if block_attn2 not in attn2_replace_patch: - block_attn2 = block - - if block_attn2 in attn2_replace_patch: - if value_attn2 is None: + if self.attn2 is not None: + n = self.norm2(x) + if self.switch_temporal_ca_to_sa: + context_attn2 = n + else: + context_attn2 = context + value_attn2 = None + if "attn2_patch" in transformer_patches: + patch = transformer_patches["attn2_patch"] value_attn2 = context_attn2 - n = self.attn2.to_q(n) - context_attn2 = self.attn2.to_k(context_attn2) - value_attn2 = self.attn2.to_v(value_attn2) - n = attn2_replace_patch[block_attn2](n, context_attn2, value_attn2, extra_options) - n = self.attn2.to_out(n) - else: - n = self.attn2(n, context=context_attn2, value=value_attn2) + for p in patch: + n, context_attn2, value_attn2 = p(n, context_attn2, value_attn2, extra_options) + + attn2_replace_patch = transformer_patches_replace.get("attn2", {}) + block_attn2 = transformer_block + if block_attn2 not in attn2_replace_patch: + block_attn2 = block + + if block_attn2 in attn2_replace_patch: + if value_attn2 is None: + value_attn2 = context_attn2 + n = self.attn2.to_q(n) + context_attn2 = self.attn2.to_k(context_attn2) + value_attn2 = self.attn2.to_v(value_attn2) + n = attn2_replace_patch[block_attn2](n, context_attn2, value_attn2, extra_options) + n = self.attn2.to_out(n) + else: + n = self.attn2(n, context=context_attn2, value=value_attn2) if "attn2_output_patch" in transformer_patches: patch = transformer_patches["attn2_output_patch"] @@ -497,7 +532,12 @@ def _forward(self, x, context=None, transformer_options={}): n = p(n, extra_options) x += n - x = self.ff(self.norm3(x)) + x + if self.is_res: + x_skip = x + x = self.ff(self.norm3(x)) + if self.is_res: + x += x_skip + return x @@ -565,3 +605,164 @@ def forward(self, x, context=None, transformer_options={}): x = self.proj_out(x) return x + x_in + +class SpatialVideoTransformer(SpatialTransformer): + def __init__( + self, + in_channels, + n_heads, + d_head, + depth=1, + dropout=0.0, + use_linear=False, + context_dim=None, + use_spatial_context=False, + timesteps=None, + merge_strategy: str = "fixed", + merge_factor: float = 0.5, + time_context_dim=None, + ff_in=False, + checkpoint=False, + time_depth=1, + disable_self_attn=False, + disable_temporal_crossattention=False, + max_time_embed_period: int = 10000, + dtype=None, device=None, operations=comfy.ops + ): + super().__init__( + in_channels, + n_heads, + d_head, + depth=depth, + dropout=dropout, + use_checkpoint=checkpoint, + context_dim=context_dim, + use_linear=use_linear, + disable_self_attn=disable_self_attn, + dtype=dtype, device=device, operations=operations + ) + self.time_depth = time_depth + self.depth = depth + self.max_time_embed_period = max_time_embed_period + + time_mix_d_head = d_head + n_time_mix_heads = n_heads + + time_mix_inner_dim = int(time_mix_d_head * n_time_mix_heads) + + inner_dim = n_heads * d_head + if use_spatial_context: + time_context_dim = context_dim + + self.time_stack = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + n_time_mix_heads, + time_mix_d_head, + dropout=dropout, + context_dim=time_context_dim, + # timesteps=timesteps, + checkpoint=checkpoint, + ff_in=ff_in, + inner_dim=time_mix_inner_dim, + disable_self_attn=disable_self_attn, + disable_temporal_crossattention=disable_temporal_crossattention, + dtype=dtype, device=device, operations=operations + ) + for _ in range(self.depth) + ] + ) + + assert len(self.time_stack) == len(self.transformer_blocks) + + self.use_spatial_context = use_spatial_context + self.in_channels = in_channels + + time_embed_dim = self.in_channels * 4 + self.time_pos_embed = nn.Sequential( + operations.Linear(self.in_channels, time_embed_dim, dtype=dtype, device=device), + nn.SiLU(), + operations.Linear(time_embed_dim, self.in_channels, dtype=dtype, device=device), + ) + + self.time_mixer = AlphaBlender( + alpha=merge_factor, merge_strategy=merge_strategy + ) + + def forward( + self, + x: torch.Tensor, + context: Optional[torch.Tensor] = None, + time_context: Optional[torch.Tensor] = None, + timesteps: Optional[int] = None, + image_only_indicator: Optional[torch.Tensor] = None, + transformer_options={} + ) -> torch.Tensor: + _, _, h, w = x.shape + x_in = x + spatial_context = None + if exists(context): + spatial_context = context + + if self.use_spatial_context: + assert ( + context.ndim == 3 + ), f"n dims of spatial context should be 3 but are {context.ndim}" + + if time_context is None: + time_context = context + time_context_first_timestep = time_context[::timesteps] + time_context = repeat( + time_context_first_timestep, "b ... -> (b n) ...", n=h * w + ) + elif time_context is not None and not self.use_spatial_context: + time_context = repeat(time_context, "b ... -> (b n) ...", n=h * w) + if time_context.ndim == 2: + time_context = rearrange(time_context, "b c -> b 1 c") + + x = self.norm(x) + if not self.use_linear: + x = self.proj_in(x) + x = rearrange(x, "b c h w -> b (h w) c") + if self.use_linear: + x = self.proj_in(x) + + num_frames = torch.arange(timesteps, device=x.device) + num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps) + num_frames = rearrange(num_frames, "b t -> (b t)") + t_emb = timestep_embedding(num_frames, self.in_channels, repeat_only=False, max_period=self.max_time_embed_period).to(x.dtype) + emb = self.time_pos_embed(t_emb) + emb = emb[:, None, :] + + for it_, (block, mix_block) in enumerate( + zip(self.transformer_blocks, self.time_stack) + ): + transformer_options["block_index"] = it_ + x = block( + x, + context=spatial_context, + transformer_options=transformer_options, + ) + + x_mix = x + x_mix = x_mix + emb + + B, S, C = x_mix.shape + x_mix = rearrange(x_mix, "(b t) s c -> (b s) t c", t=timesteps) + x_mix = mix_block(x_mix, context=time_context) #TODO: transformer_options + x_mix = rearrange( + x_mix, "(b s) t c -> (b t) s c", s=S, b=B // timesteps, c=C, t=timesteps + ) + + x = self.time_mixer(x_spatial=x, x_temporal=x_mix, image_only_indicator=image_only_indicator) + + if self.use_linear: + x = self.proj_out(x) + x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) + if not self.use_linear: + x = self.proj_out(x) + out = x + x_in + return out + + diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index e8f35a540fa..a497ed34478 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -5,6 +5,8 @@ import torch as th import torch.nn as nn import torch.nn.functional as F +from einops import rearrange +from functools import partial from .util import ( checkpoint, @@ -12,8 +14,9 @@ zero_module, normalization, timestep_embedding, + AlphaBlender, ) -from ..attention import SpatialTransformer +from ..attention import SpatialTransformer, SpatialVideoTransformer, default from comfy.ldm.util import exists import comfy.ops @@ -29,10 +32,15 @@ def forward(self, x, emb): """ #This is needed because accelerate makes a copy of transformer_options which breaks "current_index" -def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, output_shape=None): +def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, output_shape=None, time_context=None, num_video_frames=None, image_only_indicator=None): for layer in ts: - if isinstance(layer, TimestepBlock): + if isinstance(layer, VideoResBlock): + x = layer(x, emb, num_video_frames, image_only_indicator) + elif isinstance(layer, TimestepBlock): x = layer(x, emb) + elif isinstance(layer, SpatialVideoTransformer): + x = layer(x, context, time_context, num_video_frames, image_only_indicator, transformer_options) + transformer_options["current_index"] += 1 elif isinstance(layer, SpatialTransformer): x = layer(x, context, transformer_options) if "current_index" in transformer_options: @@ -145,6 +153,9 @@ def __init__( use_checkpoint=False, up=False, down=False, + kernel_size=3, + exchange_temb_dims=False, + skip_t_emb=False, dtype=None, device=None, operations=comfy.ops @@ -157,11 +168,17 @@ def __init__( self.use_conv = use_conv self.use_checkpoint = use_checkpoint self.use_scale_shift_norm = use_scale_shift_norm + self.exchange_temb_dims = exchange_temb_dims + + if isinstance(kernel_size, list): + padding = [k // 2 for k in kernel_size] + else: + padding = kernel_size // 2 self.in_layers = nn.Sequential( nn.GroupNorm(32, channels, dtype=dtype, device=device), nn.SiLU(), - operations.conv_nd(dims, channels, self.out_channels, 3, padding=1, dtype=dtype, device=device), + operations.conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device), ) self.updown = up or down @@ -175,19 +192,24 @@ def __init__( else: self.h_upd = self.x_upd = nn.Identity() - self.emb_layers = nn.Sequential( - nn.SiLU(), - operations.Linear( - emb_channels, - 2 * self.out_channels if use_scale_shift_norm else self.out_channels, dtype=dtype, device=device - ), - ) + self.skip_t_emb = skip_t_emb + if self.skip_t_emb: + self.emb_layers = None + self.exchange_temb_dims = False + else: + self.emb_layers = nn.Sequential( + nn.SiLU(), + operations.Linear( + emb_channels, + 2 * self.out_channels if use_scale_shift_norm else self.out_channels, dtype=dtype, device=device + ), + ) self.out_layers = nn.Sequential( nn.GroupNorm(32, self.out_channels, dtype=dtype, device=device), nn.SiLU(), nn.Dropout(p=dropout), zero_module( - operations.conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1, dtype=dtype, device=device) + operations.conv_nd(dims, self.out_channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device) ), ) @@ -195,7 +217,7 @@ def __init__( self.skip_connection = nn.Identity() elif use_conv: self.skip_connection = operations.conv_nd( - dims, channels, self.out_channels, 3, padding=1, dtype=dtype, device=device + dims, channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device ) else: self.skip_connection = operations.conv_nd(dims, channels, self.out_channels, 1, dtype=dtype, device=device) @@ -221,19 +243,110 @@ def _forward(self, x, emb): h = in_conv(h) else: h = self.in_layers(x) - emb_out = self.emb_layers(emb).type(h.dtype) - while len(emb_out.shape) < len(h.shape): - emb_out = emb_out[..., None] + + emb_out = None + if not self.skip_t_emb: + emb_out = self.emb_layers(emb).type(h.dtype) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] if self.use_scale_shift_norm: out_norm, out_rest = self.out_layers[0], self.out_layers[1:] - scale, shift = th.chunk(emb_out, 2, dim=1) - h = out_norm(h) * (1 + scale) + shift + h = out_norm(h) + if emb_out is not None: + scale, shift = th.chunk(emb_out, 2, dim=1) + h *= (1 + scale) + h += shift h = out_rest(h) else: - h = h + emb_out + if emb_out is not None: + if self.exchange_temb_dims: + emb_out = rearrange(emb_out, "b t c ... -> b c t ...") + h = h + emb_out h = self.out_layers(h) return self.skip_connection(x) + h + +class VideoResBlock(ResBlock): + def __init__( + self, + channels: int, + emb_channels: int, + dropout: float, + video_kernel_size=3, + merge_strategy: str = "fixed", + merge_factor: float = 0.5, + out_channels=None, + use_conv: bool = False, + use_scale_shift_norm: bool = False, + dims: int = 2, + use_checkpoint: bool = False, + up: bool = False, + down: bool = False, + dtype=None, + device=None, + operations=comfy.ops + ): + super().__init__( + channels, + emb_channels, + dropout, + out_channels=out_channels, + use_conv=use_conv, + use_scale_shift_norm=use_scale_shift_norm, + dims=dims, + use_checkpoint=use_checkpoint, + up=up, + down=down, + dtype=dtype, + device=device, + operations=operations + ) + + self.time_stack = ResBlock( + default(out_channels, channels), + emb_channels, + dropout=dropout, + dims=3, + out_channels=default(out_channels, channels), + use_scale_shift_norm=False, + use_conv=False, + up=False, + down=False, + kernel_size=video_kernel_size, + use_checkpoint=use_checkpoint, + exchange_temb_dims=True, + dtype=dtype, + device=device, + operations=operations + ) + self.time_mixer = AlphaBlender( + alpha=merge_factor, + merge_strategy=merge_strategy, + rearrange_pattern="b t -> b 1 t 1 1", + ) + + def forward( + self, + x: th.Tensor, + emb: th.Tensor, + num_video_frames: int, + image_only_indicator = None, + ) -> th.Tensor: + x = super().forward(x, emb) + + x_mix = rearrange(x, "(b t) c h w -> b c t h w", t=num_video_frames) + x = rearrange(x, "(b t) c h w -> b c t h w", t=num_video_frames) + + x = self.time_stack( + x, rearrange(emb, "(b t) ... -> b t ...", t=num_video_frames) + ) + x = self.time_mixer( + x_spatial=x_mix, x_temporal=x, image_only_indicator=image_only_indicator + ) + x = rearrange(x, "b c t h w -> (b t) c h w") + return x + + class Timestep(nn.Module): def __init__(self, dim): super().__init__() @@ -310,6 +423,16 @@ def __init__( adm_in_channels=None, transformer_depth_middle=None, transformer_depth_output=None, + use_temporal_resblock=False, + use_temporal_attention=False, + time_context_dim=None, + extra_ff_mix_layer=False, + use_spatial_context=False, + merge_strategy=None, + merge_factor=0.0, + video_kernel_size=None, + disable_temporal_crossattention=False, + max_ddpm_temb_period=10000, device=None, operations=comfy.ops, ): @@ -364,8 +487,12 @@ def __init__( self.num_heads = num_heads self.num_head_channels = num_head_channels self.num_heads_upsample = num_heads_upsample + self.use_temporal_resblocks = use_temporal_resblock self.predict_codebook_ids = n_embed is not None + self.default_num_video_frames = None + self.default_image_only_indicator = None + time_embed_dim = model_channels * 4 self.time_embed = nn.Sequential( operations.Linear(model_channels, time_embed_dim, dtype=self.dtype, device=device), @@ -402,13 +529,104 @@ def __init__( input_block_chans = [model_channels] ch = model_channels ds = 1 + + def get_attention_layer( + ch, + num_heads, + dim_head, + depth=1, + context_dim=None, + use_checkpoint=False, + disable_self_attn=False, + ): + if use_temporal_attention: + return SpatialVideoTransformer( + ch, + num_heads, + dim_head, + depth=depth, + context_dim=context_dim, + time_context_dim=time_context_dim, + dropout=dropout, + ff_in=extra_ff_mix_layer, + use_spatial_context=use_spatial_context, + merge_strategy=merge_strategy, + merge_factor=merge_factor, + checkpoint=use_checkpoint, + use_linear=use_linear_in_transformer, + disable_self_attn=disable_self_attn, + disable_temporal_crossattention=disable_temporal_crossattention, + max_time_embed_period=max_ddpm_temb_period, + dtype=self.dtype, device=device, operations=operations + ) + else: + return SpatialTransformer( + ch, num_heads, dim_head, depth=depth, context_dim=context_dim, + disable_self_attn=disable_self_attn, use_linear=use_linear_in_transformer, + use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations + ) + + def get_resblock( + merge_factor, + merge_strategy, + video_kernel_size, + ch, + time_embed_dim, + dropout, + out_channels, + dims, + use_checkpoint, + use_scale_shift_norm, + down=False, + up=False, + dtype=None, + device=None, + operations=comfy.ops + ): + if self.use_temporal_resblocks: + return VideoResBlock( + merge_factor=merge_factor, + merge_strategy=merge_strategy, + video_kernel_size=video_kernel_size, + channels=ch, + emb_channels=time_embed_dim, + dropout=dropout, + out_channels=out_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=down, + up=up, + dtype=dtype, + device=device, + operations=operations + ) + else: + return ResBlock( + channels=ch, + emb_channels=time_embed_dim, + dropout=dropout, + out_channels=out_channels, + use_checkpoint=use_checkpoint, + dims=dims, + use_scale_shift_norm=use_scale_shift_norm, + down=down, + up=up, + dtype=dtype, + device=device, + operations=operations + ) + for level, mult in enumerate(channel_mult): for nr in range(self.num_res_blocks[level]): layers = [ - ResBlock( - ch, - time_embed_dim, - dropout, + get_resblock( + merge_factor=merge_factor, + merge_strategy=merge_strategy, + video_kernel_size=video_kernel_size, + ch=ch, + time_embed_dim=time_embed_dim, + dropout=dropout, out_channels=mult * model_channels, dims=dims, use_checkpoint=use_checkpoint, @@ -435,11 +653,9 @@ def __init__( disabled_sa = False if not exists(num_attention_blocks) or nr < num_attention_blocks[level]: - layers.append(SpatialTransformer( + layers.append(get_attention_layer( ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim, - disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, - use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations - ) + disable_self_attn=disabled_sa, use_checkpoint=use_checkpoint) ) self.input_blocks.append(TimestepEmbedSequential(*layers)) self._feature_size += ch @@ -448,10 +664,13 @@ def __init__( out_ch = ch self.input_blocks.append( TimestepEmbedSequential( - ResBlock( - ch, - time_embed_dim, - dropout, + get_resblock( + merge_factor=merge_factor, + merge_strategy=merge_strategy, + video_kernel_size=video_kernel_size, + ch=ch, + time_embed_dim=time_embed_dim, + dropout=dropout, out_channels=out_ch, dims=dims, use_checkpoint=use_checkpoint, @@ -481,10 +700,14 @@ def __init__( #num_heads = 1 dim_head = ch // num_heads if use_spatial_transformer else num_head_channels mid_block = [ - ResBlock( - ch, - time_embed_dim, - dropout, + get_resblock( + merge_factor=merge_factor, + merge_strategy=merge_strategy, + video_kernel_size=video_kernel_size, + ch=ch, + time_embed_dim=time_embed_dim, + dropout=dropout, + out_channels=None, dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, @@ -493,15 +716,18 @@ def __init__( operations=operations )] if transformer_depth_middle >= 0: - mid_block += [SpatialTransformer( # always uses a self-attn + mid_block += [get_attention_layer( # always uses a self-attn ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim, - disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer, - use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations + disable_self_attn=disable_middle_self_attn, use_checkpoint=use_checkpoint ), - ResBlock( - ch, - time_embed_dim, - dropout, + get_resblock( + merge_factor=merge_factor, + merge_strategy=merge_strategy, + video_kernel_size=video_kernel_size, + ch=ch, + time_embed_dim=time_embed_dim, + dropout=dropout, + out_channels=None, dims=dims, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, @@ -517,10 +743,13 @@ def __init__( for i in range(self.num_res_blocks[level] + 1): ich = input_block_chans.pop() layers = [ - ResBlock( - ch + ich, - time_embed_dim, - dropout, + get_resblock( + merge_factor=merge_factor, + merge_strategy=merge_strategy, + video_kernel_size=video_kernel_size, + ch=ch + ich, + time_embed_dim=time_embed_dim, + dropout=dropout, out_channels=model_channels * mult, dims=dims, use_checkpoint=use_checkpoint, @@ -548,19 +777,21 @@ def __init__( if not exists(num_attention_blocks) or i < num_attention_blocks[level]: layers.append( - SpatialTransformer( + get_attention_layer( ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim, - disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, - use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations + disable_self_attn=disabled_sa, use_checkpoint=use_checkpoint ) ) if level and i == self.num_res_blocks[level]: out_ch = ch layers.append( - ResBlock( - ch, - time_embed_dim, - dropout, + get_resblock( + merge_factor=merge_factor, + merge_strategy=merge_strategy, + video_kernel_size=video_kernel_size, + ch=ch, + time_embed_dim=time_embed_dim, + dropout=dropout, out_channels=out_ch, dims=dims, use_checkpoint=use_checkpoint, @@ -602,6 +833,10 @@ def forward(self, x, timesteps=None, context=None, y=None, control=None, transfo transformer_options["current_index"] = 0 transformer_patches = transformer_options.get("patches", {}) + num_video_frames = kwargs.get("num_video_frames", self.default_num_video_frames) + image_only_indicator = kwargs.get("image_only_indicator", self.default_image_only_indicator) + time_context = kwargs.get("time_context", None) + assert (y is not None) == ( self.num_classes is not None ), "must specify y if and only if the model is class-conditional" @@ -616,7 +851,7 @@ def forward(self, x, timesteps=None, context=None, y=None, control=None, transfo h = x.type(self.dtype) for id, module in enumerate(self.input_blocks): transformer_options["block"] = ("input", id) - h = forward_timestep_embed(module, h, emb, context, transformer_options) + h = forward_timestep_embed(module, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator) h = apply_control(h, control, 'input') if "input_block_patch" in transformer_patches: patch = transformer_patches["input_block_patch"] @@ -630,9 +865,10 @@ def forward(self, x, timesteps=None, context=None, y=None, control=None, transfo h = p(h, transformer_options) transformer_options["block"] = ("middle", 0) - h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options) + h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator) h = apply_control(h, control, 'middle') + for id, module in enumerate(self.output_blocks): transformer_options["block"] = ("output", id) hsp = hs.pop() @@ -649,7 +885,7 @@ def forward(self, x, timesteps=None, context=None, y=None, control=None, transfo output_shape = hs[-1].shape else: output_shape = None - h = forward_timestep_embed(module, h, emb, context, transformer_options, output_shape) + h = forward_timestep_embed(module, h, emb, context, transformer_options, output_shape, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator) h = h.type(x.dtype) if self.predict_codebook_ids: return self.id_predictor(h) diff --git a/comfy/ldm/modules/diffusionmodules/util.py b/comfy/ldm/modules/diffusionmodules/util.py index 0298ca99d4d..704bbe57450 100644 --- a/comfy/ldm/modules/diffusionmodules/util.py +++ b/comfy/ldm/modules/diffusionmodules/util.py @@ -13,11 +13,78 @@ import torch import torch.nn as nn import numpy as np -from einops import repeat +from einops import repeat, rearrange from comfy.ldm.util import instantiate_from_config import comfy.ops +class AlphaBlender(nn.Module): + strategies = ["learned", "fixed", "learned_with_images"] + + def __init__( + self, + alpha: float, + merge_strategy: str = "learned_with_images", + rearrange_pattern: str = "b t -> (b t) 1 1", + ): + super().__init__() + self.merge_strategy = merge_strategy + self.rearrange_pattern = rearrange_pattern + + assert ( + merge_strategy in self.strategies + ), f"merge_strategy needs to be in {self.strategies}" + + if self.merge_strategy == "fixed": + self.register_buffer("mix_factor", torch.Tensor([alpha])) + elif ( + self.merge_strategy == "learned" + or self.merge_strategy == "learned_with_images" + ): + self.register_parameter( + "mix_factor", torch.nn.Parameter(torch.Tensor([alpha])) + ) + else: + raise ValueError(f"unknown merge strategy {self.merge_strategy}") + + def get_alpha(self, image_only_indicator: torch.Tensor) -> torch.Tensor: + # skip_time_mix = rearrange(repeat(skip_time_mix, 'b -> (b t) () () ()', t=t), '(b t) 1 ... -> b 1 t ...', t=t) + if self.merge_strategy == "fixed": + # make shape compatible + # alpha = repeat(self.mix_factor, '1 -> b () t () ()', t=t, b=bs) + alpha = self.mix_factor + elif self.merge_strategy == "learned": + alpha = torch.sigmoid(self.mix_factor) + # make shape compatible + # alpha = repeat(alpha, '1 -> s () ()', s = t * bs) + elif self.merge_strategy == "learned_with_images": + assert image_only_indicator is not None, "need image_only_indicator ..." + alpha = torch.where( + image_only_indicator.bool(), + torch.ones(1, 1, device=image_only_indicator.device), + rearrange(torch.sigmoid(self.mix_factor), "... -> ... 1"), + ) + alpha = rearrange(alpha, self.rearrange_pattern) + # make shape compatible + # alpha = repeat(alpha, '1 -> s () ()', s = t * bs) + else: + raise NotImplementedError() + return alpha + + def forward( + self, + x_spatial, + x_temporal, + image_only_indicator=None, + ) -> torch.Tensor: + alpha = self.get_alpha(image_only_indicator) + x = ( + alpha.to(x_spatial.dtype) * x_spatial + + (1.0 - alpha).to(x_spatial.dtype) * x_temporal + ) + return x + + def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): if schedule == "linear": betas = ( diff --git a/comfy/ldm/modules/temporal_ae.py b/comfy/ldm/modules/temporal_ae.py new file mode 100644 index 00000000000..11ae049f3be --- /dev/null +++ b/comfy/ldm/modules/temporal_ae.py @@ -0,0 +1,244 @@ +import functools +from typing import Callable, Iterable, Union + +import torch +from einops import rearrange, repeat + +import comfy.ops + +from .diffusionmodules.model import ( + AttnBlock, + Decoder, + ResnetBlock, +) +from .diffusionmodules.openaimodel import ResBlock, timestep_embedding +from .attention import BasicTransformerBlock + +def partialclass(cls, *args, **kwargs): + class NewCls(cls): + __init__ = functools.partialmethod(cls.__init__, *args, **kwargs) + + return NewCls + + +class VideoResBlock(ResnetBlock): + def __init__( + self, + out_channels, + *args, + dropout=0.0, + video_kernel_size=3, + alpha=0.0, + merge_strategy="learned", + **kwargs, + ): + super().__init__(out_channels=out_channels, dropout=dropout, *args, **kwargs) + if video_kernel_size is None: + video_kernel_size = [3, 1, 1] + self.time_stack = ResBlock( + channels=out_channels, + emb_channels=0, + dropout=dropout, + dims=3, + use_scale_shift_norm=False, + use_conv=False, + up=False, + down=False, + kernel_size=video_kernel_size, + use_checkpoint=False, + skip_t_emb=True, + ) + + self.merge_strategy = merge_strategy + if self.merge_strategy == "fixed": + self.register_buffer("mix_factor", torch.Tensor([alpha])) + elif self.merge_strategy == "learned": + self.register_parameter( + "mix_factor", torch.nn.Parameter(torch.Tensor([alpha])) + ) + else: + raise ValueError(f"unknown merge strategy {self.merge_strategy}") + + def get_alpha(self, bs): + if self.merge_strategy == "fixed": + return self.mix_factor + elif self.merge_strategy == "learned": + return torch.sigmoid(self.mix_factor) + else: + raise NotImplementedError() + + def forward(self, x, temb, skip_video=False, timesteps=None): + b, c, h, w = x.shape + if timesteps is None: + timesteps = b + + x = super().forward(x, temb) + + if not skip_video: + x_mix = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps) + + x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps) + + x = self.time_stack(x, temb) + + alpha = self.get_alpha(bs=b // timesteps) + x = alpha * x + (1.0 - alpha) * x_mix + + x = rearrange(x, "b c t h w -> (b t) c h w") + return x + + +class AE3DConv(torch.nn.Conv2d): + def __init__(self, in_channels, out_channels, video_kernel_size=3, *args, **kwargs): + super().__init__(in_channels, out_channels, *args, **kwargs) + if isinstance(video_kernel_size, Iterable): + padding = [int(k // 2) for k in video_kernel_size] + else: + padding = int(video_kernel_size // 2) + + self.time_mix_conv = torch.nn.Conv3d( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=video_kernel_size, + padding=padding, + ) + + def forward(self, input, timesteps=None, skip_video=False): + if timesteps is None: + timesteps = input.shape[0] + x = super().forward(input) + if skip_video: + return x + x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps) + x = self.time_mix_conv(x) + return rearrange(x, "b c t h w -> (b t) c h w") + + +class AttnVideoBlock(AttnBlock): + def __init__( + self, in_channels: int, alpha: float = 0, merge_strategy: str = "learned" + ): + super().__init__(in_channels) + # no context, single headed, as in base class + self.time_mix_block = BasicTransformerBlock( + dim=in_channels, + n_heads=1, + d_head=in_channels, + checkpoint=False, + ff_in=True, + ) + + time_embed_dim = self.in_channels * 4 + self.video_time_embed = torch.nn.Sequential( + comfy.ops.Linear(self.in_channels, time_embed_dim), + torch.nn.SiLU(), + comfy.ops.Linear(time_embed_dim, self.in_channels), + ) + + self.merge_strategy = merge_strategy + if self.merge_strategy == "fixed": + self.register_buffer("mix_factor", torch.Tensor([alpha])) + elif self.merge_strategy == "learned": + self.register_parameter( + "mix_factor", torch.nn.Parameter(torch.Tensor([alpha])) + ) + else: + raise ValueError(f"unknown merge strategy {self.merge_strategy}") + + def forward(self, x, timesteps=None, skip_time_block=False): + if skip_time_block: + return super().forward(x) + + if timesteps is None: + timesteps = x.shape[0] + + x_in = x + x = self.attention(x) + h, w = x.shape[2:] + x = rearrange(x, "b c h w -> b (h w) c") + + x_mix = x + num_frames = torch.arange(timesteps, device=x.device) + num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps) + num_frames = rearrange(num_frames, "b t -> (b t)") + t_emb = timestep_embedding(num_frames, self.in_channels, repeat_only=False) + emb = self.video_time_embed(t_emb) # b, n_channels + emb = emb[:, None, :] + x_mix = x_mix + emb + + alpha = self.get_alpha() + x_mix = self.time_mix_block(x_mix, timesteps=timesteps) + x = alpha * x + (1.0 - alpha) * x_mix # alpha merge + + x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) + x = self.proj_out(x) + + return x_in + x + + def get_alpha( + self, + ): + if self.merge_strategy == "fixed": + return self.mix_factor + elif self.merge_strategy == "learned": + return torch.sigmoid(self.mix_factor) + else: + raise NotImplementedError(f"unknown merge strategy {self.merge_strategy}") + + + +def make_time_attn( + in_channels, + attn_type="vanilla", + attn_kwargs=None, + alpha: float = 0, + merge_strategy: str = "learned", +): + return partialclass( + AttnVideoBlock, in_channels, alpha=alpha, merge_strategy=merge_strategy + ) + + +class Conv2DWrapper(torch.nn.Conv2d): + def forward(self, input: torch.Tensor, **kwargs) -> torch.Tensor: + return super().forward(input) + + +class VideoDecoder(Decoder): + available_time_modes = ["all", "conv-only", "attn-only"] + + def __init__( + self, + *args, + video_kernel_size: Union[int, list] = 3, + alpha: float = 0.0, + merge_strategy: str = "learned", + time_mode: str = "conv-only", + **kwargs, + ): + self.video_kernel_size = video_kernel_size + self.alpha = alpha + self.merge_strategy = merge_strategy + self.time_mode = time_mode + assert ( + self.time_mode in self.available_time_modes + ), f"time_mode parameter has to be in {self.available_time_modes}" + + if self.time_mode != "attn-only": + kwargs["conv_out_op"] = partialclass(AE3DConv, video_kernel_size=self.video_kernel_size) + if self.time_mode not in ["conv-only", "only-last-conv"]: + kwargs["attn_op"] = partialclass(make_time_attn, alpha=self.alpha, merge_strategy=self.merge_strategy) + if self.time_mode not in ["attn-only", "only-last-conv"]: + kwargs["resnet_op"] = partialclass(VideoResBlock, video_kernel_size=self.video_kernel_size, alpha=self.alpha, merge_strategy=self.merge_strategy) + + super().__init__(*args, **kwargs) + + def get_last_layer(self, skip_time_mix=False, **kwargs): + if self.time_mode == "attn-only": + raise NotImplementedError("TODO") + else: + return ( + self.conv_out.time_mix_conv.weight + if not skip_time_mix + else self.conv_out.weight + ) diff --git a/comfy/model_base.py b/comfy/model_base.py index 772e2693493..34274c4aeee 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -10,17 +10,22 @@ class ModelType(Enum): EPS = 1 V_PREDICTION = 2 + V_PREDICTION_EDM = 3 -from comfy.model_sampling import EPS, V_PREDICTION, ModelSamplingDiscrete +from comfy.model_sampling import EPS, V_PREDICTION, ModelSamplingDiscrete, ModelSamplingContinuousEDM + def model_sampling(model_config, model_type): + s = ModelSamplingDiscrete + if model_type == ModelType.EPS: c = EPS elif model_type == ModelType.V_PREDICTION: c = V_PREDICTION - - s = ModelSamplingDiscrete + elif model_type == ModelType.V_PREDICTION_EDM: + c = V_PREDICTION + s = ModelSamplingContinuousEDM class ModelSampling(s, c): pass @@ -262,3 +267,48 @@ def encode_adm(self, **kwargs): out.append(self.embedder(torch.Tensor([target_width]))) flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1) return torch.cat((clip_pooled.to(flat.device), flat), dim=1) + +class SVD_img2vid(BaseModel): + def __init__(self, model_config, model_type=ModelType.V_PREDICTION_EDM, device=None): + super().__init__(model_config, model_type, device=device) + self.embedder = Timestep(256) + + def encode_adm(self, **kwargs): + fps_id = kwargs.get("fps", 6) - 1 + motion_bucket_id = kwargs.get("motion_bucket_id", 127) + augmentation = kwargs.get("augmentation_level", 0) + + out = [] + out.append(self.embedder(torch.Tensor([fps_id]))) + out.append(self.embedder(torch.Tensor([motion_bucket_id]))) + out.append(self.embedder(torch.Tensor([augmentation]))) + + flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0) + return flat + + def extra_conds(self, **kwargs): + out = {} + adm = self.encode_adm(**kwargs) + if adm is not None: + out['y'] = comfy.conds.CONDRegular(adm) + + latent_image = kwargs.get("concat_latent_image", None) + noise = kwargs.get("noise", None) + device = kwargs["device"] + + if latent_image is None: + latent_image = torch.zeros_like(noise) + + if latent_image.shape[1:] != noise.shape[1:]: + latent_image = utils.common_upscale(latent_image, noise.shape[-1], noise.shape[-2], "bilinear", "center") + + latent_image = utils.repeat_to_batch_size(latent_image, noise.shape[0]) + + out['c_concat'] = comfy.conds.CONDNoiseShape(latent_image) + + if "time_conditioning" in kwargs: + out["time_context"] = comfy.conds.CONDCrossAttn(kwargs["time_conditioning"]) + + out['image_only_indicator'] = comfy.conds.CONDConstant(torch.zeros((1,), device=device)) + out['num_video_frames'] = comfy.conds.CONDConstant(noise.shape[0]) + return out diff --git a/comfy/model_detection.py b/comfy/model_detection.py index d65d91e7cb5..45d603a0c63 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -24,7 +24,8 @@ def calculate_transformer_depth(prefix, state_dict_keys, state_dict): last_transformer_depth = count_blocks(state_dict_keys, transformer_prefix + '{}') context_dim = state_dict['{}0.attn2.to_k.weight'.format(transformer_prefix)].shape[1] use_linear_in_transformer = len(state_dict['{}1.proj_in.weight'.format(prefix)].shape) == 2 - return last_transformer_depth, context_dim, use_linear_in_transformer + time_stack = '{}1.time_stack.0.attn1.to_q.weight'.format(prefix) in state_dict or '{}1.time_mix_blocks.0.attn1.to_q.weight'.format(prefix) in state_dict + return last_transformer_depth, context_dim, use_linear_in_transformer, time_stack return None def detect_unet_config(state_dict, key_prefix, dtype): @@ -57,6 +58,7 @@ def detect_unet_config(state_dict, key_prefix, dtype): context_dim = None use_linear_in_transformer = False + video_model = False current_res = 1 count = 0 @@ -99,6 +101,7 @@ def detect_unet_config(state_dict, key_prefix, dtype): if context_dim is None: context_dim = out[1] use_linear_in_transformer = out[2] + video_model = out[3] else: transformer_depth.append(0) @@ -127,6 +130,19 @@ def detect_unet_config(state_dict, key_prefix, dtype): unet_config["transformer_depth_middle"] = transformer_depth_middle unet_config['use_linear_in_transformer'] = use_linear_in_transformer unet_config["context_dim"] = context_dim + + if video_model: + unet_config["extra_ff_mix_layer"] = True + unet_config["use_spatial_context"] = True + unet_config["merge_strategy"] = "learned_with_images" + unet_config["merge_factor"] = 0.0 + unet_config["video_kernel_size"] = [3, 1, 1] + unet_config["use_temporal_resblock"] = True + unet_config["use_temporal_attention"] = True + else: + unet_config["use_temporal_resblock"] = False + unet_config["use_temporal_attention"] = False + return unet_config def model_config_from_unet_config(unet_config): diff --git a/comfy/model_sampling.py b/comfy/model_sampling.py index 9e2a1c1afa6..fac5c995e41 100644 --- a/comfy/model_sampling.py +++ b/comfy/model_sampling.py @@ -1,7 +1,7 @@ import torch import numpy as np from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule - +import math class EPS: def calculate_input(self, sigma, noise): @@ -83,3 +83,47 @@ def percent_to_sigma(self, percent): percent = 1.0 - percent return self.sigma(torch.tensor(percent * 999.0)).item() + +class ModelSamplingContinuousEDM(torch.nn.Module): + def __init__(self, model_config=None): + super().__init__() + self.sigma_data = 1.0 + + if model_config is not None: + sampling_settings = model_config.sampling_settings + else: + sampling_settings = {} + + sigma_min = sampling_settings.get("sigma_min", 0.002) + sigma_max = sampling_settings.get("sigma_max", 120.0) + self.set_sigma_range(sigma_min, sigma_max) + + def set_sigma_range(self, sigma_min, sigma_max): + sigmas = torch.linspace(math.log(sigma_min), math.log(sigma_max), 1000).exp() + + self.register_buffer('sigmas', sigmas) #for compatibility with some schedulers + self.register_buffer('log_sigmas', sigmas.log()) + + @property + def sigma_min(self): + return self.sigmas[0] + + @property + def sigma_max(self): + return self.sigmas[-1] + + def timestep(self, sigma): + return 0.25 * sigma.log() + + def sigma(self, timestep): + return (timestep / 0.25).exp() + + def percent_to_sigma(self, percent): + if percent <= 0.0: + return 999999999.9 + if percent >= 1.0: + return 0.0 + percent = 1.0 - percent + + log_sigma_min = math.log(self.sigma_min) + return math.exp((math.log(self.sigma_max) - log_sigma_min) * percent + log_sigma_min) diff --git a/comfy/sd.py b/comfy/sd.py index a8df3bdd449..7f85540c4eb 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -159,7 +159,15 @@ def __init__(self, sd=None, device=None, config=None): self.memory_used_decode = lambda shape, dtype: (2178 * shape[2] * shape[3] * 64) * model_management.dtype_size(dtype) if config is None: - if "taesd_decoder.1.weight" in sd: + if "decoder.mid.block_1.mix_factor" in sd: + encoder_config = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0} + decoder_config = encoder_config.copy() + decoder_config["video_kernel_size"] = [3, 1, 1] + decoder_config["alpha"] = 0.0 + self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"}, + encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': encoder_config}, + decoder_config={'target': "comfy.ldm.modules.temporal_ae.VideoDecoder", 'params': decoder_config}) + elif "taesd_decoder.1.weight" in sd: self.first_stage_model = comfy.taesd.taesd.TAESD() else: #default SD1.x/SD2.x VAE parameters diff --git a/comfy/supported_models.py b/comfy/supported_models.py index fdd4ea4f5c2..7e2ac677d51 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -17,6 +17,7 @@ class SD15(supported_models_base.BASE): "model_channels": 320, "use_linear_in_transformer": False, "adm_in_channels": None, + "use_temporal_attention": False, } unet_extra_config = { @@ -56,6 +57,7 @@ class SD20(supported_models_base.BASE): "model_channels": 320, "use_linear_in_transformer": True, "adm_in_channels": None, + "use_temporal_attention": False, } latent_format = latent_formats.SD15 @@ -88,6 +90,7 @@ class SD21UnclipL(SD20): "model_channels": 320, "use_linear_in_transformer": True, "adm_in_channels": 1536, + "use_temporal_attention": False, } clip_vision_prefix = "embedder.model.visual." @@ -100,6 +103,7 @@ class SD21UnclipH(SD20): "model_channels": 320, "use_linear_in_transformer": True, "adm_in_channels": 2048, + "use_temporal_attention": False, } clip_vision_prefix = "embedder.model.visual." @@ -112,6 +116,7 @@ class SDXLRefiner(supported_models_base.BASE): "context_dim": 1280, "adm_in_channels": 2560, "transformer_depth": [0, 0, 4, 4, 4, 4, 0, 0], + "use_temporal_attention": False, } latent_format = latent_formats.SDXL @@ -148,7 +153,8 @@ class SDXL(supported_models_base.BASE): "use_linear_in_transformer": True, "transformer_depth": [0, 0, 2, 2, 10, 10], "context_dim": 2048, - "adm_in_channels": 2816 + "adm_in_channels": 2816, + "use_temporal_attention": False, } latent_format = latent_formats.SDXL @@ -203,8 +209,34 @@ class SSD1B(SDXL): "use_linear_in_transformer": True, "transformer_depth": [0, 0, 2, 2, 4, 4], "context_dim": 2048, - "adm_in_channels": 2816 + "adm_in_channels": 2816, + "use_temporal_attention": False, } +class SVD_img2vid(supported_models_base.BASE): + unet_config = { + "model_channels": 320, + "in_channels": 8, + "use_linear_in_transformer": True, + "transformer_depth": [1, 1, 1, 1, 1, 1, 0, 0], + "context_dim": 1024, + "adm_in_channels": 768, + "use_temporal_attention": True, + "use_temporal_resblock": True + } + + clip_vision_prefix = "conditioner.embedders.0.open_clip.model.visual." + + latent_format = latent_formats.SD15 + + sampling_settings = {"sigma_max": 700.0, "sigma_min": 0.002} + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.SVD_img2vid(self, device=device) + return out + + def clip_target(self): + return None models = [SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B] +models += [SVD_img2vid] diff --git a/comfy_extras/nodes_model_advanced.py b/comfy_extras/nodes_model_advanced.py index 0f4ddd9c340..6991c983728 100644 --- a/comfy_extras/nodes_model_advanced.py +++ b/comfy_extras/nodes_model_advanced.py @@ -128,6 +128,36 @@ class ModelSamplingAdvanced(sampling_base, sampling_type): m.add_object_patch("model_sampling", model_sampling) return (m, ) +class ModelSamplingContinuousEDM: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "sampling": (["v_prediction", "eps"],), + "sigma_max": ("FLOAT", {"default": 120.0, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}), + "sigma_min": ("FLOAT", {"default": 0.002, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}), + }} + + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + + CATEGORY = "advanced/model" + + def patch(self, model, sampling, sigma_max, sigma_min): + m = model.clone() + + if sampling == "eps": + sampling_type = comfy.model_sampling.EPS + elif sampling == "v_prediction": + sampling_type = comfy.model_sampling.V_PREDICTION + + class ModelSamplingAdvanced(comfy.model_sampling.ModelSamplingContinuousEDM, sampling_type): + pass + + model_sampling = ModelSamplingAdvanced() + model_sampling.set_sigma_range(sigma_min, sigma_max) + m.add_object_patch("model_sampling", model_sampling) + return (m, ) + class RescaleCFG: @classmethod def INPUT_TYPES(s): @@ -169,5 +199,6 @@ def rescale_cfg(args): NODE_CLASS_MAPPINGS = { "ModelSamplingDiscrete": ModelSamplingDiscrete, + "ModelSamplingContinuousEDM": ModelSamplingContinuousEDM, "RescaleCFG": RescaleCFG, } From 42dfae63312f443d13841a0c4a5de467f5c354c9 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 23 Nov 2023 19:43:09 -0500 Subject: [PATCH 50/84] Nodes to properly use the SDV img2vid checkpoint. The img2vid model is conditioned on clip vision output only which means there's no CLIP model which is why I added a ImageOnlyCheckpointLoader to load it. Note that the unClipCheckpointLoader can also load it because it also has a CLIP_VISION output. SDV_img2vid_Conditioning is the node used to pass the right conditioning to the img2vid model. VideoLinearCFGGuidance applies a linearly decreasing CFG scale to each video frame from the cfg set in the sampler node to min_cfg. SDV_img2vid_Conditioning can be found in conditioning->video_models ImageOnlyCheckpointLoader can be found in loaders->video_models VideoLinearCFGGuidance can be found in sampling->video_models --- comfy_extras/nodes_video_model.py | 89 +++++++++++++++++++++++++++++++ nodes.py | 1 + 2 files changed, 90 insertions(+) create mode 100644 comfy_extras/nodes_video_model.py diff --git a/comfy_extras/nodes_video_model.py b/comfy_extras/nodes_video_model.py new file mode 100644 index 00000000000..92bd883aebc --- /dev/null +++ b/comfy_extras/nodes_video_model.py @@ -0,0 +1,89 @@ +import nodes +import torch +import comfy.utils +import comfy.sd +import folder_paths + + +class ImageOnlyCheckpointLoader: + @classmethod + def INPUT_TYPES(s): + return {"required": { "ckpt_name": (folder_paths.get_filename_list("checkpoints"), ), + }} + RETURN_TYPES = ("MODEL", "CLIP_VISION", "VAE") + FUNCTION = "load_checkpoint" + + CATEGORY = "loaders/video_models" + + def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True): + ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name) + out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=False, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings")) + return (out[0], out[3], out[2]) + + +class SDV_img2vid_Conditioning: + @classmethod + def INPUT_TYPES(s): + return {"required": { "clip_vision": ("CLIP_VISION",), + "init_image": ("IMAGE",), + "vae": ("VAE",), + "width": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}), + "height": ("INT", {"default": 576, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}), + "video_frames": ("INT", {"default": 14, "min": 1, "max": 4096}), + "motion_bucket_id": ("INT", {"default": 127, "min": 1, "max": 1023}), + "fps": ("INT", {"default": 6, "min": 1, "max": 1024}), + "augmentation_level": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 10.0, "step": 0.01}) + }} + RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") + RETURN_NAMES = ("positive", "negative", "latent") + + FUNCTION = "encode" + + CATEGORY = "conditioning/video_models" + + def encode(self, clip_vision, init_image, vae, width, height, video_frames, motion_bucket_id, fps, augmentation_level): + output = clip_vision.encode_image(init_image) + pooled = output.image_embeds.unsqueeze(0) + pixels = comfy.utils.common_upscale(init_image.movedim(-1,1), width, height, "bilinear", "center").movedim(1,-1) + encode_pixels = pixels[:,:,:,:3] + if augmentation_level > 0: + encode_pixels += torch.randn_like(pixels) * augmentation_level + t = vae.encode(encode_pixels) + positive = [[pooled, {"motion_bucket_id": motion_bucket_id, "fps": fps, "augmentation_level": augmentation_level, "concat_latent_image": t}]] + negative = [[torch.zeros_like(pooled), {"motion_bucket_id": motion_bucket_id, "fps": fps, "augmentation_level": augmentation_level, "concat_latent_image": torch.zeros_like(t)}]] + latent = torch.zeros([video_frames, 4, height // 8, width // 8]) + return (positive, negative, {"samples":latent}) + +class VideoLinearCFGGuidance: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "min_cfg": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.5, "round": 0.01}), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + + CATEGORY = "sampling/video_models" + + def patch(self, model, min_cfg): + def linear_cfg(args): + cond = args["cond"] + uncond = args["uncond"] + cond_scale = args["cond_scale"] + + scale = torch.linspace(min_cfg, cond_scale, cond.shape[0], device=cond.device).reshape((cond.shape[0], 1, 1, 1)) + return uncond + scale * (cond - uncond) + + m = model.clone() + m.set_model_sampler_cfg_function(linear_cfg) + return (m, ) + +NODE_CLASS_MAPPINGS = { + "ImageOnlyCheckpointLoader": ImageOnlyCheckpointLoader, + "SDV_img2vid_Conditioning": SDV_img2vid_Conditioning, + "VideoLinearCFGGuidance": VideoLinearCFGGuidance, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "ImageOnlyCheckpointLoader": "Image Only Checkpoint Loader (img2vid model)", +} diff --git a/nodes.py b/nodes.py index 2de468da709..bb24bc6e897 100644 --- a/nodes.py +++ b/nodes.py @@ -1850,6 +1850,7 @@ def init_custom_nodes(): "nodes_model_advanced.py", "nodes_model_downscale.py", "nodes_images.py", + "nodes_video_model.py", ] for node_file in extras_files: From 02ffbb2de3e33d9d64d38c13e70e860d9af90101 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 23 Nov 2023 23:20:07 -0500 Subject: [PATCH 51/84] Fix typo. --- comfy_extras/nodes_video_model.py | 4 ++-- web/scripts/app.js | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/comfy_extras/nodes_video_model.py b/comfy_extras/nodes_video_model.py index 92bd883aebc..26a717a3836 100644 --- a/comfy_extras/nodes_video_model.py +++ b/comfy_extras/nodes_video_model.py @@ -21,7 +21,7 @@ def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True): return (out[0], out[3], out[2]) -class SDV_img2vid_Conditioning: +class SVD_img2vid_Conditioning: @classmethod def INPUT_TYPES(s): return {"required": { "clip_vision": ("CLIP_VISION",), @@ -80,7 +80,7 @@ def linear_cfg(args): NODE_CLASS_MAPPINGS = { "ImageOnlyCheckpointLoader": ImageOnlyCheckpointLoader, - "SDV_img2vid_Conditioning": SDV_img2vid_Conditioning, + "SVD_img2vid_Conditioning": SVD_img2vid_Conditioning, "VideoLinearCFGGuidance": VideoLinearCFGGuidance, } diff --git a/web/scripts/app.js b/web/scripts/app.js index 180416ef971..cd20c40fd0a 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -1523,6 +1523,7 @@ export class ComfyApp { // Patch T2IAdapterLoader to ControlNetLoader since they are the same node now if (n.type == "T2IAdapterLoader") n.type = "ControlNetLoader"; if (n.type == "ConditioningAverage ") n.type = "ConditioningAverage"; //typo fix + if (n.type == "SDV_img2vid_Conditioning") n.type = "SVD_img2vid_Conditioning"; //typo fix // Find missing node types if (!(n.type in LiteGraph.registered_node_types)) { From c782cf3ea95021b0d9fa95014b13e7c32f20fd6e Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 24 Nov 2023 00:27:08 -0500 Subject: [PATCH 52/84] Add to Readme that Stable Video Diffusion is supported. --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index f87c0404f74..9d7e317907f 100644 --- a/README.md +++ b/README.md @@ -11,7 +11,7 @@ This ui will let you design and execute advanced stable diffusion pipelines usin ## Features - Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything. -- Fully supports SD1.x, SD2.x and SDXL +- Fully supports SD1.x, SD2.x, [SDXL](https://comfyanonymous.github.io/ComfyUI_examples/sdxl/) and [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/) - Asynchronous Queue system - Many optimizations: Only re-executes the parts of the workflow that changes between executions. - Command line option: ```--lowvram``` to make it work on GPUs with less than 3GB vram (enabled automatically on GPUs with low vram) From 982338b9bb41301000ddac46d67103af9d0582cd Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 24 Nov 2023 02:08:08 -0500 Subject: [PATCH 53/84] Fix issue loading webp files in UI. --- web/scripts/ui.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/scripts/ui.js b/web/scripts/ui.js index 6f01aa5b245..8a58d30b3a7 100644 --- a/web/scripts/ui.js +++ b/web/scripts/ui.js @@ -599,7 +599,7 @@ export class ComfyUI { const fileInput = $el("input", { id: "comfy-file-input", type: "file", - accept: ".json,image/png,.latent,.safetensors", + accept: ".json,image/png,.latent,.safetensors,image/webp", style: {display: "none"}, parent: document.body, onchange: () => { From 3e5ea74ad356e849ea27f1d766a7b6d90a5acfda Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 24 Nov 2023 03:55:35 -0500 Subject: [PATCH 54/84] Make buggy xformers fall back on pytorch attention. --- comfy/ldm/modules/attention.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 947e2008cbd..d511dda16e8 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -278,9 +278,20 @@ def attention_split(q, k, v, heads, mask=None): ) return r1 +BROKEN_XFORMERS = False +try: + x_vers = xformers.__version__ + #I think 0.0.23 is also broken (q with bs bigger than 65535 gives CUDA error) + BROKEN_XFORMERS = x_vers.startswith("0.0.21") or x_vers.startswith("0.0.22") or x_vers.startswith("0.0.23") +except: + pass + def attention_xformers(q, k, v, heads, mask=None): b, _, dim_head = q.shape dim_head //= heads + if BROKEN_XFORMERS: + if b * heads > 65535: + return attention_pytorch(q, k, v, heads, mask) q, k, v = map( lambda t: t.unsqueeze(3) From eff24ea6aa4f53870f575ec34371b7db940c1cfc Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 24 Nov 2023 11:12:10 -0500 Subject: [PATCH 55/84] Add a node to save animated PNG files. These work in ffpmeg unlike webp. --- comfy_extras/nodes_images.py | 56 ++++++++++++++++++++++++++++++++++++ 1 file changed, 56 insertions(+) diff --git a/comfy_extras/nodes_images.py b/comfy_extras/nodes_images.py index 8c6ae538711..450c8dc40dd 100644 --- a/comfy_extras/nodes_images.py +++ b/comfy_extras/nodes_images.py @@ -3,6 +3,8 @@ from comfy.cli_args import args from PIL import Image +from PIL.PngImagePlugin import PngInfo + import numpy as np import json import os @@ -112,8 +114,62 @@ def save_images(self, images, fps, filename_prefix, lossless, quality, method, n animated = num_frames != 1 return { "ui": { "images": results, "animated": (animated,) } } +class SaveAnimatedPNG: + def __init__(self): + self.output_dir = folder_paths.get_output_directory() + self.type = "output" + self.prefix_append = "" + + @classmethod + def INPUT_TYPES(s): + return {"required": + {"images": ("IMAGE", ), + "filename_prefix": ("STRING", {"default": "ComfyUI"}), + "fps": ("FLOAT", {"default": 12.0, "min": 0.01, "max": 1000.0, "step": 0.01}), + "compress_level": ("INT", {"default": 4, "min": 0, "max": 9}) + }, + "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, + } + + RETURN_TYPES = () + FUNCTION = "save_images" + + OUTPUT_NODE = True + + CATEGORY = "_for_testing" + + def save_images(self, images, fps, compress_level, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None): + filename_prefix += self.prefix_append + full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0]) + results = list() + pil_images = [] + for image in images: + i = 255. * image.cpu().numpy() + img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8)) + pil_images.append(img) + + metadata = None + if not args.disable_metadata: + metadata = PngInfo() + if prompt is not None: + metadata.add_text("prompt", json.dumps(prompt)) + if extra_pnginfo is not None: + for x in extra_pnginfo: + metadata.add_text(x, json.dumps(extra_pnginfo[x])) + + file = f"{filename}_{counter:05}_.png" + pil_images[0].save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=compress_level, save_all=True, duration=int(1000.0/fps), append_images=pil_images[1:]) + results.append({ + "filename": file, + "subfolder": subfolder, + "type": self.type + }) + + return { "ui": { "images": results, "animated": (True,)} } + NODE_CLASS_MAPPINGS = { "ImageCrop": ImageCrop, "RepeatImageBatch": RepeatImageBatch, "SaveAnimatedWEBP": SaveAnimatedWEBP, + "SaveAnimatedPNG": SaveAnimatedPNG, } From 916e9c998c5952a30e7795ccfda74186a82a2a06 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 24 Nov 2023 11:19:23 -0500 Subject: [PATCH 56/84] Use same default fps as webp node. --- comfy_extras/nodes_images.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy_extras/nodes_images.py b/comfy_extras/nodes_images.py index 450c8dc40dd..4c86b2df651 100644 --- a/comfy_extras/nodes_images.py +++ b/comfy_extras/nodes_images.py @@ -125,7 +125,7 @@ def INPUT_TYPES(s): return {"required": {"images": ("IMAGE", ), "filename_prefix": ("STRING", {"default": "ComfyUI"}), - "fps": ("FLOAT", {"default": 12.0, "min": 0.01, "max": 1000.0, "step": 0.01}), + "fps": ("FLOAT", {"default": 6.0, "min": 0.01, "max": 1000.0, "step": 0.01}), "compress_level": ("INT", {"default": 4, "min": 0, "max": 9}) }, "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, From 8ad5d494d52883e02f5745603dfd06f1a49c040b Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 24 Nov 2023 18:14:17 -0500 Subject: [PATCH 57/84] Fix APNG not working in ffmpeg. --- comfy_extras/nodes_images.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy_extras/nodes_images.py b/comfy_extras/nodes_images.py index 4c86b2df651..4b6cd3d1b7f 100644 --- a/comfy_extras/nodes_images.py +++ b/comfy_extras/nodes_images.py @@ -152,10 +152,10 @@ def save_images(self, images, fps, compress_level, filename_prefix="ComfyUI", pr if not args.disable_metadata: metadata = PngInfo() if prompt is not None: - metadata.add_text("prompt", json.dumps(prompt)) + metadata.add(b"tEXt", "prompt".encode("latin-1", "strict") + b"\0" + json.dumps(prompt).encode("latin-1", "strict"), after_idat=True) if extra_pnginfo is not None: for x in extra_pnginfo: - metadata.add_text(x, json.dumps(extra_pnginfo[x])) + metadata.add(b"tEXt", x.encode("latin-1", "strict") + b"\0" + json.dumps(extra_pnginfo[x]).encode("latin-1", "strict"), after_idat=True) file = f"{filename}_{counter:05}_.png" pil_images[0].save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=compress_level, save_all=True, duration=int(1000.0/fps), append_images=pil_images[1:]) From e020ab61f97fd8bccc31e7eebd23acd5dd9e2ecd Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 24 Nov 2023 18:24:19 -0500 Subject: [PATCH 58/84] Fix output APNG not working with ffmpeg. --- comfy_extras/nodes_images.py | 4 ++-- web/scripts/pnginfo.js | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/comfy_extras/nodes_images.py b/comfy_extras/nodes_images.py index 4b6cd3d1b7f..5ad2235a523 100644 --- a/comfy_extras/nodes_images.py +++ b/comfy_extras/nodes_images.py @@ -152,10 +152,10 @@ def save_images(self, images, fps, compress_level, filename_prefix="ComfyUI", pr if not args.disable_metadata: metadata = PngInfo() if prompt is not None: - metadata.add(b"tEXt", "prompt".encode("latin-1", "strict") + b"\0" + json.dumps(prompt).encode("latin-1", "strict"), after_idat=True) + metadata.add(b"comf", "prompt".encode("latin-1", "strict") + b"\0" + json.dumps(prompt).encode("latin-1", "strict"), after_idat=True) if extra_pnginfo is not None: for x in extra_pnginfo: - metadata.add(b"tEXt", x.encode("latin-1", "strict") + b"\0" + json.dumps(extra_pnginfo[x]).encode("latin-1", "strict"), after_idat=True) + metadata.add(b"comf", x.encode("latin-1", "strict") + b"\0" + json.dumps(extra_pnginfo[x]).encode("latin-1", "strict"), after_idat=True) file = f"{filename}_{counter:05}_.png" pil_images[0].save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=compress_level, save_all=True, duration=int(1000.0/fps), append_images=pil_images[1:]) diff --git a/web/scripts/pnginfo.js b/web/scripts/pnginfo.js index f8cbe7a3cd9..83a4ebc86c4 100644 --- a/web/scripts/pnginfo.js +++ b/web/scripts/pnginfo.js @@ -24,7 +24,7 @@ export function getPngMetadata(file) { const length = dataView.getUint32(offset); // Get the chunk type const type = String.fromCharCode(...pngData.slice(offset + 4, offset + 8)); - if (type === "tEXt") { + if (type === "tEXt" || type == "comf") { // Get the keyword let keyword_end = offset + 8; while (pngData[keyword_end] !== 0) { From 5d6dfce5481f67bcfb30b1b39ad6eb78022653af Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 24 Nov 2023 20:35:29 -0500 Subject: [PATCH 59/84] Fix importing diffusers unets. --- comfy/model_detection.py | 30 ++++++++++++++++++++---------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 45d603a0c63..c682c3e1a18 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -232,52 +232,62 @@ def unet_config_from_diffusers_unet(state_dict, dtype): SDXL = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, 'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 2, 2, 10, 10], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 10, - 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 2, 2, 2, 10, 10, 10]} + 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 2, 2, 2, 10, 10, 10], + 'use_temporal_attention': False, 'use_temporal_resblock': False} SDXL_refiner = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, 'num_classes': 'sequential', 'adm_in_channels': 2560, 'dtype': dtype, 'in_channels': 4, 'model_channels': 384, 'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [0, 0, 4, 4, 4, 4, 0, 0], 'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 4, - 'use_linear_in_transformer': True, 'context_dim': 1280, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 4, 4, 4, 4, 4, 4, 0, 0, 0]} + 'use_linear_in_transformer': True, 'context_dim': 1280, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 4, 4, 4, 4, 4, 4, 0, 0, 0], + 'use_temporal_attention': False, 'use_temporal_resblock': False} SD21 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, 'adm_in_channels': None, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0], 'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1, 'use_linear_in_transformer': True, - 'context_dim': 1024, 'num_head_channels': 64, 'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0]} + 'context_dim': 1024, 'num_head_channels': 64, 'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0], + 'use_temporal_attention': False, 'use_temporal_resblock': False} SD21_uncliph = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, 'num_classes': 'sequential', 'adm_in_channels': 2048, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0], 'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1, - 'use_linear_in_transformer': True, 'context_dim': 1024, 'num_head_channels': 64, 'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0]} + 'use_linear_in_transformer': True, 'context_dim': 1024, 'num_head_channels': 64, 'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0], + 'use_temporal_attention': False, 'use_temporal_resblock': False} SD21_unclipl = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, 'num_classes': 'sequential', 'adm_in_channels': 1536, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0], 'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1, - 'use_linear_in_transformer': True, 'context_dim': 1024, 'num_head_channels': 64, 'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0]} + 'use_linear_in_transformer': True, 'context_dim': 1024, 'num_head_channels': 64, 'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0], + 'use_temporal_attention': False, 'use_temporal_resblock': False} SD15 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, 'adm_in_channels': None, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0], 'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1, 'use_linear_in_transformer': False, 'context_dim': 768, 'num_heads': 8, - 'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0]} + 'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0], + 'use_temporal_attention': False, 'use_temporal_resblock': False} SDXL_mid_cnet = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, 'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 0, 0, 1, 1], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 1, - 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 0, 0, 0, 1, 1, 1]} + 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 0, 0, 0, 1, 1, 1], + 'use_temporal_attention': False, 'use_temporal_resblock': False} SDXL_small_cnet = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, 'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 0, 0, 0, 0], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 0, - 'use_linear_in_transformer': True, 'num_head_channels': 64, 'context_dim': 1, 'transformer_depth_output': [0, 0, 0, 0, 0, 0, 0, 0, 0]} + 'use_linear_in_transformer': True, 'num_head_channels': 64, 'context_dim': 1, 'transformer_depth_output': [0, 0, 0, 0, 0, 0, 0, 0, 0], + 'use_temporal_attention': False, 'use_temporal_resblock': False} SDXL_diffusers_inpaint = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, 'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 9, 'model_channels': 320, 'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 2, 2, 10, 10], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 10, - 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 2, 2, 2, 10, 10, 10]} + 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 2, 2, 2, 10, 10, 10], + 'use_temporal_attention': False, 'use_temporal_resblock': False} SSD_1B = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, 'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 2, 2, 4, 4], 'transformer_depth_output': [0, 0, 0, 1, 1, 2, 10, 4, 4], - 'channel_mult': [1, 2, 4], 'transformer_depth_middle': -1, 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64} + 'channel_mult': [1, 2, 4], 'transformer_depth_middle': -1, 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, + 'use_temporal_attention': False, 'use_temporal_resblock': False} supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B] From 5b37270d3ad2227a30e15101a8d528ca77bd589d Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 25 Nov 2023 02:26:50 -0500 Subject: [PATCH 60/84] Add a lora loader node for models with no CLIP. --- nodes.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/nodes.py b/nodes.py index bb24bc6e897..df40f809456 100644 --- a/nodes.py +++ b/nodes.py @@ -572,6 +572,19 @@ def load_lora(self, model, clip, lora_name, strength_model, strength_clip): model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora, strength_model, strength_clip) return (model_lora, clip_lora) +class LoraLoaderModelOnly(LoraLoader): + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "lora_name": (folder_paths.get_filename_list("loras"), ), + "strength_model": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "load_lora_model_only" + + def load_lora_model_only(self, model, lora_name, strength_model): + return (self.load_lora(model, None, lora_name, strength_model, 0)[0],) + class VAELoader: @staticmethod def vae_list(): @@ -1703,6 +1716,7 @@ def expand_image(self, image, left, top, right, bottom, feathering): "ConditioningZeroOut": ConditioningZeroOut, "ConditioningSetTimestepRange": ConditioningSetTimestepRange, + "LoraLoaderModelOnly": LoraLoaderModelOnly, } NODE_DISPLAY_NAME_MAPPINGS = { From 50dc39d6ec5420f35b81f965c106b6710ff48e6e Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 26 Nov 2023 03:13:56 -0500 Subject: [PATCH 61/84] Clean up the extra_options dict for the transformer patches. Now everything in transformer_options gets put in extra_options. --- comfy/ldm/modules/attention.py | 31 ++++++------------- .../modules/diffusionmodules/openaimodel.py | 11 ++++--- 2 files changed, 16 insertions(+), 26 deletions(-) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index d511dda16e8..7dc1a1b5ceb 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -430,31 +430,20 @@ def _forward(self, x, context=None, transformer_options={}): extra_options = {} block = None block_index = 0 - if "current_index" in transformer_options: - extra_options["transformer_index"] = transformer_options["current_index"] - if "block_index" in transformer_options: - block_index = transformer_options["block_index"] - extra_options["block_index"] = block_index - if "original_shape" in transformer_options: - extra_options["original_shape"] = transformer_options["original_shape"] - if "block" in transformer_options: - block = transformer_options["block"] - extra_options["block"] = block - if "cond_or_uncond" in transformer_options: - extra_options["cond_or_uncond"] = transformer_options["cond_or_uncond"] - if "patches" in transformer_options: - transformer_patches = transformer_options["patches"] - else: - transformer_patches = {} + transformer_patches = {} + transformer_patches_replace = {} + + for k in transformer_options: + if k == "patches": + transformer_patches = transformer_options[k] + elif k == "patches_replace": + transformer_patches_replace = transformer_options[k] + else: + extra_options[k] = transformer_options[k] extra_options["n_heads"] = self.n_heads extra_options["dim_head"] = self.d_head - if "patches_replace" in transformer_options: - transformer_patches_replace = transformer_options["patches_replace"] - else: - transformer_patches_replace = {} - if self.ff_in: x_skip = x x = self.ff_in(self.norm_in(x)) diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index a497ed34478..48264892c26 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -31,7 +31,7 @@ def forward(self, x, emb): Apply the module to `x` given `emb` timestep embeddings. """ -#This is needed because accelerate makes a copy of transformer_options which breaks "current_index" +#This is needed because accelerate makes a copy of transformer_options which breaks "transformer_index" def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, output_shape=None, time_context=None, num_video_frames=None, image_only_indicator=None): for layer in ts: if isinstance(layer, VideoResBlock): @@ -40,11 +40,12 @@ def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, out x = layer(x, emb) elif isinstance(layer, SpatialVideoTransformer): x = layer(x, context, time_context, num_video_frames, image_only_indicator, transformer_options) - transformer_options["current_index"] += 1 + if "transformer_index" in transformer_options: + transformer_options["transformer_index"] += 1 elif isinstance(layer, SpatialTransformer): x = layer(x, context, transformer_options) - if "current_index" in transformer_options: - transformer_options["current_index"] += 1 + if "transformer_index" in transformer_options: + transformer_options["transformer_index"] += 1 elif isinstance(layer, Upsample): x = layer(x, output_shape=output_shape) else: @@ -830,7 +831,7 @@ def forward(self, x, timesteps=None, context=None, y=None, control=None, transfo :return: an [N x C x ...] Tensor of outputs. """ transformer_options["original_shape"] = list(x.shape) - transformer_options["current_index"] = 0 + transformer_options["transformer_index"] = 0 transformer_patches = transformer_options.get("patches", {}) num_video_frames = kwargs.get("num_video_frames", self.default_num_video_frames) From 39e75862b248a20e8233ccee743ba5b2e977cdcf Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 26 Nov 2023 03:43:02 -0500 Subject: [PATCH 62/84] Fix regression from last commit. --- comfy/ldm/modules/attention.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 7dc1a1b5ceb..f684523823d 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -428,8 +428,8 @@ def forward(self, x, context=None, transformer_options={}): def _forward(self, x, context=None, transformer_options={}): extra_options = {} - block = None - block_index = 0 + block = transformer_options.get("block", None) + block_index = transformer_options.get("block_index", 0) transformer_patches = {} transformer_patches_replace = {} From 6aa1bcd601dfdcb4485ea31947ffbf992a5b54fc Mon Sep 17 00:00:00 2001 From: Jack Bauer <2308123+dmx974@users.noreply.github.com> Date: Sun, 26 Nov 2023 17:23:11 +0400 Subject: [PATCH 63/84] Remove hard coded max_items in history API --- web/scripts/api.js | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/web/scripts/api.js b/web/scripts/api.js index de56b23108b..9aa7528af04 100644 --- a/web/scripts/api.js +++ b/web/scripts/api.js @@ -254,9 +254,9 @@ class ComfyApi extends EventTarget { * Gets the prompt execution history * @returns Prompt history including node outputs */ - async getHistory() { + async getHistory(max_items=200) { try { - const res = await this.fetchApi("/history?max_items=200"); + const res = await this.fetchApi(`/history?max_items=${max_items}`); return { History: Object.values(await res.json()) }; } catch (error) { console.error(error); From edd6f75d3ad243e6c2d38f2d94191da40d12b2f3 Mon Sep 17 00:00:00 2001 From: David Jeske Date: Sun, 26 Nov 2023 13:10:31 -0700 Subject: [PATCH 64/84] better error for invalid output paths --- folder_paths.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/folder_paths.py b/folder_paths.py index 4a38deec06f..5479fd7b2b1 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -228,8 +228,12 @@ def compute_vars(input, image_width, image_height): full_output_folder = os.path.join(output_dir, subfolder) if os.path.commonpath((output_dir, os.path.abspath(full_output_folder))) != output_dir: - print("Saving image outside the output folder is not allowed.") - return {} + err = "**** ERROR: Saving image outside the output folder is not allowed." + \ + "\n full_output_folder: " + os.path.abspath(full_output_folder) + \ + "\n output_dir: " + output_dir + \ + "\n commonpath: " + os.path.commonpath((output_dir, os.path.abspath(full_output_folder))) + print(err) + raise Exception(err) try: counter = max(filter(lambda a: a[1][:-1] == filename and a[1][-1] == "_", map(map_filename, os.listdir(full_output_folder))))[0] + 1 From 34eccd863bb41f48346de178a55be308dc36e5e5 Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Mon, 27 Nov 2023 14:00:15 +0000 Subject: [PATCH 65/84] Add simple undo redo history --- web/extensions/core/undoRedo.js | 150 ++++++++++++++++++++++++++++++++ 1 file changed, 150 insertions(+) create mode 100644 web/extensions/core/undoRedo.js diff --git a/web/extensions/core/undoRedo.js b/web/extensions/core/undoRedo.js new file mode 100644 index 00000000000..1c1d785a8e9 --- /dev/null +++ b/web/extensions/core/undoRedo.js @@ -0,0 +1,150 @@ +import { app } from "../../scripts/app.js"; + +const MAX_HISTORY = 50; + +let undo = []; +let redo = []; +let activeState = null; +let isOurLoad = false; +function checkState() { + const currentState = app.graph.serialize(); + if (!graphEqual(activeState, currentState)) { + undo.push(activeState); + if(undo.length > MAX_HISTORY) { + undo.shift(); + } + activeState = clone(currentState); + redo.length = 0; + } +} + +const loadGraphData = app.loadGraphData; +app.loadGraphData = async function () { + const v = await loadGraphData.apply(this, arguments); + if (isOurLoad) { + isOurLoad = false; + } else { + checkState(); + } + return v; +}; + +function clone(obj) { + try { + if (typeof structuredClone !== "undefined") { + return structuredClone(obj); + } + } catch (error) { + // structuredClone is stricter than using JSON.parse/stringify so fallback to that + } + + return JSON.parse(JSON.stringify(obj)); +} + +function graphEqual(a, b, root = true) { + if (a === b) return true; + + if (typeof a == "object" && a && typeof b == "object" && b) { + const keys = Object.getOwnPropertyNames(a); + + if (keys.length != Object.getOwnPropertyNames(b).length) { + return false; + } + + for (const key of keys) { + let av = a[key]; + let bv = b[key]; + if (root && key === "nodes") { + // Nodes need to be sorted as the order changes when selecting nodes + av = [...av].sort((a, b) => a.id - b.id); + bv = [...bv].sort((a, b) => a.id - b.id); + } + if (!graphEqual(av, bv, false)) { + return false; + } + } + + return true; + } + + return false; +} + +const undoRedo = async (e) => { + if (e.ctrlKey || e.metaKey) { + if (e.key === "y") { + const prevState = redo.pop(); + if (prevState) { + undo.push(activeState); + isOurLoad = true; + await app.loadGraphData(prevState); + activeState = prevState; + } + return true; + } else if (e.key === "z") { + const prevState = undo.pop(); + if (prevState) { + redo.push(activeState); + isOurLoad = true; + await app.loadGraphData(prevState); + activeState = prevState; + } + return true; + } + } +}; + +const bindInput = (activeEl) => { + if (activeEl?.tagName !== "CANVAS" && activeEl?.tagName !== "BODY") { + for (const evt of ["change", "input", "blur"]) { + if (`on${evt}` in activeEl) { + const listener = () => { + checkState(); + activeEl.removeEventListener(evt, listener); + }; + activeEl.addEventListener(evt, listener); + return true; + } + } + } +}; + +window.addEventListener( + "keydown", + (e) => { + requestAnimationFrame(async () => { + const activeEl = document.activeElement; + if (activeEl?.tagName === "INPUT" || activeEl?.type === "textarea") { + // Ignore events on inputs, they have their native history + return; + } + + // Check if this is a ctrl+z ctrl+y + if (await undoRedo(e)) return; + + // If our active element is some type of input then handle changes after they're done + if (bindInput(activeEl)) return; + checkState(); + }); + }, + true +); + +// Handle clicking DOM elements (e.g. widgets) +window.addEventListener("mouseup", () => { + checkState(); +}); + +// Handle litegraph clicks +const processMouseUp = LGraphCanvas.prototype.processMouseUp; +LGraphCanvas.prototype.processMouseUp = function (e) { + const v = processMouseUp.apply(this, arguments); + checkState(); + return v; +}; +const processMouseDown = LGraphCanvas.prototype.processMouseDown; +LGraphCanvas.prototype.processMouseDown = function (e) { + const v = processMouseDown.apply(this, arguments); + checkState(); + return v; +}; From 9be0b30cf1f69384e72823f5112072b15f1f431d Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Mon, 27 Nov 2023 14:02:50 +0000 Subject: [PATCH 66/84] fix formatting --- web/extensions/core/undoRedo.js | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/web/extensions/core/undoRedo.js b/web/extensions/core/undoRedo.js index 1c1d785a8e9..c6613b0f02d 100644 --- a/web/extensions/core/undoRedo.js +++ b/web/extensions/core/undoRedo.js @@ -10,9 +10,9 @@ function checkState() { const currentState = app.graph.serialize(); if (!graphEqual(activeState, currentState)) { undo.push(activeState); - if(undo.length > MAX_HISTORY) { - undo.shift(); - } + if (undo.length > MAX_HISTORY) { + undo.shift(); + } activeState = clone(currentState); redo.length = 0; } From be71bb5e13d716c541a5372a518e9d512073fe18 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 27 Nov 2023 14:04:16 -0500 Subject: [PATCH 67/84] Tweak memory inference calculations a bit. --- comfy/model_base.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index 34274c4aeee..3d6879ae631 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -164,12 +164,13 @@ def set_inpaint(self): self.inpaint_model = True def memory_required(self, input_shape): - area = input_shape[0] * input_shape[2] * input_shape[3] if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention(): #TODO: this needs to be tweaked - return (area / (comfy.model_management.dtype_size(self.get_dtype()) * 10)) * (1024 * 1024) + area = max(input_shape[0], 3) * input_shape[2] * input_shape[3] + return (area * comfy.model_management.dtype_size(self.get_dtype()) / 60) * (1024 * 1024) else: #TODO: this formula might be too aggressive since I tweaked the sub-quad and split algorithms to use less memory. + area = input_shape[0] * input_shape[2] * input_shape[3] return (((area * 0.6) / 0.9) + 1024) * (1024 * 1024) From 13fdee6abf7a7b072ad0f1ebbaa76aca13ddd2a8 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 27 Nov 2023 14:55:40 -0500 Subject: [PATCH 68/84] Try to free memory for both cond+uncond before inference. --- comfy/model_base.py | 4 ++-- comfy/sample.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index 3d6879ae631..786c9cf47ba 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -166,8 +166,8 @@ def set_inpaint(self): def memory_required(self, input_shape): if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention(): #TODO: this needs to be tweaked - area = max(input_shape[0], 3) * input_shape[2] * input_shape[3] - return (area * comfy.model_management.dtype_size(self.get_dtype()) / 60) * (1024 * 1024) + area = input_shape[0] * input_shape[2] * input_shape[3] + return (area * comfy.model_management.dtype_size(self.get_dtype()) / 50) * (1024 * 1024) else: #TODO: this formula might be too aggressive since I tweaked the sub-quad and split algorithms to use less memory. area = input_shape[0] * input_shape[2] * input_shape[3] diff --git a/comfy/sample.py b/comfy/sample.py index 4bfdb8ce55d..034db97ee88 100644 --- a/comfy/sample.py +++ b/comfy/sample.py @@ -83,7 +83,7 @@ def prepare_sampling(model, noise_shape, positive, negative, noise_mask): real_model = None models, inference_memory = get_additional_models(positive, negative, model.model_dtype()) - comfy.model_management.load_models_gpu([model] + models, model.memory_required(noise_shape) + inference_memory) + comfy.model_management.load_models_gpu([model] + models, model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory) real_model = model.model return real_model, positive, negative, noise_mask, models From 488de0b4df524589c11a9bd0e2b3663d03003342 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 27 Nov 2023 16:32:03 -0500 Subject: [PATCH 69/84] ModelSamplingDiscreteLCM -> ModelSamplingDiscreteDistilled --- comfy_extras/nodes_model_advanced.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/comfy_extras/nodes_model_advanced.py b/comfy_extras/nodes_model_advanced.py index 6991c983728..20261aadea6 100644 --- a/comfy_extras/nodes_model_advanced.py +++ b/comfy_extras/nodes_model_advanced.py @@ -17,7 +17,9 @@ def calculate_denoised(self, sigma, model_output, model_input): return c_out * x0 + c_skip * model_input -class ModelSamplingDiscreteLCM(torch.nn.Module): +class ModelSamplingDiscreteDistilled(torch.nn.Module): + original_timesteps = 50 + def __init__(self): super().__init__() self.sigma_data = 1.0 @@ -29,13 +31,12 @@ def __init__(self): alphas = 1.0 - betas alphas_cumprod = torch.cumprod(alphas, dim=0) - original_timesteps = 50 - self.skip_steps = timesteps // original_timesteps + self.skip_steps = timesteps // self.original_timesteps - alphas_cumprod_valid = torch.zeros((original_timesteps), dtype=torch.float32) - for x in range(original_timesteps): - alphas_cumprod_valid[original_timesteps - 1 - x] = alphas_cumprod[timesteps - 1 - x * self.skip_steps] + alphas_cumprod_valid = torch.zeros((self.original_timesteps), dtype=torch.float32) + for x in range(self.original_timesteps): + alphas_cumprod_valid[self.original_timesteps - 1 - x] = alphas_cumprod[timesteps - 1 - x * self.skip_steps] sigmas = ((1 - alphas_cumprod_valid) / alphas_cumprod_valid) ** 0.5 self.set_sigmas(sigmas) @@ -116,7 +117,7 @@ def patch(self, model, sampling, zsnr): sampling_type = comfy.model_sampling.V_PREDICTION elif sampling == "lcm": sampling_type = LCM - sampling_base = ModelSamplingDiscreteLCM + sampling_base = ModelSamplingDiscreteDistilled class ModelSamplingAdvanced(sampling_base, sampling_type): pass From f30b992b18078415f7c31c6c2f5ad1513db0bf5e Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 27 Nov 2023 16:41:33 -0500 Subject: [PATCH 70/84] .sigma and .timestep now return tensors on the same device as the input. --- comfy/model_sampling.py | 6 +++--- comfy_extras/nodes_model_advanced.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/comfy/model_sampling.py b/comfy/model_sampling.py index fac5c995e41..69c8b1f01fc 100644 --- a/comfy/model_sampling.py +++ b/comfy/model_sampling.py @@ -65,15 +65,15 @@ def sigma_max(self): def timestep(self, sigma): log_sigma = sigma.log() dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None] - return dists.abs().argmin(dim=0).view(sigma.shape) + return dists.abs().argmin(dim=0).view(sigma.shape).to(sigma.device) def sigma(self, timestep): - t = torch.clamp(timestep.float(), min=0, max=(len(self.sigmas) - 1)) + t = torch.clamp(timestep.float().to(self.log_sigmas.device), min=0, max=(len(self.sigmas) - 1)) low_idx = t.floor().long() high_idx = t.ceil().long() w = t.frac() log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx] - return log_sigma.exp() + return log_sigma.exp().to(timestep.device) def percent_to_sigma(self, percent): if percent <= 0.0: diff --git a/comfy_extras/nodes_model_advanced.py b/comfy_extras/nodes_model_advanced.py index 20261aadea6..efcdf1932e4 100644 --- a/comfy_extras/nodes_model_advanced.py +++ b/comfy_extras/nodes_model_advanced.py @@ -56,15 +56,15 @@ def sigma_max(self): def timestep(self, sigma): log_sigma = sigma.log() dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None] - return dists.abs().argmin(dim=0).view(sigma.shape) * self.skip_steps + (self.skip_steps - 1) + return (dists.abs().argmin(dim=0).view(sigma.shape) * self.skip_steps + (self.skip_steps - 1)).to(sigma.device) def sigma(self, timestep): - t = torch.clamp(((timestep - (self.skip_steps - 1)) / self.skip_steps).float(), min=0, max=(len(self.sigmas) - 1)) + t = torch.clamp(((timestep.float().to(self.log_sigmas.device) - (self.skip_steps - 1)) / self.skip_steps).float(), min=0, max=(len(self.sigmas) - 1)) low_idx = t.floor().long() high_idx = t.ceil().long() w = t.frac() log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx] - return log_sigma.exp() + return log_sigma.exp().to(timestep.device) def percent_to_sigma(self, percent): if percent <= 0.0: From c45d1b9b67a98c9ff9743b93caf8303286a430c3 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 27 Nov 2023 17:32:07 -0500 Subject: [PATCH 71/84] Add a function to load a unet from a state dict. --- comfy/sd.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index 7f85540c4eb..53c79e1c57a 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -481,20 +481,18 @@ class WeightsLoader(torch.nn.Module): return (model_patcher, clip, vae, clipvision) -def load_unet(unet_path): #load unet in diffusers format - sd = comfy.utils.load_torch_file(unet_path) +def load_unet_state_dict(sd): #load unet in diffusers format parameters = comfy.utils.calculate_parameters(sd) unet_dtype = model_management.unet_dtype(model_params=parameters) if "input_blocks.0.0.weight" in sd: #ldm model_config = model_detection.model_config_from_unet(sd, "", unet_dtype) if model_config is None: - raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path)) + return None new_sd = sd else: #diffusers model_config = model_detection.model_config_from_diffusers_unet(sd, unet_dtype) if model_config is None: - print("ERROR UNSUPPORTED UNET", unet_path) return None diffusers_keys = comfy.utils.unet_to_diffusers(model_config.unet_config) @@ -514,6 +512,14 @@ def load_unet(unet_path): #load unet in diffusers format print("left over keys in unet:", left_over) return comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device) +def load_unet(unet_path): + sd = comfy.utils.load_torch_file(unet_path) + model = load_unet_state_dict(sd) + if model is None: + print("ERROR UNSUPPORTED UNET", unet_path) + raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path)) + return model + def save_checkpoint(output_path, model, clip, vae, metadata=None): model_management.load_models_gpu([model, clip.load_model()]) sd = model.model.state_dict_for_saving(clip.get_sd(), vae.get_sd()) From 798a34d009cd78f02bd4c0b30f1c9fd6a594d345 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 28 Nov 2023 04:57:59 -0500 Subject: [PATCH 72/84] Lower compress level for image preview. --- nodes.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/nodes.py b/nodes.py index df40f809456..8b4a9b11932 100644 --- a/nodes.py +++ b/nodes.py @@ -1337,6 +1337,7 @@ def __init__(self): self.output_dir = folder_paths.get_output_directory() self.type = "output" self.prefix_append = "" + self.compress_level = 4 @classmethod def INPUT_TYPES(s): @@ -1370,7 +1371,7 @@ def save_images(self, images, filename_prefix="ComfyUI", prompt=None, extra_pngi metadata.add_text(x, json.dumps(extra_pnginfo[x])) file = f"{filename}_{counter:05}_.png" - img.save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=4) + img.save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=self.compress_level) results.append({ "filename": file, "subfolder": subfolder, @@ -1385,6 +1386,7 @@ def __init__(self): self.output_dir = folder_paths.get_temp_directory() self.type = "temp" self.prefix_append = "_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for x in range(5)) + self.compress_level = 1 @classmethod def INPUT_TYPES(s): From 983ebc579212e209f52dff014b79bfe1932c0959 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 28 Nov 2023 04:58:32 -0500 Subject: [PATCH 73/84] Use smart model management for VAE to decrease latency. --- comfy/sd.py | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index 53c79e1c57a..f4f84d0a032 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -187,10 +187,12 @@ def __init__(self, sd=None, device=None, config=None): if device is None: device = model_management.vae_device() self.device = device - self.offload_device = model_management.vae_offload_device() + offload_device = model_management.vae_offload_device() self.vae_dtype = model_management.vae_dtype() self.first_stage_model.to(self.vae_dtype) + self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device) + def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16): steps = samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap) steps += samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x // 2, tile_y * 2, overlap) @@ -219,10 +221,9 @@ def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64): return samples def decode(self, samples_in): - self.first_stage_model = self.first_stage_model.to(self.device) try: memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype) - model_management.free_memory(memory_used, self.device) + model_management.load_models_gpu([self.patcher], memory_required=memory_used) free_memory = model_management.get_free_memory(self.device) batch_number = int(free_memory / memory_used) batch_number = max(1, batch_number) @@ -235,22 +236,19 @@ def decode(self, samples_in): print("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.") pixel_samples = self.decode_tiled_(samples_in) - self.first_stage_model = self.first_stage_model.to(self.offload_device) pixel_samples = pixel_samples.cpu().movedim(1,-1) return pixel_samples def decode_tiled(self, samples, tile_x=64, tile_y=64, overlap = 16): - self.first_stage_model = self.first_stage_model.to(self.device) + model_management.load_model_gpu(self.patcher) output = self.decode_tiled_(samples, tile_x, tile_y, overlap) - self.first_stage_model = self.first_stage_model.to(self.offload_device) return output.movedim(1,-1) def encode(self, pixel_samples): - self.first_stage_model = self.first_stage_model.to(self.device) pixel_samples = pixel_samples.movedim(-1,1) try: memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype) - model_management.free_memory(memory_used, self.device) + model_management.load_models_gpu([self.patcher], memory_required=memory_used) free_memory = model_management.get_free_memory(self.device) batch_number = int(free_memory / memory_used) batch_number = max(1, batch_number) @@ -263,14 +261,12 @@ def encode(self, pixel_samples): print("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.") samples = self.encode_tiled_(pixel_samples) - self.first_stage_model = self.first_stage_model.to(self.offload_device) return samples def encode_tiled(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64): - self.first_stage_model = self.first_stage_model.to(self.device) + model_management.load_model_gpu(self.patcher) pixel_samples = pixel_samples.movedim(-1,1) samples = self.encode_tiled_(pixel_samples, tile_x=tile_x, tile_y=tile_y, overlap=overlap) - self.first_stage_model = self.first_stage_model.to(self.offload_device) return samples def get_sd(self): From 21063fa35b53683f6ca01ccf1a5d5b509f702ba7 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 28 Nov 2023 11:01:05 -0500 Subject: [PATCH 74/84] Lower compress level of png sent on websocket. --- server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server.py b/server.py index 1a8e92b8f96..9b1e3269d7f 100644 --- a/server.py +++ b/server.py @@ -576,7 +576,7 @@ async def send_image(self, image_data, sid=None): bytesIO = BytesIO() header = struct.pack(">I", type_num) bytesIO.write(header) - image.save(bytesIO, format=image_type, quality=95, compress_level=4) + image.save(bytesIO, format=image_type, quality=95, compress_level=1) preview_bytes = bytesIO.getvalue() await self.send_bytes(BinaryEventTypes.PREVIEW_IMAGE, preview_bytes, sid=sid) From 57d7f4464f2a40521666cc8436711f73bf728a97 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 28 Nov 2023 13:35:32 -0500 Subject: [PATCH 75/84] Add SDTurboScheduler node. --- comfy_extras/nodes_custom_sampler.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/comfy_extras/nodes_custom_sampler.py b/comfy_extras/nodes_custom_sampler.py index d3c1d4a23ee..008d0b8d6be 100644 --- a/comfy_extras/nodes_custom_sampler.py +++ b/comfy_extras/nodes_custom_sampler.py @@ -81,6 +81,25 @@ def get_sigmas(self, steps, sigma_max, sigma_min, rho): sigmas = k_diffusion_sampling.get_sigmas_polyexponential(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, rho=rho) return (sigmas, ) +class SDTurboScheduler: + @classmethod + def INPUT_TYPES(s): + return {"required": + {"model": ("MODEL",), + "steps": ("INT", {"default": 1, "min": 1, "max": 10}), + } + } + RETURN_TYPES = ("SIGMAS",) + CATEGORY = "sampling/custom_sampling/schedulers" + + FUNCTION = "get_sigmas" + + def get_sigmas(self, model, steps): + timesteps = torch.flip(torch.arange(1, 11) * 100 - 1, (0,))[:steps] + sigmas = model.model.model_sampling.sigma(timesteps) + sigmas = torch.cat([sigmas, sigmas.new_zeros([1])]) + return (sigmas, ) + class VPScheduler: @classmethod def INPUT_TYPES(s): @@ -257,6 +276,7 @@ def sample(self, model, add_noise, noise_seed, cfg, positive, negative, sampler, "ExponentialScheduler": ExponentialScheduler, "PolyexponentialScheduler": PolyexponentialScheduler, "VPScheduler": VPScheduler, + "SDTurboScheduler": SDTurboScheduler, "KSamplerSelect": KSamplerSelect, "SamplerDPMPP_2M_SDE": SamplerDPMPP_2M_SDE, "SamplerDPMPP_SDE": SamplerDPMPP_SDE, From b911eefc4278b6069390d01a6ac9010ae6eecbac Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 28 Nov 2023 14:20:56 -0500 Subject: [PATCH 76/84] Limit gc.collect() to once every 10 seconds. --- main.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/main.py b/main.py index 1100a07f42a..3997fbefcb3 100644 --- a/main.py +++ b/main.py @@ -88,6 +88,7 @@ def cuda_malloc_warning(): def prompt_worker(q, server): e = execution.PromptExecutor(server) + last_gc_collect = 0 while True: item, item_id = q.get() execution_start_time = time.perf_counter() @@ -97,9 +98,14 @@ def prompt_worker(q, server): if server.client_id is not None: server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, server.client_id) - print("Prompt executed in {:.2f} seconds".format(time.perf_counter() - execution_start_time)) - gc.collect() - comfy.model_management.soft_empty_cache() + current_time = time.perf_counter() + execution_time = current_time - execution_start_time + print("Prompt executed in {:.2f} seconds".format(execution_time)) + if (current_time - last_gc_collect) > 10.0: + gc.collect() + comfy.model_management.soft_empty_cache() + last_gc_collect = current_time + print("gc collect") async def run(server, address='', port=8188, verbose=True, call_on_start=None): await asyncio.gather(server.start(address, port, verbose, call_on_start), server.publish_loop()) From 777f6b15225197898a5f49742682a2be859072d7 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 28 Nov 2023 14:45:00 -0500 Subject: [PATCH 77/84] Add to README that SDXL Turbo is supported. --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 9d7e317907f..af1f2281158 100644 --- a/README.md +++ b/README.md @@ -31,6 +31,7 @@ This ui will let you design and execute advanced stable diffusion pipelines usin - [GLIGEN](https://comfyanonymous.github.io/ComfyUI_examples/gligen/) - [Model Merging](https://comfyanonymous.github.io/ComfyUI_examples/model_merging/) - [LCM models and Loras](https://comfyanonymous.github.io/ComfyUI_examples/lcm/) +- [SDXL Turbo](https://comfyanonymous.github.io/ComfyUI_examples/sdturbo/) - Latent previews with [TAESD](#how-to-show-high-quality-previews) - Starts up very fast. - Works fully offline: will never download anything. From 7f469203b7b4547f1d0f7113d18095334fa06a4d Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Thu, 30 Nov 2023 19:13:27 +0000 Subject: [PATCH 78/84] Group nodes (#1776) * setup ui unit tests * Refactoring, adding connections * Few tweaks * Fix type * Add general test * Refactored and extended test * move to describe * for groups * wip group nodes * Relink nodes Fixed widget values Convert to nodes * Reconnect on convert back * add via node menu + canvas refactor * Add ws event handling * fix using wrong node on widget serialize * allow reroute pipe fix control_after_generate configure * allow multiple images * Add test for converted widgets on missing nodes + fix crash * tidy * mores tests + refactor * throw earlier to get less confusing error * support outputs * more test * add ci action * use lts node * Fix? * Prevent connecting non matching combos * update * accidently removed npm i * Disable logging extension * fix naming allow control_after_generate custom name allow convert from reroutes * group node tests * Add executing info, custom node icon Tidy * internal reroute just works * Fix crash on virtual nodes e.g. note * Save group nodes to templates * Fix template nodes not being stored * Fix aborting convert * tidy * Fix reconnecting output links on convert to group * Fix links on convert to nodes * Handle missing internal nodes * Trigger callback on text change * Apply value on connect * Fix converted widgets not reconnecting * Group node updates - persist internal ids in current session - copy widget values when converting to nodes - fix issue serializing converted inputs * Resolve issue with sanitized node name * Fix internal id * allow outputs to be used internally and externally * order widgets on group node various fixes * fix imageupload widget requiring a specific name * groupnode imageupload test give widget unique name * Fix issue with external node links * Add VAE model * Fix internal node id check * fix potential crash * wip widget input support * more wip group widget inputs * Group node refactor Support for primitives/converted widgets * Fix convert to nodes with internal reroutes * fix applying primitive * Fix control widget values * fix test --- .vscode/settings.json | 9 + tests-ui/setup.js | 1 + tests-ui/tests/groupNode.test.js | 818 ++++++++++++++++++++ tests-ui/tests/widgetInputs.test.js | 4 +- tests-ui/utils/ezgraph.js | 46 +- tests-ui/utils/index.js | 60 +- tests-ui/utils/setup.js | 20 +- web/extensions/core/groupNode.js | 1054 ++++++++++++++++++++++++++ web/extensions/core/nodeTemplates.js | 57 +- web/extensions/core/widgetInputs.js | 225 +++--- web/scripts/app.js | 322 ++++---- web/scripts/domWidget.js | 3 +- web/scripts/ui.js | 8 +- web/scripts/widgets.js | 144 +++- 14 files changed, 2416 insertions(+), 355 deletions(-) create mode 100644 .vscode/settings.json create mode 100644 tests-ui/tests/groupNode.test.js create mode 100644 web/extensions/core/groupNode.js diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 00000000000..202121e10fc --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,9 @@ +{ + "path-intellisense.mappings": { + "../": "${workspaceFolder}/web/extensions/core" + }, + "[python]": { + "editor.defaultFormatter": "ms-python.autopep8" + }, + "python.formatting.provider": "none" +} diff --git a/tests-ui/setup.js b/tests-ui/setup.js index 0f368ab22f9..8bbd9dcdf20 100644 --- a/tests-ui/setup.js +++ b/tests-ui/setup.js @@ -20,6 +20,7 @@ async function setup() { // Modify the response data to add some checkpoints const objectInfo = JSON.parse(data); objectInfo.CheckpointLoaderSimple.input.required.ckpt_name[0] = ["model1.safetensors", "model2.ckpt"]; + objectInfo.VAELoader.input.required.vae_name[0] = ["vae1.safetensors", "vae2.ckpt"]; data = JSON.stringify(objectInfo, undefined, "\t"); diff --git a/tests-ui/tests/groupNode.test.js b/tests-ui/tests/groupNode.test.js new file mode 100644 index 00000000000..ce54c11542c --- /dev/null +++ b/tests-ui/tests/groupNode.test.js @@ -0,0 +1,818 @@ +// @ts-check +/// + +const { start, createDefaultWorkflow } = require("../utils"); +const lg = require("../utils/litegraph"); + +describe("group node", () => { + beforeEach(() => { + lg.setup(global); + }); + + afterEach(() => { + lg.teardown(global); + }); + + /** + * + * @param {*} app + * @param {*} graph + * @param {*} name + * @param {*} nodes + * @returns { Promise> } + */ + async function convertToGroup(app, graph, name, nodes) { + // Select the nodes we are converting + for (const n of nodes) { + n.select(true); + } + + expect(Object.keys(app.canvas.selected_nodes).sort((a, b) => +a - +b)).toEqual( + nodes.map((n) => n.id + "").sort((a, b) => +a - +b) + ); + + global.prompt = jest.fn().mockImplementation(() => name); + const groupNode = await nodes[0].menu["Convert to Group Node"].call(false); + + // Check group name was requested + expect(window.prompt).toHaveBeenCalled(); + + // Ensure old nodes are removed + for (const n of nodes) { + expect(n.isRemoved).toBeTruthy(); + } + + expect(groupNode.type).toEqual("workflow/" + name); + + return graph.find(groupNode); + } + + /** + * @param { Record | number[] } idMap + * @param { Record> } valueMap + */ + function getOutput(idMap = {}, valueMap = {}) { + if (idMap instanceof Array) { + idMap = idMap.reduce((p, n) => { + p[n] = n + ""; + return p; + }, {}); + } + const expected = { + 1: { inputs: { ckpt_name: "model1.safetensors", ...valueMap?.[1] }, class_type: "CheckpointLoaderSimple" }, + 2: { inputs: { text: "positive", clip: ["1", 1], ...valueMap?.[2] }, class_type: "CLIPTextEncode" }, + 3: { inputs: { text: "negative", clip: ["1", 1], ...valueMap?.[3] }, class_type: "CLIPTextEncode" }, + 4: { inputs: { width: 512, height: 512, batch_size: 1, ...valueMap?.[4] }, class_type: "EmptyLatentImage" }, + 5: { + inputs: { + seed: 0, + steps: 20, + cfg: 8, + sampler_name: "euler", + scheduler: "normal", + denoise: 1, + model: ["1", 0], + positive: ["2", 0], + negative: ["3", 0], + latent_image: ["4", 0], + ...valueMap?.[5], + }, + class_type: "KSampler", + }, + 6: { inputs: { samples: ["5", 0], vae: ["1", 2], ...valueMap?.[6] }, class_type: "VAEDecode" }, + 7: { inputs: { filename_prefix: "ComfyUI", images: ["6", 0], ...valueMap?.[7] }, class_type: "SaveImage" }, + }; + + // Map old IDs to new at the top level + const mapped = {}; + for (const oldId in idMap) { + mapped[idMap[oldId]] = expected[oldId]; + delete expected[oldId]; + } + Object.assign(mapped, expected); + + // Map old IDs to new inside links + for (const k in mapped) { + for (const input in mapped[k].inputs) { + const v = mapped[k].inputs[input]; + if (v instanceof Array) { + if (v[0] in idMap) { + v[0] = idMap[v[0]] + ""; + } + } + } + } + + return mapped; + } + + test("can be created from selected nodes", async () => { + const { ez, graph, app } = await start(); + const nodes = createDefaultWorkflow(ez, graph); + const group = await convertToGroup(app, graph, "test", [nodes.pos, nodes.neg, nodes.empty]); + + // Ensure links are now to the group node + expect(group.inputs).toHaveLength(2); + expect(group.outputs).toHaveLength(3); + + expect(group.inputs.map((i) => i.input.name)).toEqual(["clip", "CLIPTextEncode clip"]); + expect(group.outputs.map((i) => i.output.name)).toEqual(["LATENT", "CONDITIONING", "CLIPTextEncode CONDITIONING"]); + + // ckpt clip to both clip inputs on the group + expect(nodes.ckpt.outputs.CLIP.connections.map((t) => [t.targetNode.id, t.targetInput.index])).toEqual([ + [group.id, 0], + [group.id, 1], + ]); + + // group conditioning to sampler + expect(group.outputs["CONDITIONING"].connections.map((t) => [t.targetNode.id, t.targetInput.index])).toEqual([ + [nodes.sampler.id, 1], + ]); + // group conditioning 2 to sampler + expect( + group.outputs["CLIPTextEncode CONDITIONING"].connections.map((t) => [t.targetNode.id, t.targetInput.index]) + ).toEqual([[nodes.sampler.id, 2]]); + // group latent to sampler + expect(group.outputs["LATENT"].connections.map((t) => [t.targetNode.id, t.targetInput.index])).toEqual([ + [nodes.sampler.id, 3], + ]); + }); + + test("maintains all output links on conversion", async () => { + const { ez, graph, app } = await start(); + const nodes = createDefaultWorkflow(ez, graph); + const save2 = ez.SaveImage(...nodes.decode.outputs); + const save3 = ez.SaveImage(...nodes.decode.outputs); + // Ensure an output with multiple links maintains them on convert to group + const group = await convertToGroup(app, graph, "test", [nodes.sampler, nodes.decode]); + expect(group.outputs[0].connections.length).toBe(3); + expect(group.outputs[0].connections[0].targetNode.id).toBe(nodes.save.id); + expect(group.outputs[0].connections[1].targetNode.id).toBe(save2.id); + expect(group.outputs[0].connections[2].targetNode.id).toBe(save3.id); + + // and they're still linked when converting back to nodes + const newNodes = group.menu["Convert to nodes"].call(); + const decode = graph.find(newNodes.find((n) => n.type === "VAEDecode")); + expect(decode.outputs[0].connections.length).toBe(3); + expect(decode.outputs[0].connections[0].targetNode.id).toBe(nodes.save.id); + expect(decode.outputs[0].connections[1].targetNode.id).toBe(save2.id); + expect(decode.outputs[0].connections[2].targetNode.id).toBe(save3.id); + }); + test("can be be converted back to nodes", async () => { + const { ez, graph, app } = await start(); + const nodes = createDefaultWorkflow(ez, graph); + const toConvert = [nodes.pos, nodes.neg, nodes.empty, nodes.sampler]; + const group = await convertToGroup(app, graph, "test", toConvert); + + // Edit some values to ensure they are set back onto the converted nodes + expect(group.widgets["text"].value).toBe("positive"); + group.widgets["text"].value = "pos"; + expect(group.widgets["CLIPTextEncode text"].value).toBe("negative"); + group.widgets["CLIPTextEncode text"].value = "neg"; + expect(group.widgets["width"].value).toBe(512); + group.widgets["width"].value = 1024; + expect(group.widgets["sampler_name"].value).toBe("euler"); + group.widgets["sampler_name"].value = "ddim"; + expect(group.widgets["control_after_generate"].value).toBe("randomize"); + group.widgets["control_after_generate"].value = "fixed"; + + /** @type { Array } */ + group.menu["Convert to nodes"].call(); + + // ensure widget values are set + const pos = graph.find(nodes.pos.id); + expect(pos.node.type).toBe("CLIPTextEncode"); + expect(pos.widgets["text"].value).toBe("pos"); + const neg = graph.find(nodes.neg.id); + expect(neg.node.type).toBe("CLIPTextEncode"); + expect(neg.widgets["text"].value).toBe("neg"); + const empty = graph.find(nodes.empty.id); + expect(empty.node.type).toBe("EmptyLatentImage"); + expect(empty.widgets["width"].value).toBe(1024); + const sampler = graph.find(nodes.sampler.id); + expect(sampler.node.type).toBe("KSampler"); + expect(sampler.widgets["sampler_name"].value).toBe("ddim"); + expect(sampler.widgets["control_after_generate"].value).toBe("fixed"); + + // validate links + expect(nodes.ckpt.outputs.CLIP.connections.map((t) => [t.targetNode.id, t.targetInput.index])).toEqual([ + [pos.id, 0], + [neg.id, 0], + ]); + + expect(pos.outputs["CONDITIONING"].connections.map((t) => [t.targetNode.id, t.targetInput.index])).toEqual([ + [nodes.sampler.id, 1], + ]); + + expect(neg.outputs["CONDITIONING"].connections.map((t) => [t.targetNode.id, t.targetInput.index])).toEqual([ + [nodes.sampler.id, 2], + ]); + + expect(empty.outputs["LATENT"].connections.map((t) => [t.targetNode.id, t.targetInput.index])).toEqual([ + [nodes.sampler.id, 3], + ]); + }); + test("it can embed reroutes as inputs", async () => { + const { ez, graph, app } = await start(); + const nodes = createDefaultWorkflow(ez, graph); + + // Add and connect a reroute to the clip text encodes + const reroute = ez.Reroute(); + nodes.ckpt.outputs.CLIP.connectTo(reroute.inputs[0]); + reroute.outputs[0].connectTo(nodes.pos.inputs[0]); + reroute.outputs[0].connectTo(nodes.neg.inputs[0]); + + // Convert to group and ensure we only have 1 input of the correct type + const group = await convertToGroup(app, graph, "test", [nodes.pos, nodes.neg, nodes.empty, reroute]); + expect(group.inputs).toHaveLength(1); + expect(group.inputs[0].input.type).toEqual("CLIP"); + + expect((await graph.toPrompt()).output).toEqual(getOutput()); + }); + test("it can embed reroutes as outputs", async () => { + const { ez, graph, app } = await start(); + const nodes = createDefaultWorkflow(ez, graph); + + // Add a reroute with no output so we output IMAGE even though its used internally + const reroute = ez.Reroute(); + nodes.decode.outputs.IMAGE.connectTo(reroute.inputs[0]); + + // Convert to group and ensure there is an IMAGE output + const group = await convertToGroup(app, graph, "test", [nodes.decode, nodes.save, reroute]); + expect(group.outputs).toHaveLength(1); + expect(group.outputs[0].output.type).toEqual("IMAGE"); + expect((await graph.toPrompt()).output).toEqual(getOutput([nodes.decode.id, nodes.save.id])); + }); + test("it can embed reroutes as pipes", async () => { + const { ez, graph, app } = await start(); + const nodes = createDefaultWorkflow(ez, graph); + + // Use reroutes as a pipe + const rerouteModel = ez.Reroute(); + const rerouteClip = ez.Reroute(); + const rerouteVae = ez.Reroute(); + nodes.ckpt.outputs.MODEL.connectTo(rerouteModel.inputs[0]); + nodes.ckpt.outputs.CLIP.connectTo(rerouteClip.inputs[0]); + nodes.ckpt.outputs.VAE.connectTo(rerouteVae.inputs[0]); + + const group = await convertToGroup(app, graph, "test", [rerouteModel, rerouteClip, rerouteVae]); + + expect(group.outputs).toHaveLength(3); + expect(group.outputs.map((o) => o.output.type)).toEqual(["MODEL", "CLIP", "VAE"]); + + expect(group.outputs).toHaveLength(3); + expect(group.outputs.map((o) => o.output.type)).toEqual(["MODEL", "CLIP", "VAE"]); + + group.outputs[0].connectTo(nodes.sampler.inputs.model); + group.outputs[1].connectTo(nodes.pos.inputs.clip); + group.outputs[1].connectTo(nodes.neg.inputs.clip); + }); + test("can handle reroutes used internally", async () => { + const { ez, graph, app } = await start(); + const nodes = createDefaultWorkflow(ez, graph); + + let reroutes = []; + let prevNode = nodes.ckpt; + for(let i = 0; i < 5; i++) { + const reroute = ez.Reroute(); + prevNode.outputs[0].connectTo(reroute.inputs[0]); + prevNode = reroute; + reroutes.push(reroute); + } + prevNode.outputs[0].connectTo(nodes.sampler.inputs.model); + + const group = await convertToGroup(app, graph, "test", [...reroutes, ...Object.values(nodes)]); + expect((await graph.toPrompt()).output).toEqual(getOutput()); + + group.menu["Convert to nodes"].call(); + expect((await graph.toPrompt()).output).toEqual(getOutput()); + }); + test("creates with widget values from inner nodes", async () => { + const { ez, graph, app } = await start(); + const nodes = createDefaultWorkflow(ez, graph); + + nodes.ckpt.widgets.ckpt_name.value = "model2.ckpt"; + nodes.pos.widgets.text.value = "hello"; + nodes.neg.widgets.text.value = "world"; + nodes.empty.widgets.width.value = 256; + nodes.empty.widgets.height.value = 1024; + nodes.sampler.widgets.seed.value = 1; + nodes.sampler.widgets.control_after_generate.value = "increment"; + nodes.sampler.widgets.steps.value = 8; + nodes.sampler.widgets.cfg.value = 4.5; + nodes.sampler.widgets.sampler_name.value = "uni_pc"; + nodes.sampler.widgets.scheduler.value = "karras"; + nodes.sampler.widgets.denoise.value = 0.9; + + const group = await convertToGroup(app, graph, "test", [ + nodes.ckpt, + nodes.pos, + nodes.neg, + nodes.empty, + nodes.sampler, + ]); + + expect(group.widgets["ckpt_name"].value).toEqual("model2.ckpt"); + expect(group.widgets["text"].value).toEqual("hello"); + expect(group.widgets["CLIPTextEncode text"].value).toEqual("world"); + expect(group.widgets["width"].value).toEqual(256); + expect(group.widgets["height"].value).toEqual(1024); + expect(group.widgets["seed"].value).toEqual(1); + expect(group.widgets["control_after_generate"].value).toEqual("increment"); + expect(group.widgets["steps"].value).toEqual(8); + expect(group.widgets["cfg"].value).toEqual(4.5); + expect(group.widgets["sampler_name"].value).toEqual("uni_pc"); + expect(group.widgets["scheduler"].value).toEqual("karras"); + expect(group.widgets["denoise"].value).toEqual(0.9); + + expect((await graph.toPrompt()).output).toEqual( + getOutput([nodes.ckpt.id, nodes.pos.id, nodes.neg.id, nodes.empty.id, nodes.sampler.id], { + [nodes.ckpt.id]: { ckpt_name: "model2.ckpt" }, + [nodes.pos.id]: { text: "hello" }, + [nodes.neg.id]: { text: "world" }, + [nodes.empty.id]: { width: 256, height: 1024 }, + [nodes.sampler.id]: { + seed: 1, + steps: 8, + cfg: 4.5, + sampler_name: "uni_pc", + scheduler: "karras", + denoise: 0.9, + }, + }) + ); + }); + test("group inputs can be reroutes", async () => { + const { ez, graph, app } = await start(); + const nodes = createDefaultWorkflow(ez, graph); + const group = await convertToGroup(app, graph, "test", [nodes.pos, nodes.neg]); + + const reroute = ez.Reroute(); + nodes.ckpt.outputs.CLIP.connectTo(reroute.inputs[0]); + + reroute.outputs[0].connectTo(group.inputs[0]); + reroute.outputs[0].connectTo(group.inputs[1]); + + expect((await graph.toPrompt()).output).toEqual(getOutput([nodes.pos.id, nodes.neg.id])); + }); + test("group outputs can be reroutes", async () => { + const { ez, graph, app } = await start(); + const nodes = createDefaultWorkflow(ez, graph); + const group = await convertToGroup(app, graph, "test", [nodes.pos, nodes.neg]); + + const reroute1 = ez.Reroute(); + const reroute2 = ez.Reroute(); + group.outputs[0].connectTo(reroute1.inputs[0]); + group.outputs[1].connectTo(reroute2.inputs[0]); + + reroute1.outputs[0].connectTo(nodes.sampler.inputs.positive); + reroute2.outputs[0].connectTo(nodes.sampler.inputs.negative); + + expect((await graph.toPrompt()).output).toEqual(getOutput([nodes.pos.id, nodes.neg.id])); + }); + test("groups can connect to each other", async () => { + const { ez, graph, app } = await start(); + const nodes = createDefaultWorkflow(ez, graph); + const group1 = await convertToGroup(app, graph, "test", [nodes.pos, nodes.neg]); + const group2 = await convertToGroup(app, graph, "test2", [nodes.empty, nodes.sampler]); + + group1.outputs[0].connectTo(group2.inputs["positive"]); + group1.outputs[1].connectTo(group2.inputs["negative"]); + + expect((await graph.toPrompt()).output).toEqual( + getOutput([nodes.pos.id, nodes.neg.id, nodes.empty.id, nodes.sampler.id]) + ); + }); + test("displays generated image on group node", async () => { + const { ez, graph, app } = await start(); + const nodes = createDefaultWorkflow(ez, graph); + let group = await convertToGroup(app, graph, "test", [ + nodes.pos, + nodes.neg, + nodes.empty, + nodes.sampler, + nodes.decode, + nodes.save, + ]); + + const { api } = require("../../web/scripts/api"); + + api.dispatchEvent(new CustomEvent("execution_start", {})); + api.dispatchEvent(new CustomEvent("executing", { detail: `${nodes.save.id}` })); + // Event should be forwarded to group node id + expect(+app.runningNodeId).toEqual(group.id); + expect(group.node["imgs"]).toBeFalsy(); + api.dispatchEvent( + new CustomEvent("executed", { + detail: { + node: `${nodes.save.id}`, + output: { + images: [ + { + filename: "test.png", + type: "output", + }, + ], + }, + }, + }) + ); + + // Trigger paint + group.node.onDrawBackground?.(app.canvas.ctx, app.canvas.canvas); + + expect(group.node["images"]).toEqual([ + { + filename: "test.png", + type: "output", + }, + ]); + + // Reload + const workflow = JSON.stringify((await graph.toPrompt()).workflow); + await app.loadGraphData(JSON.parse(workflow)); + group = graph.find(group); + + // Trigger inner nodes to get created + group.node["getInnerNodes"](); + + // Check it works for internal node ids + api.dispatchEvent(new CustomEvent("execution_start", {})); + api.dispatchEvent(new CustomEvent("executing", { detail: `${group.id}:5` })); + // Event should be forwarded to group node id + expect(+app.runningNodeId).toEqual(group.id); + expect(group.node["imgs"]).toBeFalsy(); + api.dispatchEvent( + new CustomEvent("executed", { + detail: { + node: `${group.id}:5`, + output: { + images: [ + { + filename: "test2.png", + type: "output", + }, + ], + }, + }, + }) + ); + + // Trigger paint + group.node.onDrawBackground?.(app.canvas.ctx, app.canvas.canvas); + + expect(group.node["images"]).toEqual([ + { + filename: "test2.png", + type: "output", + }, + ]); + }); + test("allows widgets to be converted to inputs", async () => { + const { ez, graph, app } = await start(); + const nodes = createDefaultWorkflow(ez, graph); + const group = await convertToGroup(app, graph, "test", [nodes.pos, nodes.neg]); + group.widgets[0].convertToInput(); + + const primitive = ez.PrimitiveNode(); + primitive.outputs[0].connectTo(group.inputs["text"]); + primitive.widgets[0].value = "hello"; + + expect((await graph.toPrompt()).output).toEqual( + getOutput([nodes.pos.id, nodes.neg.id], { + [nodes.pos.id]: { text: "hello" }, + }) + ); + }); + test("can be copied", async () => { + const { ez, graph, app } = await start(); + const nodes = createDefaultWorkflow(ez, graph); + + const group1 = await convertToGroup(app, graph, "test", [ + nodes.pos, + nodes.neg, + nodes.empty, + nodes.sampler, + nodes.decode, + nodes.save, + ]); + + group1.widgets["text"].value = "hello"; + group1.widgets["width"].value = 256; + group1.widgets["seed"].value = 1; + + // Clone the node + group1.menu.Clone.call(); + expect(app.graph._nodes).toHaveLength(3); + const group2 = graph.find(app.graph._nodes[2]); + expect(group2.node.type).toEqual("workflow/test"); + expect(group2.id).not.toEqual(group1.id); + + // Reconnect ckpt + nodes.ckpt.outputs.MODEL.connectTo(group2.inputs["model"]); + nodes.ckpt.outputs.CLIP.connectTo(group2.inputs["clip"]); + nodes.ckpt.outputs.CLIP.connectTo(group2.inputs["CLIPTextEncode clip"]); + nodes.ckpt.outputs.VAE.connectTo(group2.inputs["vae"]); + + group2.widgets["text"].value = "world"; + group2.widgets["width"].value = 1024; + group2.widgets["seed"].value = 100; + + let i = 0; + expect((await graph.toPrompt()).output).toEqual({ + ...getOutput([nodes.empty.id, nodes.pos.id, nodes.neg.id, nodes.sampler.id, nodes.decode.id, nodes.save.id], { + [nodes.empty.id]: { width: 256 }, + [nodes.pos.id]: { text: "hello" }, + [nodes.sampler.id]: { seed: 1 }, + }), + ...getOutput( + { + [nodes.empty.id]: `${group2.id}:${i++}`, + [nodes.pos.id]: `${group2.id}:${i++}`, + [nodes.neg.id]: `${group2.id}:${i++}`, + [nodes.sampler.id]: `${group2.id}:${i++}`, + [nodes.decode.id]: `${group2.id}:${i++}`, + [nodes.save.id]: `${group2.id}:${i++}`, + }, + { + [nodes.empty.id]: { width: 1024 }, + [nodes.pos.id]: { text: "world" }, + [nodes.sampler.id]: { seed: 100 }, + } + ), + }); + + graph.arrange(); + }); + test("is embedded in workflow", async () => { + let { ez, graph, app } = await start(); + const nodes = createDefaultWorkflow(ez, graph); + let group = await convertToGroup(app, graph, "test", [nodes.pos, nodes.neg]); + const workflow = JSON.stringify((await graph.toPrompt()).workflow); + + // Clear the environment + ({ ez, graph, app } = await start({ + resetEnv: true, + })); + // Ensure the node isnt registered + expect(() => ez["workflow/test"]).toThrow(); + + // Reload the workflow + await app.loadGraphData(JSON.parse(workflow)); + + // Ensure the node is found + group = graph.find(group); + + // Generate prompt and ensure it is as expected + expect((await graph.toPrompt()).output).toEqual( + getOutput({ + [nodes.pos.id]: `${group.id}:0`, + [nodes.neg.id]: `${group.id}:1`, + }) + ); + }); + test("shows missing node error on missing internal node when loading graph data", async () => { + const { graph } = await start(); + + const dialogShow = jest.spyOn(graph.app.ui.dialog, "show"); + await graph.app.loadGraphData({ + last_node_id: 3, + last_link_id: 1, + nodes: [ + { + id: 3, + type: "workflow/testerror", + }, + ], + links: [], + groups: [], + config: {}, + extra: { + groupNodes: { + testerror: { + nodes: [ + { + type: "NotKSampler", + }, + { + type: "NotVAEDecode", + }, + ], + }, + }, + }, + }); + + expect(dialogShow).toBeCalledTimes(1); + const call = dialogShow.mock.calls[0][0].innerHTML; + expect(call).toContain("the following node types were not found"); + expect(call).toContain("NotKSampler"); + expect(call).toContain("NotVAEDecode"); + expect(call).toContain("workflow/testerror"); + }); + test("maintains widget inputs on conversion back to nodes", async () => { + const { ez, graph, app } = await start(); + let pos = ez.CLIPTextEncode({ text: "positive" }); + pos.node.title = "Positive"; + let neg = ez.CLIPTextEncode({ text: "negative" }); + neg.node.title = "Negative"; + pos.widgets.text.convertToInput(); + neg.widgets.text.convertToInput(); + + let primitive = ez.PrimitiveNode(); + primitive.outputs[0].connectTo(pos.inputs.text); + primitive.outputs[0].connectTo(neg.inputs.text); + + const group = await convertToGroup(app, graph, "test", [pos, neg, primitive]); + // This will use a primitive widget named 'value' + expect(group.widgets.length).toBe(1); + expect(group.widgets["value"].value).toBe("positive"); + + const newNodes = group.menu["Convert to nodes"].call(); + pos = graph.find(newNodes.find((n) => n.title === "Positive")); + neg = graph.find(newNodes.find((n) => n.title === "Negative")); + primitive = graph.find(newNodes.find((n) => n.type === "PrimitiveNode")); + + expect(pos.inputs).toHaveLength(2); + expect(neg.inputs).toHaveLength(2); + expect(primitive.outputs[0].connections).toHaveLength(2); + + expect((await graph.toPrompt()).output).toEqual({ + 1: { inputs: { text: "positive" }, class_type: "CLIPTextEncode" }, + 2: { inputs: { text: "positive" }, class_type: "CLIPTextEncode" }, + }); + }); + test("adds widgets in node execution order", async () => { + const { ez, graph, app } = await start(); + const scale = ez.LatentUpscale(); + const save = ez.SaveImage(); + const empty = ez.EmptyLatentImage(); + const decode = ez.VAEDecode(); + + scale.outputs.LATENT.connectTo(decode.inputs.samples); + decode.outputs.IMAGE.connectTo(save.inputs.images); + empty.outputs.LATENT.connectTo(scale.inputs.samples); + + const group = await convertToGroup(app, graph, "test", [scale, save, empty, decode]); + const widgets = group.widgets.map((w) => w.widget.name); + expect(widgets).toStrictEqual([ + "width", + "height", + "batch_size", + "upscale_method", + "LatentUpscale width", + "LatentUpscale height", + "crop", + "filename_prefix", + ]); + }); + test("adds output for external links when converting to group", async () => { + const { ez, graph, app } = await start(); + const img = ez.EmptyLatentImage(); + let decode = ez.VAEDecode(...img.outputs); + const preview1 = ez.PreviewImage(...decode.outputs); + const preview2 = ez.PreviewImage(...decode.outputs); + + const group = await convertToGroup(app, graph, "test", [img, decode, preview1]); + + // Ensure we have an output connected to the 2nd preview node + expect(group.outputs.length).toBe(1); + expect(group.outputs[0].connections.length).toBe(1); + expect(group.outputs[0].connections[0].targetNode.id).toBe(preview2.id); + + // Convert back and ensure bothe previews are still connected + group.menu["Convert to nodes"].call(); + decode = graph.find(decode); + expect(decode.outputs[0].connections.length).toBe(2); + expect(decode.outputs[0].connections[0].targetNode.id).toBe(preview1.id); + expect(decode.outputs[0].connections[1].targetNode.id).toBe(preview2.id); + }); + test("adds output for external links when converting to group when nodes are not in execution order", async () => { + const { ez, graph, app } = await start(); + const sampler = ez.KSampler(); + const ckpt = ez.CheckpointLoaderSimple(); + const empty = ez.EmptyLatentImage(); + const pos = ez.CLIPTextEncode(ckpt.outputs.CLIP, { text: "positive" }); + const neg = ez.CLIPTextEncode(ckpt.outputs.CLIP, { text: "negative" }); + const decode1 = ez.VAEDecode(sampler.outputs.LATENT, ckpt.outputs.VAE); + const save = ez.SaveImage(decode1.outputs.IMAGE); + ckpt.outputs.MODEL.connectTo(sampler.inputs.model); + pos.outputs.CONDITIONING.connectTo(sampler.inputs.positive); + neg.outputs.CONDITIONING.connectTo(sampler.inputs.negative); + empty.outputs.LATENT.connectTo(sampler.inputs.latent_image); + + const encode = ez.VAEEncode(decode1.outputs.IMAGE); + const vae = ez.VAELoader(); + const decode2 = ez.VAEDecode(encode.outputs.LATENT, vae.outputs.VAE); + const preview = ez.PreviewImage(decode2.outputs.IMAGE); + vae.outputs.VAE.connectTo(encode.inputs.vae); + + const group = await convertToGroup(app, graph, "test", [vae, decode1, encode, sampler]); + + expect(group.outputs.length).toBe(3); + expect(group.outputs[0].output.name).toBe("VAE"); + expect(group.outputs[0].output.type).toBe("VAE"); + expect(group.outputs[1].output.name).toBe("IMAGE"); + expect(group.outputs[1].output.type).toBe("IMAGE"); + expect(group.outputs[2].output.name).toBe("LATENT"); + expect(group.outputs[2].output.type).toBe("LATENT"); + + expect(group.outputs[0].connections.length).toBe(1); + expect(group.outputs[0].connections[0].targetNode.id).toBe(decode2.id); + expect(group.outputs[0].connections[0].targetInput.index).toBe(1); + + expect(group.outputs[1].connections.length).toBe(1); + expect(group.outputs[1].connections[0].targetNode.id).toBe(save.id); + expect(group.outputs[1].connections[0].targetInput.index).toBe(0); + + expect(group.outputs[2].connections.length).toBe(1); + expect(group.outputs[2].connections[0].targetNode.id).toBe(decode2.id); + expect(group.outputs[2].connections[0].targetInput.index).toBe(0); + + expect((await graph.toPrompt()).output).toEqual({ + ...getOutput({ 1: ckpt.id, 2: pos.id, 3: neg.id, 4: empty.id, 5: sampler.id, 6: decode1.id, 7: save.id }), + [vae.id]: { inputs: { vae_name: "vae1.safetensors" }, class_type: vae.node.type }, + [encode.id]: { inputs: { pixels: ["6", 0], vae: [vae.id + "", 0] }, class_type: encode.node.type }, + [decode2.id]: { inputs: { samples: [encode.id + "", 0], vae: [vae.id + "", 0] }, class_type: decode2.node.type }, + [preview.id]: { inputs: { images: [decode2.id + "", 0] }, class_type: preview.node.type }, + }); + }); + test("works with IMAGEUPLOAD widget", async () => { + const { ez, graph, app } = await start(); + const img = ez.LoadImage(); + const preview1 = ez.PreviewImage(img.outputs[0]); + + const group = await convertToGroup(app, graph, "test", [img, preview1]); + const widget = group.widgets["upload"]; + expect(widget).toBeTruthy(); + expect(widget.widget.type).toBe("button"); + }); + test("internal primitive populates widgets for all linked inputs", async () => { + const { ez, graph, app } = await start(); + const img = ez.LoadImage(); + const scale1 = ez.ImageScale(img.outputs[0]); + const scale2 = ez.ImageScale(img.outputs[0]); + ez.PreviewImage(scale1.outputs[0]); + ez.PreviewImage(scale2.outputs[0]); + + scale1.widgets.width.convertToInput(); + scale2.widgets.height.convertToInput(); + + const primitive = ez.PrimitiveNode(); + primitive.outputs[0].connectTo(scale1.inputs.width); + primitive.outputs[0].connectTo(scale2.inputs.height); + + const group = await convertToGroup(app, graph, "test", [img, primitive, scale1, scale2]); + group.widgets.value.value = 100; + expect((await graph.toPrompt()).output).toEqual({ + 1: { + inputs: { image: img.widgets.image.value, upload: "image" }, + class_type: "LoadImage", + }, + 2: { + inputs: { upscale_method: "nearest-exact", width: 100, height: 512, crop: "disabled", image: ["1", 0] }, + class_type: "ImageScale", + }, + 3: { + inputs: { upscale_method: "nearest-exact", width: 512, height: 100, crop: "disabled", image: ["1", 0] }, + class_type: "ImageScale", + }, + 4: { inputs: { images: ["2", 0] }, class_type: "PreviewImage" }, + 5: { inputs: { images: ["3", 0] }, class_type: "PreviewImage" }, + }); + }); + test("primitive control widgets values are copied on convert", async () => { + const { ez, graph, app } = await start(); + const sampler = ez.KSampler(); + sampler.widgets.seed.convertToInput(); + sampler.widgets.sampler_name.convertToInput(); + + let p1 = ez.PrimitiveNode(); + let p2 = ez.PrimitiveNode(); + p1.outputs[0].connectTo(sampler.inputs.seed); + p2.outputs[0].connectTo(sampler.inputs.sampler_name); + + p1.widgets.control_after_generate.value = "increment"; + p2.widgets.control_after_generate.value = "decrement"; + p2.widgets.control_filter_list.value = "/.*/"; + + p2.node.title = "p2"; + + const group = await convertToGroup(app, graph, "test", [sampler, p1, p2]); + expect(group.widgets.control_after_generate.value).toBe("increment"); + expect(group.widgets["p2 control_after_generate"].value).toBe("decrement"); + expect(group.widgets["p2 control_filter_list"].value).toBe("/.*/"); + + group.widgets.control_after_generate.value = "fixed"; + group.widgets["p2 control_after_generate"].value = "randomize"; + group.widgets["p2 control_filter_list"].value = "/.+/"; + + group.menu["Convert to nodes"].call(); + p1 = graph.find(p1); + p2 = graph.find(p2); + + expect(p1.widgets.control_after_generate.value).toBe("fixed"); + expect(p2.widgets.control_after_generate.value).toBe("randomize"); + expect(p2.widgets.control_filter_list.value).toBe("/.+/"); + }); +}); diff --git a/tests-ui/tests/widgetInputs.test.js b/tests-ui/tests/widgetInputs.test.js index e1873105acc..8e191adf043 100644 --- a/tests-ui/tests/widgetInputs.test.js +++ b/tests-ui/tests/widgetInputs.test.js @@ -202,8 +202,8 @@ describe("widget inputs", () => { }); expect(dialogShow).toBeCalledTimes(1); - expect(dialogShow.mock.calls[0][0]).toContain("the following node types were not found"); - expect(dialogShow.mock.calls[0][0]).toContain("TestNode"); + expect(dialogShow.mock.calls[0][0].innerHTML).toContain("the following node types were not found"); + expect(dialogShow.mock.calls[0][0].innerHTML).toContain("TestNode"); }); test("defaultInput widgets can be converted back to inputs", async () => { diff --git a/tests-ui/utils/ezgraph.js b/tests-ui/utils/ezgraph.js index 0e81fd47beb..898b82db051 100644 --- a/tests-ui/utils/ezgraph.js +++ b/tests-ui/utils/ezgraph.js @@ -150,7 +150,7 @@ export class EzNodeMenuItem { if (selectNode) { this.node.select(); } - this.item.callback.call(this.node.node, undefined, undefined, undefined, undefined, this.node.node); + return this.item.callback.call(this.node.node, undefined, undefined, undefined, undefined, this.node.node); } } @@ -240,8 +240,12 @@ export class EzNode { return this.#makeLookupArray(() => this.app.canvas.getNodeMenuOptions(this.node), "content", EzNodeMenuItem); } - select() { - this.app.canvas.selectNode(this.node); + get isRemoved() { + return !this.app.graph.getNodeById(this.id); + } + + select(addToSelection = false) { + this.app.canvas.selectNode(this.node, addToSelection); } // /** @@ -275,12 +279,17 @@ export class EzNode { if (!s) return p; const name = s[nameProperty]; + const item = new ctor(this, i, s); // @ts-ignore - if (!name || name in p) { - throw new Error(`Unable to store ${nodeProperty} ${name} on array as name conflicts.`); + p.push(item); + if (name) { + // @ts-ignore + if (name in p) { + throw new Error(`Unable to store ${nodeProperty} ${name} on array as name conflicts.`); + } } // @ts-ignore - p.push((p[name] = new ctor(this, i, s))); + p[name] = item; return p; }, Object.assign([], { $: this })); } @@ -348,6 +357,19 @@ export class EzGraph { }, 10); }); } + + /** + * @returns { Promise<{ + * workflow: {}, + * output: Record + * }>}> } + */ + toPrompt() { + // @ts-ignore + return this.app.graphToPrompt(); + } } export const Ez = { @@ -356,12 +378,12 @@ export const Ez = { * @example * const { ez, graph } = Ez.graph(app); * graph.clear(); - * const [model, clip, vae] = ez.CheckpointLoaderSimple(); - * const [pos] = ez.CLIPTextEncode(clip, { text: "positive" }); - * const [neg] = ez.CLIPTextEncode(clip, { text: "negative" }); - * const [latent] = ez.KSampler(model, pos, neg, ...ez.EmptyLatentImage()); - * const [image] = ez.VAEDecode(latent, vae); - * const saveNode = ez.SaveImage(image).node; + * const [model, clip, vae] = ez.CheckpointLoaderSimple().outputs; + * const [pos] = ez.CLIPTextEncode(clip, { text: "positive" }).outputs; + * const [neg] = ez.CLIPTextEncode(clip, { text: "negative" }).outputs; + * const [latent] = ez.KSampler(model, pos, neg, ...ez.EmptyLatentImage().outputs).outputs; + * const [image] = ez.VAEDecode(latent, vae).outputs; + * const saveNode = ez.SaveImage(image); * console.log(saveNode); * graph.arrange(); * @param { app } app diff --git a/tests-ui/utils/index.js b/tests-ui/utils/index.js index 01c58b21f5c..eeccdb3d921 100644 --- a/tests-ui/utils/index.js +++ b/tests-ui/utils/index.js @@ -1,21 +1,28 @@ const { mockApi } = require("./setup"); const { Ez } = require("./ezgraph"); +const lg = require("./litegraph"); /** * - * @param { Parameters[0] } config + * @param { Parameters[0] & { resetEnv?: boolean } } config * @returns */ export async function start(config = undefined) { + if(config?.resetEnv) { + jest.resetModules(); + jest.resetAllMocks(); + lg.setup(global); + } + mockApi(config); const { app } = require("../../web/scripts/app"); await app.setup(); - return Ez.graph(app, global["LiteGraph"], global["LGraphCanvas"]); + return { ...Ez.graph(app, global["LiteGraph"], global["LGraphCanvas"]), app }; } /** - * @param { ReturnType["graph"] } graph - * @param { (hasReloaded: boolean) => (Promise | void) } cb + * @param { ReturnType["graph"] } graph + * @param { (hasReloaded: boolean) => (Promise | void) } cb */ export async function checkBeforeAndAfterReload(graph, cb) { await cb(false); @@ -24,10 +31,10 @@ export async function checkBeforeAndAfterReload(graph, cb) { } /** - * @param { string } name - * @param { Record } input + * @param { string } name + * @param { Record } input * @param { (string | string[])[] | Record } output - * @returns { Record } + * @returns { Record } */ export function makeNodeDef(name, input, output = {}) { const nodeDef = { @@ -37,19 +44,19 @@ export function makeNodeDef(name, input, output = {}) { output_name: [], output_is_list: [], input: { - required: {} + required: {}, }, }; - for(const k in input) { + for (const k in input) { nodeDef.input.required[k] = typeof input[k] === "string" ? [input[k], {}] : [...input[k]]; } - if(output instanceof Array) { + if (output instanceof Array) { output = output.reduce((p, c) => { p[c] = c; return p; - }, {}) + }, {}); } - for(const k in output) { + for (const k in output) { nodeDef.output.push(output[k]); nodeDef.output_name.push(k); nodeDef.output_is_list.push(false); @@ -68,4 +75,31 @@ export function assertNotNullOrUndefined(x) { expect(x).not.toEqual(null); expect(x).not.toEqual(undefined); return true; -} \ No newline at end of file +} + +/** + * + * @param { ReturnType["ez"] } ez + * @param { ReturnType["graph"] } graph + */ +export function createDefaultWorkflow(ez, graph) { + graph.clear(); + const ckpt = ez.CheckpointLoaderSimple(); + + const pos = ez.CLIPTextEncode(ckpt.outputs.CLIP, { text: "positive" }); + const neg = ez.CLIPTextEncode(ckpt.outputs.CLIP, { text: "negative" }); + + const empty = ez.EmptyLatentImage(); + const sampler = ez.KSampler( + ckpt.outputs.MODEL, + pos.outputs.CONDITIONING, + neg.outputs.CONDITIONING, + empty.outputs.LATENT + ); + + const decode = ez.VAEDecode(sampler.outputs.LATENT, ckpt.outputs.VAE); + const save = ez.SaveImage(decode.outputs.IMAGE); + graph.arrange(); + + return { ckpt, pos, neg, empty, sampler, decode, save }; +} diff --git a/tests-ui/utils/setup.js b/tests-ui/utils/setup.js index 17e8ac1ad28..dd150214a34 100644 --- a/tests-ui/utils/setup.js +++ b/tests-ui/utils/setup.js @@ -30,16 +30,20 @@ export function mockApi({ mockExtensions, mockNodeDefs } = {}) { mockNodeDefs = JSON.parse(fs.readFileSync(path.resolve("./data/object_info.json"))); } + const events = new EventTarget(); + const mockApi = { + addEventListener: events.addEventListener.bind(events), + removeEventListener: events.removeEventListener.bind(events), + dispatchEvent: events.dispatchEvent.bind(events), + getSystemStats: jest.fn(), + getExtensions: jest.fn(() => mockExtensions), + getNodeDefs: jest.fn(() => mockNodeDefs), + init: jest.fn(), + apiURL: jest.fn((x) => "../../web/" + x), + }; jest.mock("../../web/scripts/api", () => ({ get api() { - return { - addEventListener: jest.fn(), - getSystemStats: jest.fn(), - getExtensions: jest.fn(() => mockExtensions), - getNodeDefs: jest.fn(() => mockNodeDefs), - init: jest.fn(), - apiURL: jest.fn((x) => "../../web/" + x), - }; + return mockApi; }, })); } diff --git a/web/extensions/core/groupNode.js b/web/extensions/core/groupNode.js new file mode 100644 index 00000000000..450b4f5f35c --- /dev/null +++ b/web/extensions/core/groupNode.js @@ -0,0 +1,1054 @@ +import { app } from "../../scripts/app.js"; +import { api } from "../../scripts/api.js"; +import { getWidgetType } from "../../scripts/widgets.js"; +import { mergeIfValid } from "./widgetInputs.js"; + +const GROUP = Symbol(); + +const Workflow = { + InUse: { + Free: 0, + Registered: 1, + InWorkflow: 2, + }, + isInUseGroupNode(name) { + const id = `workflow/${name}`; + // Check if lready registered/in use in this workflow + if (app.graph.extra?.groupNodes?.[name]) { + if (app.graph._nodes.find((n) => n.type === id)) { + return Workflow.InUse.InWorkflow; + } else { + return Workflow.InUse.Registered; + } + } + return Workflow.InUse.Free; + }, + storeGroupNode(name, data) { + let extra = app.graph.extra; + if (!extra) app.graph.extra = extra = {}; + let groupNodes = extra.groupNodes; + if (!groupNodes) extra.groupNodes = groupNodes = {}; + groupNodes[name] = data; + }, +}; + +class GroupNodeBuilder { + constructor(nodes) { + this.nodes = nodes; + } + + build() { + const name = this.getName(); + if (!name) return; + + // Sort the nodes so they are in execution order + // this allows for widgets to be in the correct order when reconstructing + this.sortNodes(); + + this.nodeData = this.getNodeData(); + Workflow.storeGroupNode(name, this.nodeData); + + return { name, nodeData: this.nodeData }; + } + + getName() { + const name = prompt("Enter group name"); + if (!name) return; + const used = Workflow.isInUseGroupNode(name); + switch (used) { + case Workflow.InUse.InWorkflow: + alert( + "An in use group node with this name already exists embedded in this workflow, please remove any instances or use a new name." + ); + return; + case Workflow.InUse.Registered: + if ( + !confirm( + "An group node with this name already exists embedded in this workflow, are you sure you want to overwrite it?" + ) + ) { + return; + } + break; + } + return name; + } + + sortNodes() { + // Gets the builders nodes in graph execution order + const nodesInOrder = app.graph.computeExecutionOrder(false); + this.nodes = this.nodes + .map((node) => ({ index: nodesInOrder.indexOf(node), node })) + .sort((a, b) => a.index - b.index || a.node.id - b.node.id) + .map(({ node }) => node); + } + + getNodeData() { + const storeLinkTypes = (config) => { + // Store link types for dynamically typed nodes e.g. reroutes + for (const link of config.links) { + const origin = app.graph.getNodeById(link[4]); + const type = origin.outputs[link[1]].type; + link.push(type); + } + }; + + const storeExternalLinks = (config) => { + // Store any external links to the group in the config so when rebuilding we add extra slots + config.external = []; + for (let i = 0; i < this.nodes.length; i++) { + const node = this.nodes[i]; + if (!node.outputs?.length) continue; + for (let slot = 0; slot < node.outputs.length; slot++) { + let hasExternal = false; + const output = node.outputs[slot]; + let type = output.type; + if (!output.links?.length) continue; + for (const l of output.links) { + const link = app.graph.links[l]; + if (!link) continue; + if (type === "*") type = link.type; + + if (!app.canvas.selected_nodes[link.target_id]) { + hasExternal = true; + break; + } + } + if (hasExternal) { + config.external.push([i, slot, type]); + } + } + } + }; + + // Use the built in copyToClipboard function to generate the node data we need + const backup = localStorage.getItem("litegrapheditor_clipboard"); + try { + app.canvas.copyToClipboard(this.nodes); + const config = JSON.parse(localStorage.getItem("litegrapheditor_clipboard")); + + storeLinkTypes(config); + storeExternalLinks(config); + + return config; + } finally { + localStorage.setItem("litegrapheditor_clipboard", backup); + } + } +} + +export class GroupNodeConfig { + constructor(name, nodeData) { + this.name = name; + this.nodeData = nodeData; + this.getLinks(); + + this.inputCount = 0; + this.oldToNewOutputMap = {}; + this.newToOldOutputMap = {}; + this.oldToNewInputMap = {}; + this.oldToNewWidgetMap = {}; + this.newToOldWidgetMap = {}; + this.primitiveDefs = {}; + this.widgetToPrimitive = {}; + this.primitiveToWidget = {}; + } + + async registerType(source = "workflow") { + this.nodeDef = { + output: [], + output_name: [], + output_is_list: [], + name: source + "/" + this.name, + display_name: this.name, + category: "group nodes" + ("/" + source), + input: { required: {} }, + + [GROUP]: this, + }; + + this.inputs = []; + const seenInputs = {}; + const seenOutputs = {}; + for (let i = 0; i < this.nodeData.nodes.length; i++) { + const node = this.nodeData.nodes[i]; + node.index = i; + this.processNode(node, seenInputs, seenOutputs); + } + await app.registerNodeDef("workflow/" + this.name, this.nodeDef); + } + + getLinks() { + this.linksFrom = {}; + this.linksTo = {}; + this.externalFrom = {}; + + // Extract links for easy lookup + for (const l of this.nodeData.links) { + const [sourceNodeId, sourceNodeSlot, targetNodeId, targetNodeSlot] = l; + + // Skip links outside the copy config + if (sourceNodeId == null) continue; + + if (!this.linksFrom[sourceNodeId]) { + this.linksFrom[sourceNodeId] = {}; + } + this.linksFrom[sourceNodeId][sourceNodeSlot] = l; + + if (!this.linksTo[targetNodeId]) { + this.linksTo[targetNodeId] = {}; + } + this.linksTo[targetNodeId][targetNodeSlot] = l; + } + + if (this.nodeData.external) { + for (const ext of this.nodeData.external) { + if (!this.externalFrom[ext[0]]) { + this.externalFrom[ext[0]] = { [ext[1]]: ext[2] }; + } else { + this.externalFrom[ext[0]][ext[1]] = ext[2]; + } + } + } + } + + processNode(node, seenInputs, seenOutputs) { + const def = this.getNodeDef(node); + if (!def) return; + + const inputs = { ...def.input?.required, ...def.input?.optional }; + + this.inputs.push(this.processNodeInputs(node, seenInputs, inputs)); + if (def.output?.length) this.processNodeOutputs(node, seenOutputs, def); + } + + getNodeDef(node) { + const def = globalDefs[node.type]; + if (def) return def; + + const linksFrom = this.linksFrom[node.index]; + if (node.type === "PrimitiveNode") { + // Skip as its not linked + if (!linksFrom) return; + + let type = linksFrom["0"][5]; + if (type === "COMBO") { + // Use the array items + const source = node.outputs[0].widget.name; + const fromTypeName = this.nodeData.nodes[linksFrom["0"][2]].type; + const fromType = globalDefs[fromTypeName]; + const input = fromType.input.required[source] ?? fromType.input.optional[source]; + type = input[0]; + } + + const def = (this.primitiveDefs[node.index] = { + input: { + required: { + value: [type, {}], + }, + }, + output: [type], + output_name: [], + output_is_list: [], + }); + return def; + } else if (node.type === "Reroute") { + const linksTo = this.linksTo[node.index]; + if (linksTo && linksFrom && !this.externalFrom[node.index]?.[0]) { + // Being used internally + return null; + } + + let rerouteType = "*"; + if (linksFrom) { + const [, , id, slot] = linksFrom["0"]; + rerouteType = this.nodeData.nodes[id].inputs[slot].type; + } else if (linksTo) { + const [id, slot] = linksTo["0"]; + rerouteType = this.nodeData.nodes[id].outputs[slot].type; + } else { + // Reroute used as a pipe + for (const l of this.nodeData.links) { + if (l[2] === node.index) { + rerouteType = l[5]; + break; + } + } + if (rerouteType === "*") { + // Check for an external link + const t = this.externalFrom[node.index]?.[0]; + if (t) { + rerouteType = t; + } + } + } + + return { + input: { + required: { + [rerouteType]: [rerouteType, {}], + }, + }, + output: [rerouteType], + output_name: [], + output_is_list: [], + }; + } + + console.warn("Skipping virtual node " + node.type + " when building group node " + this.name); + } + + getInputConfig(node, inputName, seenInputs, config, extra) { + let name = node.inputs?.find((inp) => inp.name === inputName)?.label ?? inputName; + let prefix = ""; + // Special handling for primitive to include the title if it is set rather than just "value" + if ((node.type === "PrimitiveNode" && node.title) || name in seenInputs) { + prefix = `${node.title ?? node.type} `; + name = `${prefix}${inputName}`; + if (name in seenInputs) { + name = `${prefix}${seenInputs[name]} ${inputName}`; + } + } + seenInputs[name] = (seenInputs[name] ?? 1) + 1; + + if (inputName === "seed" || inputName === "noise_seed") { + if (!extra) extra = {}; + extra.control_after_generate = `${prefix}control_after_generate`; + } + if (config[0] === "IMAGEUPLOAD") { + if (!extra) extra = {}; + extra.widget = `${prefix}${config[1]?.widget ?? "image"}`; + } + + if (extra) { + config = [config[0], { ...config[1], ...extra }]; + } + + return { name, config }; + } + + processWidgetInputs(inputs, node, inputNames, seenInputs) { + const slots = []; + const converted = new Map(); + const widgetMap = (this.oldToNewWidgetMap[node.index] = {}); + for (const inputName of inputNames) { + let widgetType = getWidgetType(inputs[inputName], inputName); + if (widgetType) { + const convertedIndex = node.inputs?.findIndex( + (inp) => inp.name === inputName && inp.widget?.name === inputName + ); + if (convertedIndex > -1) { + // This widget has been converted to a widget + // We need to store this in the correct position so link ids line up + converted.set(convertedIndex, inputName); + widgetMap[inputName] = null; + } else { + // Normal widget + const { name, config } = this.getInputConfig(node, inputName, seenInputs, inputs[inputName]); + this.nodeDef.input.required[name] = config; + widgetMap[inputName] = name; + this.newToOldWidgetMap[name] = { node, inputName }; + } + } else { + // Normal input + slots.push(inputName); + } + } + return { converted, slots }; + } + + checkPrimitiveConnection(link, inputName, inputs) { + const sourceNode = this.nodeData.nodes[link[0]]; + if (sourceNode.type === "PrimitiveNode") { + // Merge link configurations + const [sourceNodeId, _, targetNodeId, __] = link; + const primitiveDef = this.primitiveDefs[sourceNodeId]; + const targetWidget = inputs[inputName]; + const primitiveConfig = primitiveDef.input.required.value; + const output = { widget: primitiveConfig }; + const config = mergeIfValid(output, targetWidget, false, null, primitiveConfig); + primitiveConfig[1] = config?.customConfig ?? inputs[inputName][1] ? { ...inputs[inputName][1] } : {}; + + let name = this.oldToNewWidgetMap[sourceNodeId]["value"]; + name = name.substr(0, name.length - 6); + primitiveConfig[1].control_after_generate = true; + primitiveConfig[1].control_prefix = name; + + let toPrimitive = this.widgetToPrimitive[targetNodeId]; + if (!toPrimitive) { + toPrimitive = this.widgetToPrimitive[targetNodeId] = {}; + } + if (toPrimitive[inputName]) { + toPrimitive[inputName].push(sourceNodeId); + } + toPrimitive[inputName] = sourceNodeId; + + let toWidget = this.primitiveToWidget[sourceNodeId]; + if (!toWidget) { + toWidget = this.primitiveToWidget[sourceNodeId] = []; + } + toWidget.push({ nodeId: targetNodeId, inputName }); + } + } + + processInputSlots(inputs, node, slots, linksTo, inputMap, seenInputs) { + for (let i = 0; i < slots.length; i++) { + const inputName = slots[i]; + if (linksTo[i]) { + this.checkPrimitiveConnection(linksTo[i], inputName, inputs); + // This input is linked so we can skip it + continue; + } + + const { name, config } = this.getInputConfig(node, inputName, seenInputs, inputs[inputName]); + this.nodeDef.input.required[name] = config; + inputMap[i] = this.inputCount++; + } + } + + processConvertedWidgets(inputs, node, slots, converted, linksTo, inputMap, seenInputs) { + // Add converted widgets sorted into their index order (ordered as they were converted) so link ids match up + const convertedSlots = [...converted.keys()].sort().map((k) => converted.get(k)); + for (let i = 0; i < convertedSlots.length; i++) { + const inputName = convertedSlots[i]; + if (linksTo[slots.length + i]) { + this.checkPrimitiveConnection(linksTo[slots.length + i], inputName, inputs); + // This input is linked so we can skip it + continue; + } + + const { name, config } = this.getInputConfig(node, inputName, seenInputs, inputs[inputName], { + defaultInput: true, + }); + this.nodeDef.input.required[name] = config; + inputMap[slots.length + i] = this.inputCount++; + } + } + + processNodeInputs(node, seenInputs, inputs) { + const inputMapping = []; + + const inputNames = Object.keys(inputs); + if (!inputNames.length) return; + + const { converted, slots } = this.processWidgetInputs(inputs, node, inputNames, seenInputs); + const linksTo = this.linksTo[node.index] ?? {}; + const inputMap = (this.oldToNewInputMap[node.index] = {}); + this.processInputSlots(inputs, node, slots, linksTo, inputMap, seenInputs); + this.processConvertedWidgets(inputs, node, slots, converted, linksTo, inputMap, seenInputs); + + return inputMapping; + } + + processNodeOutputs(node, seenOutputs, def) { + const oldToNew = (this.oldToNewOutputMap[node.index] = {}); + + // Add outputs + for (let outputId = 0; outputId < def.output.length; outputId++) { + const linksFrom = this.linksFrom[node.index]; + if (linksFrom?.[outputId] && !this.externalFrom[node.index]?.[outputId]) { + // This output is linked internally so we can skip it + continue; + } + + oldToNew[outputId] = this.nodeDef.output.length; + this.newToOldOutputMap[this.nodeDef.output.length] = { node, slot: outputId }; + this.nodeDef.output.push(def.output[outputId]); + this.nodeDef.output_is_list.push(def.output_is_list[outputId]); + + let label = def.output_name?.[outputId] ?? def.output[outputId]; + const output = node.outputs.find((o) => o.name === label); + if (output?.label) { + label = output.label; + } + let name = label; + if (name in seenOutputs) { + const prefix = `${node.title ?? node.type} `; + name = `${prefix}${label}`; + if (name in seenOutputs) { + name = `${prefix}${node.index} ${label}`; + } + } + seenOutputs[name] = 1; + + this.nodeDef.output_name.push(name); + } + } + + static async registerFromWorkflow(groupNodes, missingNodeTypes) { + for (const g in groupNodes) { + const groupData = groupNodes[g]; + + let hasMissing = false; + for (const n of groupData.nodes) { + // Find missing node types + if (!(n.type in LiteGraph.registered_node_types)) { + missingNodeTypes.push(n.type); + hasMissing = true; + } + } + + if (hasMissing) continue; + + const config = new GroupNodeConfig(g, groupData); + await config.registerType(); + } + } +} + +export class GroupNodeHandler { + node; + groupData; + + constructor(node) { + this.node = node; + this.groupData = node.constructor?.nodeData?.[GROUP]; + + this.node.setInnerNodes = (innerNodes) => { + this.innerNodes = innerNodes; + + for (let innerNodeIndex = 0; innerNodeIndex < this.innerNodes.length; innerNodeIndex++) { + const innerNode = this.innerNodes[innerNodeIndex]; + + for (const w of innerNode.widgets ?? []) { + if (w.type === "converted-widget") { + w.serializeValue = w.origSerializeValue; + } + } + + innerNode.index = innerNodeIndex; + innerNode.getInputNode = (slot) => { + // Check if this input is internal or external + const externalSlot = this.groupData.oldToNewInputMap[innerNode.index]?.[slot]; + if (externalSlot != null) { + return this.node.getInputNode(externalSlot); + } + + // Internal link + const innerLink = this.groupData.linksTo[innerNode.index]?.[slot]; + if (!innerLink) return null; + + const inputNode = innerNodes[innerLink[0]]; + // Primitives will already apply their values + if (inputNode.type === "PrimitiveNode") return null; + + return inputNode; + }; + + innerNode.getInputLink = (slot) => { + const externalSlot = this.groupData.oldToNewInputMap[innerNode.index]?.[slot]; + if (externalSlot != null) { + // The inner node is connected via the group node inputs + const linkId = this.node.inputs[externalSlot].link; + let link = app.graph.links[linkId]; + + // Use the outer link, but update the target to the inner node + link = { + ...link, + target_id: innerNode.id, + target_slot: +slot, + }; + return link; + } + + let link = this.groupData.linksTo[innerNode.index]?.[slot]; + if (!link) return null; + // Use the inner link, but update the origin node to be inner node id + link = { + origin_id: innerNodes[link[0]].id, + origin_slot: link[1], + target_id: innerNode.id, + target_slot: +slot, + }; + return link; + }; + } + }; + + this.node.updateLink = (link) => { + // Replace the group node reference with the internal node + link = { ...link }; + const output = this.groupData.newToOldOutputMap[link.origin_slot]; + let innerNode = this.innerNodes[output.node.index]; + let l; + while (innerNode.type === "Reroute") { + l = innerNode.getInputLink(0); + innerNode = innerNode.getInputNode(0); + } + + link.origin_id = innerNode.id; + link.origin_slot = l?.origin_slot ?? output.slot; + return link; + }; + + this.node.getInnerNodes = () => { + if (!this.innerNodes) { + this.node.setInnerNodes( + this.groupData.nodeData.nodes.map((n, i) => { + const innerNode = LiteGraph.createNode(n.type); + innerNode.configure(n); + innerNode.id = `${this.node.id}:${i}`; + return innerNode; + }) + ); + } + + this.updateInnerWidgets(); + + return this.innerNodes; + }; + + this.node.convertToNodes = () => { + const addInnerNodes = () => { + const backup = localStorage.getItem("litegrapheditor_clipboard"); + // Clone the node data so we dont mutate it for other nodes + const c = { ...this.groupData.nodeData }; + c.nodes = [...c.nodes]; + const innerNodes = this.node.getInnerNodes(); + let ids = []; + for (let i = 0; i < c.nodes.length; i++) { + let id = innerNodes?.[i]?.id; + // Use existing IDs if they are set on the inner nodes + if (id == null || isNaN(id)) { + id = undefined; + } else { + ids.push(id); + } + c.nodes[i] = { ...c.nodes[i], id }; + } + localStorage.setItem("litegrapheditor_clipboard", JSON.stringify(c)); + app.canvas.pasteFromClipboard(); + localStorage.setItem("litegrapheditor_clipboard", backup); + + const [x, y] = this.node.pos; + let top; + let left; + // Configure nodes with current widget data + const selectedIds = ids.length ? ids : Object.keys(app.canvas.selected_nodes); + const newNodes = []; + for (let i = 0; i < selectedIds.length; i++) { + const id = selectedIds[i]; + const newNode = app.graph.getNodeById(id); + const innerNode = innerNodes[i]; + newNodes.push(newNode); + + if (left == null || newNode.pos[0] < left) { + left = newNode.pos[0]; + } + if (top == null || newNode.pos[1] < top) { + top = newNode.pos[1]; + } + + const map = this.groupData.oldToNewWidgetMap[innerNode.index]; + if (map) { + const widgets = Object.keys(map); + + for (const oldName of widgets) { + const newName = map[oldName]; + if (!newName) continue; + + const widgetIndex = this.node.widgets.findIndex((w) => w.name === newName); + if (widgetIndex === -1) continue; + + // Populate the main and any linked widgets + if (innerNode.type === "PrimitiveNode") { + for (let i = 0; i < newNode.widgets.length; i++) { + newNode.widgets[i].value = this.node.widgets[widgetIndex + i].value; + } + } else { + const outerWidget = this.node.widgets[widgetIndex]; + const newWidget = newNode.widgets.find((w) => w.name === oldName); + if (!newWidget) continue; + + newWidget.value = outerWidget.value; + for (let w = 0; w < outerWidget.linkedWidgets?.length; w++) { + newWidget.linkedWidgets[w].value = outerWidget.linkedWidgets[w].value; + } + } + } + } + } + + // Shift each node + for (const newNode of newNodes) { + newNode.pos = [newNode.pos[0] - (left - x), newNode.pos[1] - (top - y)]; + } + + return { newNodes, selectedIds }; + }; + + const reconnectInputs = (selectedIds) => { + for (const innerNodeIndex in this.groupData.oldToNewInputMap) { + const id = selectedIds[innerNodeIndex]; + const newNode = app.graph.getNodeById(id); + const map = this.groupData.oldToNewInputMap[innerNodeIndex]; + for (const innerInputId in map) { + const groupSlotId = map[innerInputId]; + if (groupSlotId == null) continue; + const slot = node.inputs[groupSlotId]; + if (slot.link == null) continue; + const link = app.graph.links[slot.link]; + // connect this node output to the input of another node + const originNode = app.graph.getNodeById(link.origin_id); + originNode.connect(link.origin_slot, newNode, +innerInputId); + } + } + }; + + const reconnectOutputs = () => { + for (let groupOutputId = 0; groupOutputId < node.outputs?.length; groupOutputId++) { + const output = node.outputs[groupOutputId]; + if (!output.links) continue; + const links = [...output.links]; + for (const l of links) { + const slot = this.groupData.newToOldOutputMap[groupOutputId]; + const link = app.graph.links[l]; + const targetNode = app.graph.getNodeById(link.target_id); + const newNode = app.graph.getNodeById(selectedIds[slot.node.index]); + newNode.connect(slot.slot, targetNode, link.target_slot); + } + } + }; + + const { newNodes, selectedIds } = addInnerNodes(); + reconnectInputs(selectedIds); + reconnectOutputs(selectedIds); + app.graph.remove(this.node); + + return newNodes; + }; + + const getExtraMenuOptions = this.node.getExtraMenuOptions; + this.node.getExtraMenuOptions = function (_, options) { + getExtraMenuOptions?.apply(this, arguments); + + let optionIndex = options.findIndex((o) => o.content === "Outputs"); + if (optionIndex === -1) optionIndex = options.length; + else optionIndex++; + options.splice(optionIndex, 0, null, { + content: "Convert to nodes", + callback: () => { + return this.convertToNodes(); + }, + }); + }; + + // Draw custom collapse icon to identity this as a group + const onDrawTitleBox = this.node.onDrawTitleBox; + this.node.onDrawTitleBox = function (ctx, height, size, scale) { + onDrawTitleBox?.apply(this, arguments); + + const fill = ctx.fillStyle; + ctx.beginPath(); + ctx.rect(11, -height + 11, 2, 2); + ctx.rect(14, -height + 11, 2, 2); + ctx.rect(17, -height + 11, 2, 2); + ctx.rect(11, -height + 14, 2, 2); + ctx.rect(14, -height + 14, 2, 2); + ctx.rect(17, -height + 14, 2, 2); + ctx.rect(11, -height + 17, 2, 2); + ctx.rect(14, -height + 17, 2, 2); + ctx.rect(17, -height + 17, 2, 2); + + ctx.fillStyle = this.boxcolor || LiteGraph.NODE_DEFAULT_BOXCOLOR; + ctx.fill(); + ctx.fillStyle = fill; + }; + + // Draw progress label + const onDrawForeground = node.onDrawForeground; + const groupData = this.groupData.nodeData; + node.onDrawForeground = function (ctx) { + const r = onDrawForeground?.apply?.(this, arguments); + if (+app.runningNodeId === this.id && this.runningInternalNodeId !== null) { + const n = groupData.nodes[this.runningInternalNodeId]; + const message = `Running ${n.title || n.type} (${this.runningInternalNodeId}/${groupData.nodes.length})`; + ctx.save(); + ctx.font = "12px sans-serif"; + const sz = ctx.measureText(message); + ctx.fillStyle = node.boxcolor || LiteGraph.NODE_DEFAULT_BOXCOLOR; + ctx.beginPath(); + ctx.roundRect(0, -LiteGraph.NODE_TITLE_HEIGHT - 20, sz.width + 12, 20, 5); + ctx.fill(); + + ctx.fillStyle = "#fff"; + ctx.fillText(message, 6, -LiteGraph.NODE_TITLE_HEIGHT - 6); + ctx.restore(); + } + }; + + // Flag this node as needing to be reset + const onExecutionStart = this.node.onExecutionStart; + this.node.onExecutionStart = function () { + this.resetExecution = true; + return onExecutionStart?.apply(this, arguments); + }; + + function handleEvent(type, getId, getEvent) { + const handler = ({ detail }) => { + const id = getId(detail); + if (!id) return; + const node = app.graph.getNodeById(id); + if (node) return; + + const innerNodeIndex = this.innerNodes?.findIndex((n) => n.id == id); + if (innerNodeIndex > -1) { + this.node.runningInternalNodeId = innerNodeIndex; + api.dispatchEvent(new CustomEvent(type, { detail: getEvent(detail, this.node.id + "", this.node) })); + } + }; + api.addEventListener(type, handler); + return handler; + } + + const executing = handleEvent.call( + this, + "executing", + (d) => d, + (d, id, node) => id + ); + + const executed = handleEvent.call( + this, + "executed", + (d) => d?.node, + (d, id, node) => ({ ...d, node: id, merge: !node.resetExecution }) + ); + + const onRemoved = node.onRemoved; + this.node.onRemoved = function () { + onRemoved?.apply(this, arguments); + api.removeEventListener("executing", executing); + api.removeEventListener("executed", executed); + }; + } + + updateInnerWidgets() { + for (const newWidgetName in this.groupData.newToOldWidgetMap) { + const newWidget = this.node.widgets.find((w) => w.name === newWidgetName); + if (!newWidget) continue; + + const newValue = newWidget.value; + const old = this.groupData.newToOldWidgetMap[newWidgetName]; + let innerNode = this.innerNodes[old.node.index]; + + if (innerNode.type === "PrimitiveNode") { + innerNode.primitiveValue = newValue; + const primitiveLinked = this.groupData.primitiveToWidget[old.node.index]; + for (const linked of primitiveLinked) { + const node = this.innerNodes[linked.nodeId]; + const widget = node.widgets.find((w) => w.name === linked.inputName); + + if (widget) { + widget.value = newValue; + } + } + continue; + } + + const widget = innerNode.widgets?.find((w) => w.name === old.inputName); + if (widget) { + widget.value = newValue; + } + } + } + + populatePrimitive(node, nodeId, oldName, i, linkedShift) { + // Converted widget, populate primitive if linked + const primitiveId = this.groupData.widgetToPrimitive[nodeId]?.[oldName]; + if (primitiveId == null) return; + const targetWidgetName = this.groupData.oldToNewWidgetMap[primitiveId]["value"]; + const targetWidgetIndex = this.node.widgets.findIndex((w) => w.name === targetWidgetName); + if (targetWidgetIndex > -1) { + const primitiveNode = this.innerNodes[primitiveId]; + let len = primitiveNode.widgets.length; + if (len - 1 !== this.node.widgets[targetWidgetIndex].linkedWidgets?.length) { + // Fallback handling for if some reason the primitive has a different number of widgets + // we dont want to overwrite random widgets, better to leave blank + len = 1; + } + for (let i = 0; i < len; i++) { + this.node.widgets[targetWidgetIndex + i].value = primitiveNode.widgets[i].value; + } + } + } + + populateWidgets() { + for (let nodeId = 0; nodeId < this.groupData.nodeData.nodes.length; nodeId++) { + const node = this.groupData.nodeData.nodes[nodeId]; + + if (!node.widgets_values?.length) continue; + + const map = this.groupData.oldToNewWidgetMap[nodeId]; + const widgets = Object.keys(map); + + let linkedShift = 0; + for (let i = 0; i < widgets.length; i++) { + const oldName = widgets[i]; + const newName = map[oldName]; + const widgetIndex = this.node.widgets.findIndex((w) => w.name === newName); + const mainWidget = this.node.widgets[widgetIndex]; + if (!newName) { + // New name will be null if its a converted widget + this.populatePrimitive(node, nodeId, oldName, i, linkedShift); + + // Find the inner widget and shift by the number of linked widgets as they will have been removed too + const innerWidget = this.innerNodes[nodeId].widgets?.find((w) => w.name === oldName); + linkedShift += innerWidget.linkedWidgets?.length ?? 0; + continue; + } + + if (widgetIndex === -1) { + continue; + } + + // Populate the main and any linked widget + mainWidget.value = node.widgets_values[i + linkedShift]; + for (let w = 0; w < mainWidget.linkedWidgets?.length; w++) { + this.node.widgets[widgetIndex + w + 1].value = node.widgets_values[i + ++linkedShift]; + } + } + } + } + + replaceNodes(nodes) { + let top; + let left; + + for (let i = 0; i < nodes.length; i++) { + const node = nodes[i]; + if (left == null || node.pos[0] < left) { + left = node.pos[0]; + } + if (top == null || node.pos[1] < top) { + top = node.pos[1]; + } + + this.linkOutputs(node, i); + app.graph.remove(node); + } + + this.linkInputs(); + this.node.pos = [left, top]; + } + + linkOutputs(originalNode, nodeId) { + if (!originalNode.outputs) return; + + for (const output of originalNode.outputs) { + if (!output.links) continue; + // Clone the links as they'll be changed if we reconnect + const links = [...output.links]; + for (const l of links) { + const link = app.graph.links[l]; + if (!link) continue; + + const targetNode = app.graph.getNodeById(link.target_id); + const newSlot = this.groupData.oldToNewOutputMap[nodeId]?.[link.origin_slot]; + if (newSlot != null) { + this.node.connect(newSlot, targetNode, link.target_slot); + } + } + } + } + + linkInputs() { + for (const link of this.groupData.nodeData.links ?? []) { + const [, originSlot, targetId, targetSlot, actualOriginId] = link; + const originNode = app.graph.getNodeById(actualOriginId); + if (!originNode) continue; // this node is in the group + originNode.connect(originSlot, this.node.id, this.groupData.oldToNewInputMap[targetId][targetSlot]); + } + } + + static getGroupData(node) { + return node.constructor?.nodeData?.[GROUP]; + } + + static isGroupNode(node) { + return !!node.constructor?.nodeData?.[GROUP]; + } + + static async fromNodes(nodes) { + // Process the nodes into the stored workflow group node data + const builder = new GroupNodeBuilder(nodes); + const res = builder.build(); + if (!res) return; + + const { name, nodeData } = res; + + // Convert this data into a LG node definition and register it + const config = new GroupNodeConfig(name, nodeData); + await config.registerType(); + + const groupNode = LiteGraph.createNode(`workflow/${name}`); + // Reuse the existing nodes for this instance + groupNode.setInnerNodes(builder.nodes); + groupNode[GROUP].populateWidgets(); + app.graph.add(groupNode); + + // Remove all converted nodes and relink them + groupNode[GROUP].replaceNodes(builder.nodes); + return groupNode; + } +} + +function addConvertToGroupOptions() { + function addOption(options, index) { + const selected = Object.values(app.canvas.selected_nodes ?? {}); + const disabled = selected.length < 2 || selected.find((n) => GroupNodeHandler.isGroupNode(n)); + options.splice(index + 1, null, { + content: `Convert to Group Node`, + disabled, + callback: async () => { + return await GroupNodeHandler.fromNodes(selected); + }, + }); + } + + // Add to canvas + const getCanvasMenuOptions = LGraphCanvas.prototype.getCanvasMenuOptions; + LGraphCanvas.prototype.getCanvasMenuOptions = function () { + const options = getCanvasMenuOptions.apply(this, arguments); + const index = options.findIndex((o) => o?.content === "Add Group") + 1 || opts.length; + addOption(options, index); + return options; + }; + + // Add to nodes + const getNodeMenuOptions = LGraphCanvas.prototype.getNodeMenuOptions; + LGraphCanvas.prototype.getNodeMenuOptions = function (node) { + const options = getNodeMenuOptions.apply(this, arguments); + if (!GroupNodeHandler.isGroupNode(node)) { + const index = options.findIndex((o) => o?.content === "Outputs") + 1 || opts.length - 1; + addOption(options, index); + } + return options; + }; +} + +const id = "Comfy.GroupNode"; +let globalDefs; +const ext = { + name: id, + setup() { + addConvertToGroupOptions(); + }, + async beforeConfigureGraph(graphData, missingNodeTypes) { + const nodes = graphData?.extra?.groupNodes; + if (nodes) { + await GroupNodeConfig.registerFromWorkflow(nodes, missingNodeTypes); + } + }, + addCustomNodeDefs(defs) { + // Store this so we can mutate it later with group nodes + globalDefs = defs; + }, + nodeCreated(node) { + if (GroupNodeHandler.isGroupNode(node)) { + node[GROUP] = new GroupNodeHandler(node); + } + }, +}; + +app.registerExtension(ext); diff --git a/web/extensions/core/nodeTemplates.js b/web/extensions/core/nodeTemplates.js index b6479f454da..2d4821742d1 100644 --- a/web/extensions/core/nodeTemplates.js +++ b/web/extensions/core/nodeTemplates.js @@ -1,5 +1,6 @@ import { app } from "../../scripts/app.js"; import { ComfyDialog, $el } from "../../scripts/ui.js"; +import { GroupNodeConfig, GroupNodeHandler } from "./groupNode.js"; // Adds the ability to save and add multiple nodes as a template // To save: @@ -34,7 +35,7 @@ class ManageTemplates extends ComfyDialog { type: "file", accept: ".json", multiple: true, - style: {display: "none"}, + style: { display: "none" }, parent: document.body, onchange: () => this.importAll(), }); @@ -109,13 +110,13 @@ class ManageTemplates extends ComfyDialog { return; } - const json = JSON.stringify({templates: this.templates}, null, 2); // convert the data to a JSON string - const blob = new Blob([json], {type: "application/json"}); + const json = JSON.stringify({ templates: this.templates }, null, 2); // convert the data to a JSON string + const blob = new Blob([json], { type: "application/json" }); const url = URL.createObjectURL(blob); const a = $el("a", { href: url, download: "node_templates.json", - style: {display: "none"}, + style: { display: "none" }, parent: document.body, }); a.click(); @@ -291,11 +292,11 @@ app.registerExtension({ setup() { const manage = new ManageTemplates(); - const clipboardAction = (cb) => { + const clipboardAction = async (cb) => { // We use the clipboard functions but dont want to overwrite the current user clipboard // Restore it after we've run our callback const old = localStorage.getItem("litegrapheditor_clipboard"); - cb(); + await cb(); localStorage.setItem("litegrapheditor_clipboard", old); }; @@ -309,13 +310,31 @@ app.registerExtension({ disabled: !Object.keys(app.canvas.selected_nodes || {}).length, callback: () => { const name = prompt("Enter name"); - if (!name || !name.trim()) return; + if (!name?.trim()) return; clipboardAction(() => { app.canvas.copyToClipboard(); + let data = localStorage.getItem("litegrapheditor_clipboard"); + data = JSON.parse(data); + const nodeIds = Object.keys(app.canvas.selected_nodes); + for (let i = 0; i < nodeIds.length; i++) { + const node = app.graph.getNodeById(nodeIds[i]); + const nodeData = node?.constructor.nodeData; + + let groupData = GroupNodeHandler.getGroupData(node); + if (groupData) { + groupData = groupData.nodeData; + if (!data.groupNodes) { + data.groupNodes = {}; + } + data.groupNodes[nodeData.name] = groupData; + data.nodes[i].type = nodeData.name; + } + } + manage.templates.push({ name, - data: localStorage.getItem("litegrapheditor_clipboard"), + data: JSON.stringify(data), }); manage.store(); }); @@ -323,15 +342,19 @@ app.registerExtension({ }); // Map each template to a menu item - const subItems = manage.templates.map((t) => ({ - content: t.name, - callback: () => { - clipboardAction(() => { - localStorage.setItem("litegrapheditor_clipboard", t.data); - app.canvas.pasteFromClipboard(); - }); - }, - })); + const subItems = manage.templates.map((t) => { + return { + content: t.name, + callback: () => { + clipboardAction(async () => { + const data = JSON.parse(t.data); + await GroupNodeConfig.registerFromWorkflow(data.groupNodes, {}); + localStorage.setItem("litegrapheditor_clipboard", t.data); + app.canvas.pasteFromClipboard(); + }); + }, + }; + }); subItems.push(null, { content: "Manage", diff --git a/web/extensions/core/widgetInputs.js b/web/extensions/core/widgetInputs.js index 5c8fbc9b2d3..b6fa411f7e1 100644 --- a/web/extensions/core/widgetInputs.js +++ b/web/extensions/core/widgetInputs.js @@ -121,6 +121,110 @@ function isValidCombo(combo, obj) { return true; } +export function mergeIfValid(output, config2, forceUpdate, recreateWidget, config1) { + if (!config1) { + config1 = output.widget[CONFIG] ?? output.widget[GET_CONFIG](); + } + + if (config1[0] instanceof Array) { + if (!isValidCombo(config1[0], config2[0])) return false; + } else if (config1[0] !== config2[0]) { + // Types dont match + console.log(`connection rejected: types dont match`, config1[0], config2[0]); + return false; + } + + const keys = new Set([...Object.keys(config1[1] ?? {}), ...Object.keys(config2[1] ?? {})]); + + let customConfig; + const getCustomConfig = () => { + if (!customConfig) { + if (typeof structuredClone === "undefined") { + customConfig = JSON.parse(JSON.stringify(config1[1] ?? {})); + } else { + customConfig = structuredClone(config1[1] ?? {}); + } + } + return customConfig; + }; + + const isNumber = config1[0] === "INT" || config1[0] === "FLOAT"; + for (const k of keys.values()) { + if (k !== "default" && k !== "forceInput" && k !== "defaultInput") { + let v1 = config1[1][k]; + let v2 = config2[1]?.[k]; + + if (v1 === v2 || (!v1 && !v2)) continue; + + if (isNumber) { + if (k === "min") { + const theirMax = config2[1]?.["max"]; + if (theirMax != null && v1 > theirMax) { + console.log("connection rejected: min > max", v1, theirMax); + return false; + } + getCustomConfig()[k] = v1 == null ? v2 : v2 == null ? v1 : Math.max(v1, v2); + continue; + } else if (k === "max") { + const theirMin = config2[1]?.["min"]; + if (theirMin != null && v1 < theirMin) { + console.log("connection rejected: max < min", v1, theirMin); + return false; + } + getCustomConfig()[k] = v1 == null ? v2 : v2 == null ? v1 : Math.min(v1, v2); + continue; + } else if (k === "step") { + let step; + if (v1 == null) { + // No current step + step = v2; + } else if (v2 == null) { + // No new step + step = v1; + } else { + if (v1 < v2) { + // Ensure v1 is larger for the mod + const a = v2; + v2 = v1; + v1 = a; + } + if (v1 % v2) { + console.log("connection rejected: steps not divisible", "current:", v1, "new:", v2); + return false; + } + + step = v1; + } + + getCustomConfig()[k] = step; + continue; + } + } + + console.log(`connection rejected: config ${k} values dont match`, v1, v2); + return false; + } + } + + if (customConfig || forceUpdate) { + if (customConfig) { + output.widget[CONFIG] = [config1[0], customConfig]; + } + + const widget = recreateWidget?.call(this); + // When deleting a node this can be null + if (widget) { + const min = widget.options.min; + const max = widget.options.max; + if (min != null && widget.value < min) widget.value = min; + if (max != null && widget.value > max) widget.value = max; + widget.callback(widget.value); + } + } + + return { customConfig }; +} + app.registerExtension({ name: "Comfy.WidgetInputs", async beforeRegisterNodeDef(nodeType, nodeData, app) { @@ -308,7 +412,7 @@ app.registerExtension({ this.isVirtualNode = true; } - applyToGraph() { + applyToGraph(extraLinks = []) { if (!this.outputs[0].links?.length) return; function get_links(node) { @@ -325,10 +429,9 @@ app.registerExtension({ return links; } - let links = get_links(this); + let links = [...get_links(this).map((l) => app.graph.links[l]), ...extraLinks]; // For each output link copy our value over the original widget value - for (const l of links) { - const linkInfo = app.graph.links[l]; + for (const linkInfo of links) { const node = this.graph.getNodeById(linkInfo.target_id); const input = node.inputs[linkInfo.target_slot]; const widgetName = input.widget.name; @@ -405,7 +508,12 @@ app.registerExtension({ } if (this.outputs[slot].links?.length) { - return this.#isValidConnection(input); + const valid = this.#isValidConnection(input); + if (valid) { + // On connect of additional outputs, copy our value to their widget + this.applyToGraph([{ target_id: target_node.id, target_slot }]); + } + return valid; } } @@ -462,12 +570,12 @@ app.registerExtension({ } } - if (widget.type === "number" || widget.type === "combo") { + if (!inputData?.[1]?.control_after_generate && (widget.type === "number" || widget.type === "combo")) { let control_value = this.widgets_values?.[1]; if (!control_value) { control_value = "fixed"; } - addValueControlWidgets(this, widget, control_value); + addValueControlWidgets(this, widget, control_value, undefined, inputData); let filter = this.widgets_values?.[2]; if(filter && this.widgets.length === 3) { this.widgets[2].value = filter; @@ -507,6 +615,7 @@ app.registerExtension({ this.#removeWidgets(); this.#onFirstConnection(true); for (let i = 0; i < this.widgets?.length; i++) this.widgets[i].value = values[i]; + return this.widgets[0]; } #mergeWidgetConfig() { @@ -547,108 +656,8 @@ app.registerExtension({ #isValidConnection(input, forceUpdate) { // Only allow connections where the configs match const output = this.outputs[0]; - const config1 = output.widget[CONFIG] ?? output.widget[GET_CONFIG](); const config2 = input.widget[GET_CONFIG](); - - if (config1[0] instanceof Array) { - if (!isValidCombo(config1[0], config2[0])) return false; - } else if (config1[0] !== config2[0]) { - // Types dont match - console.log(`connection rejected: types dont match`, config1[0], config2[0]); - return false; - } - - const keys = new Set([...Object.keys(config1[1] ?? {}), ...Object.keys(config2[1] ?? {})]); - - let customConfig; - const getCustomConfig = () => { - if (!customConfig) { - if (typeof structuredClone === "undefined") { - customConfig = JSON.parse(JSON.stringify(config1[1] ?? {})); - } else { - customConfig = structuredClone(config1[1] ?? {}); - } - } - return customConfig; - }; - - const isNumber = config1[0] === "INT" || config1[0] === "FLOAT"; - for (const k of keys.values()) { - if (k !== "default" && k !== "forceInput" && k !== "defaultInput") { - let v1 = config1[1][k]; - let v2 = config2[1][k]; - - if (v1 === v2 || (!v1 && !v2)) continue; - - if (isNumber) { - if (k === "min") { - const theirMax = config2[1]["max"]; - if (theirMax != null && v1 > theirMax) { - console.log("connection rejected: min > max", v1, theirMax); - return false; - } - getCustomConfig()[k] = v1 == null ? v2 : v2 == null ? v1 : Math.max(v1, v2); - continue; - } else if (k === "max") { - const theirMin = config2[1]["min"]; - if (theirMin != null && v1 < theirMin) { - console.log("connection rejected: max < min", v1, theirMin); - return false; - } - getCustomConfig()[k] = v1 == null ? v2 : v2 == null ? v1 : Math.min(v1, v2); - continue; - } else if (k === "step") { - let step; - if (v1 == null) { - // No current step - step = v2; - } else if (v2 == null) { - // No new step - step = v1; - } else { - if (v1 < v2) { - // Ensure v1 is larger for the mod - const a = v2; - v2 = v1; - v1 = a; - } - if (v1 % v2) { - console.log("connection rejected: steps not divisible", "current:", v1, "new:", v2); - return false; - } - - step = v1; - } - - getCustomConfig()[k] = step; - continue; - } - } - - console.log(`connection rejected: config ${k} values dont match`, v1, v2); - return false; - } - } - - if (customConfig || forceUpdate) { - if (customConfig) { - output.widget[CONFIG] = [config1[0], customConfig]; - } - - this.#recreateWidget(); - - const widget = this.widgets[0]; - // When deleting a node this can be null - if (widget) { - const min = widget.options.min; - const max = widget.options.max; - if (min != null && widget.value < min) widget.value = min; - if (max != null && widget.value > max) widget.value = max; - widget.callback(widget.value); - } - } - - return true; + return !!mergeIfValid.call(this, output, config2, forceUpdate, this.#recreateWidget); } #removeWidgets() { diff --git a/web/scripts/app.js b/web/scripts/app.js index cd20c40fd0a..e9cfb277dd4 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -1,5 +1,5 @@ import { ComfyLogging } from "./logging.js"; -import { ComfyWidgets } from "./widgets.js"; +import { ComfyWidgets, getWidgetType } from "./widgets.js"; import { ComfyUI, $el } from "./ui.js"; import { api } from "./api.js"; import { defaultGraph } from "./defaultGraph.js"; @@ -779,7 +779,7 @@ export class ComfyApp { * Adds a handler on paste that extracts and loads images or workflows from pasted JSON data */ #addPasteHandler() { - document.addEventListener("paste", (e) => { + document.addEventListener("paste", async (e) => { // ctrl+shift+v is used to paste nodes with connections // this is handled by litegraph if(this.shiftDown) return; @@ -827,7 +827,7 @@ export class ComfyApp { } if (workflow && workflow.version && workflow.nodes && workflow.extra) { - this.loadGraphData(workflow); + await this.loadGraphData(workflow); } else { if (e.target.type === "text" || e.target.type === "textarea") { @@ -1177,7 +1177,19 @@ export class ComfyApp { }); api.addEventListener("executed", ({ detail }) => { - this.nodeOutputs[detail.node] = detail.output; + const output = this.nodeOutputs[detail.node]; + if (detail.merge && output) { + for (const k in detail.output ?? {}) { + const v = output[k]; + if (v instanceof Array) { + output[k] = v.concat(detail.output[k]); + } else { + output[k] = detail.output[k]; + } + } + } else { + this.nodeOutputs[detail.node] = detail.output; + } const node = this.graph.getNodeById(detail.node); if (node) { if (node.onExecuted) @@ -1292,6 +1304,7 @@ export class ComfyApp { this.#addProcessMouseHandler(); this.#addProcessKeyHandler(); this.#addConfigureHandler(); + this.#addApiUpdateHandlers(); this.graph = new LGraph(); @@ -1328,7 +1341,7 @@ export class ComfyApp { const json = localStorage.getItem("workflow"); if (json) { const workflow = JSON.parse(json); - this.loadGraphData(workflow); + await this.loadGraphData(workflow); restored = true; } } catch (err) { @@ -1337,7 +1350,7 @@ export class ComfyApp { // We failed to restore a workflow so load the default if (!restored) { - this.loadGraphData(); + await this.loadGraphData(); } // Save current workflow automatically @@ -1345,7 +1358,6 @@ export class ComfyApp { this.#addDrawNodeHandler(); this.#addDrawGroupsHandler(); - this.#addApiUpdateHandlers(); this.#addDropHandler(); this.#addCopyHandler(); this.#addPasteHandler(); @@ -1365,11 +1377,81 @@ export class ComfyApp { await this.#invokeExtensionsAsync("registerCustomNodes"); } + async registerNodeDef(nodeId, nodeData) { + const self = this; + const node = Object.assign( + function ComfyNode() { + var inputs = nodeData["input"]["required"]; + if (nodeData["input"]["optional"] != undefined) { + inputs = Object.assign({}, nodeData["input"]["required"], nodeData["input"]["optional"]); + } + const config = { minWidth: 1, minHeight: 1 }; + for (const inputName in inputs) { + const inputData = inputs[inputName]; + const type = inputData[0]; + + let widgetCreated = true; + const widgetType = getWidgetType(inputData, inputName); + if(widgetType) { + if(widgetType === "COMBO") { + Object.assign(config, self.widgets.COMBO(this, inputName, inputData, app) || {}); + } else { + Object.assign(config, self.widgets[widgetType](this, inputName, inputData, app) || {}); + } + } else { + // Node connection inputs + this.addInput(inputName, type); + widgetCreated = false; + } + + if(widgetCreated && inputData[1]?.forceInput && config?.widget) { + if (!config.widget.options) config.widget.options = {}; + config.widget.options.forceInput = inputData[1].forceInput; + } + if(widgetCreated && inputData[1]?.defaultInput && config?.widget) { + if (!config.widget.options) config.widget.options = {}; + config.widget.options.defaultInput = inputData[1].defaultInput; + } + } + + for (const o in nodeData["output"]) { + let output = nodeData["output"][o]; + if(output instanceof Array) output = "COMBO"; + const outputName = nodeData["output_name"][o] || output; + const outputShape = nodeData["output_is_list"][o] ? LiteGraph.GRID_SHAPE : LiteGraph.CIRCLE_SHAPE ; + this.addOutput(outputName, output, { shape: outputShape }); + } + + const s = this.computeSize(); + s[0] = Math.max(config.minWidth, s[0] * 1.5); + s[1] = Math.max(config.minHeight, s[1]); + this.size = s; + this.serialize_widgets = true; + + app.#invokeExtensionsAsync("nodeCreated", this); + }, + { + title: nodeData.display_name || nodeData.name, + comfyClass: nodeData.name, + nodeData + } + ); + node.prototype.comfyClass = nodeData.name; + + this.#addNodeContextMenuHandler(node); + this.#addDrawBackgroundHandler(node, app); + this.#addNodeKeyHandler(node); + + await this.#invokeExtensionsAsync("beforeRegisterNodeDef", node, nodeData); + LiteGraph.registerNodeType(nodeId, node); + node.category = nodeData.category; + } + async registerNodesFromDefs(defs) { await this.#invokeExtensionsAsync("addCustomNodeDefs", defs); // Generate list of known widgets - const widgets = Object.assign( + this.widgets = Object.assign( {}, ComfyWidgets, ...(await this.#invokeExtensionsAsync("getCustomWidgets")).filter(Boolean) @@ -1377,75 +1459,7 @@ export class ComfyApp { // Register a node for each definition for (const nodeId in defs) { - const nodeData = defs[nodeId]; - const node = Object.assign( - function ComfyNode() { - var inputs = nodeData["input"]["required"]; - if (nodeData["input"]["optional"] != undefined){ - inputs = Object.assign({}, nodeData["input"]["required"], nodeData["input"]["optional"]) - } - const config = { minWidth: 1, minHeight: 1 }; - for (const inputName in inputs) { - const inputData = inputs[inputName]; - const type = inputData[0]; - - let widgetCreated = true; - if (Array.isArray(type)) { - // Enums - Object.assign(config, widgets.COMBO(this, inputName, inputData, app) || {}); - } else if (`${type}:${inputName}` in widgets) { - // Support custom widgets by Type:Name - Object.assign(config, widgets[`${type}:${inputName}`](this, inputName, inputData, app) || {}); - } else if (type in widgets) { - // Standard type widgets - Object.assign(config, widgets[type](this, inputName, inputData, app) || {}); - } else { - // Node connection inputs - this.addInput(inputName, type); - widgetCreated = false; - } - - if(widgetCreated && inputData[1]?.forceInput && config?.widget) { - if (!config.widget.options) config.widget.options = {}; - config.widget.options.forceInput = inputData[1].forceInput; - } - if(widgetCreated && inputData[1]?.defaultInput && config?.widget) { - if (!config.widget.options) config.widget.options = {}; - config.widget.options.defaultInput = inputData[1].defaultInput; - } - } - - for (const o in nodeData["output"]) { - let output = nodeData["output"][o]; - if(output instanceof Array) output = "COMBO"; - const outputName = nodeData["output_name"][o] || output; - const outputShape = nodeData["output_is_list"][o] ? LiteGraph.GRID_SHAPE : LiteGraph.CIRCLE_SHAPE ; - this.addOutput(outputName, output, { shape: outputShape }); - } - - const s = this.computeSize(); - s[0] = Math.max(config.minWidth, s[0] * 1.5); - s[1] = Math.max(config.minHeight, s[1]); - this.size = s; - this.serialize_widgets = true; - - app.#invokeExtensionsAsync("nodeCreated", this); - }, - { - title: nodeData.display_name || nodeData.name, - comfyClass: nodeData.name, - nodeData - } - ); - node.prototype.comfyClass = nodeData.name; - - this.#addNodeContextMenuHandler(node); - this.#addDrawBackgroundHandler(node, app); - this.#addNodeKeyHandler(node); - - await this.#invokeExtensionsAsync("beforeRegisterNodeDef", node, nodeData); - LiteGraph.registerNodeType(nodeId, node); - node.category = nodeData.category; + this.registerNodeDef(nodeId, defs[nodeId]); } } @@ -1488,9 +1502,14 @@ export class ComfyApp { showMissingNodesError(missingNodeTypes, hasAddedNodes = true) { this.ui.dialog.show( - `When loading the graph, the following node types were not found:
    ${Array.from(new Set(missingNodeTypes)).map( - (t) => `
  • ${t}
  • ` - ).join("")}
${hasAddedNodes ? "Nodes that have failed to load will show as red on the graph." : ""}` + $el("div", [ + $el("span", { textContent: "When loading the graph, the following node types were not found: " }), + $el( + "ul", + Array.from(new Set(missingNodeTypes)).map((t) => $el("li", { textContent: t })) + ), + ...(hasAddedNodes ? [$el("span", { textContent: "Nodes that have failed to load will show as red on the graph." })] : []), + ]) ); this.logging.addEntry("Comfy.App", "warn", { MissingNodes: missingNodeTypes, @@ -1501,7 +1520,7 @@ export class ComfyApp { * Populates the graph with the specified workflow data * @param {*} graphData A serialized graph object */ - loadGraphData(graphData) { + async loadGraphData(graphData) { this.clean(); let reset_invalid_values = false; @@ -1519,6 +1538,7 @@ export class ComfyApp { } const missingNodeTypes = []; + await this.#invokeExtensionsAsync("beforeConfigureGraph", graphData, missingNodeTypes); for (let n of graphData.nodes) { // Patch T2IAdapterLoader to ControlNetLoader since they are the same node now if (n.type == "T2IAdapterLoader") n.type = "ControlNetLoader"; @@ -1527,8 +1547,8 @@ export class ComfyApp { // Find missing node types if (!(n.type in LiteGraph.registered_node_types)) { - n.type = sanitizeNodeName(n.type); missingNodeTypes.push(n.type); + n.type = sanitizeNodeName(n.type); } } @@ -1627,92 +1647,98 @@ export class ComfyApp { * @returns The workflow and node links */ async graphToPrompt() { - for (const node of this.graph.computeExecutionOrder(false)) { - if (node.isVirtualNode) { - // Don't serialize frontend only nodes but let them make changes - if (node.applyToGraph) { - node.applyToGraph(); + for (const outerNode of this.graph.computeExecutionOrder(false)) { + const innerNodes = outerNode.getInnerNodes ? outerNode.getInnerNodes() : [outerNode]; + for (const node of innerNodes) { + if (node.isVirtualNode) { + // Don't serialize frontend only nodes but let them make changes + if (node.applyToGraph) { + node.applyToGraph(); + } } - continue; } } const workflow = this.graph.serialize(); const output = {}; // Process nodes in order of execution - for (const node of this.graph.computeExecutionOrder(false)) { - const n = workflow.nodes.find((n) => n.id === node.id); - - if (node.isVirtualNode) { - continue; - } + for (const outerNode of this.graph.computeExecutionOrder(false)) { + const innerNodes = outerNode.getInnerNodes ? outerNode.getInnerNodes() : [outerNode]; + for (const node of innerNodes) { + if (node.isVirtualNode) { + continue; + } - if (node.mode === 2 || node.mode === 4) { - // Don't serialize muted nodes - continue; - } + if (node.mode === 2 || node.mode === 4) { + // Don't serialize muted nodes + continue; + } - const inputs = {}; - const widgets = node.widgets; + const inputs = {}; + const widgets = node.widgets; - // Store all widget values - if (widgets) { - for (const i in widgets) { - const widget = widgets[i]; - if (!widget.options || widget.options.serialize !== false) { - inputs[widget.name] = widget.serializeValue ? await widget.serializeValue(n, i) : widget.value; + // Store all widget values + if (widgets) { + for (const i in widgets) { + const widget = widgets[i]; + if (!widget.options || widget.options.serialize !== false) { + inputs[widget.name] = widget.serializeValue ? await widget.serializeValue(node, i) : widget.value; + } } } - } - // Store all node links - for (let i in node.inputs) { - let parent = node.getInputNode(i); - if (parent) { - let link = node.getInputLink(i); - while (parent.mode === 4 || parent.isVirtualNode) { - let found = false; - if (parent.isVirtualNode) { - link = parent.getInputLink(link.origin_slot); - if (link) { - parent = parent.getInputNode(link.target_slot); - if (parent) { - found = true; + // Store all node links + for (let i in node.inputs) { + let parent = node.getInputNode(i); + if (parent) { + let link = node.getInputLink(i); + while (parent.mode === 4 || parent.isVirtualNode) { + let found = false; + if (parent.isVirtualNode) { + link = parent.getInputLink(link.origin_slot); + if (link) { + parent = parent.getInputNode(link.target_slot); + if (parent) { + found = true; + } } - } - } else if (link && parent.mode === 4) { - let all_inputs = [link.origin_slot]; - if (parent.inputs) { - all_inputs = all_inputs.concat(Object.keys(parent.inputs)) - for (let parent_input in all_inputs) { - parent_input = all_inputs[parent_input]; - if (parent.inputs[parent_input]?.type === node.inputs[i].type) { - link = parent.getInputLink(parent_input); - if (link) { - parent = parent.getInputNode(parent_input); + } else if (link && parent.mode === 4) { + let all_inputs = [link.origin_slot]; + if (parent.inputs) { + all_inputs = all_inputs.concat(Object.keys(parent.inputs)) + for (let parent_input in all_inputs) { + parent_input = all_inputs[parent_input]; + if (parent.inputs[parent_input]?.type === node.inputs[i].type) { + link = parent.getInputLink(parent_input); + if (link) { + parent = parent.getInputNode(parent_input); + } + found = true; + break; } - found = true; - break; } } } - } - if (!found) { - break; + if (!found) { + break; + } } - } - if (link) { - inputs[node.inputs[i].name] = [String(link.origin_id), parseInt(link.origin_slot)]; + if (link) { + if (parent?.updateLink) { + link = parent.updateLink(link); + } + inputs[node.inputs[i].name] = [String(link.origin_id), parseInt(link.origin_slot)]; + } } } - } - output[String(node.id)] = { - inputs, - class_type: node.comfyClass, - }; + output[String(node.id)] = { + inputs, + class_type: node.comfyClass, + }; + } } // Remove inputs connected to removed nodes @@ -1832,7 +1858,7 @@ export class ComfyApp { const pngInfo = await getPngMetadata(file); if (pngInfo) { if (pngInfo.workflow) { - this.loadGraphData(JSON.parse(pngInfo.workflow)); + await this.loadGraphData(JSON.parse(pngInfo.workflow)); } else if (pngInfo.parameters) { importA1111(this.graph, pngInfo.parameters); } @@ -1848,21 +1874,21 @@ export class ComfyApp { } } else if (file.type === "application/json" || file.name?.endsWith(".json")) { const reader = new FileReader(); - reader.onload = () => { + reader.onload = async () => { const jsonContent = JSON.parse(reader.result); if (jsonContent?.templates) { this.loadTemplateData(jsonContent); } else if(this.isApiJson(jsonContent)) { this.loadApiJson(jsonContent); } else { - this.loadGraphData(jsonContent); + await this.loadGraphData(jsonContent); } }; reader.readAsText(file); } else if (file.name?.endsWith(".latent") || file.name?.endsWith(".safetensors")) { const info = await getLatentMetadata(file); if (info.workflow) { - this.loadGraphData(JSON.parse(info.workflow)); + await this.loadGraphData(JSON.parse(info.workflow)); } } } diff --git a/web/scripts/domWidget.js b/web/scripts/domWidget.js index 07da591cb7d..37d26f3c5ef 100644 --- a/web/scripts/domWidget.js +++ b/web/scripts/domWidget.js @@ -44,7 +44,7 @@ function getClipPath(node, element, elRect) { } function computeSize(size) { - if (this.widgets?.[0].last_y == null) return; + if (this.widgets?.[0]?.last_y == null) return; let y = this.widgets[0].last_y; let freeSpace = size[1] - y; @@ -195,7 +195,6 @@ export function addDomClippingSetting() { type: "boolean", defaultValue: enableDomClipping, onChange(value) { - console.log("enableDomClipping", enableDomClipping); enableDomClipping = !!value; }, }); diff --git a/web/scripts/ui.js b/web/scripts/ui.js index 8a58d30b3a7..ebaf86fe428 100644 --- a/web/scripts/ui.js +++ b/web/scripts/ui.js @@ -462,8 +462,8 @@ class ComfyList { return $el("div", {textContent: item.prompt[0] + ": "}, [ $el("button", { textContent: "Load", - onclick: () => { - app.loadGraphData(item.prompt[3].extra_pnginfo.workflow); + onclick: async () => { + await app.loadGraphData(item.prompt[3].extra_pnginfo.workflow); if (item.outputs) { app.nodeOutputs = item.outputs; } @@ -784,9 +784,9 @@ export class ComfyUI { } }), $el("button", { - id: "comfy-load-default-button", textContent: "Load Default", onclick: () => { + id: "comfy-load-default-button", textContent: "Load Default", onclick: async () => { if (!confirmClear.value || confirm("Load default workflow?")) { - app.loadGraphData() + await app.loadGraphData() } } }), diff --git a/web/scripts/widgets.js b/web/scripts/widgets.js index fbc1d0fc324..de5877e5448 100644 --- a/web/scripts/widgets.js +++ b/web/scripts/widgets.js @@ -23,29 +23,73 @@ function getNumberDefaults(inputData, defaultStep, precision, enable_rounding) { return { val: defaultVal, config: { min, max, step: 10.0 * step, round, precision } }; } -export function addValueControlWidget(node, targetWidget, defaultValue = "randomize", values) { - const widgets = addValueControlWidgets(node, targetWidget, defaultValue, values, { +export function getWidgetType(inputData, inputName) { + const type = inputData[0]; + + if (Array.isArray(type)) { + return "COMBO"; + } else if (`${type}:${inputName}` in ComfyWidgets) { + return `${type}:${inputName}`; + } else if (type in ComfyWidgets) { + return type; + } else { + return null; + } +} + +export function addValueControlWidget(node, targetWidget, defaultValue = "randomize", values, widgetName, inputData) { + let name = inputData[1]?.control_after_generate; + if(typeof name !== "string") { + name = widgetName; + } + const widgets = addValueControlWidgets(node, targetWidget, defaultValue, { addFilterList: false, - }); + controlAfterGenerateName: name + }, inputData); return widgets[0]; } -export function addValueControlWidgets(node, targetWidget, defaultValue = "randomize", values, options) { +export function addValueControlWidgets(node, targetWidget, defaultValue = "randomize", options, inputData) { + if (!defaultValue) defaultValue = "randomize"; if (!options) options = {}; - + + const getName = (defaultName, optionName) => { + let name = defaultName; + if (options[optionName]) { + name = options[optionName]; + } else if (typeof inputData?.[1]?.[defaultName] === "string") { + name = inputData?.[1]?.[defaultName]; + } else if (inputData?.[1]?.control_prefix) { + name = inputData?.[1]?.control_prefix + " " + name + } + return name; + } + const widgets = []; - const valueControl = node.addWidget("combo", "control_after_generate", defaultValue, function (v) { }, { - values: ["fixed", "increment", "decrement", "randomize"], - serialize: false, // Don't include this in prompt. - }); + const valueControl = node.addWidget( + "combo", + getName("control_after_generate", "controlAfterGenerateName"), + defaultValue, + function () {}, + { + values: ["fixed", "increment", "decrement", "randomize"], + serialize: false, // Don't include this in prompt. + } + ); widgets.push(valueControl); const isCombo = targetWidget.type === "combo"; let comboFilter; if (isCombo && options.addFilterList !== false) { - comboFilter = node.addWidget("string", "control_filter_list", "", function (v) {}, { - serialize: false, // Don't include this in prompt. - }); + comboFilter = node.addWidget( + "string", + getName("control_filter_list", "controlFilterListName"), + "", + function () {}, + { + serialize: false, // Don't include this in prompt. + } + ); widgets.push(comboFilter); } @@ -96,7 +140,8 @@ export function addValueControlWidgets(node, targetWidget, defaultValue = "rando targetWidget.value = value; targetWidget.callback(value); } - } else { //number + } else { + //number let min = targetWidget.options.min; let max = targetWidget.options.max; // limit to something that javascript can handle @@ -119,32 +164,54 @@ export function addValueControlWidgets(node, targetWidget, defaultValue = "rando default: break; } - /*check if values are over or under their respective - * ranges and set them to min or max.*/ - if (targetWidget.value < min) - targetWidget.value = min; + /*check if values are over or under their respective + * ranges and set them to min or max.*/ + if (targetWidget.value < min) targetWidget.value = min; if (targetWidget.value > max) targetWidget.value = max; targetWidget.callback(targetWidget.value); } - } - + }; return widgets; }; -function seedWidget(node, inputName, inputData, app) { - const seed = ComfyWidgets.INT(node, inputName, inputData, app); - const seedControl = addValueControlWidget(node, seed.widget, "randomize"); +function seedWidget(node, inputName, inputData, app, widgetName) { + const seed = createIntWidget(node, inputName, inputData, app, true); + const seedControl = addValueControlWidget(node, seed.widget, "randomize", undefined, widgetName, inputData); seed.widget.linkedWidgets = [seedControl]; return seed; } + +function createIntWidget(node, inputName, inputData, app, isSeedInput) { + const control = inputData[1]?.control_after_generate; + if (!isSeedInput && control) { + return seedWidget(node, inputName, inputData, app, typeof control === "string" ? control : undefined); + } + + let widgetType = isSlider(inputData[1]["display"], app); + const { val, config } = getNumberDefaults(inputData, 1, 0, true); + Object.assign(config, { precision: 0 }); + return { + widget: node.addWidget( + widgetType, + inputName, + val, + function (v) { + const s = this.options.step / 10; + this.value = Math.round(v / s) * s; + }, + config + ), + }; +} + function addMultilineWidget(node, name, opts, app) { const inputEl = document.createElement("textarea"); inputEl.className = "comfy-multiline-input"; inputEl.value = opts.defaultVal; - inputEl.placeholder = opts.placeholder || ""; + inputEl.placeholder = opts.placeholder || name; const widget = node.addDOMWidget(name, "customtext", inputEl, { getValue() { @@ -156,6 +223,10 @@ function addMultilineWidget(node, name, opts, app) { }); widget.inputEl = inputEl; + inputEl.addEventListener("input", () => { + widget.callback?.(widget.value); + }); + return { minWidth: 400, minHeight: 200, widget }; } @@ -186,21 +257,7 @@ export const ComfyWidgets = { }, config) }; }, INT(node, inputName, inputData, app) { - let widgetType = isSlider(inputData[1]["display"], app); - const { val, config } = getNumberDefaults(inputData, 1, 0, true); - Object.assign(config, { precision: 0 }); - return { - widget: node.addWidget( - widgetType, - inputName, - val, - function (v) { - const s = this.options.step / 10; - this.value = Math.round(v / s) * s; - }, - config - ), - }; + return createIntWidget(node, inputName, inputData, app); }, BOOLEAN(node, inputName, inputData) { let defaultVal = false; @@ -245,10 +302,14 @@ export const ComfyWidgets = { if (inputData[1] && inputData[1].default) { defaultValue = inputData[1].default; } - return { widget: node.addWidget("combo", inputName, defaultValue, () => {}, { values: type }) }; + const res = { widget: node.addWidget("combo", inputName, defaultValue, () => {}, { values: type }) }; + if (inputData[1]?.control_after_generate) { + res.widget.linkedWidgets = addValueControlWidgets(node, res.widget, undefined, undefined, inputData); + } + return res; }, IMAGEUPLOAD(node, inputName, inputData, app) { - const imageWidget = node.widgets.find((w) => w.name === "image"); + const imageWidget = node.widgets.find((w) => w.name === (inputData[1]?.widget ?? "image")); let uploadWidget; function showImage(name) { @@ -362,9 +423,10 @@ export const ComfyWidgets = { document.body.append(fileInput); // Create the button widget for selecting the files - uploadWidget = node.addWidget("button", "choose file to upload", "image", () => { + uploadWidget = node.addWidget("button", inputName, "image", () => { fileInput.click(); }); + uploadWidget.label = "choose file to upload"; uploadWidget.serialize = false; // Add handler to check if an image is being dragged over our node From 6b769bca01bf7de989ab4aaafd8db41a92a87094 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 30 Nov 2023 15:22:32 -0500 Subject: [PATCH 79/84] Do a garbage collect after the interval even if nothing is running. --- execution.py | 6 ++++-- main.py | 45 +++++++++++++++++++++++++++++---------------- 2 files changed, 33 insertions(+), 18 deletions(-) diff --git a/execution.py b/execution.py index bca48a785c2..7db1f095b10 100644 --- a/execution.py +++ b/execution.py @@ -700,10 +700,12 @@ def put(self, item): self.server.queue_updated() self.not_empty.notify() - def get(self): + def get(self, timeout=None): with self.not_empty: while len(self.queue) == 0: - self.not_empty.wait() + self.not_empty.wait(timeout=timeout) + if timeout is not None and len(self.queue) == 0: + return None item = heapq.heappop(self.queue) i = self.task_counter self.currently_running[i] = copy.deepcopy(item) diff --git a/main.py b/main.py index 3997fbefcb3..1f9c5f443c3 100644 --- a/main.py +++ b/main.py @@ -89,23 +89,36 @@ def cuda_malloc_warning(): def prompt_worker(q, server): e = execution.PromptExecutor(server) last_gc_collect = 0 + need_gc = False + gc_collect_interval = 10.0 + while True: - item, item_id = q.get() - execution_start_time = time.perf_counter() - prompt_id = item[1] - e.execute(item[2], prompt_id, item[3], item[4]) - q.task_done(item_id, e.outputs_ui) - if server.client_id is not None: - server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, server.client_id) - - current_time = time.perf_counter() - execution_time = current_time - execution_start_time - print("Prompt executed in {:.2f} seconds".format(execution_time)) - if (current_time - last_gc_collect) > 10.0: - gc.collect() - comfy.model_management.soft_empty_cache() - last_gc_collect = current_time - print("gc collect") + timeout = None + if need_gc: + timeout = max(gc_collect_interval - (current_time - last_gc_collect), 0.0) + + queue_item = q.get(timeout=timeout) + if queue_item is not None: + item, item_id = queue_item + execution_start_time = time.perf_counter() + prompt_id = item[1] + e.execute(item[2], prompt_id, item[3], item[4]) + need_gc = True + q.task_done(item_id, e.outputs_ui) + if server.client_id is not None: + server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, server.client_id) + + current_time = time.perf_counter() + execution_time = current_time - execution_start_time + print("Prompt executed in {:.2f} seconds".format(execution_time)) + + if need_gc: + current_time = time.perf_counter() + if (current_time - last_gc_collect) > gc_collect_interval: + gc.collect() + comfy.model_management.soft_empty_cache() + last_gc_collect = current_time + need_gc = False async def run(server, address='', port=8188, verbose=True, call_on_start=None): await asyncio.gather(server.start(address, port, verbose, call_on_start), server.publish_loop()) From c97be4db91d4a249c19afdf88fa1cf3268544e45 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 30 Nov 2023 19:27:03 -0500 Subject: [PATCH 80/84] Support SD2.1 turbo checkpoint. --- comfy/supported_models.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 7e2ac677d51..455323b9629 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -71,6 +71,10 @@ def model_type(self, state_dict, prefix=""): return model_base.ModelType.EPS def process_clip_state_dict(self, state_dict): + replace_prefix = {} + replace_prefix["conditioner.embedders.0.model."] = "cond_stage_model.model." #SD2 in sgm format + state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix) + state_dict = utils.transformers_convert(state_dict, "cond_stage_model.model.", "cond_stage_model.clip_h.transformer.text_model.", 24) return state_dict From 5d5c320054758413be00e98b26a28b39ee8f2acd Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 1 Dec 2023 02:03:34 -0500 Subject: [PATCH 81/84] Fix right click not working for some users. --- web/extensions/core/groupNode.js | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/web/extensions/core/groupNode.js b/web/extensions/core/groupNode.js index 450b4f5f35c..397c4c71393 100644 --- a/web/extensions/core/groupNode.js +++ b/web/extensions/core/groupNode.js @@ -1010,7 +1010,7 @@ function addConvertToGroupOptions() { const getCanvasMenuOptions = LGraphCanvas.prototype.getCanvasMenuOptions; LGraphCanvas.prototype.getCanvasMenuOptions = function () { const options = getCanvasMenuOptions.apply(this, arguments); - const index = options.findIndex((o) => o?.content === "Add Group") + 1 || opts.length; + const index = options.findIndex((o) => o?.content === "Add Group") + 1 || options.length; addOption(options, index); return options; }; @@ -1020,7 +1020,7 @@ function addConvertToGroupOptions() { LGraphCanvas.prototype.getNodeMenuOptions = function (node) { const options = getNodeMenuOptions.apply(this, arguments); if (!GroupNodeHandler.isGroupNode(node)) { - const index = options.findIndex((o) => o?.content === "Outputs") + 1 || opts.length - 1; + const index = options.findIndex((o) => o?.content === "Outputs") + 1 || options.length - 1; addOption(options, index); } return options; From ec7a00aa9644049c306dd0a2c02cb4f91f127286 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 1 Dec 2023 04:13:04 -0500 Subject: [PATCH 82/84] Fix extension widgets not working. --- web/extensions/core/groupNode.js | 3 +-- web/scripts/app.js | 18 ++++++++++++++++-- web/scripts/widgets.js | 14 -------------- 3 files changed, 17 insertions(+), 18 deletions(-) diff --git a/web/extensions/core/groupNode.js b/web/extensions/core/groupNode.js index 397c4c71393..4b4bf74fa08 100644 --- a/web/extensions/core/groupNode.js +++ b/web/extensions/core/groupNode.js @@ -1,6 +1,5 @@ import { app } from "../../scripts/app.js"; import { api } from "../../scripts/api.js"; -import { getWidgetType } from "../../scripts/widgets.js"; import { mergeIfValid } from "./widgetInputs.js"; const GROUP = Symbol(); @@ -332,7 +331,7 @@ export class GroupNodeConfig { const converted = new Map(); const widgetMap = (this.oldToNewWidgetMap[node.index] = {}); for (const inputName of inputNames) { - let widgetType = getWidgetType(inputs[inputName], inputName); + let widgetType = app.getWidgetType(inputs[inputName], inputName); if (widgetType) { const convertedIndex = node.inputs?.findIndex( (inp) => inp.name === inputName && inp.widget?.name === inputName diff --git a/web/scripts/app.js b/web/scripts/app.js index e9cfb277dd4..a72e30027e3 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -1,5 +1,5 @@ import { ComfyLogging } from "./logging.js"; -import { ComfyWidgets, getWidgetType } from "./widgets.js"; +import { ComfyWidgets } from "./widgets.js"; import { ComfyUI, $el } from "./ui.js"; import { api } from "./api.js"; import { defaultGraph } from "./defaultGraph.js"; @@ -1377,6 +1377,20 @@ export class ComfyApp { await this.#invokeExtensionsAsync("registerCustomNodes"); } + getWidgetType(inputData, inputName) { + const type = inputData[0]; + + if (Array.isArray(type)) { + return "COMBO"; + } else if (`${type}:${inputName}` in this.widgets) { + return `${type}:${inputName}`; + } else if (type in this.widgets) { + return type; + } else { + return null; + } + } + async registerNodeDef(nodeId, nodeData) { const self = this; const node = Object.assign( @@ -1391,7 +1405,7 @@ export class ComfyApp { const type = inputData[0]; let widgetCreated = true; - const widgetType = getWidgetType(inputData, inputName); + const widgetType = self.getWidgetType(inputData, inputName); if(widgetType) { if(widgetType === "COMBO") { Object.assign(config, self.widgets.COMBO(this, inputName, inputData, app) || {}); diff --git a/web/scripts/widgets.js b/web/scripts/widgets.js index de5877e5448..d599b85ba94 100644 --- a/web/scripts/widgets.js +++ b/web/scripts/widgets.js @@ -23,20 +23,6 @@ function getNumberDefaults(inputData, defaultStep, precision, enable_rounding) { return { val: defaultVal, config: { min, max, step: 10.0 * step, round, precision } }; } -export function getWidgetType(inputData, inputName) { - const type = inputData[0]; - - if (Array.isArray(type)) { - return "COMBO"; - } else if (`${type}:${inputName}` in ComfyWidgets) { - return `${type}:${inputName}`; - } else if (type in ComfyWidgets) { - return type; - } else { - return null; - } -} - export function addValueControlWidget(node, targetWidget, defaultValue = "randomize", values, widgetName, inputData) { let name = inputData[1]?.control_after_generate; if(typeof name !== "string") { From 8491280504d69f38d1bc72568f8f745c5dc41d74 Mon Sep 17 00:00:00 2001 From: pythongosssss <125205205+pythongosssss@users.noreply.github.com> Date: Fri, 1 Dec 2023 22:24:20 +0000 Subject: [PATCH 83/84] Add Extension tests (#2125) * Add test for extension hooks Add afterConfigureGraph callback * fix comment --- tests-ui/tests/extensions.test.js | 196 ++++++++++++++++++++++++++++++ tests-ui/utils/index.js | 7 +- web/scripts/app.js | 1 + 3 files changed, 201 insertions(+), 3 deletions(-) create mode 100644 tests-ui/tests/extensions.test.js diff --git a/tests-ui/tests/extensions.test.js b/tests-ui/tests/extensions.test.js new file mode 100644 index 00000000000..b82e55c328b --- /dev/null +++ b/tests-ui/tests/extensions.test.js @@ -0,0 +1,196 @@ +// @ts-check +/// +const { start } = require("../utils"); +const lg = require("../utils/litegraph"); + +describe("extensions", () => { + beforeEach(() => { + lg.setup(global); + }); + + afterEach(() => { + lg.teardown(global); + }); + + it("calls each extension hook", async () => { + const mockExtension = { + name: "TestExtension", + init: jest.fn(), + setup: jest.fn(), + addCustomNodeDefs: jest.fn(), + getCustomWidgets: jest.fn(), + beforeRegisterNodeDef: jest.fn(), + registerCustomNodes: jest.fn(), + loadedGraphNode: jest.fn(), + nodeCreated: jest.fn(), + beforeConfigureGraph: jest.fn(), + afterConfigureGraph: jest.fn(), + }; + + const { app, ez, graph } = await start({ + async preSetup(app) { + app.registerExtension(mockExtension); + }, + }); + + // Basic initialisation hooks should be called once, with app + expect(mockExtension.init).toHaveBeenCalledTimes(1); + expect(mockExtension.init).toHaveBeenCalledWith(app); + + // Adding custom node defs should be passed the full list of nodes + expect(mockExtension.addCustomNodeDefs).toHaveBeenCalledTimes(1); + expect(mockExtension.addCustomNodeDefs.mock.calls[0][1]).toStrictEqual(app); + const defs = mockExtension.addCustomNodeDefs.mock.calls[0][0]; + expect(defs).toHaveProperty("KSampler"); + expect(defs).toHaveProperty("LoadImage"); + + // Get custom widgets is called once and should return new widget types + expect(mockExtension.getCustomWidgets).toHaveBeenCalledTimes(1); + expect(mockExtension.getCustomWidgets).toHaveBeenCalledWith(app); + + // Before register node def will be called once per node type + const nodeNames = Object.keys(defs); + const nodeCount = nodeNames.length; + expect(mockExtension.beforeRegisterNodeDef).toHaveBeenCalledTimes(nodeCount); + for (let i = 0; i < nodeCount; i++) { + // It should be send the JS class and the original JSON definition + const nodeClass = mockExtension.beforeRegisterNodeDef.mock.calls[i][0]; + const nodeDef = mockExtension.beforeRegisterNodeDef.mock.calls[i][1]; + + expect(nodeClass.name).toBe("ComfyNode"); + expect(nodeClass.comfyClass).toBe(nodeNames[i]); + expect(nodeDef.name).toBe(nodeNames[i]); + expect(nodeDef).toHaveProperty("input"); + expect(nodeDef).toHaveProperty("output"); + } + + // Register custom nodes is called once after registerNode defs to allow adding other frontend nodes + expect(mockExtension.registerCustomNodes).toHaveBeenCalledTimes(1); + + // Before configure graph will be called here as the default graph is being loaded + expect(mockExtension.beforeConfigureGraph).toHaveBeenCalledTimes(1); + // it gets sent the graph data that is going to be loaded + const graphData = mockExtension.beforeConfigureGraph.mock.calls[0][0]; + + // A node created is fired for each node constructor that is called + expect(mockExtension.nodeCreated).toHaveBeenCalledTimes(graphData.nodes.length); + for (let i = 0; i < graphData.nodes.length; i++) { + expect(mockExtension.nodeCreated.mock.calls[i][0].type).toBe(graphData.nodes[i].type); + } + + // Each node then calls loadedGraphNode to allow them to be updated + expect(mockExtension.loadedGraphNode).toHaveBeenCalledTimes(graphData.nodes.length); + for (let i = 0; i < graphData.nodes.length; i++) { + expect(mockExtension.loadedGraphNode.mock.calls[i][0].type).toBe(graphData.nodes[i].type); + } + + // After configure is then called once all the setup is done + expect(mockExtension.afterConfigureGraph).toHaveBeenCalledTimes(1); + + expect(mockExtension.setup).toHaveBeenCalledTimes(1); + expect(mockExtension.setup).toHaveBeenCalledWith(app); + + // Ensure hooks are called in the correct order + const callOrder = [ + "init", + "addCustomNodeDefs", + "getCustomWidgets", + "beforeRegisterNodeDef", + "registerCustomNodes", + "beforeConfigureGraph", + "nodeCreated", + "loadedGraphNode", + "afterConfigureGraph", + "setup", + ]; + for (let i = 1; i < callOrder.length; i++) { + const fn1 = mockExtension[callOrder[i - 1]]; + const fn2 = mockExtension[callOrder[i]]; + expect(fn1.mock.invocationCallOrder[0]).toBeLessThan(fn2.mock.invocationCallOrder[0]); + } + + graph.clear(); + + // Ensure adding a new node calls the correct callback + ez.LoadImage(); + expect(mockExtension.loadedGraphNode).toHaveBeenCalledTimes(graphData.nodes.length); + expect(mockExtension.nodeCreated).toHaveBeenCalledTimes(graphData.nodes.length + 1); + expect(mockExtension.nodeCreated.mock.lastCall[0].type).toBe("LoadImage"); + + // Reload the graph to ensure correct hooks are fired + await graph.reload(); + + // These hooks should not be fired again + expect(mockExtension.init).toHaveBeenCalledTimes(1); + expect(mockExtension.addCustomNodeDefs).toHaveBeenCalledTimes(1); + expect(mockExtension.getCustomWidgets).toHaveBeenCalledTimes(1); + expect(mockExtension.registerCustomNodes).toHaveBeenCalledTimes(1); + expect(mockExtension.beforeRegisterNodeDef).toHaveBeenCalledTimes(nodeCount); + expect(mockExtension.setup).toHaveBeenCalledTimes(1); + + // These should be called again + expect(mockExtension.beforeConfigureGraph).toHaveBeenCalledTimes(2); + expect(mockExtension.nodeCreated).toHaveBeenCalledTimes(graphData.nodes.length + 2); + expect(mockExtension.loadedGraphNode).toHaveBeenCalledTimes(graphData.nodes.length + 1); + expect(mockExtension.afterConfigureGraph).toHaveBeenCalledTimes(2); + }); + + it("allows custom nodeDefs and widgets to be registered", async () => { + const widgetMock = jest.fn((node, inputName, inputData, app) => { + expect(node.constructor.comfyClass).toBe("TestNode"); + expect(inputName).toBe("test_input"); + expect(inputData[0]).toBe("CUSTOMWIDGET"); + expect(inputData[1]?.hello).toBe("world"); + expect(app).toStrictEqual(app); + + return { + widget: node.addWidget("button", inputName, "hello", () => {}), + }; + }); + + // Register our extension that adds a custom node + widget type + const mockExtension = { + name: "TestExtension", + addCustomNodeDefs: (nodeDefs) => { + nodeDefs["TestNode"] = { + output: [], + output_name: [], + output_is_list: [], + name: "TestNode", + display_name: "TestNode", + category: "Test", + input: { + required: { + test_input: ["CUSTOMWIDGET", { hello: "world" }], + }, + }, + }; + }, + getCustomWidgets: jest.fn(() => { + return { + CUSTOMWIDGET: widgetMock, + }; + }), + }; + + const { graph, ez } = await start({ + async preSetup(app) { + app.registerExtension(mockExtension); + }, + }); + + expect(mockExtension.getCustomWidgets).toBeCalledTimes(1); + + graph.clear(); + expect(widgetMock).toBeCalledTimes(0); + const node = ez.TestNode(); + expect(widgetMock).toBeCalledTimes(1); + + // Ensure our custom widget is created + expect(node.inputs.length).toBe(0); + expect(node.widgets.length).toBe(1); + const w = node.widgets[0].widget; + expect(w.name).toBe("test_input"); + expect(w.type).toBe("button"); + }); +}); diff --git a/tests-ui/utils/index.js b/tests-ui/utils/index.js index eeccdb3d921..3a018f566e4 100644 --- a/tests-ui/utils/index.js +++ b/tests-ui/utils/index.js @@ -4,11 +4,11 @@ const lg = require("./litegraph"); /** * - * @param { Parameters[0] & { resetEnv?: boolean } } config + * @param { Parameters[0] & { resetEnv?: boolean, preSetup?(app): Promise } } config * @returns */ -export async function start(config = undefined) { - if(config?.resetEnv) { +export async function start(config = {}) { + if(config.resetEnv) { jest.resetModules(); jest.resetAllMocks(); lg.setup(global); @@ -16,6 +16,7 @@ export async function start(config = undefined) { mockApi(config); const { app } = require("../../web/scripts/app"); + config.preSetup?.(app); await app.setup(); return { ...Ez.graph(app, global["LiteGraph"], global["LGraphCanvas"]), app }; } diff --git a/web/scripts/app.js b/web/scripts/app.js index a72e30027e3..861db16bddf 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -1654,6 +1654,7 @@ export class ComfyApp { if (missingNodeTypes.length) { this.showMissingNodesError(missingNodeTypes); } + await this.#invokeExtensionsAsync("afterConfigureGraph", missingNodeTypes); } /** From 2995a2472541cb10ce9f4934baef9d5993ed3306 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 1 Dec 2023 18:29:33 -0500 Subject: [PATCH 84/84] Update readme. --- README.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index af1f2281158..450a012bb8e 100644 --- a/README.md +++ b/README.md @@ -45,6 +45,7 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git |---------------------------|--------------------------------------------------------------------------------------------------------------------| | Ctrl + Enter | Queue up current graph for generation | | Ctrl + Shift + Enter | Queue up current graph as first for generation | +| Ctrl + Z/Ctrl + Y | Undo/Redo | | Ctrl + S | Save workflow | | Ctrl + O | Load workflow | | Ctrl + A | Select all nodes | @@ -100,6 +101,7 @@ AMD users can install rocm and pytorch with pip if you don't have it already ins ```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm5.6``` This is the command to install the nightly with ROCm 5.7 that might have some performance improvements: + ```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm5.7``` ### NVIDIA @@ -192,7 +194,7 @@ To use a textual inversion concepts/embeddings in a text prompt put them in the Make sure you use the regular loaders/Load Checkpoint node to load checkpoints. It will auto pick the right settings depending on your GPU. -You can set this command line setting to disable the upcasting to fp32 in some cross attention operations which will increase your speed. Note that this will very likely give you black images on SD2.x models. If you use xformers this option does not do anything. +You can set this command line setting to disable the upcasting to fp32 in some cross attention operations which will increase your speed. Note that this will very likely give you black images on SD2.x models. If you use xformers or pytorch attention this option does not do anything. ```--dont-upcast-attention```