@@ -174,40 +174,41 @@ def model_state_dict(self, filter_prefix=None):
174
174
sd .pop (k )
175
175
return sd
176
176
177
- def patch_model (self , device_to = None ):
177
+ def patch_model (self , device_to = None , patch_weights = True ):
178
178
for k in self .object_patches :
179
179
old = getattr (self .model , k )
180
180
if k not in self .object_patches_backup :
181
181
self .object_patches_backup [k ] = old
182
182
setattr (self .model , k , self .object_patches [k ])
183
183
184
- model_sd = self .model_state_dict ()
185
- for key in self .patches :
186
- if key not in model_sd :
187
- print ("could not patch. key doesn't exist in model:" , key )
188
- continue
184
+ if patch_weights :
185
+ model_sd = self .model_state_dict ()
186
+ for key in self .patches :
187
+ if key not in model_sd :
188
+ print ("could not patch. key doesn't exist in model:" , key )
189
+ continue
189
190
190
- weight = model_sd [key ]
191
+ weight = model_sd [key ]
191
192
192
- inplace_update = self .weight_inplace_update
193
+ inplace_update = self .weight_inplace_update
193
194
194
- if key not in self .backup :
195
- self .backup [key ] = weight .to (device = self .offload_device , copy = inplace_update )
195
+ if key not in self .backup :
196
+ self .backup [key ] = weight .to (device = self .offload_device , copy = inplace_update )
196
197
197
- if device_to is not None :
198
- temp_weight = comfy .model_management .cast_to_device (weight , device_to , torch .float32 , copy = True )
199
- else :
200
- temp_weight = weight .to (torch .float32 , copy = True )
201
- out_weight = self .calculate_weight (self .patches [key ], temp_weight , key ).to (weight .dtype )
202
- if inplace_update :
203
- comfy .utils .copy_to_param (self .model , key , out_weight )
204
- else :
205
- comfy .utils .set_attr (self .model , key , out_weight )
206
- del temp_weight
198
+ if device_to is not None :
199
+ temp_weight = comfy .model_management .cast_to_device (weight , device_to , torch .float32 , copy = True )
200
+ else :
201
+ temp_weight = weight .to (torch .float32 , copy = True )
202
+ out_weight = self .calculate_weight (self .patches [key ], temp_weight , key ).to (weight .dtype )
203
+ if inplace_update :
204
+ comfy .utils .copy_to_param (self .model , key , out_weight )
205
+ else :
206
+ comfy .utils .set_attr (self .model , key , out_weight )
207
+ del temp_weight
207
208
208
- if device_to is not None :
209
- self .model .to (device_to )
210
- self .current_device = device_to
209
+ if device_to is not None :
210
+ self .model .to (device_to )
211
+ self .current_device = device_to
211
212
212
213
return self .model
213
214
0 commit comments