-
Notifications
You must be signed in to change notification settings - Fork 2
/
opts.lua
92 lines (87 loc) · 4.31 KB
/
opts.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
--
-- Copyright (c) 2014, Facebook, Inc.
-- All rights reserved.
--
-- This source code is licensed under the BSD-style license found in the
-- LICENSE file in the root directory of this source tree. An additional grant
-- of patent rights can be found in the PATENTS file in the same directory.
--
-----------------------
require 'paths'
local M = { }
function M.parse(arg)
local cmd = torch.CmdLine()
cmd:text()
cmd:text('Torch-7 Training script')
cmd:text()
cmd:text('Options:')
------------- General options ---------------------
cmd:option('-cache', 'checkpoint/', 'subdirectory in which to save/log experiments')
cmd:option('-data', '/path/to/dataset/folder', 'dataset folder')
------------- Data options ------------------------
cmd:option('-manualSeed', 2, 'Manually set RNG seed')
cmd:option('-GPU', 1, 'Default preferred GPU')
cmd:option('-nGPU', 1, 'Number of GPUs to use by default')
cmd:option('-nDonkeys', 2, 'number of donkeys to initialize (data loading threads)')
cmd:option('-imageSize', 256, 'image size')
cmd:option('-imageCrop', 224, 'cropping size')
------------- Training options --------------------
cmd:option('-nEpochs', 20, 'Number of total epochs to run')
cmd:option('-epochSize', 10000, 'Number of iterations per epoch')
cmd:option('-epochNumber', 1, 'Manual epoch number (useful on restarts)')
cmd:option('-batchSize', 128, 'mini-batch size (1 = pure stochastic)')
cmd:option('-iterSize', 1, 'Number of batches per iteration')
------------- Testing/Eval options ----------------
cmd:option('-nEpochsTest', 1, 'Number of epochs to perform one testing')
cmd:option('-nEpochsEval', 1, 'Number of epochs to perform one evaluation')
------------- Optimization options ----------------
cmd:option('-LR', 0.0, 'learning rate; if set, overrides default LR/WD recipe')
cmd:option('-momentum', 0.9, 'momentum')
cmd:option('-weightDecay', 5e-4, 'weight decay')
------------- Model options -----------------------
cmd:option('-netType', 'alexnet', 'your deep-net implementation')
cmd:option('-dataset', 'ilsvrc', 'Select a customized dataset loader')
cmd:option('-retrain', 'none', 'provide path to model to retrain with')
------------- Run Options -------------------------
cmd:option('-train', false, 'run train procedure, note that not every dataset support trainDataLoader')
cmd:option('-eval', false, 'run eval procedure, note that not every dataset support evalDataLoader')
cmd:option('-test', false, 'run test procedure, note that not every dataset support testDataLoader')
cmd:option('-pipeline', 'standard','run a standard/customized train,test,eval procedure')
------------- Moony classifier path ----------------
cmd:option('-nn4model', '', '/pathh/to/trained/model')
-- NOTE: Currently -doTrain, -doEval, -doTest options do not passed to donkey.lua
-- this will be improved in the future
cmd:text()
------------ Options from sepcified network -------------
local netType = ''
for i=1, #arg do
if arg[i] == '-netType' then
netType = arg[i+1]
end
end
if netType ~= '' then
cmd:text('Network "'..netType..'" options:')
local config = netType
-- all models should inherit from a basic model
local basicnet = paths.dofile('models/basic_model.lua')
local net = paths.dofile('models/' .. config .. '.lua')
setmetatable(net, {__index=basicnet})
net.arguments(cmd)
cmd:text()
end
local opt = cmd:parse(arg or {})
if (not opt.train) and (not opt.eval) and (not opt.test) then
cmd:error('Must specify at least one running scheme: train, eval or test.')
end
-- append dataset to cache name
opt.cache = path.join(opt.cache, opt.dataset)
-- add commandline specified options
opt.save = paths.concat(
opt.cache,
cmd:string(opt.netType, opt,
{netType=true, retrain=true, cache=true, data=true, nn4model=true}))
-- add date/time
opt.save = paths.concat(opt.save, '' .. os.date():gsub(' ',''))
return opt
end
return M