-
Notifications
You must be signed in to change notification settings - Fork 68
/
Copy pathrepro.py
167 lines (134 loc) · 6.23 KB
/
repro.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
"""
Running this script eventually gives:
23
eval: split train. loss 4.073383e-03. error 0.62%. misses: 45
eval: split test . loss 2.838382e-02. error 4.09%. misses: 82
"""
import os
import json
import argparse
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from tensorboardX import SummaryWriter # pip install tensorboardX
# -----------------------------------------------------------------------------
class Net(nn.Module):
""" 1989 LeCun ConvNet per description in the paper """
def __init__(self):
super().__init__()
# initialization as described in the paper to my best ability, but it doesn't look right...
winit = lambda fan_in, *shape: (torch.rand(*shape) - 0.5) * 2 * 2.4 / fan_in**0.5
macs = 0 # keep track of MACs (multiply accumulates)
acts = 0 # keep track of number of activations
# H1 layer parameters and their initialization
self.H1w = nn.Parameter(winit(5*5*1, 12, 1, 5, 5))
self.H1b = nn.Parameter(torch.zeros(12, 8, 8)) # presumably init to zero for biases
assert self.H1w.nelement() + self.H1b.nelement() == 1068
macs += (5*5*1) * (8*8) * 12
acts += (8*8) * 12
# H2 layer parameters and their initialization
"""
H2 neurons all connect to only 8 of the 12 input planes, with an unspecified pattern
I am going to assume the most sensible block pattern where 4 planes at a time connect
to differently overlapping groups of 8/12 input planes. We will implement this with 3
separate convolutions that we concatenate the results of.
"""
self.H2w = nn.Parameter(winit(5*5*8, 12, 8, 5, 5))
self.H2b = nn.Parameter(torch.zeros(12, 4, 4)) # presumably init to zero for biases
assert self.H2w.nelement() + self.H2b.nelement() == 2592
macs += (5*5*8) * (4*4) * 12
acts += (4*4) * 12
# H3 is a fully connected layer
self.H3w = nn.Parameter(winit(4*4*12, 4*4*12, 30))
self.H3b = nn.Parameter(torch.zeros(30))
assert self.H3w.nelement() + self.H3b.nelement() == 5790
macs += (4*4*12) * 30
acts += 30
# output layer is also fully connected layer
self.outw = nn.Parameter(winit(30, 30, 10))
self.outb = nn.Parameter(-torch.ones(10)) # 9/10 targets are -1, so makes sense to init slightly towards it
assert self.outw.nelement() + self.outb.nelement() == 310
macs += 30 * 10
acts += 10
self.macs = macs
self.acts = acts
def forward(self, x):
# x has shape (1, 1, 16, 16)
x = F.pad(x, (2, 2, 2, 2), 'constant', -1.0) # pad by two using constant -1 for background
x = F.conv2d(x, self.H1w, stride=2) + self.H1b
x = torch.tanh(x)
# x is now shape (1, 12, 8, 8)
x = F.pad(x, (2, 2, 2, 2), 'constant', -1.0) # pad by two using constant -1 for background
slice1 = F.conv2d(x[:, 0:8], self.H2w[0:4], stride=2) # first 4 planes look at first 8 input planes
slice2 = F.conv2d(x[:, 4:12], self.H2w[4:8], stride=2) # next 4 planes look at last 8 input planes
slice3 = F.conv2d(torch.cat((x[:, 0:4], x[:, 8:12]), dim=1), self.H2w[8:12], stride=2) # last 4 planes are cross
x = torch.cat((slice1, slice2, slice3), dim=1) + self.H2b
x = torch.tanh(x)
# x is now shape (1, 12, 4, 4)
x = x.flatten(start_dim=1) # (1, 12*4*4)
x = x @ self.H3w + self.H3b
x = torch.tanh(x)
# x is now shape (1, 30)
x = x @ self.outw + self.outb
x = torch.tanh(x)
# x is finally shape (1, 10)
return x
# -----------------------------------------------------------------------------
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Train a 1989 LeCun ConvNet on digits")
parser.add_argument('--learning-rate', '-l', type=float, default=0.03, help="SGD learning rate")
parser.add_argument('--output-dir' , '-o', type=str, default='out/base', help="output directory for training logs")
args = parser.parse_args()
print(vars(args))
# init rng
torch.manual_seed(1337)
np.random.seed(1337)
torch.use_deterministic_algorithms(True)
# set up logging
os.makedirs(args.output_dir, exist_ok=True)
with open(os.path.join(args.output_dir, 'args.json'), 'w') as f:
json.dump(vars(args), f, indent=2)
writer = SummaryWriter(args.output_dir)
# init a model
model = Net()
print("model stats:")
print("# params: ", sum(p.numel() for p in model.parameters())) # in paper total is 9,760
print("# MACs: ", model.macs)
print("# activations: ", model.acts)
# init data
Xtr, Ytr = torch.load('train1989.pt')
Xte, Yte = torch.load('test1989.pt')
# init optimizer
optimizer = optim.SGD(model.parameters(), lr=args.learning_rate)
def eval_split(split):
# eval the full train/test set, batched implementation for efficiency
model.eval()
X, Y = (Xtr, Ytr) if split == 'train' else (Xte, Yte)
Yhat = model(X)
loss = torch.mean((Y - Yhat)**2)
err = torch.mean((Y.argmax(dim=1) != Yhat.argmax(dim=1)).float())
print(f"eval: split {split:5s}. loss {loss.item():e}. error {err.item()*100:.2f}%. misses: {int(err.item()*Y.size(0))}")
writer.add_scalar(f'error/{split}', err.item()*100, pass_num)
writer.add_scalar(f'loss/{split}', loss.item(), pass_num)
# train
for pass_num in range(23):
# perform one epoch of training
model.train()
for step_num in range(Xtr.size(0)):
# fetch a single example into a batch of 1
x, y = Xtr[[step_num]], Ytr[[step_num]]
# forward the model and the loss
yhat = model(x)
loss = torch.mean((y - yhat)**2)
# calculate the gradient and update the parameters
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()
# after epoch epoch evaluate the train and test error / metrics
print(pass_num + 1)
eval_split('train')
eval_split('test')
# save final model to file
torch.save(model.state_dict(), os.path.join(args.output_dir, 'model.pt'))