-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy path02_hbp_sequential_cifar10.py
158 lines (129 loc) · 3.94 KB
/
02_hbp_sequential_cifar10.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import matplotlib.pyplot as plt
import torch
import torchvision
import torchvision.transforms as transforms
from bpexts.hbp.crossentropy import HBPCrossEntropyLoss
from bpexts.hbp.linear import HBPLinear
from bpexts.hbp.sequential import HBPSequential
from bpexts.hbp.sigmoid import HBPSigmoid
from bpexts.optim.cg_newton import CGNewton
from bpexts.utils import set_seeds
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
set_seeds(0)
batch_size = 500
# download directory
data_dir = "~/tmp/CIFAR10"
# training set loader
train_set = torchvision.datasets.CIFAR10(
root=data_dir, train=True, transform=transforms.ToTensor(), download=True
)
train_loader = torch.utils.data.DataLoader(
dataset=train_set, batch_size=batch_size, shuffle=True
)
# layers
linear1 = HBPLinear(in_features=3072, out_features=1024, bias=True)
activation1 = HBPSigmoid()
linear2 = HBPLinear(in_features=1024, out_features=512, bias=True)
activation2 = HBPSigmoid()
linear3 = HBPLinear(in_features=512, out_features=256, bias=True)
activation3 = HBPSigmoid()
linear4 = HBPLinear(in_features=256, out_features=128, bias=True)
activation4 = HBPSigmoid()
linear5 = HBPLinear(in_features=128, out_features=10, bias=True)
# sequential model
model = HBPSequential(
linear1,
activation1,
linear2,
activation2,
linear3,
activation3,
linear4,
activation4,
linear5,
)
# load to device
model.to(device)
print(model)
loss_func = HBPCrossEntropyLoss()
# learning rate
lr = 0.1
# regularization
alpha = 0.02
# convergence criteria for CG
cg_maxiter = 50
cg_atol = 0.0
cg_tol = 0.1
# construct the optimizer
optimizer = CGNewton(
model.parameters(),
lr=lr,
alpha=alpha,
cg_atol=cg_atol,
cg_tol=cg_tol,
cg_maxiter=cg_maxiter,
)
# use the PCH with absolute values of second-order module effects
modify_2nd_order_terms = "abs"
# train for thirty epochs
num_epochs = 30
# log some metrics
train_epoch = []
batch_loss = []
batch_acc = []
samples = 0
samples_per_epoch = 50000.0
for epoch in range(num_epochs):
iters = len(train_loader)
for i, (images, labels) in enumerate(train_loader):
# reshape and load to device
images = images.reshape(-1, 3072).to(device)
labels = labels.to(device)
# 1) forward pass
outputs = model(images)
loss = loss_func(outputs, labels)
# set gradients to zero
optimizer.zero_grad()
# Hessian backpropagation and backward pass
# 2) batch average of Hessian of loss w.r.t. model output
output_hessian = loss_func.batch_summed_hessian(loss, outputs)
# 3) compute gradients
loss.backward()
# 4) propagate Hessian back through the graph
model.backward_hessian(
output_hessian, modify_2nd_order_terms=modify_2nd_order_terms
)
# 5) second-order optimization step
optimizer.step()
# compute statistics
total = labels.size(0)
_, predicted = torch.max(outputs, 1)
correct = (predicted == labels).sum().item()
accuracy = correct / total
# update lists every 15 iterations
samples += total
if i % 15 == 0:
train_epoch.append(samples / samples_per_epoch)
batch_loss.append(loss.item())
batch_acc.append(accuracy)
# print every 20 iterations
if i % 20 == 0:
print(
"Epoch [{}/{}], Iter. [{}/{}], Loss: {:.4f}, Acc.: {:.4f}".format(
epoch + 1, num_epochs, i + 1, iters, loss.item(), accuracy
)
)
plt.subplots(121, figsize=(7, 3))
# plot batch loss
plt.subplot(121)
plt.plot(train_epoch, batch_loss, color="darkorange")
plt.xlabel("epoch")
plt.ylabel("batch loss")
# plot batch accuracy
plt.subplot(122)
plt.plot(train_epoch, batch_acc, color="darkblue")
plt.xlabel("epoch")
plt.ylabel("batch accuracy")
# save plot
plt.tight_layout()
plt.savefig("02_hbp_sequential_cifar10_metrics.png")