-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathkmeans_feature_imp.py
68 lines (56 loc) · 3.25 KB
/
kmeans_feature_imp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from sklearn.cluster import KMeans
import numpy as np
class KMeansInterp(KMeans):
def __init__(self, ordered_feature_names, feature_importance_method='wcss_min', **kwargs):
super(KMeansInterp, self).__init__(**kwargs)
self.feature_importance_method = feature_importance_method
self.ordered_feature_names = ordered_feature_names
def fit(self, X, y=None, sample_weight=None):
super().fit(X=X, y=y, sample_weight=sample_weight)
if not len(self.ordered_feature_names) == self.n_features_in_:
raise Exception(f"Model is fitted on {self.n_features_in_} but ordered_feature_names = {len(self.ordered_feature_names)}")
if self.feature_importance_method == "wcss_min":
self.feature_importances_ = self.get_feature_imp_wcss_min()
elif self.feature_importance_method == "unsup2sup":
self.feature_importances_ = self.get_feature_imp_unsup2sup(X)
else:
raise Exception(f" {self.feature_importance_method}"+\
"is not available. Please choose from ['wcss_min' , 'unsup2sup']")
return self
def get_feature_imp_wcss_min(self):
labels = self.n_clusters
centroids = self.cluster_centers_
centroids = np.vectorize(lambda x: np.abs(x))(centroids)
sorted_centroid_features_idx = centroids.argsort(axis=1)[:,::-1]
cluster_feature_weights = {}
for label, centroid in zip(range(labels), sorted_centroid_features_idx):
ordered_cluster_feature_weights = centroids[label][sorted_centroid_features_idx[label]]
ordered_cluster_features = [self.ordered_feature_names[feature] for feature in centroid]
cluster_feature_weights[label] = list(zip(ordered_cluster_features,
ordered_cluster_feature_weights))
return cluster_feature_weights
def get_feature_imp_unsup2sup(self, X):
try:
from sklearn.ensemble import RandomForestClassifier
except ImportError as IE:
print(IE.__class__.__name__ + ": " + IE.message)
raise Exception("Please install scikit-learn. " +
"'unsup2sup' method requires using a classifier"+
"and depends on 'sklearn.ensemble.RandomForestClassifier'")
cluster_feature_weights = {}
for label in range(self.n_clusters):
binary_enc = np.vectorize(lambda x: 1 if x == label else 0)(self.labels_)
clf = RandomForestClassifier()
clf.fit(X, binary_enc)
sorted_feature_weight_idxes = np.argsort(clf.feature_importances_)[::-1]
ordered_cluster_features = np.take_along_axis(
np.array(self.ordered_feature_names),
sorted_feature_weight_idxes,
axis=0)
ordered_cluster_feature_weights = np.take_along_axis(
np.array(clf.feature_importances_),
sorted_feature_weight_idxes,
axis=0)
cluster_feature_weights[label] = list(zip(ordered_cluster_features,
ordered_cluster_feature_weights))
return cluster_feature_weights