Skip to content

Commit e995633

Browse files
committed
Merge branch 'master' into beta
2 parents d0ba2a6 + 58812ab commit e995633

23 files changed

+680
-232
lines changed

comfy/controlnet.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -138,11 +138,13 @@ def control_merge(self, control_input, control_output, control_prev, output_dtyp
138138
return out
139139

140140
class ControlNet(ControlBase):
141-
def __init__(self, control_model, global_average_pooling=False, device=None, load_device=None, manual_cast_dtype=None):
141+
def __init__(self, control_model=None, global_average_pooling=False, device=None, load_device=None, manual_cast_dtype=None):
142142
super().__init__(device)
143143
self.control_model = control_model
144144
self.load_device = load_device
145-
self.control_model_wrapped = comfy.model_patcher.ModelPatcher(self.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device())
145+
if control_model is not None:
146+
self.control_model_wrapped = comfy.model_patcher.ModelPatcher(self.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device())
147+
146148
self.global_average_pooling = global_average_pooling
147149
self.model_sampling_current = None
148150
self.manual_cast_dtype = manual_cast_dtype
@@ -183,7 +185,9 @@ def get_control(self, x_noisy, t, cond, batched_number):
183185
return self.control_merge(None, control, control_prev, output_dtype)
184186

185187
def copy(self):
186-
c = ControlNet(self.control_model, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
188+
c = ControlNet(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
189+
c.control_model = self.control_model
190+
c.control_model_wrapped = self.control_model_wrapped
187191
self.copy_to(c)
188192
return c
189193

comfy/model_base.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -503,8 +503,10 @@ def __init__(self, model_config, model_type=ModelType.EPS, device=None):
503503
class SDXL_instructpix2pix(IP2P, SDXL):
504504
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
505505
super().__init__(model_config, model_type, device=device)
506-
# self.process_ip2p_image_in = lambda image: comfy.latent_formats.SDXL().process_in(image)
507-
self.process_ip2p_image_in = lambda image: image
506+
if model_type == ModelType.V_PREDICTION_EDM:
507+
self.process_ip2p_image_in = lambda image: comfy.latent_formats.SDXL().process_in(image) #cosxl ip2p
508+
else:
509+
self.process_ip2p_image_in = lambda image: image #diffusers ip2p
508510

509511

510512
class StableCascade_C(BaseModel):

comfy/model_detection.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,14 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None):
357357
'context_dim': 1024, 'num_head_channels': 64, 'transformer_depth_output': [1, 1, 1, 1, 1, 1],
358358
'use_temporal_attention': False, 'use_temporal_resblock': False, 'disable_self_attentions': [True, False, False]}
359359

360-
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega, KOALA_700M, KOALA_1B, SD09_XS, SDXL_diffusers_ip2p]
360+
SD_XS = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
361+
'adm_in_channels': None, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': [1, 1, 1],
362+
'transformer_depth': [0, 1, 1], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': -2, 'use_linear_in_transformer': False,
363+
'context_dim': 768, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 1, 1, 1, 1],
364+
'use_temporal_attention': False, 'use_temporal_resblock': False}
365+
366+
367+
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega, KOALA_700M, KOALA_1B, SD09_XS, SD_XS, SDXL_diffusers_ip2p]
361368

362369
for unet_config in supported_models:
363370
matches = True

comfy/model_management.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -399,6 +399,8 @@ def load_models_gpu(models, memory_required=0):
399399
inference_memory = minimum_inference_memory()
400400
extra_mem = max(inference_memory, memory_required)
401401

402+
models = set(models)
403+
402404
models_to_load = []
403405
models_already_loaded = []
404406
for x in models:

comfy/model_patcher.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,15 @@ def set_model_output_block_patch(self, patch):
150150
def add_object_patch(self, name, obj):
151151
self.object_patches[name] = obj
152152

153+
def get_model_object(self, name):
154+
if name in self.object_patches:
155+
return self.object_patches[name]
156+
else:
157+
if name in self.object_patches_backup:
158+
return self.object_patches_backup[name]
159+
else:
160+
return comfy.utils.get_attr(self.model, name)
161+
153162
def model_patches_to(self, device):
154163
to = self.model_options["transformer_options"]
155164
if "patches" in to:
@@ -278,7 +287,7 @@ def __call__(self, weight):
278287
if weight_key in self.patches:
279288
m.weight_function = LowVramPatch(weight_key, self)
280289
if bias_key in self.patches:
281-
m.bias_function = LowVramPatch(weight_key, self)
290+
m.bias_function = LowVramPatch(bias_key, self)
282291

