Skip to content

Commit

Permalink
Merge branch 'master' into beta
Browse files Browse the repository at this point in the history
  • Loading branch information
jn-jairo committed Nov 10, 2023
2 parents 4716d45 + 002aefa commit 73d9e5a
Show file tree
Hide file tree
Showing 6 changed files with 152 additions and 17 deletions.
15 changes: 14 additions & 1 deletion comfy/k_diffusion/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -717,7 +717,6 @@ def DDPMSampler_step(x, sigma, sigma_prev, noise, noise_sampler):
mu += ((1 - alpha) * (1. - alpha_cumprod_prev) / (1. - alpha_cumprod)).sqrt() * noise_sampler(sigma, sigma_prev)
return mu


def generic_step_sampler(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None, step_function=None):
extra_args = {} if extra_args is None else extra_args
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
Expand All @@ -737,3 +736,17 @@ def generic_step_sampler(model, x, sigmas, extra_args=None, callback=None, disab
def sample_ddpm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None):
return generic_step_sampler(model, x, sigmas, extra_args, callback, disable, noise_sampler, DDPMSampler_step)

@torch.no_grad()
def sample_lcm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None):
extra_args = {} if extra_args is None else extra_args
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})

x = denoised
if sigmas[i + 1] > 0:
x += sigmas[i + 1] * noise_sampler(sigmas[i], sigmas[i + 1])
return x
2 changes: 1 addition & 1 deletion comfy/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,7 @@ def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=N

KSAMPLER_NAMES = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral",
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu",
"dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm"]
"dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm"]

def ksampler(sampler_name, extra_options={}, inpaint_options={}):
class KSAMPLER(Sampler):
Expand Down
2 changes: 1 addition & 1 deletion comfy_extras/nodes_custom_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def INPUT_TYPES(s):
{"model": ("MODEL",),
"add_noise": ("BOOLEAN", {"default": True}),
"noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
"cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.5, "round": 0.01}),
"cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}),
"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"sampler": ("SAMPLER", ),
Expand Down
75 changes: 73 additions & 2 deletions comfy_extras/nodes_model_advanced.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,72 @@
import folder_paths
import comfy.sd
import comfy.model_sampling
import torch

class LCM(comfy.model_sampling.EPS):
def calculate_denoised(self, sigma, model_output, model_input):
timestep = self.timestep(sigma).view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
x0 = model_input - model_output * sigma

sigma_data = 0.5
scaled_timestep = timestep * 10.0 #timestep_scaling

c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2)
c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5

return c_out * x0 + c_skip * model_input

class ModelSamplingDiscreteLCM(torch.nn.Module):
def __init__(self):
super().__init__()
self.sigma_data = 1.0
timesteps = 1000
beta_start = 0.00085
beta_end = 0.012

betas = torch.linspace(beta_start**0.5, beta_end**0.5, timesteps, dtype=torch.float32) ** 2
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)

original_timesteps = 50
self.skip_steps = timesteps // original_timesteps


alphas_cumprod_valid = torch.zeros((original_timesteps), dtype=torch.float32)
for x in range(original_timesteps):
alphas_cumprod_valid[original_timesteps - 1 - x] = alphas_cumprod[timesteps - 1 - x * self.skip_steps]

sigmas = ((1 - alphas_cumprod_valid) / alphas_cumprod_valid) ** 0.5
self.set_sigmas(sigmas)

def set_sigmas(self, sigmas):
self.register_buffer('sigmas', sigmas)
self.register_buffer('log_sigmas', sigmas.log())

@property
def sigma_min(self):
return self.sigmas[0]

@property
def sigma_max(self):
return self.sigmas[-1]

def timestep(self, sigma):
log_sigma = sigma.log()
dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None]
return dists.abs().argmin(dim=0).view(sigma.shape) * self.skip_steps + (self.skip_steps - 1)

def sigma(self, timestep):
t = torch.clamp(((timestep - (self.skip_steps - 1)) / self.skip_steps).float(), min=0, max=(len(self.sigmas) - 1))
low_idx = t.floor().long()
high_idx = t.ceil().long()
w = t.frac()
log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
return log_sigma.exp()

def percent_to_sigma(self, percent):
return self.sigma(torch.tensor(percent * 999.0))


def rescale_zero_terminal_snr_sigmas(sigmas):
Expand All @@ -26,7 +92,7 @@ class ModelSamplingDiscrete:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"sampling": (["eps", "v_prediction"],),
"sampling": (["eps", "v_prediction", "lcm"],),
"zsnr": ("BOOLEAN", {"default": False}),
}}

Expand All @@ -38,17 +104,22 @@ def INPUT_TYPES(s):
def patch(self, model, sampling, zsnr):
m = model.clone()

sampling_base = comfy.model_sampling.ModelSamplingDiscrete
if sampling == "eps":
sampling_type = comfy.model_sampling.EPS
elif sampling == "v_prediction":
sampling_type = comfy.model_sampling.V_PREDICTION
elif sampling == "lcm":
sampling_type = LCM
sampling_base = ModelSamplingDiscreteLCM

