-
-
Notifications
You must be signed in to change notification settings - Fork 13
/
Copy pathCLR.py
62 lines (50 loc) · 2.21 KB
/
CLR.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import math
import matplotlib.pyplot as plt
class CLR(object):
"""
The method is described in paper : https://arxiv.org/abs/1506.01186 to find out optimum
learning rate. The learning rate is increased from lower value to higher per iteration
for some iterations till loss starts exploding.The learning rate one power lower than
the one where loss is minimum is chosen as optimum learning rate for training.
Args:
optim Optimizer used in training.
bn Total number of iterations used for this test run.
The learning rate increasing factor is calculated based on this
iteration number.
base_lr The lower boundary for learning rate which will be used as
initial learning rate during test run. It is adviced to start from
small learning rate value like 1e-4.
Default value is 1e-5
max_lr The upper boundary for learning rate. This value defines amplitude
for learning rate increase(max_lr-base_lr). max_lr value may not be
reached in test run as loss may explode before reaching max_lr.
It is adviced to use higher value like 10, 100.
Default value is 100.
"""
def __init__(self, optim, bn, base_lr=1e-5, max_lr=100):
self.base_lr = base_lr
self.max_lr = max_lr
self.optim = optim
self.bn = bn - 1
ratio = self.max_lr/self.base_lr
self.mult = ratio ** (1/self.bn)
self.best_loss = 1e9
self.iteration = 0
self.lrs = []
self.losses = []
def calc_lr(self, loss):
self.iteration +=1
if math.isnan(loss) or loss > 4 * self.best_loss:
return -1
if loss < self.best_loss and self.iteration > 1:
self.best_loss = loss
mult = self.mult ** self.iteration
lr = self.base_lr * mult
self.lrs.append(lr)
self.losses.append(loss)
return lr
def plot(self, start=10, end=-5):
plt.xlabel("Learning Rate")
plt.ylabel("Losses")
plt.plot(self.lrs[start:end], self.losses[start:end])
plt.xscale('log')