Skip to content

Commit

Permalink
Merge branch 'master' into beta
Browse files Browse the repository at this point in the history
  • Loading branch information
jn-jairo committed Mar 26, 2024
2 parents fdebd80 + ae77590 commit 268f79e
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 14 deletions.
14 changes: 10 additions & 4 deletions comfy/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@ def load_lora(lora, to_load):
alpha = lora[alpha_name].item()
loaded_keys.add(alpha_name)

dora_scale_name = "{}.dora_scale".format(x)
dora_scale = None
if dora_scale_name in lora.keys():
dora_scale = lora[dora_scale_name]
loaded_keys.add(dora_scale_name)

regular_lora = "{}.lora_up.weight".format(x)
diffusers_lora = "{}_lora.up.weight".format(x)
transformers_lora = "{}.lora_linear_layer.up.weight".format(x)
Expand All @@ -44,7 +50,7 @@ def load_lora(lora, to_load):
if mid_name is not None and mid_name in lora.keys():
mid = lora[mid_name]
loaded_keys.add(mid_name)
patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid))
patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid, dora_scale))
loaded_keys.add(A_name)
loaded_keys.add(B_name)

Expand All @@ -65,7 +71,7 @@ def load_lora(lora, to_load):
loaded_keys.add(hada_t1_name)
loaded_keys.add(hada_t2_name)

patch_dict[to_load[x]] = ("loha", (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2))
patch_dict[to_load[x]] = ("loha", (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2, dora_scale))
loaded_keys.add(hada_w1_a_name)
loaded_keys.add(hada_w1_b_name)
loaded_keys.add(hada_w2_a_name)
Expand Down Expand Up @@ -117,15 +123,15 @@ def load_lora(lora, to_load):
loaded_keys.add(lokr_t2_name)

if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None):
patch_dict[to_load[x]] = ("lokr", (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2))
patch_dict[to_load[x]] = ("lokr", (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale))

#glora
a1_name = "{}.a1.weight".format(x)
a2_name = "{}.a2.weight".format(x)
b1_name = "{}.b1.weight".format(x)
b2_name = "{}.b2.weight".format(x)
if a1_name in lora:
patch_dict[to_load[x]] = ("glora", (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha))
patch_dict[to_load[x]] = ("glora", (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha, dora_scale))
loaded_keys.add(a1_name)
loaded_keys.add(a2_name)
loaded_keys.add(b1_name)
Expand Down
25 changes: 16 additions & 9 deletions comfy/model_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,20 +363,27 @@ def unload_model_clones(model, unload_weights_only=True, force_unload=True):
return unload_weight

def free_memory(memory_required, device, keep_loaded=[]):
unloaded_model = False
unloaded_model = []
can_unload = []

for i in range(len(current_loaded_models) -1, -1, -1):
if not DISABLE_SMART_MEMORY:
if get_free_memory(device) > memory_required:
break
shift_model = current_loaded_models[i]
if shift_model.device == device:
if shift_model not in keep_loaded:
m = current_loaded_models.pop(i)
m.model_unload()
del m
unloaded_model = True
can_unload.append((sys.getrefcount(shift_model.model), shift_model.model_memory(), i))

for x in sorted(can_unload):
i = x[-1]
if not DISABLE_SMART_MEMORY:
if get_free_memory(device) > memory_required:
break
current_loaded_models[i].model_unload()
unloaded_model.append(i)

for i in sorted(unloaded_model, reverse=True):
current_loaded_models.pop(i)

if unloaded_model:
if len(unloaded_model) > 0:
soft_empty_cache()
else:
if vram_state != VRAMState.HIGH_VRAM:
Expand Down
25 changes: 25 additions & 0 deletions comfy/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,18 @@
import comfy.utils
import comfy.model_management

def apply_weight_decompose(dora_scale, weight):
weight_norm = (
weight.transpose(0, 1)
.reshape(weight.shape[1], -1)
.norm(dim=1, keepdim=True)
.reshape(weight.shape[1], *[1] * (weight.dim() - 1))
.transpose(0, 1)
)

return weight * (dora_scale / weight_norm)


class ModelPatcher:
def __init__(self, model, load_device, offload_device, size=0, current_device=None, weight_inplace_update=False):
self.size = size
Expand Down Expand Up @@ -309,6 +321,7 @@ def calculate_weight(self, patches, weight, key):
elif patch_type == "lora": #lora/locon
mat1 = comfy.model_management.cast_to_device(v[0], weight.device, torch.float32)
mat2 = comfy.model_management.cast_to_device(v[1], weight.device, torch.float32)
dora_scale = v[4]
if v[2] is not None:
alpha *= v[2] / mat2.shape[0]
if v[3] is not None:
Expand All @@ -318,6 +331,8 @@ def calculate_weight(self, patches, weight, key):
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
try:
weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape(weight.shape).type(weight.dtype)
if dora_scale is not None:
weight = apply_weight_decompose(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight)
except Exception as e:
logging.error("ERROR {} {} {}".format(patch_type, key, e))
elif patch_type == "lokr":
Expand All @@ -328,6 +343,7 @@ def calculate_weight(self, patches, weight, key):
w2_a = v[5]
w2_b = v[6]
t2 = v[7]
dora_scale = v[8]
dim = None

if w1 is None:
Expand Down Expand Up @@ -357,6 +373,8 @@ def calculate_weight(self, patches, weight, key):

try:
weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype)
if dora_scale is not None:
weight = apply_weight_decompose(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight)
except Exception as e:
logging.error("ERROR {} {} {}".format(patch_type, key, e))
elif patch_type == "loha":
Expand All @@ -366,6 +384,7 @@ def calculate_weight(self, patches, weight, key):
alpha *= v[2] / w1b.shape[0]
w2a = v[3]
w2b = v[4]
dora_scale = v[7]
if v[5] is not None: #cp decomposition
t1 = v[5]
t2 = v[6]
Expand All @@ -386,19 +405,25 @@ def calculate_weight(self, patches, weight, key):

try:
weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype)
if dora_scale is not None:
weight = apply_weight_decompose(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight)
except Exception as e:
logging.error("ERROR {} {} {}".format(patch_type, key, e))
elif patch_type == "glora":
if v[4] is not None:
alpha *= v[4] / v[0].shape[0]

dora_scale = v[5]

a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, torch.float32)
a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, torch.float32)
b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, torch.float32)
b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, torch.float32)

try:
weight += ((torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)) * alpha).reshape(weight.shape).type(weight.dtype)
if dora_scale is not None:
weight = apply_weight_decompose(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight)
except Exception as e:
logging.error("ERROR {} {} {}".format(patch_type, key, e))
else:
Expand Down
1 change: 0 additions & 1 deletion execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,6 @@ def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
d = self.outputs_ui.pop(x)
del d

comfy.model_management.cleanup_models()
self.add_message("execution_cached",
{ "nodes": list(current_outputs) , "prompt_id": prompt_id},
broadcast=False)
Expand Down
1 change: 1 addition & 0 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ def prompt_worker(q, server):
if need_gc:
current_time = time.perf_counter()
if (current_time - last_gc_collect) > gc_collect_interval:
comfy.model_management.cleanup_models()
gc.collect()
comfy.model_management.soft_empty_cache()
last_gc_collect = current_time
Expand Down

0 comments on commit 268f79e

Please sign in to comment.