diff --git a/comfy/lora.py b/comfy/lora.py index 637380d54ce..096285bba3e 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -21,6 +21,12 @@ def load_lora(lora, to_load): alpha = lora[alpha_name].item() loaded_keys.add(alpha_name) + dora_scale_name = "{}.dora_scale".format(x) + dora_scale = None + if dora_scale_name in lora.keys(): + dora_scale = lora[dora_scale_name] + loaded_keys.add(dora_scale_name) + regular_lora = "{}.lora_up.weight".format(x) diffusers_lora = "{}_lora.up.weight".format(x) transformers_lora = "{}.lora_linear_layer.up.weight".format(x) @@ -44,7 +50,7 @@ def load_lora(lora, to_load): if mid_name is not None and mid_name in lora.keys(): mid = lora[mid_name] loaded_keys.add(mid_name) - patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid)) + patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid, dora_scale)) loaded_keys.add(A_name) loaded_keys.add(B_name) @@ -65,7 +71,7 @@ def load_lora(lora, to_load): loaded_keys.add(hada_t1_name) loaded_keys.add(hada_t2_name) - patch_dict[to_load[x]] = ("loha", (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2)) + patch_dict[to_load[x]] = ("loha", (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2, dora_scale)) loaded_keys.add(hada_w1_a_name) loaded_keys.add(hada_w1_b_name) loaded_keys.add(hada_w2_a_name) @@ -117,7 +123,7 @@ def load_lora(lora, to_load): loaded_keys.add(lokr_t2_name) if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None): - patch_dict[to_load[x]] = ("lokr", (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2)) + patch_dict[to_load[x]] = ("lokr", (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale)) #glora a1_name = "{}.a1.weight".format(x) @@ -125,7 +131,7 @@ def load_lora(lora, to_load): b1_name = "{}.b1.weight".format(x) b2_name = "{}.b2.weight".format(x) if a1_name in lora: - patch_dict[to_load[x]] = ("glora", (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha)) + patch_dict[to_load[x]] = ("glora", (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha, dora_scale)) loaded_keys.add(a1_name) loaded_keys.add(a2_name) loaded_keys.add(b1_name) diff --git a/comfy/model_management.py b/comfy/model_management.py index ae30130d1f8..b644b141747 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -363,20 +363,27 @@ def unload_model_clones(model, unload_weights_only=True, force_unload=True): return unload_weight def free_memory(memory_required, device, keep_loaded=[]): - unloaded_model = False + unloaded_model = [] + can_unload = [] + for i in range(len(current_loaded_models) -1, -1, -1): - if not DISABLE_SMART_MEMORY: - if get_free_memory(device) > memory_required: - break shift_model = current_loaded_models[i] if shift_model.device == device: if shift_model not in keep_loaded: - m = current_loaded_models.pop(i) - m.model_unload() - del m - unloaded_model = True + can_unload.append((sys.getrefcount(shift_model.model), shift_model.model_memory(), i)) + + for x in sorted(can_unload): + i = x[-1] + if not DISABLE_SMART_MEMORY: + if get_free_memory(device) > memory_required: + break + current_loaded_models[i].model_unload() + unloaded_model.append(i) + + for i in sorted(unloaded_model, reverse=True): + current_loaded_models.pop(i) - if unloaded_model: + if len(unloaded_model) > 0: soft_empty_cache() else: if vram_state != VRAMState.HIGH_VRAM: diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index aa78302d215..8dda84cfd8b 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -7,6 +7,18 @@ import comfy.utils import comfy.model_management +def apply_weight_decompose(dora_scale, weight): + weight_norm = ( + weight.transpose(0, 1) + .reshape(weight.shape[1], -1) + .norm(dim=1, keepdim=True) + .reshape(weight.shape[1], *[1] * (weight.dim() - 1)) + .transpose(0, 1) + ) + + return weight * (dora_scale / weight_norm) + + class ModelPatcher: def __init__(self, model, load_device, offload_device, size=0, current_device=None, weight_inplace_update=False): self.size = size @@ -309,6 +321,7 @@ def calculate_weight(self, patches, weight, key): elif patch_type == "lora": #lora/locon mat1 = comfy.model_management.cast_to_device(v[0], weight.device, torch.float32) mat2 = comfy.model_management.cast_to_device(v[1], weight.device, torch.float32) + dora_scale = v[4] if v[2] is not None: alpha *= v[2] / mat2.shape[0] if v[3] is not None: @@ -318,6 +331,8 @@ def calculate_weight(self, patches, weight, key): mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1) try: weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape(weight.shape).type(weight.dtype) + if dora_scale is not None: + weight = apply_weight_decompose(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight) except Exception as e: logging.error("ERROR {} {} {}".format(patch_type, key, e)) elif patch_type == "lokr": @@ -328,6 +343,7 @@ def calculate_weight(self, patches, weight, key): w2_a = v[5] w2_b = v[6] t2 = v[7] + dora_scale = v[8] dim = None if w1 is None: @@ -357,6 +373,8 @@ def calculate_weight(self, patches, weight, key): try: weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype) + if dora_scale is not None: + weight = apply_weight_decompose(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight) except Exception as e: logging.error("ERROR {} {} {}".format(patch_type, key, e)) elif patch_type == "loha": @@ -366,6 +384,7 @@ def calculate_weight(self, patches, weight, key): alpha *= v[2] / w1b.shape[0] w2a = v[3] w2b = v[4] + dora_scale = v[7] if v[5] is not None: #cp decomposition t1 = v[5] t2 = v[6] @@ -386,12 +405,16 @@ def calculate_weight(self, patches, weight, key): try: weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype) + if dora_scale is not None: + weight = apply_weight_decompose(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight) except Exception as e: logging.error("ERROR {} {} {}".format(patch_type, key, e)) elif patch_type == "glora": if v[4] is not None: alpha *= v[4] / v[0].shape[0] + dora_scale = v[5] + a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, torch.float32) a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, torch.float32) b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, torch.float32) @@ -399,6 +422,8 @@ def calculate_weight(self, patches, weight, key): try: weight += ((torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)) * alpha).reshape(weight.shape).type(weight.dtype) + if dora_scale is not None: + weight = apply_weight_decompose(comfy.model_management.cast_to_device(dora_scale, weight.device, torch.float32), weight) except Exception as e: logging.error("ERROR {} {} {}".format(patch_type, key, e)) else: diff --git a/execution.py b/execution.py index 7cf09465bda..64978fbe30e 100644 --- a/execution.py +++ b/execution.py @@ -392,7 +392,6 @@ def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]): d = self.outputs_ui.pop(x) del d - comfy.model_management.cleanup_models() self.add_message("execution_cached", { "nodes": list(current_outputs) , "prompt_id": prompt_id}, broadcast=False) diff --git a/main.py b/main.py index 3dee72e3a3f..b3a3ebea8cf 100644 --- a/main.py +++ b/main.py @@ -139,6 +139,7 @@ def prompt_worker(q, server): if need_gc: current_time = time.perf_counter() if (current_time - last_gc_collect) > gc_collect_interval: + comfy.model_management.cleanup_models() gc.collect() comfy.model_management.soft_empty_cache() last_gc_collect = current_time