-
Notifications
You must be signed in to change notification settings - Fork 50
/
Balance.lua
58 lines (51 loc) · 2.06 KB
/
Balance.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
local Balance, parent = torch.class('nn.Balance', 'nn.Module')
------------------------------------------------------------------------
--[[ Balance ]]--
-- Constrains the distribution of a preceding SoftMax to have equal
-- probability of category over examples. So each category has a
-- mean probability of 1/nCategory.
------------------------------------------------------------------------
function Balance:__init(nBatch)
parent.__init(self)
self.nBatch = nBatch or 10
self.inputCache = torch.Tensor()
self.prob = torch.Tensor()
self.sum = torch.Tensor()
self.batchSize = 0
self.startIdx = 1
self.train = true
end
function Balance:updateOutput(input)
assert(input:dim() == 2, "Only works with 2D inputs (batches)")
if self.batchSize ~= input:size(1) then
self.inputCache:resize(input:size(1)*self.nBatch, input:size(2)):zero()
self.batchSize = input:size(1)
self.startIdx = 1
end
self.output:resizeAs(input):copy(input)
if not self.train then
return self.output
end
-- keep track of previous batches of P(Y|X)
self.inputCache:narrow(1, self.startIdx, input:size(1)):copy(input)
-- P(X) is uniform for all X, i.e. P(X) = 1/c where c is a constant
-- P(Y) = sum_x( P(Y|X)*P(X) )
self.prob:sum(self.inputCache, 1):div(self.prob:sum())
-- P(X|Y) = P(Y|X)*P(X)/P(Y)
self.output:cdiv(self.prob:resize(1,input:size(2)):expandAs(input))--:div(input:size(2))
-- P(Z|X) = P(X|Y)*sum_y( P(X|Y) ) where P(Z) = 1/d where d is a constant
self.sum:sum(self.output, 2)
self.output:cdiv(self.sum:resize(input:size(1),1):expandAs(self.output))
self.startIdx = self.startIdx + self.batchSize
if self.startIdx > self.inputCache:size(1) then
self.startIdx = 1
end
return self.output
end
function Balance:updateGradInput(input, gradOutput)
self.gradInput:resizeAs(gradOutput)
self.gradInput:copy(gradOutput)
self.gradInput:cdiv(self.sum:resize(input:size(1),1):expandAs(self.output))
self.gradInput:cdiv(self.prob:resize(1,input:size(2)):expandAs(input))
return self.gradInput
end