Skip to content

Commit 5e0a213

Browse files
committed
update for compatibility with sklearn 1.2+
1 parent a91641e commit 5e0a213

12 files changed

+369
-289
lines changed

imodels/discretization/discretizer.py

+44-28
Original file line numberDiff line numberDiff line change
@@ -100,10 +100,12 @@ def _validate_n_bins(self):
100100
)
101101
self.n_bins = np.full(n_features, orig_bins, dtype=int)
102102
else:
103-
n_bins = check_array(orig_bins, dtype=int, copy=True, ensure_2d=False)
103+
n_bins = check_array(orig_bins, dtype=int,
104+
copy=True, ensure_2d=False)
104105

105106
if n_bins.ndim > 1 or n_bins.shape[0] != n_features:
106-
raise ValueError("n_bins must be a scalar or array of shape (n_features,).")
107+
raise ValueError(
108+
"n_bins must be a scalar or array of shape (n_features,).")
107109

108110
bad_nbins_value = (n_bins < 2) | (n_bins != orig_bins)
109111

@@ -136,12 +138,12 @@ def _validate_args(self):
136138

137139
valid_encode = ('onehot', 'ordinal')
138140
if self.encode not in valid_encode:
139-
raise ValueError("Valid options for 'encode' are {}. Got encode={!r} instead." \
141+
raise ValueError("Valid options for 'encode' are {}. Got encode={!r} instead."
140142
.format(valid_encode, self.encode))
141143

142144
valid_strategy = ('uniform', 'quantile', 'kmeans')
143145
if (self.strategy not in valid_strategy):
144-
raise ValueError("Valid options for 'strategy' are {}. Got strategy={!r} instead." \
146+
raise ValueError("Valid options for 'strategy' are {}. Got strategy={!r} instead."
145147
.format(valid_strategy, self.strategy))
146148

147149
def _discretize_to_bins(self, x, bin_edges,
@@ -174,7 +176,8 @@ def _discretize_to_bins(self, x, bin_edges,
174176

175177
if keep_pointwise_bins:
176178
# note: min and max values are used to define pointwise bins
177-
pointwise_bins = np.unique(bin_edges[pd.Series(bin_edges).duplicated()])
179+
pointwise_bins = np.unique(
180+
bin_edges[pd.Series(bin_edges).duplicated()])
178181
else:
179182
pointwise_bins = np.array([])
180183

@@ -183,7 +186,8 @@ def _discretize_to_bins(self, x, bin_edges,
183186
for idx, split in enumerate(unique_edges):
184187
if idx == (len(unique_edges) - 1): # uppermost bin
185188
if (idx == 0) & (split in pointwise_bins):
186-
indicator = x > split # two bins total: (-inf, a], (a, inf)
189+
# two bins total: (-inf, a], (a, inf)
190+
indicator = x > split
187191
else:
188192
indicator = x >= split # uppermost bin: [a, inf)
189193
else:
@@ -217,7 +221,8 @@ def _fit_preprocessing(self, X):
217221

218222
# by default, discretize all numeric columns
219223
if len(self.dcols) == 0:
220-
numeric_cols = [col for col in X.columns if is_numeric_dtype(X[col].dtype)]
224+
numeric_cols = [
225+
col for col in X.columns if is_numeric_dtype(X[col].dtype)]
221226
self.dcols_ = numeric_cols
222227

223228
# error checking
@@ -255,7 +260,8 @@ def _transform_postprocessing(self, discretized_df, X):
255260
try:
256261
onehot_col_names = self.onehot_.get_feature_names_out(colnames)
257262
except:
258-
onehot_col_names = self.onehot_.get_feature_names(colnames) # older versions of sklearn
263+
onehot_col_names = self.onehot_.get_feature_names(
264+
colnames) # older versions of sklearn
259265
discretized_df = self.onehot_.transform(discretized_df.astype(str))
260266
discretized_df = pd.DataFrame(discretized_df,
261267
columns=onehot_col_names,
@@ -353,7 +359,7 @@ def fit(self, X, y=None):
353359
disc_ordinal_df = pd.DataFrame(disc_ordinal_np, columns=self.dcols)
354360
disc_ordinal_df_str = disc_ordinal_df.astype(int).astype(str)
355361

356-
encoder = OneHotEncoder(drop=self.onehot_drop, sparse=False)
362+
encoder = OneHotEncoder(drop=self.onehot_drop) # , sparse=False)
357363
encoder.fit(disc_ordinal_df_str)
358364
self.encoder_ = encoder
359365

@@ -382,7 +388,8 @@ def transform(self, X):
382388

383389
# One-hot encode the ordinal DF
384390
disc_onehot_np = self.encoder_.transform(disc_ordinal_df_str)
385-
disc_onehot = pd.DataFrame(disc_onehot_np, columns=self.encoder_.get_feature_names_out())
391+
disc_onehot = pd.DataFrame(
392+
disc_onehot_np, columns=self.encoder_.get_feature_names_out())
386393

387394
# Name columns after the interval they represent (e.g. 0.1_to_0.5)
388395
for col, bin_edges in zip(self.dcols, self.discretizer_.bin_edges_):
@@ -525,7 +532,7 @@ def fit(self, X, y=None):
525532

526533
# fit onehot encoded X if specified
527534
if self.encode == "onehot":
528-
onehot = OneHotEncoder(drop=self.onehot_drop, sparse=False)
535+
onehot = OneHotEncoder(drop=self.onehot_drop) # , sparse=False)
529536
onehot.fit(discretized_df.astype(str))
530537
self.onehot_ = onehot
531538

@@ -550,7 +557,8 @@ def transform(self, X):
550557
check_is_fitted(self)
551558

552559
# transform using KBinsDiscretizer
553-
discretized_df = self.discretizer_.transform(X[self.dcols_]).astype(int)
560+
discretized_df = self.discretizer_.transform(
561+
X[self.dcols_]).astype(int)
554562
discretized_df = pd.DataFrame(discretized_df,
555563
columns=self.dcols_,
556564
index=X.index)
@@ -669,7 +677,7 @@ def _validate_args(self):
669677
super()._validate_args()
670678
valid_backup_strategy = ('uniform', 'quantile', 'kmeans')
671679
if (self.backup_strategy not in valid_backup_strategy):
672-
raise ValueError("Valid options for 'strategy' are {}. Got strategy={!r} instead." \
680+
raise ValueError("Valid options for 'strategy' are {}. Got strategy={!r} instead."
673681
.format(valid_backup_strategy, self.backup_strategy))
674682

675683
def _get_rf_splits(self, col_names):
@@ -738,7 +746,8 @@ def _fit_rf(self, X, y=None):
738746
# provided rf model has not yet been trained
739747
if not check_is_fitted(self.rf_model):
740748
if y is None:
741-
raise ValueError("Must provide y if rf_model has not been trained.")
749+
raise ValueError(
750+
"Must provide y if rf_model has not been trained.")
742751
self.rf_model.fit(X, y)
743752

744753
# get all random forest split points
@@ -785,12 +794,13 @@ def reweight_n_bins(self, X, y=None, by="nsplits"):
785794
if by == "nsplits":
786795
# each col gets at least 2 bins; remaining bins get
787796
# reallocated based on number of RF splits using that feature
788-
n_rules = np.array([len(self.rf_splits[col]) for col in self.dcols_])
789-
self.n_bins = np.round(n_rules / n_rules.sum() * \
797+
n_rules = np.array([len(self.rf_splits[col])
798+
for col in self.dcols_])
799+
self.n_bins = np.round(n_rules / n_rules.sum() *
790800
(total_bins - 2 * len(self.dcols_))) + 2
791801
else:
792802
valid_by = ('nsplits')
793-
raise ValueError("Valid options for 'by' are {}. Got by={!r} instead." \
803+
raise ValueError("Valid options for 'by' are {}. Got by={!r} instead."
794804
.format(valid_by, by))
795805

796806
def fit(self, X, y=None):
@@ -817,12 +827,12 @@ def fit(self, X, y=None):
817827
self._fit_rf(X=X, y=y)
818828

819829
# features that were not used in the rf but need to be discretized
820-
self.missing_rf_cols_ = list(set(self.dcols_) - \
830+
self.missing_rf_cols_ = list(set(self.dcols_) -
821831
set(self.rf_splits.keys()))
822832
if len(self.missing_rf_cols_) > 0:
823-
print("{} did not appear in random forest so were discretized via {} discretization" \
833+
print("{} did not appear in random forest so were discretized via {} discretization"
824834
.format(self.missing_rf_cols_, self.strategy))
825-
missing_n_bins = np.array([self.n_bins[np.array(self.dcols_) == col][0] \
835+
missing_n_bins = np.array([self.n_bins[np.array(self.dcols_) == col][0]
826836
for col in self.missing_rf_cols_])
827837

828838
backup_discretizer = BasicDiscretizer(n_bins=missing_n_bins,
@@ -836,7 +846,8 @@ def fit(self, X, y=None):
836846

837847
if self.encode == 'onehot':
838848
if len(self.missing_rf_cols_) > 0:
839-
discretized_df = backup_discretizer.transform(X[self.missing_rf_cols_])
849+
discretized_df = backup_discretizer.transform(
850+
X[self.missing_rf_cols_])
840851
else:
841852
discretized_df = pd.DataFrame({}, index=X.index)
842853

@@ -848,16 +859,19 @@ def fit(self, X, y=None):
848859
if self.strategy == "quantile":
849860
q_values = np.linspace(0, 1, int(b) + 1)
850861
bin_edges = np.quantile(self.rf_splits[col], q_values)
851-
elif strategy == "uniform":
852-
width = (max(self.rf_splits[col]) - min(self.rf_splits[col])) / b
853-
bin_edges = width * np.arange(0, b + 1) + min(self.rf_splits[col])
862+
elif self.strategy == "uniform":
863+
width = (max(self.rf_splits[col]) -
864+
min(self.rf_splits[col])) / b
865+
bin_edges = width * \
866+
np.arange(0, b + 1) + min(self.rf_splits[col])
854867
self.bin_edges_[col] = bin_edges
855868
if self.encode == 'onehot':
856-
discretized_df[col] = self._discretize_to_bins(X[col], bin_edges)
869+
discretized_df[col] = self._discretize_to_bins(
870+
X[col], bin_edges)
857871

858872
# fit onehot encoded X if specified
859873
if self.encode == "onehot":
860-
onehot = OneHotEncoder(drop=self.onehot_drop, sparse=False)
874+
onehot = OneHotEncoder(drop=self.onehot_drop) # , sparse=False)
861875
onehot.fit(discretized_df[self.dcols_].astype(str))
862876
self.onehot_ = onehot
863877

@@ -883,7 +897,8 @@ def transform(self, X):
883897

884898
# transform features that did not appear in RF
885899
if len(self.missing_rf_cols_) > 0:
886-
discretized_df = self.backup_discretizer_.transform(X[self.missing_rf_cols_])
900+
discretized_df = self.backup_discretizer_.transform(
901+
X[self.missing_rf_cols_])
887902
discretized_df = pd.DataFrame(discretized_df,
888903
columns=self.missing_rf_cols_,
889904
index=X.index)
@@ -892,7 +907,8 @@ def transform(self, X):
892907

893908
# do discretization based on rf split thresholds
894909
for col in self.bin_edges_.keys():
895-
discretized_df[col] = self._discretize_to_bins(X[col], self.bin_edges_[col])
910+
discretized_df[col] = self._discretize_to_bins(
911+
X[col], self.bin_edges_[col])
896912

897913
# return onehot encoded data if specified and
898914
# join discretized columns with rest of X

imodels/util/data_util.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@
3535
"dataset_name": "readmission_clean",
3636
"data_source": "imodels",
3737
}, # big, 100k points
38-
"adult": {"dataset_name": "1182", "data_source": "openml"}, # big, 1e6 points
38+
# big, 1e6 points
39+
"adult": {"dataset_name": "1182", "data_source": "openml"},
3940
# CDI classification
4041
"csi_pecarn": {"dataset_name": "csi_pecarn_pred", "data_source": "imodels"},
4142
"iai_pecarn": {"dataset_name": "iai_pecarn_pred", "data_source": "imodels"},
@@ -221,7 +222,8 @@ def _split(X, y, feature_names):
221222
return _split(_clean_features(X), y, _clean_feat_names(feature_names))
222223
elif data_source == "synthetic":
223224
if dataset_name == "friedman1":
224-
X, y = sklearn.datasets.make_friedman1(n_samples=200, n_features=10)
225+
X, y = sklearn.datasets.make_friedman1(
226+
n_samples=200, n_features=10)
225227
elif dataset_name == "friedman2":
226228
X, y = sklearn.datasets.make_friedman2(n_samples=200)
227229
elif dataset_name == "friedman3":
@@ -234,7 +236,8 @@ def _split(X, y, feature_names):
234236

235237

236238
def _download_imodels_dataset(dataset_fname, data_path: str):
237-
dataset_fname = dataset_fname.split("/")[-1] # remove anything about the path
239+
dataset_fname = dataset_fname.split(
240+
"/")[-1] # remove anything about the path
238241
download_path = f"https://raw.githubusercontent.com/csinva/imodels-data/master/data_cleaned/{dataset_fname}"
239242
r = requests.get(download_path)
240243
if r.status_code == 404:
@@ -253,7 +256,7 @@ def encode_categories(X, features, encoder=None):
253256
X_cat = pd.DataFrame({f: X.loc[:, f] for f in features})
254257

255258
if encoder is None:
256-
one_hot_encoder = OneHotEncoder(sparse=False, categories="auto")
259+
one_hot_encoder = OneHotEncoder(categories="auto")
257260
X_one_hot = pd.DataFrame(one_hot_encoder.fit_transform(X_cat))
258261
else:
259262
one_hot_encoder = encoder

imodels/util/extract.py

+36-16
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,17 @@
77
GradientBoostingClassifier, RandomForestClassifier
88
from sklearn.tree import DecisionTreeRegressor
99
from sklearn.utils.validation import check_array
10-
10+
import inspect
1111
from imodels.util import rule, convert
1212

1313

1414
def extract_fpgrowth(X,
1515
minsupport=0.1,
1616
maxcardinality=2,
1717
verbose=False) -> List[Tuple]:
18-
19-
itemsets_df = mlx.fpgrowth(X, min_support=minsupport, max_len=maxcardinality)
18+
19+
itemsets_df = mlx.fpgrowth(
20+
X, min_support=minsupport, max_len=maxcardinality)
2021
itemsets_indices = [tuple(s[1]) for s in itemsets_df.values]
2122
itemsets = [np.array(X.columns)[list(inds)] for inds in itemsets_indices]
2223
itemsets = list(map(tuple, itemsets))
@@ -49,13 +50,15 @@ def extract_rulefit(X, y, feature_names,
4950
"RuleFit only works with GradientBoostingClassifier(), GradientBoostingRegressor(), "
5051
"RandomForestRegressor() or RandomForestClassifier()")
5152

52-
## fit tree generator
53+
# fit tree generator
5354
if not exp_rand_tree_size: # simply fit with constant tree size
5455
tree_generator.fit(X, y)
5556
else: # randomise tree size as per Friedman 2005 Sec 3.3
5657
np.random.seed(random_state)
57-
tree_sizes = np.random.exponential(scale=tree_size - 2, size=n_estimators)
58-
tree_sizes = np.asarray([2 + np.floor(tree_sizes[i_]) for i_ in np.arange(len(tree_sizes))], dtype=int)
58+
tree_sizes = np.random.exponential(
59+
scale=tree_size - 2, size=n_estimators)
60+
tree_sizes = np.asarray([2 + np.floor(tree_sizes[i_])
61+
for i_ in np.arange(len(tree_sizes))], dtype=int)
5962
tree_generator.set_params(warm_start=True)
6063
curr_est_ = 0
6164
for i_size in np.arange(len(tree_sizes)):
@@ -76,7 +79,7 @@ def extract_rulefit(X, y, feature_names,
7679

7780
seen_rules = set()
7881
extracted_rules = []
79-
for estimator in estimators_:
82+
for estimator in estimators_:
8083
for rule_value_pair in convert.tree_to_rules(estimator[0], np.array(feature_names), prediction_values=True):
8184

8285
rule_obj = rule.Rule(rule_value_pair[0])
@@ -108,12 +111,21 @@ def extract_skope(X, y, feature_names,
108111
max_depths = [max_depths]
109112

110113
for max_depth in max_depths:
114+
115+
# pass different key based on sklearn version
116+
estimator = DecisionTreeRegressor(
117+
max_depth=max_depth,
118+
max_features=max_features,
119+
min_samples_split=min_samples_split,
120+
121+
)
122+
init_signature = inspect.signature(BaggingRegressor.__init__)
123+
estimator_key = 'estimator' if 'estimator' in init_signature.parameters.keys(
124+
) else 'base_estimator'
125+
kwargs = {
126+
estimator_key: estimator,
127+
}
111128
bagging_clf = BaggingRegressor(
112-
estimator=DecisionTreeRegressor(
113-
max_depth=max_depth,
114-
max_features=max_features,
115-
min_samples_split=min_samples_split
116-
),
117129
n_estimators=n_estimators,
118130
max_samples=max_samples,
119131
max_features=max_samples_features,
@@ -124,7 +136,8 @@ def extract_skope(X, y, feature_names,
124136
# warm_start=... XXX may be added to increase computation perf.
125137
n_jobs=n_jobs,
126138
random_state=random_state,
127-
verbose=verbose
139+
verbose=verbose,
140+
**kwargs
128141
)
129142
ensembles.append(bagging_clf)
130143

@@ -134,8 +147,8 @@ def extract_skope(X, y, feature_names,
134147
weights = sample_weight - sample_weight.min()
135148
contamination = float(sum(y)) / len(y)
136149
y_reg = (
137-
pow(weights, 0.5) * 0.5 / contamination * (y > 0) -
138-
pow((weights).mean(), 0.5) * (y == 0)
150+
pow(weights, 0.5) * 0.5 / contamination * (y > 0) -
151+
pow((weights).mean(), 0.5) * (y == 0)
139152
)
140153
y_reg = 1. / (1 + np.exp(-y_reg)) # sigmoid
141154

@@ -153,10 +166,12 @@ def extract_skope(X, y, feature_names,
153166

154167
extracted_rules = []
155168
for estimator, features in zip(estimators_, estimators_features_):
156-
extracted_rules.append(convert.tree_to_rules(estimator, np.array(feature_names)[features]))
169+
extracted_rules.append(convert.tree_to_rules(
170+
estimator, np.array(feature_names)[features]))
157171

158172
return extracted_rules, estimators_samples_, estimators_features_
159173

174+
160175
def extract_marginal_curves(clf, X, max_evals=100):
161176
"""Uses predict_proba to compute marginal curves.
162177
Assumes clf is a classifier with a predict_proba method and that classifier is additive across features
@@ -193,3 +208,8 @@ def extract_marginal_curves(clf, X, max_evals=100):
193208
feature_vals_list.append(feature_vals)
194209
shape_function_vals_list.append(shape_function_vals.tolist())
195210
return feature_vals_list, shape_function_vals_list
211+
212+
213+
if __name__ == '__main__':
214+
init_signature = inspect.signature(BaggingRegressor.__init__)
215+
print('estimator' in init_signature.parameters.keys())

setup.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
'pandas',
1414
'requests', # used in c4.5
1515
'scipy',
16-
'scikit-learn', # 0.23+ only works on py3.6+
16+
'scikit-learn>=1.2.0', # recently updates this
1717
'tqdm', # used in BART
1818
]
1919

@@ -26,7 +26,7 @@
2626

2727
setuptools.setup(
2828
name="imodels",
29-
version="1.4.1",
29+
version="1.4.2",
3030
author="Chandan Singh, Keyan Nasseri, Matthew Epland, Yan Shuo Tan, Omer Ronen, Tiffany Tang, Abhineet Agarwal, Theo Saarinen, Bin Yu, and others",
3131
author_email="chandan_singh@berkeley.edu",
3232
description="Implementations of various interpretable models",

0 commit comments

Comments
 (0)