-
Notifications
You must be signed in to change notification settings - Fork 46
/
main.lua
250 lines (196 loc) · 7.13 KB
/
main.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
--Basic Usage: th main.lua
require 'sys'
require 'xlua'
require 'torch'
require 'nn'
require 'rmsprop'
require 'modules/KLDCriterion'
require 'modules/LinearCR'
require 'modules/Reparametrize'
require 'cutorch'
require 'cunn'
require 'optim'
require 'testf'
require 'utils'
require 'config'
----------------------------------------------------------------------
-- parse command-line options
--
local opt = lapp[[
-s,--save (default "logs") subdirectory to save logs
-n,--network (default "") reload pretrained network
-m,--model (default "convnet") type of model tor train: convnet | mlp | linear
-p,--plot plot while training
-o,--optimization (default "SGD") optimization: SGD | LBFGS
-r,--learningRate (default 0.0005) learning rate, for SGD only
-m,--momentum (default 0) momentum, for SGD only
-i,--maxIter (default 3) maximum nb of iterations per batch, for LBFGS
--coefL1 (default 0) L1 penalty on the weights
--coefL2 (default 0) L2 penalty on the weights
-t,--threads (default 4) number of threads
--dumpTest preloads model and dumps .mat for test
-d,--datasrc (default "") data source directory
-f,--fbmat (default 0) load fb.mattorch
-c,--color (default 0) color or not
-u,--reuse (default 0) reuse existing network weights
]]
--[[
if opt.fbmat == 1 then
mattorch = require('fb.mattorch')
else
require 'mattorch'
end
]]
-- threads
torch.setnumthreads(opt.threads)
print('<torch> set nb of threads to ' .. torch.getnumthreads())
opt.cuda = true
-- torch.manualSeed(1)
if opt.color == 1 then
MODE_TRAINING = 'color_training'
MODE_TEST = 'color_test'
else
MODE_TRAINING = 'training'
MODE_TEST = 'test'
end
config = {
learningRate = -opt.learningRate, -- -0.0005,
momentumDecay = 0.1,
updateDecay = 0.01
}
print(config)
print('IMWIDTH:', load_batch(1,MODE_TRAINING):size())
-- model = init_network2_color_width150()
-- model = init_network2_full_150()
model = init_network2_150()
function test_fw_back(model)
-- print('IMWIDTH:', load_batch(1,MODE_TRAINING):size())
res=model:forward(load_batch(1,MODE_TRAINING):cuda())
print(res:size())
-- rev=model:backward(load_batch(1,MODE_TRAINING):cuda(), load_batch(1,MODE_TRAINING):cuda())
-- print(rev:size())
end
-- test_fw_back(model)
-- criterion = nn.MSECriterion() -- does not work well at all
-- criterion = nn.GaussianCriterion()
criterion = nn.BCECriterion()
criterion.sizeAverage = false
KLD = nn.KLDCriterion()
KLD.sizeAverage = false
if opt.cuda then
criterion:cuda()
KLD:cuda()
model:cuda()
end
parameters, gradients = model:getParameters()
print('Num before', #parameters)
-- if opt.reuse == 1 then
-- print("Loading old parameters!")
-- -- model = torch.load(opt.network)
-- parameters = torch.load(opt.network)
-- -- parameters, gradients = model:getParameters()
-- else
-- epoch = 0
-- state = {}
-- end
if opt.reuse == 1 then
print("Loading old weights!")
print(opt.save)
lowerboundlist = torch.load(opt.save .. '/lowerbound.t7')
lowerbound_test_list = torch.load(opt.save .. '/lowerbound_test.t7')
state = torch.load(opt.save .. '/state.t7')
p = torch.load(opt.save .. '/parameters.t7')
print('Loaded p size:', #p)
parameters:copy(p)
epoch = lowerboundlist:size(1)
config = torch.load(opt.save .. '/config.t7')
else
epoch = 0
state = {}
end
print('Num of parameters:', #parameters)
testLogger = optim.Logger(paths.concat(opt.save, 'test.log'))
reconstruction = 0
while true do
epoch = epoch + 1
local lowerbound = 0
local time = sys.clock()
for i = 1, num_train_batches do
xlua.progress(i, num_train_batches)
--Prepare Batch
local batch = load_batch(i, MODE_TRAINING)
if opt.cuda then
batch = batch:cuda()
end
--Optimization function
local opfunc = function(x)
collectgarbage()
if x ~= parameters then
parameters:copy(x)
end
model:zeroGradParameters()
local f = model:forward(batch)
local target = target or batch.new()
target:resizeAs(f):copy(batch)
local err = - criterion:forward(f, target)
local df_dw = criterion:backward(f, target):mul(-1)
model:backward(batch,df_dw)
-- local encoder_output = model.modules[1].modules[11].output
local encoder_output = model:get(1).output
local KLDerr = KLD:forward(encoder_output, target)
local dKLD_dw = KLD:backward(encoder_output, target)
-- print(encoder_output)
-- print(batch:size())
encoder:backward(batch,dKLD_dw)
local lowerbound = err + KLDerr
if opt.verbose then
print("BCE",err/batch:size(1))
print("KLD", KLDerr/batch:size(1))
print("lowerbound", lowerbound/batch:size(1))
end
return lowerbound, gradients
end
x, batchlowerbound = rmsprop(opfunc, parameters, config, state)
lowerbound = lowerbound + batchlowerbound[1]
end
print("\nEpoch: " .. epoch .. " Lowerbound: " .. lowerbound/num_train_batches .. " time: " .. sys.clock() - time)
--Keep track of the lowerbound over time
if lowerboundlist then
lowerboundlist = torch.cat(lowerboundlist,torch.Tensor(1,1):fill(lowerbound/num_train_batches),1)
else
lowerboundlist = torch.Tensor(1,1):fill(lowerbound/num_train_batches)
end
-- save/log current net
if true then --math.fmod(epoch, 2) ==0 then
local filename = paths.concat(opt.save, 'vxnet.net')
os.execute('mkdir -p ' .. sys.dirname(filename))
if paths.filep(filename) then
os.execute('mv ' .. filename .. ' ' .. filename .. '.old')
end
print('<trainer> saving network to '..filename)
torch.save(filename, model)
end
lowerbound_test = testf(false)
-- Compute the lowerbound of the test set and save it
if true then--epoch % 2 == 0 then
-- lowerbound_test = getLowerbound(testData.data)
if lowerbound_test_list then
lowerbound_test_list = torch.cat(lowerbound_test_list,torch.Tensor(1,1):fill(lowerbound_test/num_test_batches),1)
else
lowerbound_test_list = torch.Tensor(1,1):fill(lowerbound_test/num_test_batches)
end
print('testlowerbound = ' .. lowerbound_test/num_test_batches)
--Save everything to be able to restart later
torch.save(opt.save .. '/parameters.t7', parameters)
torch.save(opt.save .. '/state.t7', state)
torch.save(opt.save .. '/lowerbound.t7', torch.Tensor(lowerboundlist))
torch.save(opt.save .. '/lowerbound_test.t7', torch.Tensor(lowerbound_test_list))
torch.save(opt.save .. '/config.t7', config)
end
-- plot errors
if opt.plot then
testLogger:style{['% mean class accuracy (test set)'] = '-'}
testLogger:plot()
end
end
--]]