Skip to content

Commit a2e0701

Browse files
committed
Refactoring to use controller
1 parent e942ca4 commit a2e0701

File tree

4 files changed

+186
-5
lines changed

4 files changed

+186
-5
lines changed
File renamed without changes.

loraw/loraw_controller.py

+61
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
from torch import optim
2+
from ema_pytorch import EMA
3+
4+
from .loraw_network import LoRAWNetwork
5+
from .loraw_module import LoRAWModule
6+
7+
class LoRAWController:
8+
def __init__(self, target_model, target_config) -> None:
9+
self.target_model = target_model
10+
self.target_config = target_config
11+
12+
self.lr = 0
13+
self.lora_ema = None
14+
15+
def create_diffuser_lora(
16+
self,
17+
lora_dim=16,
18+
alpha=1,
19+
dropout=None,
20+
):
21+
self.lora = LoRAWNetwork(
22+
net=self.target_model,
23+
target_subnets=["downsamples", "upsamples"],
24+
target_modules=["Attention"],
25+
lora_dim=lora_dim,
26+
alpha=alpha,
27+
dropout=dropout,
28+
multiplier=1.0,
29+
module_class=LoRAWModule,
30+
verbose=False,
31+
)
32+
33+
def configure_optimizer_patched(self):
34+
return optim.Adam([*self.lora.parameters()], lr=self.lr)
35+
36+
def on_before_zero_grad_patched(self, *args, **kwargs):
37+
self.lora_ema.update()
38+
39+
def activate(self, training_wrapper=None):
40+
#self.lora.to(device=self.target_model.device)
41+
self.lora.activate()
42+
43+
if training_wrapper is not None:
44+
45+
# Freeze main diffusion model
46+
self.target_model.requires_grad_(False)
47+
self.lora.requires_grad_(True)
48+
49+
# Replace optimizer to use lora parameters
50+
self.lr = training_wrapper.lr
51+
training_wrapper.configure_optimizers = self.configure_optimizer_patched
52+
53+
# Replace ema update
54+
self.lora_ema = EMA(
55+
self.lora,
56+
beta=0.9999,
57+
power=3/4,
58+
update_every=1,
59+
update_after_step=1
60+
)
61+
training_wrapper.on_before_zero_grad = self.on_before_zero_grad_patched

models/loraw_module.py renamed to loraw/loraw_module.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import torch
33
from torch import nn
44

