Skip to content

Commit ae77590

Browse files
dora_scale support for lora file.
1 parent c6de09b commit ae77590

File tree

2 files changed

+35
-4
lines changed

2 files changed

+35
-4
lines changed

comfy/lora.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,12 @@ def load_lora(lora, to_load):
2121
alpha = lora[alpha_name].item()
2222
loaded_keys.add(alpha_name)
2323

24+
dora_scale_name = "{}.dora_scale".format(x)
25+
dora_scale = None
26+
if dora_scale_name in lora.keys():
27+
dora_scale = lora[dora_scale_name]
28+
loaded_keys.add(dora_scale_name)
29+
2430
regular_lora = "{}.lora_up.weight".format(x)
2531
diffusers_lora = "{}_lora.up.weight".format(x)
2632
transformers_lora = "{}.lora_linear_layer.up.weight".format(x)
@@ -44,7 +50,7 @@ def load_lora(lora, to_load):
4450
if mid_name is not None and mid_name in lora.keys():
4551
mid = lora[mid_name]
4652
loaded_keys.add(mid_name)
47-
patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid))
53+
patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid, dora_scale))
4854
loaded_keys.add(A_name)
4955
loaded_keys.add(B_name)
5056

@@ -65,7 +71,7 @@ def load_lora(lora, to_load):
6571
loaded_keys.add(hada_t1_name)
6672
loaded_keys.add(hada_t2_name)
6773

68-
patch_dict[to_load[x]] = ("loha", (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2))
74+
patch_dict[to_load[x]] = ("loha", (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2, dora_scale))
6975
loaded_keys.add(hada_w1_a_name)
7076
loaded_keys.add(hada_w1_b_name)
7177
loaded_keys.add(hada_w2_a_name)
@@ -117,15 +123,15 @@ def load_lora(lora, to_load):
117123
loaded_keys.add(lokr_t2_name)
118124

119125
if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None):
120-
patch_dict[to_load[x]] = ("lokr", (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2))
126+
patch_dict[to_load[x]] = ("lokr", (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale))
121127

122128
#glora
123129
a1_name = "{}.a1.weight".format(x)
124130
a2_name = "{}.a2.weight".format(x)
125131
b1_name = "{}.b1.weight".format(x)
126132
b2_name = "{}.b2.weight".format(x)
127133
if a1_name in lora:
128-
patch_dict[to_load[x]] = ("glora", (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha))
134+
patch_dict[to_load[x]] = ("glora", (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha, dora_scale))
129135
loaded_keys.add(a1_name)
130136
loaded_keys.add(a2_name)
131137
loaded_keys.add(b1_name)

comfy/model_patcher.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,18 @@
77
import comfy.utils
88
import comfy.model_management
99

10+
def apply_weight_decompose(dora_scale, weight):
11+
weight_norm = (
12+
weight.transpose(0, 1)
13+
.reshape(weight.shape[1], -1)
14+
.norm(dim=1, keepdim=True)
15+
.reshape(weight.shape[1], *[1] * (weight.dim() - 1))
16+
.transpose(0, 1)
17+
)
18+
19+
return weight * (dora_scale / weight_norm)
20+
21+
1022
class ModelPatcher:
1123
def __init__(self, model, load_device, offload_device, size=0, current_device=None, weight_inplace_update=False):
1224
self.size = size
@@ -309,6 +321,7 @@ def calculate_weight(self, patches, weight, key):
309321
elif patch_type == "lora": #lora/locon
310322
mat1 = comfy.model_management.cast_to_device(v[0], weight.device, torch.float32)
311323
mat2 = comfy.model_management.cast_to_device(v[1], weight.device, torch.float32)
324+
dora_scale = v[4]
312325
if v[2] is not None:
313326
alpha *= v[2] / mat2.shape[0]
314327
if v[3] is not None:
@@ -318,6 +331,8 @@ def calculate_weight(self, patches, weight, key):
318331
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
319332
try:
320333
weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape(weight.shape).type(weight.dtype)
334+
if dora_scale is not None:
335+
weight = apply_weight_decompose(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight)
321336
except Exception as e:
322337
logging.error("ERROR {} {} {}".format(patch_type, key, e))
323338
elif patch_type == "lokr":
@@ -328,6 +343,7 @@ def calculate_weight(self, patches, weight, key):
328343
w2_a = v[5]
329344
w2_b = v[6]
330345
t2 = v[7]
346+
dora_scale = v[8]
331347
dim = None
332348

333349
if w1 is None:
@@ -357,6 +373,8 @@ def calculate_weight(self, patches, weight, key):
357373

358374
try:
359375
weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype)
376+
if dora_scale is not None:
377+
weight = apply_weight_decompose(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight)
360378
except Exception as e:
361379
logging.error("ERROR {} {} {}".format(patch_type, key, e))
362380
elif patch_type == "loha":
@@ -366,6 +384,7 @@ def calculate_weight(self, patches, weight, key):
366384
alpha *= v[2] / w1b.shape[0]
367385
w2a = v[3]
368386
w2b = v[4]
387+
dora_scale = v[7]
369388
if v[5] is not None: #cp decomposition
370389
t1 = v[5]
371390
t2 = v[6]
@@ -386,19 +405,25 @@ def calculate_weight(self, patches, weight, key):
386405

387406
try:
388407
weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype)
408+
if dora_scale is not None:
409+
weight = apply_weight_decompose(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight)
389410
except Exception as e:
390411
logging.error("ERROR {} {} {}".format(patch_type, key, e))
391412
elif patch_type == "glora":
392413
if v[4] is not None:
393414
alpha *= v[4] / v[0].shape[0]
394415

416+
dora_scale = v[5]
417+
395418
a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, torch.float32)
396419
a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, torch.float32)
397420
b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, torch.float32)
398421
b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, torch.float32)
399422

400423
try:
401424
weight += ((torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)) * alpha).reshape(weight.shape).type(weight.dtype)
425+
if dora_scale is not None:
426+
weight = apply_weight_decompose(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight)
402427
except Exception as e:
403428
logging.error("ERROR {} {} {}".format(patch_type, key, e))
404429
else:

0 commit comments

Comments
 (0)