-
Notifications
You must be signed in to change notification settings - Fork 13
/
Copy pathfind.lua
34 lines (30 loc) · 837 Bytes
/
find.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
function torch.find(tensor, val, dim)
local i = 1
local indice = {}
if dim then
assert(tensor:dim() == 2, "torch.find dim arg only supports matrices for now")
assert(dim == 2, "torch.find only supports dim=2 for now")
local colSize, rowSize = tensor:size(1), tensor:size(2)
local rowIndice = {}
tensor:apply(function(x)
if x == val then
table.insert(rowIndice, i)
end
if i == rowSize then
i = 1
table.insert(indice, rowIndice)
rowIndice = {}
else
i = i + 1
end
end)
else
tensor:apply(function(x)
if x == val then
table.insert(indice, i)
end
i = i + 1
end)
end
return indice
end