-
Notifications
You must be signed in to change notification settings - Fork 4
/
optims.py
147 lines (127 loc) · 4.17 KB
/
optims.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import torch
import math
def set_optimizer(model, init_lr, weight_decay):
# TODO make optimizer class and configurations
num_parameters = 0
p_wd, p_non_wd = [], []
for n, p in model.named_parameters():
if not p.requires_grad:
continue # frozen weights
print(n)
if p.ndim < 2 or "bias" in n or "ln" in n or "bn" in n:
p_non_wd.append(p)
else:
p_wd.append(p)
num_parameters += p.data.nelement()
print("number of trainable parameters: %d" % num_parameters)
optim_params = [
{
"params": p_wd,
"weight_decay": float(weight_decay),
},
{"params": p_non_wd, "weight_decay": 0},
]
beta2 = 0.999
_optimizer = torch.optim.AdamW(
optim_params,
lr=float(init_lr),
weight_decay=float(weight_decay),
betas=(0.9, beta2),
)
return _optimizer
class LinearWarmupStepLRScheduler:
def __init__(
self,
optimizer,
max_epoch,
min_lr,
init_lr,
decay_rate=1,
warmup_start_lr=-1,
warmup_steps=0,
**kwargs
):
self.optimizer = optimizer
self.max_epoch = max_epoch
self.min_lr = min_lr
self.decay_rate = decay_rate
self.init_lr = init_lr
self.warmup_steps = warmup_steps
self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr
def step(self, cur_epoch, cur_step):
if cur_epoch == 0:
warmup_lr_schedule(
step=cur_step,
optimizer=self.optimizer,
max_step=self.warmup_steps,
init_lr=self.warmup_start_lr,
max_lr=self.init_lr,
)
else:
step_lr_schedule(
epoch=cur_epoch,
optimizer=self.optimizer,
init_lr=self.init_lr,
min_lr=self.min_lr,
decay_rate=self.decay_rate,
)
class LinearWarmupCosineLRScheduler:
def __init__(
self,
optimizer,
max_epoch,
iters_per_epoch,
min_lr,
init_lr,
warmup_steps=0,
warmup_start_lr=-1,
**kwargs
):
self.optimizer = optimizer
self.max_epoch = max_epoch
self.iters_per_epoch = iters_per_epoch
self.min_lr = min_lr
self.init_lr = init_lr
self.warmup_steps = warmup_steps
self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr
def step(self, cur_epoch, cur_step):
total_cur_step = cur_epoch * self.iters_per_epoch + cur_step
if total_cur_step < self.warmup_steps:
warmup_lr_schedule(
step=cur_step,
optimizer=self.optimizer,
max_step=self.warmup_steps,
init_lr=self.warmup_start_lr,
max_lr=self.init_lr,
)
else:
cosine_lr_schedule(
epoch=total_cur_step,
optimizer=self.optimizer,
max_epoch=self.max_epoch * self.iters_per_epoch,
init_lr=self.init_lr,
min_lr=self.min_lr,
)
def cosine_lr_schedule(optimizer, epoch, max_epoch, init_lr, min_lr):
"""Decay the learning rate"""
lr = (init_lr - min_lr) * 0.5 * (
1.0 + math.cos(math.pi * epoch / max_epoch)
) + min_lr
for param_group in optimizer.param_groups:
param_group["lr"] = lr
def warmup_lr_schedule(optimizer, step, max_step, init_lr, max_lr):
"""Warmup the learning rate"""
lr = min(max_lr, init_lr + (max_lr - init_lr) * step / max(max_step, 1))
for param_group in optimizer.param_groups:
param_group["lr"] = lr
def step_lr_schedule(optimizer, epoch, init_lr, min_lr, decay_rate):
"""Decay the learning rate"""
lr = max(min_lr, init_lr * (decay_rate**epoch))
for param_group in optimizer.param_groups:
param_group["lr"] = lr