Skip to content

Commit ef4f603

Browse files
Fix model patches not working in custom sampling scheduler nodes.
1 parent a7874d1 commit ef4f603

File tree

2 files changed

+30
-25
lines changed

2 files changed

+30
-25
lines changed

comfy/model_patcher.py

Lines changed: 24 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -174,40 +174,41 @@ def model_state_dict(self, filter_prefix=None):
174174
sd.pop(k)
175175
return sd
176176

177-
def patch_model(self, device_to=None):
177+
def patch_model(self, device_to=None, patch_weights=True):
178178
for k in self.object_patches:
179179
old = getattr(self.model, k)
180180
if k not in self.object_patches_backup:
181181
self.object_patches_backup[k] = old
182182
setattr(self.model, k, self.object_patches[k])
183183

184-
model_sd = self.model_state_dict()
185-
for key in self.patches:
186-
if key not in model_sd:
187-
print("could not patch. key doesn't exist in model:", key)
188-
continue
184+
if patch_weights:
185+
model_sd = self.model_state_dict()
186+
for key in self.patches:
187+
if key not in model_sd:
188+
print("could not patch. key doesn't exist in model:", key)
189+
continue
189190

190-
weight = model_sd[key]
191+
weight = model_sd[key]
191192

192-
inplace_update = self.weight_inplace_update
193+
inplace_update = self.weight_inplace_update
193194

194-
if key not in self.backup:
195-
self.backup[key] = weight.to(device=self.offload_device, copy=inplace_update)
195+
if key not in self.backup:
196+
self.backup[key] = weight.to(device=self.offload_device, copy=inplace_update)
196197

197-
if device_to is not None:
198-
temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True)
199-
else:
200-
temp_weight = weight.to(torch.float32, copy=True)
201-
out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype)
202-
if inplace_update:
203-
comfy.utils.copy_to_param(self.model, key, out_weight)
204-
else:
205-
comfy.utils.set_attr(self.model, key, out_weight)
206-
del temp_weight
198+
if device_to is not None:
199+
temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True)
200+
else:
201+
temp_weight = weight.to(torch.float32, copy=True)
202+
out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype)
203+
if inplace_update:
204+
comfy.utils.copy_to_param(self.model, key, out_weight)
205+
else:
206+
comfy.utils.set_attr(self.model, key, out_weight)
207+
del temp_weight
207208

208-
if device_to is not None:
209-
self.model.to(device_to)
210-
self.current_device = device_to
209+
if device_to is not None:
210+
self.model.to(device_to)
211+
self.current_device = device_to
211212

212213
return self.model
213214

comfy_extras/nodes_custom_sampler.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@ def get_sigmas(self, model, scheduler, steps, denoise):
2626
if denoise < 1.0:
2727
total_steps = int(steps/denoise)
2828

29-
sigmas = comfy.samplers.calculate_sigmas_scheduler(model.model, scheduler, total_steps).cpu()
29+
inner_model = model.patch_model(patch_weights=False)
30+
sigmas = comfy.samplers.calculate_sigmas_scheduler(inner_model, scheduler, total_steps).cpu()
31+
model.unpatch_model()
3032
sigmas = sigmas[-(steps + 1):]
3133
return (sigmas, )
3234

@@ -104,7 +106,9 @@ def INPUT_TYPES(s):
104106
def get_sigmas(self, model, steps, denoise):
105107
start_step = 10 - int(10 * denoise)
106108
timesteps = torch.flip(torch.arange(1, 11) * 100 - 1, (0,))[start_step:start_step + steps]
107-
sigmas = model.model.model_sampling.sigma(timesteps)
109+
inner_model = model.patch_model(patch_weights=False)
110+
sigmas = inner_model.model_sampling.sigma(timesteps)
111+
model.unpatch_model()
108112
sigmas = torch.cat([sigmas, sigmas.new_zeros([1])])
109113
return (sigmas, )
110114

0 commit comments

Comments
 (0)