forked from ildoonet/pytorch-gradual-warmup-lr
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathrun.py
47 lines (33 loc) · 1.22 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import torch
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.optim.sgd import SGD
from warmup_scheduler import GradualWarmupScheduler
def plot(lr_list):
import matplotlib.pyplot as plt
from matplotlib.pyplot import MultipleLocator
f = plt.figure()
x_major_locator = MultipleLocator(1)
ax = plt.gca()
ax.xaxis.set_major_locator(x_major_locator)
x = range(1, len(lr_list) + 1)
plt.plot(x, lr_list)
plt.show()
if __name__ == '__main__':
model = [torch.nn.Parameter(torch.randn(2, 2, requires_grad=True))]
optim = SGD(model, 0.1)
epochs = 20
# scheduler_warmup is chained with lr_schduler
lr_schduler = CosineAnnealingLR(optim, T_max=epochs - 5, eta_min=0.02)
scheduler_warmup = GradualWarmupScheduler(optim, multiplier=1, total_epoch=5, after_scheduler=lr_schduler)
# this zero gradient update is needed to avoid a warning message, issue #8.
optim.zero_grad()
optim.step()
scheduler_warmup.step()
lr_list = list()
for epoch in range(epochs):
current_lr = optim.param_groups[0]['lr']
optim.step()
scheduler_warmup.step()
print(epoch + 1, current_lr)
lr_list.append(current_lr)
plot(lr_list)