Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tidy up requirements, add setup.py, prettify README.md #14

Merged
merged 3 commits into from
Oct 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 96 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,19 +1,111 @@
## Overview
`shap-select` implements a heuristic to do fast feature selection for tabular regression and classification models.
`shap-select` implements a heuristic for fast feature selection, for tabular regression and classification models.

The basic idea is running a linear or logistic regression of the target on the Shapley values on the validation set,
The basic idea is running a linear or logistic regression of the target on the Shapley values of
the original features, on the validation set,
discarding the features with negative coefficients, and ranking/filtering the rest according to their
statistical significance. For motivation and details, see the [example notebook](https://github.com/transferwise/shap-select/blob/main/docs/Quick%20feature%20selection%20through%20regression%20on%20Shapley%20values.ipynb)

Earlier packages using Shapley values for feature selection exist, the advantages of this one are
* Regression on the **validation set** to combat overfitting
* A single pass regression, not an iterative approach
* Only a single fit of the original model needed
* A single intuitive hyperparameter for feature selection: statistical significance
* Bonferroni correction for multiclass classification
* Address collinearity of (Shapley value) features by repeated (linear/logistic) regression

## Usage
```python
from shap_select import shap_select
# Here model is any model supported by the shap library, fitted on a different (train) dataset
# Task can be regression, binary, or multiclass
selected_features_df = shap_select(model, X_val, y_val, task="multiclass", threshold=0.05)
```
```

<table id="T_694ab">
<thead>
<tr>
<th class="blank level0" >&nbsp;</th>
<th id="T_694ab_level0_col0" class="col_heading level0 col0" >feature name</th>
<th id="T_694ab_level0_col1" class="col_heading level0 col1" >t-value</th>
<th id="T_694ab_level0_col2" class="col_heading level0 col2" >stat.significance</th>
<th id="T_694ab_level0_col3" class="col_heading level0 col3" >coefficient</th>
<th id="T_694ab_level0_col4" class="col_heading level0 col4" >selected</th>
</tr>
</thead>
<tbody>
<tr>
<th id="T_694ab_level0_row0" class="row_heading level0 row0" >0</th>
<td id="T_694ab_row0_col0" class="data row0 col0" >x5</td>
<td id="T_694ab_row0_col1" class="data row0 col1" >20.211299</td>
<td id="T_694ab_row0_col2" class="data row0 col2" >0.000000</td>
<td id="T_694ab_row0_col3" class="data row0 col3" >1.052030</td>
<td id="T_694ab_row0_col4" class="data row0 col4" >1</td>
</tr>
<tr>
<th id="T_694ab_level0_row1" class="row_heading level0 row1" >1</th>
<td id="T_694ab_row1_col0" class="data row1 col0" >x4</td>
<td id="T_694ab_row1_col1" class="data row1 col1" >18.315144</td>
<td id="T_694ab_row1_col2" class="data row1 col2" >0.000000</td>
<td id="T_694ab_row1_col3" class="data row1 col3" >0.952416</td>
<td id="T_694ab_row1_col4" class="data row1 col4" >1</td>
</tr>
<tr>
<th id="T_694ab_level0_row2" class="row_heading level0 row2" >2</th>
<td id="T_694ab_row2_col0" class="data row2 col0" >x3</td>
<td id="T_694ab_row2_col1" class="data row2 col1" >6.835690</td>
<td id="T_694ab_row2_col2" class="data row2 col2" >0.000000</td>
<td id="T_694ab_row2_col3" class="data row2 col3" >1.098154</td>
<td id="T_694ab_row2_col4" class="data row2 col4" >1</td>
</tr>
<tr>
<th id="T_694ab_level0_row3" class="row_heading level0 row3" >3</th>
<td id="T_694ab_row3_col0" class="data row3 col0" >x2</td>
<td id="T_694ab_row3_col1" class="data row3 col1" >6.457140</td>
<td id="T_694ab_row3_col2" class="data row3 col2" >0.000000</td>
<td id="T_694ab_row3_col3" class="data row3 col3" >1.044842</td>
<td id="T_694ab_row3_col4" class="data row3 col4" >1</td>
</tr>
<tr>
<th id="T_694ab_level0_row4" class="row_heading level0 row4" >4</th>
<td id="T_694ab_row4_col0" class="data row4 col0" >x1</td>
<td id="T_694ab_row4_col1" class="data row4 col1" >5.530556</td>
<td id="T_694ab_row4_col2" class="data row4 col2" >0.000000</td>
<td id="T_694ab_row4_col3" class="data row4 col3" >0.917242</td>
<td id="T_694ab_row4_col4" class="data row4 col4" >1</td>
</tr>
<tr>
<th id="T_694ab_level0_row5" class="row_heading level0 row5" >5</th>
<td id="T_694ab_row5_col0" class="data row5 col0" >x6</td>
<td id="T_694ab_row5_col1" class="data row5 col1" >2.390868</td>
<td id="T_694ab_row5_col2" class="data row5 col2" >0.016827</td>
<td id="T_694ab_row5_col3" class="data row5 col3" >1.497983</td>
<td id="T_694ab_row5_col4" class="data row5 col4" >1</td>
</tr>
<tr>
<th id="T_694ab_level0_row6" class="row_heading level0 row6" >6</th>
<td id="T_694ab_row6_col0" class="data row6 col0" >x7</td>
<td id="T_694ab_row6_col1" class="data row6 col1" >0.901098</td>
<td id="T_694ab_row6_col2" class="data row6 col2" >0.367558</td>
<td id="T_694ab_row6_col3" class="data row6 col3" >2.865508</td>
<td id="T_694ab_row6_col4" class="data row6 col4" >0</td>
</tr>
<tr>
<th id="T_694ab_level0_row7" class="row_heading level0 row7" >7</th>
<td id="T_694ab_row7_col0" class="data row7 col0" >x8</td>
<td id="T_694ab_row7_col1" class="data row7 col1" >0.563214</td>
<td id="T_694ab_row7_col2" class="data row7 col2" >0.573302</td>
<td id="T_694ab_row7_col3" class="data row7 col3" >1.933632</td>
<td id="T_694ab_row7_col4" class="data row7 col4" >0</td>
</tr>
<tr>
<th id="T_694ab_level0_row8" class="row_heading level0 row8" >8</th>
<td id="T_694ab_row8_col0" class="data row8 col0" >x9</td>
<td id="T_694ab_row8_col1" class="data row8 col1" >-1.607814</td>
<td id="T_694ab_row8_col2" class="data row8 col2" >0.107908</td>
<td id="T_694ab_row8_col3" class="data row8 col3" >-4.537098</td>
<td id="T_694ab_row8_col4" class="data row8 col4" >-1</td>
</tr>
</tbody>
</table>


2 changes: 0 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
pandas
scikit_learn
scipy
shap
statsmodels
numpy
38 changes: 38 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from setuptools import find_packages, setup

with open("README.md") as f:
long_description = f.read()

setup(
name="shap-select",
version="0.1.0",
description="Heuristic for quick feature selection for tabular regression/classification using shapley values",
long_description=long_description,
long_description_content_type="text/markdown",
author="Wise Plc",
url="https://github.com/transferwise/shap-select",
classifiers=[
"Programming Language :: Python :: 3 :: Only",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
],
install_requires=[
"pandas",
"scipy>=1.8.0",
"shap",
"statsmodels",
],
extras_require={
"test": ["flake8", "pytest", "pytest-cov"],
},
packages=find_packages(
include=["shap_select", "shap_select.*"],
exclude=["tests*"],
),
include_package_data=True,
keywords="shap-select",
)
17 changes: 10 additions & 7 deletions shap_select/select.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
from typing import Any, Tuple, List, Dict

import pandas as pd
import numpy as np
import statsmodels.api as sm
from sklearn.linear_model import Lasso, LogisticRegression
from sklearn.preprocessing import StandardScaler
import scipy.stats as stats
import shap

Expand Down Expand Up @@ -112,7 +109,9 @@ def multi_classifier_significance(
# Iterate through each class and perform binary classification (one-vs-all)
for cls, feature_df in shap_features.items():
binary_target = (target == cls).astype(int)
significance_df = binary_classifier_significance(feature_df, binary_target, alpha)
significance_df = binary_classifier_significance(
feature_df, binary_target, alpha
)
significance_dfs.append(significance_df)

# Combine results into a single DataFrame with the max significance value for each feature
Expand Down Expand Up @@ -227,14 +226,16 @@ def iterative_shap_feature_reduction(
shap_features: pd.DataFrame | List[pd.DataFrame],
target: pd.Series,
task: str,
alpha: float=1e-6,
alpha: float = 1e-6,
) -> pd.DataFrame:
collected_rows = [] # List to store the rows we collect during each iteration

features_left = True
while features_left:
# Call the original shap_features_to_significance function
significance_df = shap_features_to_significance(shap_features, target, task, alpha)
significance_df = shap_features_to_significance(
shap_features, target, task, alpha
)

# Find the feature with the lowest t-value
min_t_value_row = significance_df.loc[significance_df["t-value"].idxmin()]
Expand Down Expand Up @@ -315,7 +316,9 @@ def shap_select(
shap_features = create_shap_features(tree_model, validation_df[feature_names])

# Compute statistical significance of each feature, recursively ablating
significance_df = iterative_shap_feature_reduction(shap_features, target, task, alpha)
significance_df = iterative_shap_feature_reduction(
shap_features, target, task, alpha
)

# Add 'Selected' column based on the threshold
significance_df["selected"] = (
Expand Down
Loading