@@ -61,6 +61,9 @@ def set_model_sampler_cfg_function(self, sampler_cfg_function):
61
61
else :
62
62
self .model_options ["sampler_cfg_function" ] = sampler_cfg_function
63
63
64
+ def set_model_sampler_post_cfg_function (self , post_cfg_function ):
65
+ self .model_options ["sampler_post_cfg_function" ] = self .model_options .get ("sampler_post_cfg_function" , []) + [post_cfg_function ]
66
+
64
67
def set_model_unet_function_wrapper (self , unet_wrapper_function ):
65
68
self .model_options ["model_function_wrapper" ] = unet_wrapper_function
66
69
@@ -70,25 +73,29 @@ def set_model_patch(self, patch, name):
70
73
to ["patches" ] = {}
71
74
to ["patches" ][name ] = to ["patches" ].get (name , []) + [patch ]
72
75
73
- def set_model_patch_replace (self , patch , name , block_name , number ):
76
+ def set_model_patch_replace (self , patch , name , block_name , number , transformer_index = None ):
74
77
to = self .model_options ["transformer_options" ]
75
78
if "patches_replace" not in to :
76
79
to ["patches_replace" ] = {}
77
80
if name not in to ["patches_replace" ]:
78
81
to ["patches_replace" ][name ] = {}
79
- to ["patches_replace" ][name ][(block_name , number )] = patch
82
+ if transformer_index is not None :
83
+ block = (block_name , number , transformer_index )
84
+ else :
85
+ block = (block_name , number )
86
+ to ["patches_replace" ][name ][block ] = patch
80
87
81
88
def set_model_attn1_patch (self , patch ):
82
89
self .set_model_patch (patch , "attn1_patch" )
83
90
84
91
def set_model_attn2_patch (self , patch ):
85
92
self .set_model_patch (patch , "attn2_patch" )
86
93
87
- def set_model_attn1_replace (self , patch , block_name , number ):
88
- self .set_model_patch_replace (patch , "attn1" , block_name , number )
94
+ def set_model_attn1_replace (self , patch , block_name , number , transformer_index = None ):
95
+ self .set_model_patch_replace (patch , "attn1" , block_name , number , transformer_index )
89
96
90
- def set_model_attn2_replace (self , patch , block_name , number ):
91
- self .set_model_patch_replace (patch , "attn2" , block_name , number )
97
+ def set_model_attn2_replace (self , patch , block_name , number , transformer_index = None ):
98
+ self .set_model_patch_replace (patch , "attn2" , block_name , number , transformer_index )
92
99
93
100
def set_model_attn1_output_patch (self , patch ):
94
101
self .set_model_patch (patch , "attn1_output_patch" )
0 commit comments