-
Notifications
You must be signed in to change notification settings - Fork 14
/
cluster_predict.py
191 lines (161 loc) · 8.15 KB
/
cluster_predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
"""
.. module:: cluster_predict
:synopsis: clustering by set generation algorithm
.. moduleauthor:: Jiaming Shen
"""
import torch
import numpy as np
from tqdm import tqdm
def multiple_set_single_instance_prediction(model, sets, instance, size_optimized=False):
""" Apply the given model to predict the probabilities of adding that one instance into each of the given sets
:param model: a trained SynSetMine model
:type model: SSPM
:param sets: a list of sets, each contain the element index
:type sets: list
:param instance: a single instance, represented by the element index
:type instance: int
:param size_optimized: whether to optimize the multiple-set-single-instance prediction process. If the size of each
set in the given 'sets' varies a lot and there exists a single huge set in the given 'sets', set this parameter
to be True
:type size_optimized: bool
:return:
- scores of given sets, (batch_size, 1)
- scores of given sets union with the instance, (batch_size, 1)
- the probability of adding the instance into the corresponding set, (batch_size, 1)
:rtype: tuple
"""
if not size_optimized: # when there exists no single big cluster, no need for complex size optimization
return _multiple_set_single_instance_prediction(model, sets, instance)
else:
if len(sets) <= 10:
return _multiple_set_single_instance_prediction(model, sets, instance)
set_sizes = [len(ele) for ele in sets]
tmp = sorted(enumerate(set_sizes), key=lambda x: x[1]) # (old index, set_size)
n2o = {n: ele[0] for n, ele in enumerate(tmp)} # new index -> old index
o2n = {n2o[n]: n for n in n2o} # old index -> new index
sorted_set_sizes = [ele[1] for ele in tmp]
# the bining method is a combination of 'sturges' and 'fd' estimators, another choice is set "bins="sturges", which generates more bins
# c.f.: https://docs.scipy.org/doc/numpy/reference/generated/numpy.histogram_bin_edges.html#numpy.histogram_bin_edges
_, bin_edges = np.histogram(sorted_set_sizes, bins="auto")
inds = np.digitize(sorted_set_sizes, bin_edges)
sorted_setScores = []
sorted_setInstSumScores = []
sorted_positive_prob = []
cur_ind = inds[0]
cur_set = [sets[tmp[0][0]]]
for i in range(1, len(inds)):
if inds[i] == cur_ind:
cur_set.append(sets[tmp[i][0]])
else:
cur_setScores, cur_setInstSumScores, cur_positive_prob = _multiple_set_single_instance_prediction(
model, cur_set, instance)
sorted_setScores += cur_setScores
sorted_setInstSumScores += cur_setInstSumScores
sorted_positive_prob += cur_positive_prob
cur_ind = inds[i]
cur_set = [sets[tmp[i][0]]]
if len(cur_set) > 0: # working on the last bin
cur_setScores, cur_setInstSumScores, cur_positive_prob = _multiple_set_single_instance_prediction(
model, cur_set, instance)
sorted_setScores += cur_setScores
sorted_setInstSumScores += cur_setInstSumScores
sorted_positive_prob += cur_positive_prob
if len(sets) != len(sorted_positive_prob):
assert "Mismatch after binning optimization"
setScores = []
setInstSumScores = []
positive_prob = []
for o in range(len(sets)):
setScores.append(sorted_setScores[o2n[o]])
setInstSumScores.append(sorted_setInstSumScores[o2n[o]])
positive_prob.append(sorted_positive_prob[o2n[o]])
return setScores, setInstSumScores, positive_prob
def _multiple_set_single_instance_prediction(model, sets, instance):
model.eval()
# generate tensors
batch_size = len(sets)
max_set_size = max([len(ele) for ele in sets])
batch_set_tensor = np.zeros([batch_size, max_set_size], dtype=np.int)
for row_id, row in enumerate(sets):
batch_set_tensor[row_id][:len(row)] = row
batch_set_tensor = torch.from_numpy(batch_set_tensor) # (batch_size, max_set_size)
batch_inst_tensor = torch.tensor(instance).unsqueeze(0).expand(batch_size, 1) # (batch_size, 1)
batch_set_tensor = batch_set_tensor.to(model.device)
batch_inst_tensor = batch_inst_tensor.to(model.device)
# inference
setScores, setInstSumScores, prediction = model.predict(batch_set_tensor, batch_inst_tensor)
# convert to probability of each sip
positive_prob = prediction.squeeze(-1).detach()
positive_prob = list(positive_prob.to(torch.device("cpu")).numpy())
setScores = setScores.squeeze(-1).detach()
setScores = list(setScores.to(torch.device("cpu")).numpy())
setInstSumScores = setInstSumScores.squeeze(-1).detach()
setInstSumScores = list(setInstSumScores.to(torch.device("cpu")).numpy())
model.train()
return setScores, setInstSumScores, positive_prob
def set_generation(model, vocab, threshold=0.5, eid2ename=None, size_opt_clus=False, max_K=None, verbose=False):
""" Set Generation Algorithm
:param model: a trained set-instance classifier
:type model: SSPM
:param vocab: a list of elements to be clustered, each element is represented by its index
:type vocab: list
:param threshold: the probability threshold for determine whether to create new singleton cluster
:type threshold: float
:param eid2ename: a dictionary mapping element index to its corresponding (human-readable) name
:type eid2ename: dict
:param size_opt_clus: a flag indicating whether to optimize the multiple-set-single-instance prediction process
:type size_opt_clus: bool
:param max_K: maximum number of clusters, If None, we will infer this number automatically
:type max_K: int
:param verbose: whether to print out all intermediate results
:type verbose: bool
:return: a list of detected clusters
:rtype: list
"""
model.eval()
clusters = [] # will be a list of lists
candidate_pool = vocab
if verbose:
print("{}\t{}".format("vocab", [eid2ename[eid] for eid in vocab]))
if verbose:
g = tqdm(range(len(candidate_pool)), desc="Cluster prediction (aggressive one pass)...")
else:
g = range(len(candidate_pool))
for i in g:
inst = candidate_pool[i]
if i == 0:
cluster = [inst]
clusters.append(cluster)
else:
setScores, setInstSumScores, cluster_probs = multiple_set_single_instance_prediction(
model, clusters, inst, size_optimized=size_opt_clus
)
best_matching_existing_cluster_idx = -1
best_matching_existing_cluster_prob = 0.0
for cid, cluster_prob in enumerate(cluster_probs):
if cluster_prob > best_matching_existing_cluster_prob:
best_matching_existing_cluster_prob = cluster_prob
best_matching_existing_cluster_idx = cid
if verbose:
print("Current Cluster Pool:",
[(cid, [eid2ename[ele] for ele in cluster]) for cid, cluster in enumerate(clusters)])
print("-" * 20)
print("Entity: {:<30} best_prob = {:<8} Best-matching Cluster: {:<80} (cid={})".format(eid2ename[inst], best_matching_existing_cluster_prob, str(
[eid2ename[eid] for eid in clusters[best_matching_existing_cluster_idx]]), best_matching_existing_cluster_idx))
if max_K and len(clusters) >= max_K:
clusters[best_matching_existing_cluster_idx].append(inst)
if verbose:
print("!!! Add Entity In")
else:
# then either add this instance to existing cluster or create a new cluster
if best_matching_existing_cluster_prob > threshold:
clusters[best_matching_existing_cluster_idx].append(inst)
if verbose:
print("!!! Add Entity In")
else:
new_cluster = [inst]
clusters.append(new_cluster)
if verbose:
print("-" * 120)
model.train()
return clusters