-
Notifications
You must be signed in to change notification settings - Fork 2
/
utils_graphics.py
300 lines (221 loc) · 6.87 KB
/
utils_graphics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
from decimal import Decimal
# from numba import autojit
import numpy as np
# from externals.six.moves import range
def remove_newline_in_dotfile(file_path):
"""remove newlines from edge labels in dotfile
"""
with open(file_path, 'r') as f:
lines = f.readlines()
new_lines = []
for line in lines:
if '--' in line:
line = line.replace('\\n ', '')
new_lines.append(line)
with open(file_path, 'w') as f:
f.writelines(new_lines)
return
def powerset(s):
"""Get the power set of a list or set.
"""
x = len(s)
masks = [1 << i for i in range(x)]
for i in range(1 << x):
yield [ss for mask, ss in zip(masks, s) if i & mask]
def remove_zeros(r, axis):
"""Remove rows along a certain axis where the value is all zero.
"""
if axis == 0:
return r[~np.all(r == 0, axis=1)]
elif axis == 1:
return r[:, ~np.all(r == 0, axis=0)]
else:
raise ValueError('not a correct axis that we can use.')
def scientific_notation(num):
"""Convert a number into scientific notation
Parameters
----------
num : float
Any number
Returns
-------
output : str
String representation of the number
"""
return "{:.2E}".format(Decimal(num))
def bayes_boot_probs(n):
"""Bayesian bootstrap sampling for case weights
Parameters
----------
n : int
Number of Bayesian bootstrap samples
Returns
-------
p : 1d array-like
Array of sampling probabilities
"""
p = np.random.exponential(scale=1.0, size=n)
return p/p.sum()
# @autojit(nopython=True, cache=True, nogil=True)
def auc_score(y_true, y_prob):
"""ADD
Parameters
----------
Returns
-------
"""
y_true, n = y_true[np.argsort(y_prob)], len(y_true)
nfalse, auc = 0, 0.0
for i in range(n):
nfalse += 1 - y_true[i]
auc += y_true[i] * nfalse
auc /= (nfalse * (n - nfalse))
return auc
def logger(name, message):
"""Prints messages with style "[NAME] message"
Parameters
----------
name : str
Short title of message, for example, train or test
message : str
Main description to be displayed in terminal
Returns
-------
None
"""
print('[{name}] {message}'.format(name=name.upper(), message=message))
def estimate_margin(y_probs, y_true):
"""Estimates margin function of forest ensemble
Note : This function is similar to margin in R's randomForest package
Parameters
----------
y_probs : 2d array-like
Predicted probabilities where each row represents predicted
class distribution for sample and each column corresponds to
estimated class probability
y_true : 1d array-like
Array of true class labels
Returns
-------
margin : float
Estimated margin of forest ensemble
"""
# Calculate probability of correct class
n, p = y_probs.shape
true_probs = y_probs[np.arange(n, dtype=int), y_true]
# Calculate maximum probability for incorrect class
other_probs = np.zeros(n)
for i in range(n):
mask = np.zeros(p, dtype=bool)
mask[y_true[i]] = True
other_idx = np.ma.array(y_probs[i,:], mask=mask).argmax()
other_probs[i] = y_probs[i, other_idx]
# Margin is P(y == j) - max(P(y != j))
return true_probs - other_probs
def assert_array_rank(X, rank):
"""Check if the input is an numpy array and has a certain rank.
Parameters
----------
X : array-like
Array to check
rank : int
Rank of the tensor to check
Returns
-------
None
"""
if not isinstance(X, np.ndarray):
raise ValueError('You must pass in a numpy array!')
if len(X.shape) != rank:
raise ValueError('You must pass in a {}-rank array!'.format(rank))
def assert_string_type(X, name):
"""Check if the input is of string datatype.
Parameters
----------
X : array-like
Array to check
name : str
Name of the input
Returns
-------
None
"""
if not np.issubdtype(X.dtype, np.str_):
raise ValueError('{} must contain only strings!'.format(name))
def sample_from_dict(distrib):
"""Choose an item from the distribution
Parameters
----------
distrib : dict
Dictionary mapping keys to its probability values
Returns
-------
item : key of dict
A chosen key from the dictionary
"""
keys = []
probs = []
for k, prob in distrib.items():
keys.append(k)
probs.append(prob)
probs = np.array(probs)
probs /= probs.sum()
item = np.random.choice(keys, p=probs)
return item
def getNull(model,strtype='U5'):
"""
Function to generate an array of empty strings of same length as feature names in the model.
Parameters
----------
model : Qnet object
The Qnet model.
STRTYPE : str
String type to be used for the generated numpy array. Default is 'U5'.
Returns
-------
numpy.ndarray
An array of empty strings.
"""
return np.array(['']*len(model.feature_names)).astype(strtype)
def find_matching_indices(A, B):
indices = []
for i, value in enumerate(A):
if value in B:
indices.append(i)
return indices
import pygraphviz as pgv
import re
import os
import glob
def big_enough(dot_file,big_enough_threshold=-1):
return len(analyze_dot_file(str(dot_file),
fracThreshold=.25)[1]) > big_enough_threshold
def analyze_dot_file(dot_file,fracThreshold=0.0):
graph = pgv.AGraph(dot_file)
non_leaf_nodes = [node for node in graph.nodes() if graph.out_degree(node) > 0]
if len(non_leaf_nodes) <= 1:
return False, []
nodes_leading_to_big_frac = []
def dfs(node):
if graph.out_degree(node) == 0: # if leaf node
frac_value = re.search('Frac: ([0-9\.]+)', node.attr['label'])
if frac_value is not None:
return float(frac_value.group(1))
else:
return 0
else: # if non-leaf node
frac_sum = 0
for edge in graph.out_edges(node):
destination_node = graph.get_node(edge[1])
frac_sum += dfs(destination_node)
return frac_sum
for node in non_leaf_nodes:
if dfs(node) > fracThreshold:
nodes_leading_to_big_frac.append(node.attr['label'])
return True, nodes_leading_to_big_frac
def drawtrees(dotfiles,prog='dot',format='pdf',big_enough_threshold=-1):
for dot_file in dotfiles:
if big_enough(dot_file,big_enough_threshold):
graph = pgv.AGraph(str(dot_file))
graph.draw(dot_file.replace('dot',format),
prog=prog, format=format)