Skip to content

Commit ec52209

Browse files
committed
support renormalize_features in multi-task gam
1 parent 657216f commit ec52209

File tree

1 file changed

+16
-6
lines changed

1 file changed

+16
-6
lines changed

imodels/algebraic/gam_multitask.py

+16-6
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from collections import defaultdict
1717
import pandas as pd
1818
import json
19+
from sklearn.preprocessing import StandardScaler
1920

2021
import imodels
2122
from interpret.glassbox import ExplainableBoostingClassifier, ExplainableBoostingRegressor
@@ -37,7 +38,8 @@ def __init__(
3738
multitask=True,
3839
interactions=0.95,
3940
linear_penalty='ridge',
40-
onehot_prior=True,
41+
onehot_prior=False,
42+
renormalize_features=False,
4143
random_state=42,
4244
):
4345
"""
@@ -54,6 +56,7 @@ def __init__(
5456
self.random_state = random_state
5557
self.interactions = interactions
5658
self.onehot_prior = onehot_prior
59+
self.renormalize_features = renormalize_features
5760

5861
# override ebm_kwargs
5962
ebm_kwargs['random_state'] = random_state
@@ -90,9 +93,12 @@ def fit(self, X, y, sample_weight=None):
9093
self.term_names_list_ = [
9194
ebm_.term_names_ for ebm_ in self.ebms_]
9295
self.term_names_ = sum(self.term_names_list_, [])
93-
9496
feats = self._extract_ebm_features(X)
9597

98+
if self.renormalize_features:
99+
self.scaler_ = StandardScaler()
100+
feats = self.scaler_.fit_transform(feats)
101+
96102
# fit a linear model to the features
97103
if self.linear_penalty == 'ridge':
98104
self.lin_model = RidgeCV(alphas=np.logspace(-2, 3, 7))
@@ -126,13 +132,16 @@ def _extract_ebm_features(self, X):
126132
feats[:, offset: offset + n_features_ebm_num] = \
127133
self.ebms_[ebm_num].eval_terms(X)
128134
offset += n_features_ebm_num
135+
129136
return feats
130137

131138
def predict(self, X):
132139
check_is_fitted(self)
133140
X = check_array(X, accept_sparse=False)
134141
if hasattr(self, 'ebms_'):
135142
feats = self._extract_ebm_features(X)
143+
if hasattr(self, 'scaler_'):
144+
feats = self.scaler_.transform(feats)
136145
return self.lin_model.predict(feats)
137146
else:
138147
return self.ebm_.predict(X)
@@ -183,7 +192,7 @@ def test_multitask_extraction():
183192

184193

185194
if __name__ == "__main__":
186-
test_multitask_extraction()
195+
# test_multitask_extraction()
187196
# X, y, feature_names = imodels.get_clean_dataset("heart")
188197
X, y, feature_names = imodels.get_clean_dataset("bike_sharing")
189198
# X, y, feature_names = imodels.get_clean_dataset("diabetes")
@@ -199,9 +208,10 @@ def test_multitask_extraction():
199208
for gam in tqdm([
200209
# AdaBoostRegressor(estimator=MultiTaskGAMRegressor(
201210
# multitask=True), n_estimators=2),
202-
MultiTaskGAMRegressor(multitask=False, onehot_prior=True),
203-
MultiTaskGAMRegressor(multitask=False, onehot_prior=False),
204-
MultiTaskGAMRegressor(multitask=True),
211+
# MultiTaskGAMRegressor(multitask=True, onehot_prior=True),
212+
# MultiTaskGAMRegressor(multitask=True, onehot_prior=False),
213+
MultiTaskGAMRegressor(multitask=True, renormalize_features=True),
214+
MultiTaskGAMRegressor(multitask=True, renormalize_features=False),
205215
# ExplainableBoostingRegressor(n_jobs=1, interactions=0)
206216
]):
207217
np.random.seed(42)

0 commit comments

Comments
 (0)