forked from clementfarabet/lua---nnx
-
Notifications
You must be signed in to change notification settings - Fork 0
/
SpatialMSECriterion.lua
132 lines (125 loc) · 5.58 KB
/
SpatialMSECriterion.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
local SpatialMSECriterion, parent = torch.class('nn.SpatialMSECriterion', 'nn.MSECriterion')
function SpatialMSECriterion:__init(...)
parent.__init(self)
xlua.unpack_class(self, {...},
'nn.SpatialMSECriterion',
'A spatial extension of the MSECriterion class.\n'
..' Provides a set of parameters to deal with spatial mini-batch training.',
{arg='resampleTarget', type='number', help='ratio to resample target (target is a KxHxW tensor)', default=1},
{arg='nbGradients', type='number', help='number of gradients to backpropagate (-1:all, >=1:nb)', default=-1},
{arg='sizeAverage', type='number', help='if true, forward() returns an average instead of a sum of errors', default=true},
{arg='ignoreClass', type='number', help='all gradients for this class will be zeroed', default=false}
)
end
function SpatialMSECriterion:adjustTarget(input, target)
-- (1) if target has 2 dims, it is assumed to be a map
-- of target classes, for each point. we convert this map
-- into a 3D map of class distributions, to emulate a classical
-- mean-square regression problem.
local sratio = self.resampleTarget
if target:dim() == 2 then
self.newtarget = self.newtarget or torch.Tensor()
self.newtarget:resizeAs(input):fill(-1)
input.nn.SpatialMSECriterion_retarget(self.newtarget, target)
target = self.newtarget
end
-- (2) if the target map has an incorrect size, it is assumed
-- to be at the original scale of the data (e.g. for dense
-- classification problems, like scene parsing, the target
-- map is at the resolution of the input image. Now the input
-- of this criterion is the output of some neural network,
-- and might have a smaller size/resolution than the original
-- input). Step (2) corrects for convolutional-induced losses,
-- while step (3) corrects for downsampling/strides.
if (target:size(3)*sratio) ~= input:size(3) then
local h = input:size(2)/sratio
local y = math.floor((target:size(2) - (input:size(2)-1)*1/sratio)/2) + 1
local w = input:size(3)/sratio
local x = math.floor((target:size(3) - (input:size(3)-1)*1/sratio)/2) + 1
target = target:narrow(2,y,h):narrow(3,x,w)
end
-- (3) correct target by resampling it to the size of the
-- input. this is to compensate for downsampling/pooling
-- operations.
if sratio ~= 1 then
local target_scaled = torch.Tensor(target:size(1), input:size(2), input:size(3))
image.scale(target, target_scaled, 'simple')
target = target_scaled
end
-- (4) last thing, optionally filter out some classes. In the
-- MSE regression setup, -1 is the negative target.
if self.ignoreClass then
target:select(1, self.ignoreClass):fill(-1)
end
self.target = target
return target
end
function SpatialMSECriterion:forward(input,target)
-- (1) adjust target: class -> distributions of classes
-- compensate for convolution losses
-- compensate for striding effects
-- ignore a classe
target = self:adjustTarget(input, target)
-- (2) the full output contains as many errors as input
-- vectors, whereas the self.output is a scalar that
-- prunes all the errors
self.fullOutput = self.fullOutput or torch.Tensor()
self.fullOutput:resizeAs(input)
-- (3) compute the dense errors:
input.nn.SpatialMSECriterion_forward(self, input, target)
-- (4) prune the errors, either by averaging, or accumulation:
if self.sizeAverage then
self.output = self.fullOutput:mean()
else
self.output = self.fullOutput:sum()
end
return self.output
end
function SpatialMSECriterion:backward(input,target)
-- (1) retrieve adjusted target
target = self.target
-- (2) resize input gradient map
self.gradInput:resizeAs(input):zero()
-- (3) compute input gradients, based on the nbGradients param
if self.nbGradients == -1 then
-- dense gradients
input.nn.SpatialMSECriterion_backward(self, input, target, self.gradInput)
elseif self.nbGradients == 1 then
-- only 1 gradient is computed, sampled in the center
self.fullGradInput = torch.Tensor() or self.fullGradInput
self.fullGradInput:resizeAs(input):zero()
input.nn.SpatialMSECriterion_backward(self, input, target, self.fullGradInput)
local y = math.ceil(self.gradInput:size(2)/2)
local x = math.ceil(self.gradInput:size(3)/2)
self.gradInput:select(3,x):select(2,y):copy(self.fullGradInput:select(3,x):select(2,y))
else
-- only N gradients are computed, sampled in random locations
self.fullGradInput = torch.Tensor() or self.fullGradInput
self.fullGradInput:resizeAs(input):zero()
input.nn.SpatialMSECriterion_backward(self, input, target, self.fullGradInput)
for i = 1,self.nbGradients do
local x = math.random(1,self.gradInput:size(1))
local y = math.random(1,self.gradInput:size(2))
self.gradInput:select(3,x):select(2,y):copy(self.fullGradInput:select(3,x):select(2,y))
end
end
return self.gradInput
end
function SpatialMSECriterion:write(file)
parent.write(self, file)
file:writeDouble(self.resampleTarget)
file:writeInt(self.nbGradients)
if not self.ignoreClass then
file:writeInt(-1)
end
end
function SpatialMSECriterion:read(file)
parent.read(self, file)
self.resampleTarget= file:readDouble()
self.nbGradients = file:readInt()
self.ignoreClass = file:readInt()
if self.ignoreClass == -1 then
self.ignoreClass = false
end
self.fullOutput = torch.Tensor()
end