-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathplot_lr_scheduler.py
68 lines (50 loc) · 2.11 KB
/
plot_lr_scheduler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import argparse
import torch
import matplotlib.pyplot as plt
from utils import seed_all, load_config
from models.schedulers import get_fm_scheduler
def get_scheduler_results(time_steps, config_dir):
config = load_config(config_dir)
scheduler = get_fm_scheduler(config.scheduler)
seed_all(config.train.seed)
lr_list = []
deriv_list = []
deriv_divide_list = []
for t in time_steps:
lr = scheduler.get_interpolant(t)
lr_list.append(lr)
deriv = scheduler.get_log_deriv(t)
deriv_list.append(deriv)
deriv_divide = deriv * (t - 1)
deriv_divide_list.append(deriv_divide)
return lr_list, deriv_list, deriv_divide_list
def plot_lr(time_steps, cos, exp, linear, savename):
plt.plot(time_steps, cos, label='Cosine scheduler')
plt.plot(time_steps, exp, label='Exponential scheduler')
plt.plot(time_steps, linear, label='Linear scheduler')
plt.xlabel('Time')
plt.ylabel(f'{savename.replace("_", " ")}')
y_min = min(min(cos), min(exp), min(linear))
y_max = max(max(cos), max(exp), max(linear))
y_min = max(y_min, -100)
y_max = min(y_max, 100)
plt.ylim(y_min, y_max)
plt.title(f'{savename.replace("_", " ")} for different schedulers')
plt.legend()
plt.savefig(f'{savename}.png')
plt.show()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--device', type=str, default='cuda')
args = parser.parse_args()
time_steps = torch.arange(0, 0.99, 0.01)
config_dir = 'configs/cifar10_cos.yml'
cos_results = get_scheduler_results(time_steps, config_dir)
config_dir = 'configs/cifar10_exp.yml'
exp_results = get_scheduler_results(time_steps, config_dir)
config_dir = 'configs/cifar10_linear.yml'
linear_results = get_scheduler_results(time_steps, config_dir)
time_steps = time_steps.numpy()
plot_lr(time_steps, cos_results[0], exp_results[0], linear_results[0], 'schedule')
plot_lr(time_steps, cos_results[1], exp_results[1], linear_results[1], 'derivative')
plot_lr(time_steps, cos_results[2], exp_results[2], linear_results[2], 'loss_scale')