-
Notifications
You must be signed in to change notification settings - Fork 7
/
RewardCriterion.lua
127 lines (116 loc) · 4.11 KB
/
RewardCriterion.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
local RewardCriterion, parent = torch.class("RewardCriterion", "nn.Criterion")
function RewardCriterion:__init(module, scale,baselineNet,viewsLoc,entropyParam,mvCostParam)
parent.__init(self)
self.module = module -- so it can call module:reinforce(reward)
self.scale = scale or 1 -- scale of reward
self.MSECriterion =nn.MSECriterion() -- baseline criterion
self.ClassNLLCriterion =nn.ClassNLLCriterion() -- loss criterion
self.sizeAverage = true
self.gradInput = {}
self.allReward={}
self.entropyParam= entropyParam or 5
self.mvCostParam = mvConstParam or 0.1
self.baselineNet=baselineNet
self.viewsLoc=viewsLoc
end
function RewardCriterion:getViewId(location)
local viewId,value
value,results=(self.viewsLoc-location:clone():resize(1,2):repeatTensor(self.viewsLoc:size(1),1)):norm(2,2):min(1)
viewId=results[1][1]
return viewId
end
function RewardCriterion:getShannoEntropy(classPred)
prob=torch.exp(classPred)
local entropy=-torch.dot(prob,classPred)
return entropy
end
function RewardCriterion:computeMvCost(loc_k1,loc_k2)
local delt_loc=loc_k2-loc_k1
local dist=delt_loc:norm()/2
return dist
end
function RewardCriterion:updateOutput(input, target)
assert(torch.type(input) == 'table')
self.reward = self.reward or input[1].new()
self.reward:resize(target:size(1)):fill(0)
local ra = self.module:findModules('RecurrentAttention')[1]
local locations = ra.actions
local rho=#locations
local ram_softmax=input
for i=1,target:size(1) do
local tmp_target=target[i]
local tmp_entropy=torch.Tensor(rho):zero()
local tmp_maxId=torch.LongTensor(rho):zero()
local tmp_maxPro=torch.Tensor(rho):zero()
for j=1,rho do
local value,maxId= torch.max(ram_softmax[j][i],1)
tmp_entropy[j]=self:getShannoEntropy(ram_softmax[j][i])
tmp_maxId[j]=maxId[1]
tmp_maxPro[j]=value[1]
end
if tmp_maxId[1]==tmp_target then
self.reward[i]=self.reward[i]+1
end
local flag=false
for k=2,rho do
if tmp_maxId[k]==tmp_target then
self.reward[i]=self.reward[i]+1
local deltEntropy=tmp_entropy[k-1]-tmp_entropy[k]
self.reward[i]=self.reward[i]+deltEntropy*self.entropyParam-self.mvCostParam*self:computeMvCost(locations[k-1][i],locations[k][i])
end
if locations[k][i]:clone():abs():max()>1 then
flag=true
end
end
if flag then
self.reward[i]=0
end
local view_num=self.viewsLoc:size(1)
local tmp_viewIdsCount=torch.LongTensor(view_num):fill(0)
for k=1,rho do
local tmp_viewId=self:getViewId(locations[k][i])
tmp_viewIdsCount[tmp_viewId]=tmp_viewIdsCount[tmp_viewId]+1
end
if tmp_viewIdsCount:gt(1):sum()>0 then
self.reward[i]=0
end
end
self.reward:div(rho)
self.output = -self.reward:sum()
if self.sizeAverage then
self.output = self.output/input[1]:size(1)
end
return self.output
end
function RewardCriterion:updateGradInput(input, target)
local rho=#input
local baseline=self.baselineNet:forward(input[rho])
-- reduce variance of reward using baseline
self.vrReward = self.vrReward or self.reward.new()
self.vrReward:resizeAs(self.reward):copy(self.reward)
-- self.vrReward:add(-1, baseline)
if self.sizeAverage then
self.vrReward:div(input[1]:size(1))
end
-- broadcast reward to modules
self.vrReward:mul(self.scale)
self.module:reinforce(self.vrReward)
for i=1,#input do
-- self.gradInput[i]=self.ClassNLLCriterion:backward(input[i], target)
self.gradInput[i]=self.gradInput[i] or input[i].new()
self.gradInput[i]:resizeAs(input[i]):zero()
end
-- learn the baseline reward
self.baselineNet:zeroGradParameters()
local gradInput_baseline = self.MSECriterion:backward(baseline, self.reward)
self.baselineNet:backward(input[rho],gradInput_baseline)
--
return self.gradInput
end
function RewardCriterion:type(type)
local module = self.module
self.module = nil
local ret = parent.type(self, type)
self.module = module
return ret
end