283292
m.prev_comfy_cast_weights = m.comfy_cast_weights
284293
m.comfy_cast_weights = True
@@ -462,4 +471,4 @@ def unpatch_model(self, device_to=None, unpatch_weights=True):
462471
for k in keys:
463472
comfy.utils.set_attr(self.model, k, self.object_patches_backup[k])
464473

465-
self.object_patches_backup = {}
474+
self.object_patches_backup.clear()

comfy/sample.py

Lines changed: 8 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
import torch
22
import comfy.model_management
33
import comfy.samplers
4-
import comfy.conds
54
import comfy.utils
6-
import math
75
import numpy as np
6+
import logging
87

98
def prepare_noise(latent_image, seed, noise_inds=None):
109
"""
@@ -25,94 +24,21 @@ def prepare_noise(latent_image, seed, noise_inds=None):
2524
noises = torch.cat(noises, axis=0)
2625
return noises
2726

28-
def prepare_mask(noise_mask, shape, device):
29-
"""ensures noise mask is of proper dimensions"""
30-
noise_mask = torch.nn.functional.interpolate(noise_mask.reshape((-1, 1, noise_mask.shape[-2], noise_mask.shape[-1])), size=(shape[2], shape[3]), mode="bilinear")
31-
noise_mask = torch.cat([noise_mask] * shape[1], dim=1)
32-
noise_mask = comfy.utils.repeat_to_batch_size(noise_mask, shape[0])
33-
noise_mask = noise_mask.to(device)
34-
return noise_mask
35-
36-
def get_models_from_cond(cond, model_type):
37-
models = []
38-
for c in cond:
39-
if model_type in c:
40-
models += [c[model_type]]
41-
return models
42-
43-
def convert_cond(cond):
44-
out = []
45-
for c in cond:
46-
temp = c[1].copy()
47-
model_conds = temp.get("model_conds", {})
48-
if c[0] is not None:
49-
model_conds["c_crossattn"] = comfy.conds.CONDCrossAttn(c[0]) #TODO: remove
50-
temp["cross_attn"] = c[0]
51-
temp["model_conds"] = model_conds
52-
out.append(temp)
53-
return out
54-
55-
def get_additional_models(positive, negative, dtype):
56-
"""loads additional models in positive and negative conditioning"""
57-
control_nets = set(get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control"))
58-
59-
inference_memory = 0
60-
control_models = []
61-
for m in control_nets:
62-
control_models += m.get_models()
63-
inference_memory += m.inference_memory_requirements(dtype)
64-
65-
gligen = get_models_from_cond(positive, "gligen") + get_models_from_cond(negative, "gligen")
66-
gligen = [x[1] for x in gligen]
67-
models = control_models + gligen
68-
return models, inference_memory
69-
70-
def cleanup_additional_models(models):
71-
"""cleanup additional models that were loaded"""
72-
for m in models:
73-
if hasattr(m, 'cleanup'):
74-
m.cleanup()
75-
7627
def prepare_sampling(model, noise_shape, positive, negative, noise_mask):
77-
device = model.load_device
78-
positive = convert_cond(positive)
79-
negative = convert_cond(negative)
80-
81-
if noise_mask is not None:
82-
noise_mask = prepare_mask(noise_mask, noise_shape, device)
83-
84-
real_model = None
85-
models, inference_memory = get_additional_models(positive, negative, model.model_dtype())
86-
comfy.model_management.load_models_gpu([model] + models, model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory)
87-
real_model = model.model
88-
89-
return real_model, positive, negative, noise_mask, models
28+
logging.warning("Warning: comfy.sample.prepare_sampling isn't used anymore and can be removed")
29+
return model, positive, negative, noise_mask, []
9030

31+
def cleanup_additional_models(models):
32+
logging.warning("Warning: comfy.sample.cleanup_additional_models isn't used anymore and can be removed")
9133

9234
def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, noise_mask=None, sigmas=None, callback=None, disable_pbar=False, seed=None):
93-
real_model, positive_copy, negative_copy, noise_mask, models = prepare_sampling(model, noise.shape, positive, negative, noise_mask)
94-
95-
noise = noise.to(model.load_device)
96-
latent_image = latent_image.to(model.load_device)
35+
sampler = comfy.samplers.KSampler(model, steps=steps, device=model.load_device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)
9736

98-
sampler = comfy.samplers.KSampler(real_model, steps=steps, device=model.load_device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)
99-
100-
samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar, seed=seed)
37+
samples = sampler.sample(noise, positive, negative, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar, seed=seed)
10138
samples = samples.to(comfy.model_management.intermediate_device())
102-
103-
cleanup_additional_models(models)
104-
cleanup_additional_models(set(get_models_from_cond(positive_copy, "control") + get_models_from_cond(negative_copy, "control")))
10539
return samples
10640

10741
def sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=None, callback=None, disable_pbar=False, seed=None):
108-
real_model, positive_copy, negative_copy, noise_mask, models = prepare_sampling(model, noise.shape, positive, negative, noise_mask)
109-
noise = noise.to(model.load_device)
110-
latent_image = latent_image.to(model.load_device)
111-
sigmas = sigmas.to(model.load_device)
112-
113-
samples = comfy.samplers.sample(real_model, noise, positive_copy, negative_copy, cfg, model.load_device, sampler, sigmas, model_options=model.model_options, latent_image=latent_image, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
42+
samples = comfy.samplers.sample(model, noise, positive, negative, cfg, model.load_device, sampler, sigmas, model_options=model.model_options, latent_image=latent_image, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
11443
samples = samples.to(comfy.model_management.intermediate_device())
115-
cleanup_additional_models(models)
116-
cleanup_additional_models(set(get_models_from_cond(positive_copy, "control") + get_models_from_cond(negative_copy, "control")))
11744
return samples
118-

comfy/sampler_helpers.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
import torch
2+
import comfy.model_management
3+
import comfy.conds
4+
5+
def prepare_mask(noise_mask, shape, device):
6+
"""ensures noise mask is of proper dimensions"""
7+
noise_mask = torch.nn.functional.interpolate(noise_mask.reshape((-1, 1, noise_mask.shape[-2], noise_mask.shape[-1])), size=(shape[2], shape[3]), mode="bilinear")
8+
noise_mask = torch.cat([noise_mask] * shape[1], dim=1)
9+
noise_mask = comfy.utils.repeat_to_batch_size(noise_mask, shape[0])
10+
noise_mask = noise_mask.to(device)
11+
return noise_mask
12+
13+
def get_models_from_cond(cond, model_type):
14+
models = []
15+
for c in cond:
16+
if model_type in c:
17+
models += [c[model_type]]
18+
return models
19+
20+
def convert_cond(cond):
21+
out = []
22+
for c in cond:
23+
temp = c[1].copy()
24+
model_conds = temp.get("model_conds", {})
25+
if c[0] is not None:
26+
model_conds["c_crossattn"] = comfy.conds.CONDCrossAttn(c[0]) #TODO: remove
27+
temp["cross_attn"] = c[0]
28+
temp["model_conds"] = model_conds
29+
out.append(temp)
30+
return out
31+
32+
def get_additional_models(conds, dtype):
33+
"""loads additional models in conditioning"""
34+
cnets = []
35+
gligen = []
36+
37+
for k in conds:
38+
cnets += get_models_from_cond(conds[k], "control")
39+
gligen += get_models_from_cond(conds[k], "gligen")
40+
41+
control_nets = set(cnets)
42+
43+
inference_memory = 0
44+
control_models = []
45+
for m in control_nets:
46+
control_models += m.get_models()
47+
inference_memory += m.inference_memory_requirements(dtype)
48+
49+
gligen = [x[1] for x in gligen]
50+
models = control_models + gligen
51+
return models, inference_memory
52+
53+
def cleanup_additional_models(models):
54+
"""cleanup additional models that were loaded"""
55+
for m in models:
56+
if hasattr(m, 'cleanup'):
57+
m.cleanup()
58+
59+
60+
def prepare_sampling(model, noise_shape, conds):
61+
device = model.load_device
62+
real_model = None
63+
models, inference_memory = get_additional_models(conds, model.model_dtype())
64+
comfy.model_management.load_models_gpu([model] + models, model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory)
65+
real_model = model.model
66+
67+
return real_model, conds, models
68+
69+
def cleanup_models(conds, models):
70+
cleanup_additional_models(models)
71+
72+
control_cleanup = []
73+
for k in conds:
74+
control_cleanup += get_models_from_cond(conds[k], "control")
75+
76+
cleanup_additional_models(set(control_cleanup))

0 commit comments

Comments
 (0)