-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.lua
34 lines (28 loc) · 860 Bytes
/
train.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
--Zoomout model training.
function train(model,inputs, targets)
epoch = epoch or 1
-- do one epoch
if epoch % 500 == 0 then
print("==> online epoch # " .. epoch ..']')
end
-- create closure to evaluate f(X) and df/dX
local feval = function(x)
if x ~= parameters then
parameters:copy(x)
end
gradParameters:zero()
local f = 0
local output = model:forward(inputs);
collectgarbage()
f = criterion:forward(output, targets)
local df_do = criterion:backward(output, targets);
model:backward(inputs, df_do);
return f,gradParameters
end
if optimMethod == optim.asgd then
_,_,average = optimMethod(feval, parameters, optimState)
else
optimMethod(feval, parameters, optimState)
end
epoch = epoch + 1
end