forked from torch/torch7
-
Notifications
You must be signed in to change notification settings - Fork 0
/
hist.lua
123 lines (115 loc) · 3.14 KB
/
hist.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
--
-- rudimentary histogram diplay on the command line.
--
-- Author: Marco Scoffier
-- Date :
-- Mod : Oct 21, 2011
-- + made 80 columns default
-- + save index of max bin in h.max not pointer to bin
--
function torch.histc__tostring(h, barHeight)
barHeight = barHeight or 10
local lastm = h[h.max].nb
local incr = lastm/(barHeight+1)
local m = lastm - incr
local tl = torch.Tensor(#h):fill(0)
local toph = '|'
local topm = ':'
local topl = '.'
local bar = '|'
local blank = ' '
local yaxis = '--------:'
local str = 'nsamples:'
str = str ..
string.format(' min:(bin:%d/#%d/cntr:%2.2f) max:(bin:%d/#%d/cntr:%2.2f)\n',
h.min,h[h.min].nb,h[h.min].val,
h.max,h[h.max].nb,h[h.max].val)
str = str .. yaxis
for j = 1,#h do
str = str .. '-'
end
str = str .. '\n'
for i = 1,barHeight do
-- y axis
if i%1==0 then
str = str .. string.format('%1.2e:',m)
end
for j = 1,#h do
if tl[j] == 1 then
str = str .. bar
elseif h[j].nb < m then
str = str .. blank
else
-- in the bracket
tl[j] = 1
-- find 1/3rds
local p = (lastm - h[j].nb) / incr
if p > 0.66 then
str = str .. toph
elseif p > 0.33 then
str = str .. topm
else
str = str .. topl
end
end
end
str = str .. '\n'
lastm = m
m = m - incr
end
-- x axis
str = str .. yaxis
for j = 1,#h do
if ((j - 2) % 6 == 0)then
str = str .. '^'
else
str = str .. '-'
end
end
str = str .. '\ncenters '
for j = 1,#h do
if ((j - 2) % 6 == 0)then
if h[j].val < 0 then
str = str .. '-'
else
str = str .. '+'
end
str = str .. string.format('%1.2f ',math.abs(h[j].val))
end
end
return str
end
-- a simple function that computes the histogram of a tensor
function torch.histc(...)
-- get args
local args = {...}
local tensor = args[1] or error('usage: torch.histc (tensor [, nBins] [, min] [, max]')
local bins = args[2] or 80 - 8
local min = args[3] or tensor:min()
local max = args[4] or tensor:max()
local raw = args[5] or false
-- compute histogram
local hist = torch.zeros(bins)
local ten = torch.Tensor(tensor:nElement()):copy(tensor)
ten:add(-min):div(max-min):mul(bins - 1e-6):floor():add(1)
ten.torch._histc(ten, hist, bins)
-- return raw histogram (no extra info)
if raw then return hist end
-- cleanup hist
local cleanhist = {}
cleanhist.raw = hist
local _,mx = torch.max(cleanhist.raw)
local _,mn = torch.min(cleanhist.raw)
cleanhist.bins = bins
cleanhist.binwidth = (max-min)/bins
for i = 1,bins do
cleanhist[i] = {}
cleanhist[i].val = min + (i-0.5)*cleanhist.binwidth
cleanhist[i].nb = hist[i]
end
cleanhist.max = mx[1]
cleanhist.min = mn[1]
-- print function
setmetatable(cleanhist, {__tostring=torch.histc__tostring})
return cleanhist
end