Skip to content

Commit

Permalink
Merge branch 'master' into beta
Browse files Browse the repository at this point in the history
  • Loading branch information
jn-jairo committed Dec 18, 2023
2 parents 0bc4696 + 571ea8c commit f4b532c
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 7 deletions.
10 changes: 10 additions & 0 deletions comfy/model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,15 @@ def blank_inpaint_image_like(latent_image):
cond_concat.append(blank_inpaint_image_like(noise))
data = torch.cat(cond_concat, dim=1)
out['c_concat'] = comfy.conds.CONDNoiseShape(data)

adm = self.encode_adm(**kwargs)
if adm is not None:
out['y'] = comfy.conds.CONDRegular(adm)

cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDCrossAttn(cross_attn)

return out

def load_model_weights(self, sd, unet_prefix=""):
Expand Down Expand Up @@ -322,6 +328,10 @@ def extra_conds(self, **kwargs):

out['c_concat'] = comfy.conds.CONDNoiseShape(latent_image)

cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDCrossAttn(cross_attn)

if "time_conditioning" in kwargs:
out["time_context"] = comfy.conds.CONDCrossAttn(kwargs["time_conditioning"])

Expand Down
8 changes: 6 additions & 2 deletions comfy/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,18 @@ def is_clone(self, other):
def memory_required(self, input_shape):
return self.model.memory_required(input_shape=input_shape)

def set_model_sampler_cfg_function(self, sampler_cfg_function):
def set_model_sampler_cfg_function(self, sampler_cfg_function, disable_cfg1_optimization=False):
if len(inspect.signature(sampler_cfg_function).parameters) == 3:
self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way
else:
self.model_options["sampler_cfg_function"] = sampler_cfg_function
if disable_cfg1_optimization:
self.model_options["disable_cfg1_optimization"] = True

def set_model_sampler_post_cfg_function(self, post_cfg_function):
def set_model_sampler_post_cfg_function(self, post_cfg_function, disable_cfg1_optimization=False):
self.model_options["sampler_post_cfg_function"] = self.model_options.get("sampler_post_cfg_function", []) + [post_cfg_function]
if disable_cfg1_optimization:
self.model_options["disable_cfg1_optimization"] = True

def set_model_unet_function_wrapper(self, unet_wrapper_function):
self.model_options["model_function_wrapper"] = unet_wrapper_function
Expand Down
9 changes: 5 additions & 4 deletions comfy/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options):
#The main sampling function shared by all the samplers
#Returns denoised
def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None):
if math.isclose(cond_scale, 1.0):
if math.isclose(cond_scale, 1.0) and model_options.get("disable_cfg1_optimization", False) == False:
uncond_ = None
else:
uncond_ = uncond
Expand Down Expand Up @@ -599,6 +599,10 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model
calculate_start_end_timesteps(model, negative)
calculate_start_end_timesteps(model, positive)

if hasattr(model, 'extra_conds'):
positive = encode_model_conds(model.extra_conds, positive, noise, device, "positive", latent_image=latent_image, denoise_mask=denoise_mask)
negative = encode_model_conds(model.extra_conds, negative, noise, device, "negative", latent_image=latent_image, denoise_mask=denoise_mask)

#make sure each cond area has an opposite one with the same area
for c in positive:
create_cond_with_same_area_if_none(negative, c)
Expand All @@ -613,9 +617,6 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model
if latent_image is not None:
latent_image = model.process_latent_in(latent_image)

if hasattr(model, 'extra_conds'):
positive = encode_model_conds(model.extra_conds, positive, noise, device, "positive", latent_image=latent_image, denoise_mask=denoise_mask)
negative = encode_model_conds(model.extra_conds, negative, noise, device, "negative", latent_image=latent_image, denoise_mask=denoise_mask)

extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": model_options, "seed":seed}

Expand Down
2 changes: 1 addition & 1 deletion comfy_extras/nodes_sag.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def post_cfg_function(args):
(sag, _) = comfy.samplers.calc_cond_uncond_batch(model, uncond, None, degraded_noised, sigma, model_options)
return cfg_result + (degraded - sag) * sag_scale

m.set_model_sampler_post_cfg_function(post_cfg_function)
m.set_model_sampler_post_cfg_function(post_cfg_function, disable_cfg1_optimization=True)

# from diffusers:
# unet.mid_block.attentions[0].transformer_blocks[0].attn1.patch
Expand Down

0 comments on commit f4b532c

Please sign in to comment.