Skip to content

Commit

Permalink
fix bugs in lora support
Browse files Browse the repository at this point in the history
  • Loading branch information
AUTOMATIC1111 committed Jul 16, 2024
1 parent 7e5cdaa commit 2b50233
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions extensions-builtin/Lora/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,7 @@ def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Li
if weights_backup is not None:
if isinstance(self, torch.nn.MultiheadAttention):
restore_weights_backup(self, 'in_proj_weight', weights_backup[0])
restore_weights_backup(self.out_proj, 'weight', weights_backup[0])
restore_weights_backup(self.out_proj, 'weight', weights_backup[1])
else:
restore_weights_backup(self, 'weight', weights_backup)

Expand Down Expand Up @@ -437,7 +437,7 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
bias_backup = getattr(self, "network_bias_backup", None)
if bias_backup is None and wanted_names != ():
if isinstance(self, torch.nn.MultiheadAttention) and self.out_proj.bias is not None:
bias_backup = store_weights_backup(self.out_proj)
bias_backup = store_weights_backup(self.out_proj.bias)
elif getattr(self, 'bias', None) is not None:
bias_backup = store_weights_backup(self.bias)
else:
Expand Down

0 comments on commit 2b50233

Please sign in to comment.