forked from clementfarabet/lua---nnx
-
Notifications
You must be signed in to change notification settings - Fork 0
/
SpatialGraph.lua
69 lines (59 loc) · 2.43 KB
/
SpatialGraph.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
local SpatialGraph, parent = torch.class('nn.SpatialGraph', 'nn.Module')
local help_desc =
[[Creates an edge-weighted graph from a set of N feature
maps.
The input is a 3D tensor width x height x nInputPlane, the
output is a 3D tensor width x height x 2. The first slice
of the output contains horizontal edges, the second vertical
edges.
The input features are assumed to be >= 0.
More precisely:
+ dist == 'euclid' and norm == true: the input features should
also be <= 1, to produce properly normalized distances (btwn 0 and 1);
+ dist == 'cosine': the input features do not need to be bounded,
as the cosine dissimilarity normalizes with respect to each vector.
An epsilon is automatically added, so that components that are == 0
are properly considered as being similar.
]]
function SpatialGraph:__init(...)
parent.__init(self)
xlua.unpack_class(
self, {...},
'nn.SpatialGraph', help_desc,
{arg='dist', type='string', help='distance metric to use', default='euclid'},
{arg='normalize', type='boolean', help='normalize euclidean distances btwn 0 and 1 (assumes input range to be btwn 0 and 1)', default=true},
{arg='connex', type='number', help='connexity', default=4}
)
if self.connex ~= 4 then
xlua.error('4 is the only connexity supported, for now', 'nn.SpatialGraph',self.usage)
end
self.dist = ((self.dist == 'euclid') and 0) or ((self.dist == 'cosine') and 1)
or xerror('euclid is the only distance supported, for now','nn.SpatialGraph',self.usage)
self.normalize = (self.normalize and 1) or 0
if self.dist == 'cosine' and self.normalize == 1 then
xerror('normalized cosine is not supported for now [just because I couldnt figure out the gradient :-)]',
'nn.SpatialGraph', self.usage)
end
end
function SpatialGraph:forward(input)
self.output:resize(self.connex / 2, input:size(2), input:size(3))
input.nn.SpatialGraph_forward(self, input)
return self.output
end
function SpatialGraph:backward(input, gradOutput)
self.gradInput:resizeAs(input)
input.nn.SpatialGraph_backward(self, input, gradOutput)
return self.gradInput
end
function SpatialGraph:write(file)
parent.write(self, file)
file:writeInt(self.connex)
file:writeInt(self.dist)
file:writeInt(self.normalize)
end
function SpatialGraph:read(file)
parent.read(self, file)
self.connex = file:readInt()
self.dist = file:readInt()
self.normalize = file:readInt()
end