-
Notifications
You must be signed in to change notification settings - Fork 15
/
main.lua
78 lines (66 loc) · 3.03 KB
/
main.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
-----------------------------------------------------------------------------
-- IMPORT DIFFERENT MODULE, OPTIONS, DATALOADER, CHECKPOINTS, MODEL, TRAINER
-----------------------------------------------------------------------------
require 'torch'
require 'paths'
local opts = require 'opts'; local opt = opts.parse(arg)
local DataLoader = require 'dataloader'
local checkpoints = require 'checkpoints'
local models = require 'models/init'
local Trainer = require 'train'
local dict_utils = require 'utils/dict_utils'
local io_utils = require 'utils/io_utils'
-----------------------------------------------------------------------------
-- INTIALIZATION
-----------------------------------------------------------------------------
torch.setdefaulttensortype('torch.FloatTensor')
torch.setnumthreads(1)
torch.manualSeed(opt.manualSeed)
cutorch.manualSeedAll(opt.manualSeed)
local trainLoader, valLoader = DataLoader.create(opt)
local checkp, optimState = checkpoints.latest(opt)
local model = models.setup(opt, checkp)
local trainer = Trainer(model, opt, optimState)
if opt.valOnly then
local results= trainer:test(1, valLoader, 'val')
return
end
-----------------------------------------------------------------------------
-- CONFIGURE START POINTS AND HISTORY
-----------------------------------------------------------------------------
local train_hist = io_utils.loadt7(checkp, paths.concat(opt.resume, 'train_hist.t7'))
local val_hist = io_utils.loadt7(checkp, paths.concat(opt.resume, 'val_hist.t7'))
local startEpoch = checkp and checkp.epoch + 1 or opt.startEpoch
local function add_history(epoch, history, split)
if split == 'train' then
train_hist = dict_utils.insertSubDicts(train_hist, history)
torch.save(paths.concat(opt.save, split .. '_hist.t7'), train_hist)
elseif split == 'val' then
val_hist = dict_utils.insertSubDicts(val_hist, history)
torch.save(paths.concat(opt.save, split .. '_hist.t7'), val_hist)
else
error(string.format('Unknown split: %s', split))
end
end
-----------------------------------------------------------------------------
---- START TRAINING
-----------------------------------------------------------------------------
for epoch = startEpoch, opt.nEpochs do
-- TRAIN FOR A SINGLE EPOCH
local trainLoss = trainer:train(epoch, trainLoader, 'train')
-- SAVE CHECKPOINTS
if (epoch % opt.saveInterval == 0) then
print(string.format("\t**** Epoch %d saving checkpoint ****", epoch))
checkpoints.save(opt, model, trainer.optimState, epoch)
end
-- SAVE AND PLOT RESULTS FOR TRAINING STAGE
add_history(epoch, trainLoss, 'train')
io_utils.plot_results_compact(train_hist, opt.logDir, 'train')
-- VALIDATION ON SYNTHETIC DATA
if (epoch % opt.val_interval == 0) then
local valResult = trainer:test(epoch, valLoader, 'val')
add_history(epoch, valResult, 'val')
io_utils.plot_results_compact(val_hist, opt.logDir, 'val')
end
collectgarbage()
end