forked from clementfarabet/lua---nnx
-
Notifications
You must be signed in to change notification settings - Fork 0
/
SpatialNormalization.lua
316 lines (292 loc) · 11 KB
/
SpatialNormalization.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
local SpatialNormalization, parent = torch.class('nn.SpatialNormalization','nn.Module')
local help_desc =
[[a spatial (2D) contrast normalizer
? computes the local mean and local std deviation
across all input features, using the given 2D kernel
? the local mean is then removed from all maps, and the std dev
used to divide the inputs, with a threshold
? if no threshold is given, the global std dev is used
? weight replication is used to preserve sizes (this is
better than zero-padding, but more costly to compute, use
nn.ContrastNormalization to use zero-padding)
? two 1D kernels can be used instead of a single 2D kernel. This
is beneficial to integrate information over large neiborhoods.
]]
local help_example =
[[EX:
-- create a spatial normalizer, with a 9x9 gaussian kernel
-- works on 8 input feature maps, therefore the mean+dev will
-- be estimated on 8x9x9 cubes
stimulus = lab.randn(8,500,500)
gaussian = image.gaussian(9)
mod = nn.SpatialNormalization(gaussian, 8)
result = mod:forward(stimulus)]]
function SpatialNormalization:__init(...) -- kernel for weighted mean | nb of features
parent.__init(self)
-- get args
local args, nf, ker, thres
= xlua.unpack(
{...},
'nn.SpatialNormalization',
help_desc .. '\n' .. help_example,
{arg='nInputPlane', type='number', help='number of input maps', req=true},
{arg='kernel', type='torch.Tensor | table', help='a KxK filtering kernel or two {1xK, Kx1} 1D kernels'},
{arg='threshold', type='number', help='threshold, for division [default = adaptive]'}
)
-- check args
if not ker then
xerror('please provide kernel(s)', 'nn.SpatialNormalization', args.usage)
end
self.kernel = ker
local ker2
if type(ker) == 'table' then
ker2 = ker[2]
ker = ker[1]
end
self.nfeatures = nf
self.fixedThres = thres
-- padding values
self.padW = math.floor(ker:size(2)/2)
self.padH = math.floor(ker:size(1)/2)
self.kerWisPair = 0
self.kerHisPair = 0
-- padding values for 2nd kernel
if ker2 then
self.pad2W = math.floor(ker2:size(2)/2)
self.pad2H = math.floor(ker2:size(1)/2)
else
self.pad2W = 0
self.pad2H = 0
end
self.ker2WisPair = 0
self.ker2HisPair = 0
-- normalize kernel
ker:div(ker:sum())
if ker2 then ker2:div(ker2:sum()) end
-- manage the case where ker is even size (for padding issue)
if (ker:size(2)/2 == math.floor(ker:size(2)/2)) then
print ('Warning, kernel width is even -> not symetric padding')
self.kerWisPair = 1
end
if (ker:size(1)/2 == math.floor(ker:size(1)/2)) then
print ('Warning, kernel height is even -> not symetric padding')
self.kerHisPair = 1
end
if (ker2 and ker2:size(2)/2 == math.floor(ker2:size(2)/2)) then
print ('Warning, kernel width is even -> not symetric padding')
self.ker2WisPair = 1
end
if (ker2 and ker2:size(1)/2 == math.floor(ker2:size(1)/2)) then
print ('Warning, kernel height is even -> not symetric padding')
self.ker2HisPair = 1
end
-- create convolution for computing the mean
local convo1 = nn.Sequential()
convo1:add(nn.SpatialPadding(self.padW,self.padW-self.kerWisPair,
self.padH,self.padH-self.kerHisPair))
local ctable = nn.tables.oneToOne(nf)
convo1:add(nn.SpatialConvolutionMap(ctable,ker:size(2),ker:size(1)))
convo1:add(nn.Sum(1))
convo1:add(nn.Replicate(nf))
-- set kernel
local fb = convo1.modules[2].weight
for i=1,fb:size(1) do fb[i]:copy(ker) end
-- set bias to 0
convo1.modules[2].bias:zero()
-- 2nd ker ?
if ker2 then
local convo2 = nn.Sequential()
convo2:add(nn.SpatialPadding(self.pad2W,self.pad2W-self.ker2WisPair,
self.pad2H,self.pad2H-self.ker2HisPair))
local ctable = nn.tables.oneToOne(nf)
convo2:add(nn.SpatialConvolutionMap(ctable,ker2:size(2),ker2:size(1)))
convo2:add(nn.Sum(1))
convo2:add(nn.Replicate(nf))
-- set kernel
local fb = convo2.modules[2].weight
for i=1,fb:size(1) do fb[i]:copy(ker2) end
-- set bias to 0
convo2.modules[2].bias:zero()
-- convo is a double convo now:
local convopack = nn.Sequential()
convopack:add(convo1)
convopack:add(convo2)
self.convo = convopack
else
self.convo = convo1
end
-- create convolution for computing the meanstd
local convostd1 = nn.Sequential()
convostd1:add(nn.SpatialPadding(self.padW,self.padW-self.kerWisPair,
self.padH,self.padH-self.kerHisPair))
convostd1:add(nn.SpatialConvolutionMap(ctable,ker:size(2),ker:size(1)))
convostd1:add(nn.Sum(1))
convostd1:add(nn.Replicate(nf))
-- set kernel
local fb = convostd1.modules[2].weight
for i=1,fb:size(1) do fb[i]:copy(ker) end
-- set bias to 0
convostd1.modules[2].bias:zero()
-- 2nd ker ?
if ker2 then
local convostd2 = nn.Sequential()
convostd2:add(nn.SpatialPadding(self.pad2W,self.pad2W-self.ker2WisPair,
self.pad2H,self.pad2H-self.ker2HisPair))
convostd2:add(nn.SpatialConvolutionMap(ctable,ker2:size(2),ker2:size(1)))
convostd2:add(nn.Sum(1))
convostd2:add(nn.Replicate(nf))
-- set kernel
local fb = convostd2.modules[2].weight
for i=1,fb:size(1) do fb[i]:copy(ker2) end
-- set bias to 0
convostd2.modules[2].bias:zero()
-- convo is a double convo now:
local convopack = nn.Sequential()
convopack:add(convostd1)
convopack:add(convostd2)
self.convostd = convopack
else
self.convostd = convostd1
end
-- other operation
self.squareMod = nn.Square()
self.sqrtMod = nn.Sqrt()
self.subtractMod = nn.CSubTable()
self.meanDiviseMod = nn.CDivTable()
self.stdDiviseMod = nn.CDivTable()
self.diviseMod = nn.CDivTable()
self.thresMod = nn.Threshold()
-- some tempo states
self.coef = torch.Tensor(1,1)
self.inConvo = torch.Tensor()
self.inMean = torch.Tensor()
self.inputZeroMean = torch.Tensor()
self.inputZeroMeanSq = torch.Tensor()
self.inConvoVar = torch.Tensor()
self.inVar = torch.Tensor()
self.inStdDev = torch.Tensor()
self.thstd = torch.Tensor()
end
function SpatialNormalization:forward(input)
-- auto switch to 3-channel
self.input = input
if (input:nDimension() == 2) then
self.input = input:clone():resize(1,input:size(1),input:size(2))
end
-- recompute coef only if necessary
if (self.input:size(3) ~= self.coef:size(2)) or (self.input:size(2) ~= self.coef:size(1)) then
local intVals = self.input.new(self.nfeatures,self.input:size(2),self.input:size(3)):fill(1)
self.coef = self.convo:forward(intVals)
self.coef = self.coef:clone()
end
-- compute mean
self.inConvo = self.convo:forward(self.input)
self.inMean = self.meanDiviseMod:forward{self.inConvo,self.coef}
self.inputZeroMean = self.subtractMod:forward{self.input,self.inMean}
-- compute std dev
self.inputZeroMeanSq = self.squareMod:forward(self.inputZeroMean)
self.inConvoVar = self.convostd:forward(self.inputZeroMeanSq)
self.inStdDevNotUnit = self.sqrtMod:forward(self.inConvoVar)
self.inStdDev = self.stdDiviseMod:forward({self.inStdDevNotUnit,self.coef})
local meanstd = self.inStdDev:mean()
self.thresMod.threshold = self.fixedThres or math.max(meanstd,1e-3)
self.thresMod.val = self.fixedThres or math.max(meanstd,1e-3)
self.stdDev = self.thresMod:forward(self.inStdDev)
--remove std dev
self.diviseMod:forward{self.inputZeroMean,self.stdDev}
self.output = self.diviseMod.output
return self.output
end
function SpatialNormalization:backward(input, gradOutput)
-- auto switch to 3-channel
self.input = input
if (input:nDimension() == 2) then
self.input = input:clone():resize(1,input:size(1),input:size(2))
end
self.gradInput:resizeAs(self.input):zero()
-- backprop all
local gradDiv = self.diviseMod:backward({self.inputZeroMean,self.stdDev},gradOutput)
local gradThres = gradDiv[2]
local gradZeroMean = gradDiv[1]
local gradinStdDev = self.thresMod:backward(self.inStdDev,gradThres)
local gradstdDiv = self.stdDiviseMod:backward({self.inStdDevNotUnit,self.coef},gradinStdDev)
local gradinStdDevNotUnit = gradstdDiv[1]
local gradinConvoVar = self.sqrtMod:backward(self.inConvoVar,gradinStdDevNotUnit)
local gradinputZeroMeanSq = self.convostd:backward(self.inputZeroMeanSq,gradinConvoVar)
gradZeroMean:add(self.squareMod:backward(self.inputZeroMean,gradinputZeroMeanSq))
local gradDiff = self.subtractMod:backward({self.input,self.inMean},gradZeroMean)
local gradinMean = gradDiff[2]
local gradinConvoNotUnit = self.meanDiviseMod:backward({self.inConvo,self.coef},gradinMean)
local gradinConvo = gradinConvoNotUnit[1]
-- first part of the gradInput
self.gradInput:add(gradDiff[1])
-- second part of the gradInput
self.gradInput:add(self.convo:backward(self.input,gradinConvo))
return self.gradInput
end
function SpatialNormalization:type(type)
parent.type(self,type)
self.convo:type(type)
self.meanDiviseMod:type(type)
self.subtractMod:type(type)
self.squareMod:type(type)
self.convostd:type(type)
self.sqrtMod:type(type)
self.stdDiviseMod:type(type)
self.thresMod:type(type)
self.diviseMod:type(type)
return self
end
function SpatialNormalization:write(file)
parent.write(self,file)
file:writeObject(self.kernel)
file:writeInt(self.nfeatures)
file:writeInt(self.padW)
file:writeInt(self.padH)
file:writeInt(self.kerWisPair)
file:writeInt(self.kerHisPair)
file:writeObject(self.convo)
file:writeObject(self.convostd)
file:writeObject(self.squareMod)
file:writeObject(self.sqrtMod)
file:writeObject(self.subtractMod)
file:writeObject(self.meanDiviseMod)
file:writeObject(self.stdDiviseMod)
file:writeObject(self.thresMod)
file:writeObject(self.diviseMod)
file:writeObject(self.coef)
if type(self.kernel) == 'table' then
file:writeInt(self.pad2W)
file:writeInt(self.pad2H)
file:writeInt(self.ker2WisPair)
file:writeInt(self.ker2HisPair)
end
file:writeInt(self.fixedThres or 0)
end
function SpatialNormalization:read(file)
parent.read(self,file)
self.kernel = file:readObject()
self.nfeatures = file:readInt()
self.padW = file:readInt()
self.padH = file:readInt()
self.kerWisPair = file:readInt()
self.kerHisPair = file:readInt()
self.convo = file:readObject()
self.convostd = file:readObject()
self.squareMod = file:readObject()
self.sqrtMod = file:readObject()
self.subtractMod = file:readObject()
self.meanDiviseMod = file:readObject()
self.stdDiviseMod = file:readObject()
self.thresMod = file:readObject()
self.diviseMod = file:readObject()
self.coef = file:readObject()
if type(self.kernel) == 'table' then
self.pad2W = file:readInt()
self.pad2H = file:readInt()
self.ker2WisPair = file:readInt()
self.ker2HisPair = file:readInt()
end
self.fixedThres = file:readInt()
if self.fixedThres == 0 then self.fixedThres = nil end
end