Skip to content

Commit 12cac0f

Browse files
committed
Merge branch 'master' into beta
2 parents c186944 + 329c571 commit 12cac0f

File tree

13 files changed

+714
-263
lines changed

13 files changed

+714
-263
lines changed

comfy/model_detection.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,13 @@ def unet_config_from_diffusers_unet(state_dict, dtype):
289289
'channel_mult': [1, 2, 4], 'transformer_depth_middle': -1, 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64,
290290
'use_temporal_attention': False, 'use_temporal_resblock': False}
291291

292-
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B]
292+
Segmind_Vega = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
293+
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
294+
'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 1, 1, 2, 2], 'transformer_depth_output': [0, 0, 0, 1, 1, 1, 2, 2, 2],
295+
'channel_mult': [1, 2, 4], 'transformer_depth_middle': -1, 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64,
296+
'use_temporal_attention': False, 'use_temporal_resblock': False}
297+
298+
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega]
293299

294300
for unet_config in supported_models:
295301
matches = True

comfy/model_patcher.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,9 @@ def set_model_sampler_cfg_function(self, sampler_cfg_function):
6161
else:
6262
self.model_options["sampler_cfg_function"] = sampler_cfg_function
6363

64+
def set_model_sampler_post_cfg_function(self, post_cfg_function):
65+
self.model_options["sampler_post_cfg_function"] = self.model_options.get("sampler_post_cfg_function", []) + [post_cfg_function]
66+
6467
def set_model_unet_function_wrapper(self, unet_wrapper_function):
6568
self.model_options["model_function_wrapper"] = unet_wrapper_function
6669

@@ -70,25 +73,29 @@ def set_model_patch(self, patch, name):
7073
to["patches"] = {}
7174
to["patches"][name] = to["patches"].get(name, []) + [patch]
7275

73-
def set_model_patch_replace(self, patch, name, block_name, number):
76+
def set_model_patch_replace(self, patch, name, block_name, number, transformer_index=None):
7477
to = self.model_options["transformer_options"]
7578
if "patches_replace" not in to:
7679
to["patches_replace"] = {}
7780
if name not in to["patches_replace"]:
7881
to["patches_replace"][name] = {}
79-
to["patches_replace"][name][(block_name, number)] = patch
82+
if transformer_index is not None:
83+
block = (block_name, number, transformer_index)
84+
else:
85+
block = (block_name, number)
86+
to["patches_replace"][name][block] = patch
8087

8188
def set_model_attn1_patch(self, patch):
8289
self.set_model_patch(patch, "attn1_patch")
8390

8491
def set_model_attn2_patch(self, patch):
8592
self.set_model_patch(patch, "attn2_patch")
8693

87-
def set_model_attn1_replace(self, patch, block_name, number):
88-
self.set_model_patch_replace(patch, "attn1", block_name, number)
94+
def set_model_attn1_replace(self, patch, block_name, number, transformer_index=None):
95+
self.set_model_patch_replace(patch, "attn1", block_name, number, transformer_index)
8996

90-
def set_model_attn2_replace(self, patch, block_name, number):
91-
self.set_model_patch_replace(patch, "attn2", block_name, number)
97+
def set_model_attn2_replace(self, patch, block_name, number, transformer_index=None):
98+
self.set_model_patch_replace(patch, "attn2", block_name, number, transformer_index)
9299

93100
def set_model_attn1_output_patch(self, patch):
94101
self.set_model_patch(patch, "attn1_output_patch")

0 commit comments

Comments
 (0)