@@ -223,6 +223,46 @@ def shap_features_to_significance(
223
223
return result_df_sorted
224
224
225
225
226
+ def iterative_shap_feature_reduction (
227
+ shap_features : pd .DataFrame | List [pd .DataFrame ],
228
+ target : pd .Series ,
229
+ task : str ,
230
+ alpha : float = 1e-6 ,
231
+ ) -> pd .DataFrame :
232
+ collected_rows = [] # List to store the rows we collect during each iteration
233
+
234
+ features_left = True
235
+ while features_left :
236
+ # Call the original shap_features_to_significance function
237
+ significance_df = shap_features_to_significance (shap_features , target , task , alpha )
238
+
239
+ # Find the feature with the lowest t-value
240
+ min_t_value_row = significance_df .loc [significance_df ["t-value" ].idxmin ()]
241
+
242
+ # Remember this row (collect it in our list)
243
+ collected_rows .append (min_t_value_row )
244
+
245
+ # Drop the feature corresponding to the lowest t-value from shap_features
246
+ feature_to_remove = min_t_value_row ["feature name" ]
247
+ if isinstance (shap_features , pd .DataFrame ):
248
+ shap_features = shap_features .drop (columns = [feature_to_remove ])
249
+ features_left = len (shap_features .columns )
250
+ else :
251
+ shap_features = {
252
+ k : v .drop (columns = [feature_to_remove ]) for k , v in shap_features .items ()
253
+ }
254
+ features_left = len (list (shap_features .values ())[0 ].columns )
255
+
256
+ # Convert collected rows back to a dataframe
257
+ result_df = (
258
+ pd .DataFrame (collected_rows )
259
+ .sort_values (by = "t-value" , ascending = False )
260
+ .reset_index ()
261
+ )
262
+
263
+ return result_df
264
+
265
+
226
266
def shap_select (
227
267
tree_model : Any ,
228
268
validation_df : pd .DataFrame ,
@@ -274,8 +314,8 @@ def shap_select(
274
314
else :
275
315
shap_features = create_shap_features (tree_model , validation_df [feature_names ])
276
316
277
- # Compute statistical significance of each feature
278
- significance_df = shap_features_to_significance (shap_features , target , task , alpha )
317
+ # Compute statistical significance of each feature, recursively ablating
318
+ significance_df = iterative_shap_feature_reduction (shap_features , target , task , alpha )
279
319
280
320
# Add 'Selected' column based on the threshold
281
321
significance_df ["selected" ] = (
0 commit comments