-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathFeatureImportanceSelector.py
156 lines (127 loc) · 6.01 KB
/
FeatureImportanceSelector.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import sys
import numpy as np
import pandas as pd
from pyspark import keyword_only
from pyspark.ml.base import Estimator
from pyspark.ml.param import Params, Param, TypeConverters
from pyspark.ml.feature import VectorSlicer
from pyspark.ml.param.shared import HasOutputCol
def ExtractFeatureImp(featureImp, dataset, featuresCol):
"""
Takes in a feature importance from a random forest / GBT model and map it to the column names
Output as a pandas dataframe for easy reading
rf = RandomForestClassifier(featuresCol="features")
mod = rf.fit(train)
ExtractFeatureImp(mod.featureImportances, train, "features")
"""
list_extract = []
for i in dataset.schema[featuresCol].metadata["ml_attr"]["attrs"]:
list_extract = list_extract + dataset.schema[featuresCol].metadata["ml_attr"]["attrs"][i]
varlist = pd.DataFrame(list_extract)
varlist['score'] = varlist['idx'].apply(lambda x: featureImp[x])
return(varlist.sort_values('score', ascending = False))
class FeatureImpSelector(Estimator, HasOutputCol):
"""
Uses feature importance score to select features for training
Takes either the top n features or those above a certain threshold score
estimator should either be a DecisionTreeClassifier, RandomForestClassifier or GBTClassifier
featuresCol is inferred from the estimator
"""
estimator = Param(Params._dummy(), "estimator", "estimator to be cross-validated")
selectorType = Param(Params._dummy(), "selectorType",
"The selector type of the FeatureImpSelector. " +
"Supported options: numTopFeatures (default), threshold",
typeConverter=TypeConverters.toString)
numTopFeatures = \
Param(Params._dummy(), "numTopFeatures",
"Number of features that selector will select, ordered by descending feature imp score. " +
"If the number of features is < numTopFeatures, then this will select " +
"all features.", typeConverter=TypeConverters.toInt)
threshold = Param(Params._dummy(), "threshold", "The lowest feature imp score for features to be kept.",
typeConverter=TypeConverters.toFloat)
@keyword_only
def __init__(self, estimator = None, selectorType = "numTopFeatures",
numTopFeatures = 20, threshold = 0.01, outputCol = "features"):
super(FeatureImpSelector, self).__init__()
self._setDefault(selectorType="numTopFeatures", numTopFeatures=20, threshold=0.01)
kwargs = self._input_kwargs
self._set(**kwargs)
def setParams(self, estimator = None, selectorType = "numTopFeatures",
numTopFeatures = 20, threshold = 0.01, outputCol = "features"):
"""
setParams(self, estimator = None, selectorType = "numTopFeatures",
numTopFeatures = 20, threshold = 0.01, outputCol = "features")
Sets params for this ChiSqSelector.
"""
kwargs = self._input_kwargs
return self._set(**kwargs)
def setEstimator(self, value):
"""
Sets the value of :py:attr:`estimator`.
"""
return self._set(estimator=value)
def getEstimator(self):
"""
Gets the value of estimator or its default value.
"""
return self.getOrDefault(self.estimator)
def setSelectorType(self, value):
"""
Sets the value of :py:attr:`selectorType`.
"""
return self._set(selectorType=value)
def getSelectorType(self):
"""
Gets the value of selectorType or its default value.
"""
return self.getOrDefault(self.selectorType)
def setNumTopFeatures(self, value):
"""
Sets the value of :py:attr:`numTopFeatures`.
Only applicable when selectorType = "numTopFeatures".
"""
return self._set(numTopFeatures=value)
def getNumTopFeatures(self):
"""
Gets the value of numTopFeatures or its default value.
"""
return self.getOrDefault(self.numTopFeatures)
def setThreshold(self, value):
"""
Sets the value of :py:attr:`Threshold`.
Only applicable when selectorType = "threshold".
"""
return self._set(threshold=value)
def getThreshold(self):
"""
Gets the value of threshold or its default value.
"""
return self.getOrDefault(self.threshold)
def _fit(self, dataset):
est = self.getOrDefault(self.estimator)
nfeatures = self.getOrDefault(self.numTopFeatures)
threshold = self.getOrDefault(self.threshold)
selectorType = self.getOrDefault(self.selectorType)
outputCol = self.getOrDefault(self.outputCol)
if ((est.__class__.__name__ != 'DecisionTreeClassifier') &
(est.__class__.__name__ != 'DecisionTreeRegressor') &
(est.__class__.__name__ != 'RandomForestClassifier') &
(est.__class__.__name__ != 'RandomForestRegressor') &
(est.__class__.__name__ != 'GBTClassifier') &
(est.__class__.__name__ != 'GBTRegressor')):
raise NameError("Estimator must be either DecisionTree, RandomForest or RandomForest Model")
else:
# Fit classifier & extract feature importance
mod = est.fit(dataset)
dataset2 = mod.transform(dataset)
varlist = ExtractFeatureImp(mod.featureImportances, dataset2, est.getFeaturesCol())
if (selectorType == "numTopFeatures"):
varidx = [x for x in varlist['idx'][0:nfeatures]]
elif (selectorType == "threshold"):
varidx = [x for x in varlist[varlist['score'] > threshold]['idx']]
else:
raise NameError("Invalid selectorType")
# Extract relevant columns
return VectorSlicer(inputCol = est.getFeaturesCol(),
outputCol = outputCol,
indices = varidx)