Skip to content

Commit

Permalink
Merge branch 'master' into beta
Browse files Browse the repository at this point in the history
  • Loading branch information
jn-jairo committed Dec 10, 2023
2 parents 8fe502d + 340177e commit 9de5a86
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 18 deletions.
10 changes: 7 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,23 +93,27 @@ Put your SD checkpoints (the huge ckpt/safetensors files) in: models/checkpoints

Put your VAE in: models/vae

Note: pytorch does not support python 3.12 yet so make sure your python version is 3.11 or earlier.
Note: pytorch stable does not support python 3.12 yet. If you have python 3.12 you will have to use the nightly version of pytorch. If you run into issues you should try python 3.11 instead.

### AMD GPUs (Linux only)
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:

```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm5.6```

This is the command to install the nightly with ROCm 5.7 that might have some performance improvements:
This is the command to install the nightly with ROCm 5.7 which has a python 3.12 package and might have some performance improvements:

```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm5.7```

### NVIDIA

Nvidia users should install pytorch using this command:
Nvidia users should install stable pytorch using this command:

```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu121```

This is the command to install pytorch nightly instead which has a python 3.12 package and might have performance improvements:

```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121```

#### Troubleshooting

If you get the "Torch not compiled with CUDA enabled" error, uninstall torch with:
Expand Down
27 changes: 19 additions & 8 deletions comfy/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def load_lora(lora, to_load):
if mid_name is not None and mid_name in lora.keys():
mid = lora[mid_name]
loaded_keys.add(mid_name)
patch_dict[to_load[x]] = (lora[A_name], lora[B_name], alpha, mid)
patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid))
loaded_keys.add(A_name)
loaded_keys.add(B_name)

Expand All @@ -64,7 +64,7 @@ def load_lora(lora, to_load):
loaded_keys.add(hada_t1_name)
loaded_keys.add(hada_t2_name)

patch_dict[to_load[x]] = (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2)
patch_dict[to_load[x]] = ("loha", (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2))
loaded_keys.add(hada_w1_a_name)
loaded_keys.add(hada_w1_b_name)
loaded_keys.add(hada_w2_a_name)
Expand Down Expand Up @@ -116,8 +116,19 @@ def load_lora(lora, to_load):
loaded_keys.add(lokr_t2_name)

if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None):
patch_dict[to_load[x]] = (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2)

patch_dict[to_load[x]] = ("lokr", (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2))

#glora
a1_name = "{}.a1.weight".format(x)
a2_name = "{}.a2.weight".format(x)
b1_name = "{}.b1.weight".format(x)
b2_name = "{}.b2.weight".format(x)
if a1_name in lora:
patch_dict[to_load[x]] = ("glora", (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha))
loaded_keys.add(a1_name)
loaded_keys.add(a2_name)
loaded_keys.add(b1_name)
loaded_keys.add(b2_name)

w_norm_name = "{}.w_norm".format(x)
b_norm_name = "{}.b_norm".format(x)
Expand All @@ -126,21 +137,21 @@ def load_lora(lora, to_load):

if w_norm is not None:
loaded_keys.add(w_norm_name)
patch_dict[to_load[x]] = (w_norm,)
patch_dict[to_load[x]] = ("diff", (w_norm,))
if b_norm is not None:
loaded_keys.add(b_norm_name)
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = (b_norm,)
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (b_norm,))

diff_name = "{}.diff".format(x)
diff_weight = lora.get(diff_name, None)
if diff_weight is not None:
patch_dict[to_load[x]] = (diff_weight,)
patch_dict[to_load[x]] = ("diff", (diff_weight,))
loaded_keys.add(diff_name)

diff_bias_name = "{}.diff_b".format(x)
diff_bias = lora.get(diff_bias_name, None)
if diff_bias is not None:
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = (diff_bias,)
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (diff_bias,))
loaded_keys.add(diff_bias_name)

for x in lora.keys():
Expand Down
12 changes: 8 additions & 4 deletions comfy/model_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,15 +567,19 @@ def cast_to_device(tensor, device, dtype, copy=False):
elif is_intel_xpu():
device_supports_cast = True

non_blocking = True
if is_device_mps(device):
non_blocking = False #pytorch bug? mps doesn't support non blocking

if device_supports_cast:
if copy:
if tensor.device == device:
return tensor.to(dtype, copy=copy, non_blocking=True)
return tensor.to(device, copy=copy, non_blocking=True).to(dtype, non_blocking=True)
return tensor.to(dtype, copy=copy, non_blocking=non_blocking)
return tensor.to(device, copy=copy, non_blocking=non_blocking).to(dtype, non_blocking=non_blocking)
else:
return tensor.to(device, non_blocking=True).to(dtype, non_blocking=True)
return tensor.to(device, non_blocking=non_blocking).to(dtype, non_blocking=non_blocking)
else:
return tensor.to(device, dtype, copy=copy, non_blocking=True)
return tensor.to(device, dtype, copy=copy, non_blocking=non_blocking)

def xformers_enabled():
global directml_enabled
Expand Down
24 changes: 21 additions & 3 deletions comfy/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,13 +217,19 @@ def calculate_weight(self, patches, weight, key):
v = (self.calculate_weight(v[1:], v[0].clone(), key), )

if len(v) == 1:
patch_type = "diff"
elif len(v) == 2:
patch_type = v[0]
v = v[1]

if patch_type == "diff":
w1 = v[0]
if alpha != 0.0:
if w1.shape != weight.shape:
print("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
else:
weight += alpha * comfy.model_management.cast_to_device(w1, weight.device, weight.dtype)
elif len(v) == 4: #lora/locon
elif patch_type == "lora": #lora/locon
mat1 = comfy.model_management.cast_to_device(v[0], weight.device, torch.float32)
mat2 = comfy.model_management.cast_to_device(v[1], weight.device, torch.float32)
if v[2] is not None:
Expand All @@ -237,7 +243,7 @@ def calculate_weight(self, patches, weight, key):
weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape(weight.shape).type(weight.dtype)
except Exception as e:
print("ERROR", key, e)
elif len(v) == 8: #lokr
elif patch_type == "lokr":
w1 = v[0]
w2 = v[1]
w1_a = v[3]
Expand Down Expand Up @@ -276,7 +282,7 @@ def calculate_weight(self, patches, weight, key):
weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype)
except Exception as e:
print("ERROR", key, e)
else: #loha
elif patch_type == "loha":
w1a = v[0]
w1b = v[1]
if v[2] is not None:
Expand Down Expand Up @@ -305,6 +311,18 @@ def calculate_weight(self, patches, weight, key):
weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype)
except Exception as e:
print("ERROR", key, e)
elif patch_type == "glora":
if v[4] is not None:
alpha *= v[4] / v[0].shape[0]

a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, torch.float32)
a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, torch.float32)
b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, torch.float32)
b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, torch.float32)

weight += ((torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)) * alpha).reshape(weight.shape).type(weight.dtype)
else:
print("patch type not recognized", patch_type, key)

return weight

Expand Down

0 comments on commit 9de5a86

Please sign in to comment.