Skip to content

Commit 448a79f

Browse files
Merge pull request #13 from transferwise/recursive
To combat potential collinearity issues, eliminate variables from the regression one at a time
2 parents 7325a88 + 925a196 commit 448a79f

File tree

3 files changed

+43
-3
lines changed

3 files changed

+43
-3
lines changed

LICENSE

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@
186186
same "printed page" as the copyright notice for easier
187187
identification within third-party archives.
188188

189-
Copyright [2024] [Wise PLC]
189+
Copyright 2024 Wise PLC
190190

191191
Licensed under the Apache License, Version 2.0 (the "License");
192192
you may not use this file except in compliance with the License.

docs/bug.py

Whitespace-only changes.

shap_select/select.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,46 @@ def shap_features_to_significance(
223223
return result_df_sorted
224224

225225

226+
def iterative_shap_feature_reduction(
227+
shap_features: pd.DataFrame | List[pd.DataFrame],
228+
target: pd.Series,
229+
task: str,
230+
alpha: float=1e-6,
231+
) -> pd.DataFrame:
232+
collected_rows = [] # List to store the rows we collect during each iteration
233+
234+
features_left = True
235+
while features_left:
236+
# Call the original shap_features_to_significance function
237+
significance_df = shap_features_to_significance(shap_features, target, task, alpha)
238+
239+
# Find the feature with the lowest t-value
240+
min_t_value_row = significance_df.loc[significance_df["t-value"].idxmin()]
241+
242+
# Remember this row (collect it in our list)
243+
collected_rows.append(min_t_value_row)
244+
245+
# Drop the feature corresponding to the lowest t-value from shap_features
246+
feature_to_remove = min_t_value_row["feature name"]
247+
if isinstance(shap_features, pd.DataFrame):
248+
shap_features = shap_features.drop(columns=[feature_to_remove])
249+
features_left = len(shap_features.columns)
250+
else:
251+
shap_features = {
252+
k: v.drop(columns=[feature_to_remove]) for k, v in shap_features.items()
253+
}
254+
features_left = len(list(shap_features.values())[0].columns)
255+
256+
# Convert collected rows back to a dataframe
257+
result_df = (
258+
pd.DataFrame(collected_rows)
259+
.sort_values(by="t-value", ascending=False)
260+
.reset_index()
261+
)
262+
263+
return result_df
264+
265+
226266
def shap_select(
227267
tree_model: Any,
228268
validation_df: pd.DataFrame,
@@ -274,8 +314,8 @@ def shap_select(
274314
else:
275315
shap_features = create_shap_features(tree_model, validation_df[feature_names])
276316

277-
# Compute statistical significance of each feature
278-
significance_df = shap_features_to_significance(shap_features, target, task, alpha)
317+
# Compute statistical significance of each feature, recursively ablating
318+
significance_df = iterative_shap_feature_reduction(shap_features, target, task, alpha)
279319

280320
# Add 'Selected' column based on the threshold
281321
significance_df["selected"] = (

0 commit comments

Comments
 (0)