Skip to content

Commit

Permalink
Merge branch 'master' into beta
Browse files Browse the repository at this point in the history
  • Loading branch information
jn-jairo committed Nov 17, 2023
2 parents f902999 + 7e3fe3a commit 7ab2121
Show file tree
Hide file tree
Showing 9 changed files with 80 additions and 8 deletions.
4 changes: 2 additions & 2 deletions comfy/controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(self, device=None):
self.cond_hint_original = None
self.cond_hint = None
self.strength = 1.0
self.timestep_percent_range = (1.0, 0.0)
self.timestep_percent_range = (0.0, 1.0)
self.timestep_range = None

if device is None:
Expand All @@ -42,7 +42,7 @@ def __init__(self, device=None):
self.previous_controlnet = None
self.global_average_pooling = False

def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(1.0, 0.0)):
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0)):
self.cond_hint_original = cond_hint
self.strength = strength
self.timestep_percent_range = timestep_percent_range
Expand Down
9 changes: 8 additions & 1 deletion comfy/ldm/modules/diffusionmodules/openaimodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,10 @@ def apply_control(h, control, name):
if control is not None and name in control and len(control[name]) > 0:
ctrl = control[name].pop()
if ctrl is not None:
h += ctrl
try:
h += ctrl
except:
print("warning control could not be applied", h.shape, ctrl.shape)
return h

class UNetModel(nn.Module):
Expand Down Expand Up @@ -630,6 +633,10 @@ def forward(self, x, timesteps=None, context=None, y=None, control=None, transfo
h = p(h, transformer_options)

hs.append(h)
if "input_block_patch_after_skip" in transformer_patches:
patch = transformer_patches["input_block_patch_after_skip"]
for p in patch:
h = p(h, transformer_options)

transformer_options["block"] = ("middle", 0)
h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options)
Expand Down
5 changes: 4 additions & 1 deletion comfy/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def model_size(self):
return size

def clone(self):
n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device)
n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device, weight_inplace_update=self.weight_inplace_update)
n.patches = {}
for k in self.patches:
n.patches[k] = self.patches[k][:]
Expand Down Expand Up @@ -99,6 +99,9 @@ def set_model_attn2_output_patch(self, patch):
def set_model_input_block_patch(self, patch):
self.set_model_patch(patch, "input_block_patch")

def set_model_input_block_patch_after_skip(self, patch):
self.set_model_patch(patch, "input_block_patch_after_skip")

def set_model_output_block_patch(self, patch):
self.set_model_patch(patch, "output_block_patch")

Expand Down
5 changes: 5 additions & 0 deletions comfy/model_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,5 +76,10 @@ def sigma(self, timestep):
return log_sigma.exp()

def percent_to_sigma(self, percent):
if percent <= 0.0:
return torch.tensor(999999999.9)
if percent >= 1.0:
return torch.tensor(0.0)
percent = 1.0 - percent
return self.sigma(torch.tensor(percent * 999.0))

2 changes: 2 additions & 0 deletions comfy/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,8 @@ def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options):
transformer_options["patches"] = patches

transformer_options["cond_or_uncond"] = cond_or_uncond[:]
transformer_options["sigmas"] = timestep

c['transformer_options'] = transformer_options

if 'model_function_wrapper' in model_options:
Expand Down
2 changes: 1 addition & 1 deletion comfy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ def set_attr(obj, attr, value):
for name in attrs[:-1]:
obj = getattr(obj, name)
prev = getattr(obj, attrs[-1])
setattr(obj, attrs[-1], torch.nn.Parameter(value))
setattr(obj, attrs[-1], torch.nn.Parameter(value, requires_grad=False))
del prev

def copy_to_param(obj, attr, value):
Expand Down
5 changes: 5 additions & 0 deletions comfy_extras/nodes_model_advanced.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,11 @@ def sigma(self, timestep):
return log_sigma.exp()

def percent_to_sigma(self, percent):
if percent <= 0.0:
return torch.tensor(999999999.9)
if percent >= 1.0:
return torch.tensor(0.0)
percent = 1.0 - percent
return self.sigma(torch.tensor(percent * 999.0))


Expand Down
49 changes: 49 additions & 0 deletions comfy_extras/nodes_model_downscale.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import torch

class PatchModelAddDownscale:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"block_number": ("INT", {"default": 3, "min": 1, "max": 32, "step": 1}),
"downscale_factor": ("FLOAT", {"default": 2.0, "min": 0.1, "max": 9.0, "step": 0.001}),
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
"end_percent": ("FLOAT", {"default": 0.35, "min": 0.0, "max": 1.0, "step": 0.001}),
"downscale_after_skip": ("BOOLEAN", {"default": True}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"

CATEGORY = "_for_testing"

def patch(self, model, block_number, downscale_factor, start_percent, end_percent, downscale_after_skip):
sigma_start = model.model.model_sampling.percent_to_sigma(start_percent).item()
sigma_end = model.model.model_sampling.percent_to_sigma(end_percent).item()

def input_block_patch(h, transformer_options):
if transformer_options["block"][1] == block_number:
sigma = transformer_options["sigmas"][0].item()
if sigma <= sigma_start and sigma >= sigma_end:
h = torch.nn.functional.interpolate(h, scale_factor=(1.0 / downscale_factor), mode="bicubic", align_corners=False)
return h

def output_block_patch(h, hsp, transformer_options):
if h.shape[2] != hsp.shape[2]:
h = torch.nn.functional.interpolate(h, size=(hsp.shape[2], hsp.shape[3]), mode="bicubic", align_corners=False)
return h, hsp

m = model.clone()
if downscale_after_skip:
m.set_model_input_block_patch_after_skip(input_block_patch)
else:
m.set_model_input_block_patch(input_block_patch)
m.set_model_output_block_patch(output_block_patch)
return (m, )

NODE_CLASS_MAPPINGS = {
"PatchModelAddDownscale": PatchModelAddDownscale,
}

NODE_DISPLAY_NAME_MAPPINGS = {
# Sampling
"PatchModelAddDownscale": "PatchModelAddDownscale (Kohya Deep Shrink)",
}
7 changes: 4 additions & 3 deletions nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,8 +248,8 @@ def set_range(self, conditioning, start, end):
c = []
for t in conditioning:
d = t[1].copy()
d['start_percent'] = 1.0 - start
d['end_percent'] = 1.0 - end
d['start_percent'] = start
d['end_percent'] = end
n = [t[0], d]
c.append(n)
return (c, )
Expand Down Expand Up @@ -685,7 +685,7 @@ def apply_controlnet(self, positive, negative, control_net, image, strength, sta
if prev_cnet in cnets:
c_net = cnets[prev_cnet]
else:
c_net = control_net.copy().set_cond_hint(control_hint, strength, (1.0 - start_percent, 1.0 - end_percent))
c_net = control_net.copy().set_cond_hint(control_hint, strength, (start_percent, end_percent))
c_net.set_previous_controlnet(prev_cnet)
cnets[prev_cnet] = c_net

Expand Down Expand Up @@ -1799,6 +1799,7 @@ def init_custom_nodes():
"nodes_custom_sampler.py",
"nodes_hypertile.py",
"nodes_model_advanced.py",
"nodes_model_downscale.py",
]

for node_file in extras_files:
Expand Down

0 comments on commit 7ab2121

Please sign in to comment.