16
16
from collections import defaultdict
17
17
import pandas as pd
18
18
import json
19
+ from sklearn .preprocessing import StandardScaler
19
20
20
21
import imodels
21
22
from interpret .glassbox import ExplainableBoostingClassifier , ExplainableBoostingRegressor
@@ -37,7 +38,8 @@ def __init__(
37
38
multitask = True ,
38
39
interactions = 0.95 ,
39
40
linear_penalty = 'ridge' ,
40
- onehot_prior = True ,
41
+ onehot_prior = False ,
42
+ renormalize_features = False ,
41
43
random_state = 42 ,
42
44
):
43
45
"""
@@ -54,6 +56,7 @@ def __init__(
54
56
self .random_state = random_state
55
57
self .interactions = interactions
56
58
self .onehot_prior = onehot_prior
59
+ self .renormalize_features = renormalize_features
57
60
58
61
# override ebm_kwargs
59
62
ebm_kwargs ['random_state' ] = random_state
@@ -90,9 +93,12 @@ def fit(self, X, y, sample_weight=None):
90
93
self .term_names_list_ = [
91
94
ebm_ .term_names_ for ebm_ in self .ebms_ ]
92
95
self .term_names_ = sum (self .term_names_list_ , [])
93
-
94
96
feats = self ._extract_ebm_features (X )
95
97
98
+ if self .renormalize_features :
99
+ self .scaler_ = StandardScaler ()
100
+ feats = self .scaler_ .fit_transform (feats )
101
+
96
102
# fit a linear model to the features
97
103
if self .linear_penalty == 'ridge' :
98
104
self .lin_model = RidgeCV (alphas = np .logspace (- 2 , 3 , 7 ))
@@ -126,13 +132,16 @@ def _extract_ebm_features(self, X):
126
132
feats [:, offset : offset + n_features_ebm_num ] = \
127
133
self .ebms_ [ebm_num ].eval_terms (X )
128
134
offset += n_features_ebm_num
135
+
129
136
return feats
130
137
131
138
def predict (self , X ):
132
139
check_is_fitted (self )
133
140
X = check_array (X , accept_sparse = False )
134
141
if hasattr (self , 'ebms_' ):
135
142
feats = self ._extract_ebm_features (X )
143
+ if hasattr (self , 'scaler_' ):
144
+ feats = self .scaler_ .transform (feats )
136
145
return self .lin_model .predict (feats )
137
146
else :
138
147
return self .ebm_ .predict (X )
@@ -183,7 +192,7 @@ def test_multitask_extraction():
183
192
184
193
185
194
if __name__ == "__main__" :
186
- test_multitask_extraction ()
195
+ # test_multitask_extraction()
187
196
# X, y, feature_names = imodels.get_clean_dataset("heart")
188
197
X , y , feature_names = imodels .get_clean_dataset ("bike_sharing" )
189
198
# X, y, feature_names = imodels.get_clean_dataset("diabetes")
@@ -199,9 +208,10 @@ def test_multitask_extraction():
199
208
for gam in tqdm ([
200
209
# AdaBoostRegressor(estimator=MultiTaskGAMRegressor(
201
210
# multitask=True), n_estimators=2),
202
- MultiTaskGAMRegressor (multitask = False , onehot_prior = True ),
203
- MultiTaskGAMRegressor (multitask = False , onehot_prior = False ),
204
- MultiTaskGAMRegressor (multitask = True ),
211
+ # MultiTaskGAMRegressor(multitask=True, onehot_prior=True),
212
+ # MultiTaskGAMRegressor(multitask=True, onehot_prior=False),
213
+ MultiTaskGAMRegressor (multitask = True , renormalize_features = True ),
214
+ MultiTaskGAMRegressor (multitask = True , renormalize_features = False ),
205
215
# ExplainableBoostingRegressor(n_jobs=1, interactions=0)
206
216
]):
207
217
np .random .seed (42 )
0 commit comments