class ModelSamplingAdvanced(comfy.model_sampling.ModelSamplingDiscrete, sampling_type):
class ModelSamplingAdvanced(sampling_base, sampling_type):
pass

model_sampling = ModelSamplingAdvanced()
if zsnr:
model_sampling.set_sigmas(rescale_zero_terminal_snr_sigmas(model_sampling.sigmas))

m.add_object_patch("model_sampling", model_sampling)
return (m, )

Expand Down
4 changes: 2 additions & 2 deletions nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1218,7 +1218,7 @@ def INPUT_TYPES(s):
{"model": ("MODEL",),
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
"cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.5, "round": 0.01}),
"cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}),
"sampler_name": (comfy.samplers.KSampler.SAMPLERS, ),
"scheduler": (comfy.samplers.KSampler.SCHEDULERS, ),
"positive": ("CONDITIONING", ),
Expand All @@ -1244,7 +1244,7 @@ def INPUT_TYPES(s):
"add_noise": (["enable", "disable"], ),
"noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
"cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.5, "round": 0.01}),
"cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}),
"sampler_name": (comfy.samplers.KSampler.SAMPLERS, ),
"scheduler": (comfy.samplers.KSampler.SCHEDULERS, ),
"positive": ("CONDITIONING", ),
Expand Down
71 changes: 61 additions & 10 deletions web/scripts/app.js
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ function sanitizeNodeName(string) {
'`': '',
'=': ''
};
return String(string).replace(/[&<>"'`=\/]/g, function fromEntityMap (s) {
return String(string).replace(/[&<>"'`=]/g, function fromEntityMap (s) {
return entityMap[s];
});
}
Expand Down Expand Up @@ -1474,6 +1474,17 @@ export class ComfyApp {
localStorage.setItem("litegrapheditor_clipboard", old);
}

showMissingNodesError(missingNodeTypes, hasAddedNodes = true) {
this.ui.dialog.show(
`When loading the graph, the following node types were not found: <ul>${Array.from(new Set(missingNodeTypes)).map(
(t) => `<li>${t}</li>`
).join("")}</ul>${hasAddedNodes ? "Nodes that have failed to load will show as red on the graph." : ""}`
);
this.logging.addEntry("Comfy.App", "warn", {
MissingNodes: missingNodeTypes,
});
}

/**
* Populates the graph with the specified workflow data
* @param {*} graphData A serialized graph object
Expand Down Expand Up @@ -1592,14 +1603,7 @@ export class ComfyApp {
}

if (missingNodeTypes.length) {
this.ui.dialog.show(
`When loading the graph, the following node types were not found: <ul>${Array.from(new Set(missingNodeTypes)).map(
(t) => `<li>${t}</li>`
).join("")}</ul>Nodes that have failed to load will show as red on the graph.`
);
this.logging.addEntry("Comfy.App", "warn", {
MissingNodes: missingNodeTypes,
});
this.showMissingNodesError(missingNodeTypes);
}
}

Expand Down Expand Up @@ -1830,9 +1834,11 @@ export class ComfyApp {
} else if (file.type === "application/json" || file.name?.endsWith(".json")) {
const reader = new FileReader();
reader.onload = () => {
var jsonContent = JSON.parse(reader.result);
const jsonContent = JSON.parse(reader.result);
if (jsonContent?.templates) {
this.loadTemplateData(jsonContent);
} else if(this.isApiJson(jsonContent)) {
this.loadApiJson(jsonContent);
} else {
this.loadGraphData(jsonContent);
}
Expand All @@ -1846,6 +1852,51 @@ export class ComfyApp {
}
}

isApiJson(data) {
return Object.values(data).every((v) => v.class_type);
}

loadApiJson(apiData) {
const missingNodeTypes = Object.values(apiData).filter((n) => !LiteGraph.registered_node_types[n.class_type]);
if (missingNodeTypes.length) {
this.showMissingNodesError(missingNodeTypes.map(t => t.class_type), false);
return;
}

const ids = Object.keys(apiData);
app.graph.clear();
for (const id of ids) {
const data = apiData[id];
const node = LiteGraph.createNode(data.class_type);
node.id = id;
graph.add(node);
}

for (const id of ids) {
const data = apiData[id];
const node = app.graph.getNodeById(id);
for (const input in data.inputs ?? {}) {
const value = data.inputs[input];
if (value instanceof Array) {
const [fromId, fromSlot] = value;
const fromNode = app.graph.getNodeById(fromId);
const toSlot = node.inputs?.findIndex((inp) => inp.name === input);
if (toSlot !== -1) {
fromNode.connect(fromSlot, node, toSlot);
}
} else {
const widget = node.widgets?.find((w) => w.name === input);
if (widget) {
widget.value = value;
widget.callback?.(value);
}
}
}
}

app.graph.arrange();
}

/**
* Registers a Comfy web extension with the app
* @param {ComfyExtension} extension
Expand Down

0 comments on commit 73d9e5a

Please sign in to comment.