5+
56
class LoRAWModule(nn.Module):
67
def __init__(
78
self,
@@ -28,7 +29,7 @@ def __init__(
2829

2930
module_type = orig_module.__class__.__name__
3031

31-
if module_type == "Conv1d":
32+
if module_type == "Conv1d":
3233
in_dim = orig_module.in_channels
3334
out_dim = orig_module.out_channels
3435
kernel_size = orig_module.kernel_size
@@ -38,7 +39,7 @@ def __init__(
3839
in_dim, self.lora_dim, kernel_size, stride, padding, bias=False
3940
)
4041
self.lora_up = torch.nn.Conv1d(self.lora_dim, out_dim, 1, 1, bias=False)
41-
else:
42+
elif module_type == "Linear":
4243
in_dim = orig_module.in_features
4344
out_dim = orig_module.out_features
4445
self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
@@ -50,11 +51,10 @@ def __init__(
5051
self.scale = alpha / self.lora_dim
5152
self.register_buffer("alpha", torch.tensor(alpha))
5253

53-
# same as microsoft's
5454
torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
5555
torch.nn.init.zeros_(self.lora_up.weight)
5656

57-
def apply_to(self):
57+
def activate(self):
5858
self.orig_forward = self.orig_module.forward
5959
self.orig_module.forward = self.forward
6060
del self.orig_module
@@ -95,4 +95,4 @@ def forward(self, x):
9595

9696
lx = self.lora_up(lx)
9797

98-
return orig_forwarded + lx * self.multiplier * scale
98+
return orig_forwarded + lx * self.multiplier * scale

loraw/loraw_network.py

+120
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
2+
import torch
3+
from torch import nn
4+
from typing import List
5+
6+
from .loraw_module import LoRAWModule
7+
8+
class LoRAWNetwork(nn.Module):
9+
def __init__(
10+
self,
11+
net,
12+
target_subnets=None,
13+
target_modules=[
14+
'SelfAttention1d'
15+
],
16+
multiplier=1.0,
17+
lora_dim=4,
18+
alpha=1,
19+
dropout=None,
20+
module_class=LoRAWModule,
21+
verbose=False,
22+
):
23+
super().__init__()
24+
25+
self.lora_map = {}
26+
self.multiplier = multiplier
27+
self.lora_dim = lora_dim
28+
self.alpha = alpha
29+
self.dropout = dropout
30+
31+
def create_modules(
32+
root_name, root_module: nn.Module, target_replace_modules
33+
) -> nn.ModuleList:
34+
loras = nn.ModuleList()
35+
skipped = nn.ModuleList()
36+
for name, module in root_module.named_modules():
37+
if module.__class__.__name__ in target_replace_modules:
38+
for child_name, child_module in module.named_modules():
39+
is_linear = child_module.__class__.__name__ == "Linear"
40+
is_conv1d = child_module.__class__.__name__ == "Conv1d"
41+
42+
if is_linear or is_conv1d:
43+
lora_name = "lora.{root_name}.{name}.{child_name}"
44+
lora_name = lora_name.replace(".", "_")
45+
46+
lora = module_class(
47+
lora_name,
48+
child_module,
49+
multiplier=self.multiplier,
50+
lora_dim=self.lora_dim,
51+
alpha=self.alpha,
52+
dropout=self.dropout
53+
)
54+
loras.append(lora)
55+
return loras, skipped
56+
57+
for subnet_name in target_subnets:
58+
if hasattr(net.model, subnet_name):
59+
subnet = getattr(net.model, subnet_name)
60+
self.lora_map[subnet_name], _ = create_modules(subnet_name, subnet, target_modules)
61+
print(f"Created LoRAW for {subnet_name}: {len(self.lora_map[subnet_name])} modules.")
62+
63+
'''
64+
if verbose and len(skipped) > 0:
65+
print(
66+
f"because block_lr_weight is 0 or dim (rank) is 0, {len(skipped)} LoRA modules are skipped / block_lr_weightまたはdim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:"
67+
)
68+
for name in skipped:
69+
print(f"\t{name}")
70+
71+
self.up_lr_weight: List[float] = None
72+
self.down_lr_weight: List[float] = None
73+
self.mid_lr_weight: float = None
74+
self.block_lr = False
75+
76+
# assertion
77+
names = set()
78+
for lora in self.unet_loras:
79+
assert (
80+
lora.lora_name not in names
81+
), f"duplicated lora name: {lora.lora_name}"
82+
names.add(lora.lora_name)
83+
'''
84+
else:
85+
print(f'Skipping {subnet_name}: not present in this network')
86+
87+
88+
89+
def set_multiplier(self, multiplier):
90+
self.multiplier = multiplier
91+
for lora in self.unet_loras:
92+
lora.multiplier = self.multiplier
93+
94+
def activate(self):
95+
for subnet_name, subnet in self.lora_map.items():
96+
for lora in subnet:
97+
lora.activate()
98+
self.add_module(lora.lora_name, lora)
99+
print(f'Injected {len(subnet)} LoRAW modules into {subnet_name}')
100+
101+
def is_mergeable(self):
102+
return True
103+
104+
def save_weights(self, file, dtype=torch.float16):
105+
106+
state_dict = self.state_dict()
107+
108+
if dtype is not None:
109+
for key in list(state_dict.keys()):
110+
v = state_dict[key]
111+
v = v.detach().clone().to("cpu").to(dtype)
112+
state_dict[key] = v
113+
114+
torch.save(state_dict, file)
115+
116+
def load_weights(self, file):
117+
weights_sd = torch.load(file, map_location="cpu")
118+
119+
info = self.load_state_dict(weights_sd, False)
120+
return info

0 commit comments

Comments
 (0)