This repository has been archived by the owner on Mar 29, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 9
/
lora_save.py
43 lines (36 loc) · 1.55 KB
/
lora_save.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import comfy
import folder_paths
import math
import os
class LoraSave:
def __init__(self):
self.loaded_lora = None
@classmethod
def INPUT_TYPES(s):
return {"required": { "lora": ("LoRA",),
"file_name": ("STRING", {"multiline": False, "default": "merged"}),
"extension": (["safetensors"], ),
}}
RETURN_TYPES = ()
FUNCTION = "lora_save"
CATEGORY = "lora_merge"
OUTPUT_NODE = True
def lora_save(self, lora, file_name, extension):
save_path = os.path.join(folder_paths.folder_names_and_paths["loras"][0][0], file_name + "." + extension)
if lora["strength_model"] == 1 and lora["strength_clip"] == 1:
new_state_dict = lora["lora"]
else:
new_state_dict = {}
for key in lora["lora"].keys():
scale = lora["strength_clip"] if "lora_te" in key else lora["strength_model"]
sqrt_scale = math.sqrt(abs(scale))
sign_scale = 1 if scale >= 0 else -1
if "lora_up" in key:
new_state_dict[key] = lora["lora"][key] * sqrt_scale * sign_scale
elif "lora_down" in key:
new_state_dict[key] = lora["lora"][key] * sqrt_scale
else:
new_state_dict[key] = lora["lora"][key]
print(f"Saving LoRA to {save_path}")
comfy.utils.save_torch_file(new_state_dict, save_path)
return {}