-
Notifications
You must be signed in to change notification settings - Fork 7
/
model.lua
80 lines (61 loc) · 2.63 KB
/
model.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
mvcnn_net_path=opt.mvcnn_dir .. '/mvcnn.net'
mvcnn=torch.load(mvcnn_net_path):float()
mvcnn_cnn1=mvcnn.modules[1]
AllData=mvcnn_cnn1:forward(AllData) -- little trick for speed training
--torch.save('Alldata',AllData)
--AllData=torch.load('Alldata')
mvcnn_cnn2=mvcnn.modules[4]
viewFeatureNet = nn.Sequential()
viewFeatureNet:add(ViewSelect(AllData,opt.view_num,viewsLoc,false))
--viewFeatureNet:add(mvcnn_cnn1)
recurrent1 = nn.Linear(opt.hiddenSize, opt.hiddenSize)
-- recurrent neural network
mergeModule=nn.Sequential()
-- :add(nn.ParallelTable():add(nn.Reshape(opt.hiddenSize,1,true)):add(nn.Reshape(opt.hiddenSize,1,true)))
:add(nn.JoinTable(1,1))
:add(nn.Reshape(2,opt.hiddenSize))
:add(nn.Max(1,2))
rnn1=nn.Recurrent(nn.Identity(), viewFeatureNet, nn.Identity(), nn.Identity(), 99999,mergeModule)
-- actions (nbvRegNet)
nbvRegNet = nn.Sequential()
nbvRegNet:add(nn.Linear(opt.hiddenSize, 2))
nbvRegNet:add(nn.HardTanh(-1.11,1.11)) -- bounds sample between -1 and 1
nbvRegNet:add(nn.ReinforceNormal(opt.locatorStd, opt.stochastic)) -- sample from normal, uses REINFORCE learning rule
nbvRegNet:add(nn.HardTanh(-1.11,1.11)) -- bounds sample between -1 and 1
locationFeatureNet = nn.Sequential()
locationFeatureNet:add(nn.Linear(2, opt.locationFeatureSize))
locationFeatureNet:add(nn[opt.transfer]())
viewglimpse = nn.Sequential()
:add(nn.ParallelTable():add(nn.Identity()):add(locationFeatureNet))
:add(nn.CMulTable())
--rnn2: action contain rnn2
recurrent2 = nn.Linear(opt.hiddenSize, opt.hiddenSize)
action = nn.Sequential()
:add(nn.Recurrent(opt.hiddenSize, viewglimpse, recurrent2, nn[opt.transfer](), 99999))
:add(nbvRegNet)
attention = RecurrentAttention(rnn1, action, opt.rho, {opt.hiddenSize}, opt.view_num,viewsLoc)
-- model is a reinforcement learning agent
agent = nn.Sequential()
agent:add(nn.Convert())
agent:add(attention)
-- classifier :
paraDealer=nn.ParallelTable()
for i=1,opt.rho do
local classifier=mvcnn_cnn2:clone('weight','bias','gradWeight','gradBias')
paraDealer:add(classifier)
end
agent:add(paraDealer)
--classifier=nn.Sequential()
-- :add(mvcnn_cnn2)
--classifier=nn.Sequential()
-- :add(nn.Linear(opt.hiddenSize, #ds:classes()))
-- :add(nn.LogSoftMax())
--agent:add(classifier)
-- add the baseline reward predictor
baselineNet = nn.Sequential()
--baselineNet:add(nn.Linear(opt.hiddenSize,1))
baselineNet:add(nn.Constant(1,1))
baselineNet:add(nn.Add(1))
--concat = nn.ConcatTable():add(nn.Identity()):add(baselineNet)
-- output will be : {classpred, {classpred, basereward}}
--agent:add(concat)