7
7
import comfy .utils
8
8
import comfy .model_management
9
9
10
+ def apply_weight_decompose (dora_scale , weight ):
11
+ weight_norm = (
12
+ weight .transpose (0 , 1 )
13
+ .reshape (weight .shape [1 ], - 1 )
14
+ .norm (dim = 1 , keepdim = True )
15
+ .reshape (weight .shape [1 ], * [1 ] * (weight .dim () - 1 ))
16
+ .transpose (0 , 1 )
17
+ )
18
+
19
+ return weight * (dora_scale / weight_norm )
20
+
21
+
10
22
class ModelPatcher :
11
23
def __init__ (self , model , load_device , offload_device , size = 0 , current_device = None , weight_inplace_update = False ):
12
24
self .size = size
@@ -309,6 +321,7 @@ def calculate_weight(self, patches, weight, key):
309
321
elif patch_type == "lora" : #lora/locon
310
322
mat1 = comfy .model_management .cast_to_device (v [0 ], weight .device , torch .float32 )
311
323
mat2 = comfy .model_management .cast_to_device (v [1 ], weight .device , torch .float32 )
324
+ dora_scale = v [4 ]
312
325
if v [2 ] is not None :
313
326
alpha *= v [2 ] / mat2 .shape [0 ]
314
327
if v [3 ] is not None :
@@ -318,6 +331,8 @@ def calculate_weight(self, patches, weight, key):
318
331
mat2 = torch .mm (mat2 .transpose (0 , 1 ).flatten (start_dim = 1 ), mat3 .transpose (0 , 1 ).flatten (start_dim = 1 )).reshape (final_shape ).transpose (0 , 1 )
319
332
try :
320
333
weight += (alpha * torch .mm (mat1 .flatten (start_dim = 1 ), mat2 .flatten (start_dim = 1 ))).reshape (weight .shape ).type (weight .dtype )
334
+ if dora_scale is not None :
335
+ weight = apply_weight_decompose (comfy .model_management .cast_to_device (dora_scale , weight .device , torch .float32 ), weight )
321
336
except Exception as e :
322
337
logging .error ("ERROR {} {} {}" .format (patch_type , key , e ))
323
338
elif patch_type == "lokr" :
@@ -328,6 +343,7 @@ def calculate_weight(self, patches, weight, key):
328
343
w2_a = v [5 ]
329
344
w2_b = v [6 ]
330
345
t2 = v [7 ]
346
+ dora_scale = v [8 ]
331
347
dim = None
332
348
333
349
if w1 is None :
@@ -357,6 +373,8 @@ def calculate_weight(self, patches, weight, key):
357
373
358
374
try :
359
375
weight += alpha * torch .kron (w1 , w2 ).reshape (weight .shape ).type (weight .dtype )
376
+ if dora_scale is not None :
377
+ weight = apply_weight_decompose (comfy .model_management .cast_to_device (dora_scale , weight .device , torch .float32 ), weight )
360
378
except Exception as e :
361
379
logging .error ("ERROR {} {} {}" .format (patch_type , key , e ))
362
380
elif patch_type == "loha" :
@@ -366,6 +384,7 @@ def calculate_weight(self, patches, weight, key):
366
384
alpha *= v [2 ] / w1b .shape [0 ]
367
385
w2a = v [3 ]
368
386
w2b = v [4 ]
387
+ dora_scale = v [7 ]
369
388
if v [5 ] is not None : #cp decomposition
370
389
t1 = v [5 ]
371
390
t2 = v [6 ]
@@ -386,19 +405,25 @@ def calculate_weight(self, patches, weight, key):
386
405
387
406
try :
388
407
weight += (alpha * m1 * m2 ).reshape (weight .shape ).type (weight .dtype )
408
+ if dora_scale is not None :
409
+ weight = apply_weight_decompose (comfy .model_management .cast_to_device (dora_scale , weight .device , torch .float32 ), weight )
389
410
except Exception as e :
390
411
logging .error ("ERROR {} {} {}" .format (patch_type , key , e ))
391
412
elif patch_type == "glora" :
392
413
if v [4 ] is not None :
393
414
alpha *= v [4 ] / v [0 ].shape [0 ]
394
415
416
+ dora_scale = v [5 ]
417
+
395
418
a1 = comfy .model_management .cast_to_device (v [0 ].flatten (start_dim = 1 ), weight .device , torch .float32 )
396
419
a2 = comfy .model_management .cast_to_device (v [1 ].flatten (start_dim = 1 ), weight .device , torch .float32 )
397
420
b1 = comfy .model_management .cast_to_device (v [2 ].flatten (start_dim = 1 ), weight .device , torch .float32 )
398
421
b2 = comfy .model_management .cast_to_device (v [3 ].flatten (start_dim = 1 ), weight .device , torch .float32 )
399
422
400
423
try :
401
424
weight += ((torch .mm (b2 , b1 ) + torch .mm (torch .mm (weight .flatten (start_dim = 1 ), a2 ), a1 )) * alpha ).reshape (weight .shape ).type (weight .dtype )
425
+ if dora_scale is not None :
426
+ weight = apply_weight_decompose (comfy .model_management .cast_to_device (dora_scale , weight .device , torch .float32 ), weight )
402
427
except Exception as e :
403
428
logging .error ("ERROR {} {} {}" .format (patch_type , key , e ))
404
429
else :
0 commit comments