Skip to content

Commit

Permalink
Fix SAG not working with cfg 1.0
Browse files Browse the repository at this point in the history
  • Loading branch information
comfyanonymous committed Dec 18, 2023
1 parent 8cf1daa commit 571ea8c
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 4 deletions.
8 changes: 6 additions & 2 deletions comfy/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,18 @@ def is_clone(self, other):
def memory_required(self, input_shape):
return self.model.memory_required(input_shape=input_shape)

def set_model_sampler_cfg_function(self, sampler_cfg_function):
def set_model_sampler_cfg_function(self, sampler_cfg_function, disable_cfg1_optimization=False):
if len(inspect.signature(sampler_cfg_function).parameters) == 3:
self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way
else:
self.model_options["sampler_cfg_function"] = sampler_cfg_function
if disable_cfg1_optimization:
self.model_options["disable_cfg1_optimization"] = True

def set_model_sampler_post_cfg_function(self, post_cfg_function):
def set_model_sampler_post_cfg_function(self, post_cfg_function, disable_cfg1_optimization=False):
self.model_options["sampler_post_cfg_function"] = self.model_options.get("sampler_post_cfg_function", []) + [post_cfg_function]
if disable_cfg1_optimization:
self.model_options["disable_cfg1_optimization"] = True

def set_model_unet_function_wrapper(self, unet_wrapper_function):
self.model_options["model_function_wrapper"] = unet_wrapper_function
Expand Down
2 changes: 1 addition & 1 deletion comfy/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options):
#The main sampling function shared by all the samplers
#Returns denoised
def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None):
if math.isclose(cond_scale, 1.0):
if math.isclose(cond_scale, 1.0) and model_options.get("disable_cfg1_optimization", False) == False:
uncond_ = None
else:
uncond_ = uncond
Expand Down
2 changes: 1 addition & 1 deletion comfy_extras/nodes_sag.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def post_cfg_function(args):
(sag, _) = comfy.samplers.calc_cond_uncond_batch(model, uncond, None, degraded_noised, sigma, model_options)
return cfg_result + (degraded - sag) * sag_scale

m.set_model_sampler_post_cfg_function(post_cfg_function)
m.set_model_sampler_post_cfg_function(post_cfg_function, disable_cfg1_optimization=True)

# from diffusers:
# unet.mid_block.attentions[0].transformer_blocks[0].attn1.patch
Expand Down

0 comments on commit 571ea8c

Please sign in to comment.