diff --git a/README.md b/README.md
index 1933886..166d3ee 100644
--- a/README.md
+++ b/README.md
@@ -1,19 +1,111 @@
## Overview
-`shap-select` implements a heuristic to do fast feature selection for tabular regression and classification models.
+`shap-select` implements a heuristic for fast feature selection, for tabular regression and classification models.
-The basic idea is running a linear or logistic regression of the target on the Shapley values on the validation set,
+The basic idea is running a linear or logistic regression of the target on the Shapley values of
+the original features, on the validation set,
discarding the features with negative coefficients, and ranking/filtering the rest according to their
statistical significance. For motivation and details, see the [example notebook](https://github.com/transferwise/shap-select/blob/main/docs/Quick%20feature%20selection%20through%20regression%20on%20Shapley%20values.ipynb)
Earlier packages using Shapley values for feature selection exist, the advantages of this one are
* Regression on the **validation set** to combat overfitting
-* A single pass regression, not an iterative approach
+* Only a single fit of the original model needed
* A single intuitive hyperparameter for feature selection: statistical significance
* Bonferroni correction for multiclass classification
+* Address collinearity of (Shapley value) features by repeated (linear/logistic) regression
+
## Usage
```python
from shap_select import shap_select
# Here model is any model supported by the shap library, fitted on a different (train) dataset
# Task can be regression, binary, or multiclass
selected_features_df = shap_select(model, X_val, y_val, task="multiclass", threshold=0.05)
-```
\ No newline at end of file
+```
+
+
+
+
+ |
+ feature name |
+ t-value |
+ stat.significance |
+ coefficient |
+ selected |
+
+
+
+
+ 0 |
+ x5 |
+ 20.211299 |
+ 0.000000 |
+ 1.052030 |
+ 1 |
+
+
+ 1 |
+ x4 |
+ 18.315144 |
+ 0.000000 |
+ 0.952416 |
+ 1 |
+
+
+ 2 |
+ x3 |
+ 6.835690 |
+ 0.000000 |
+ 1.098154 |
+ 1 |
+
+
+ 3 |
+ x2 |
+ 6.457140 |
+ 0.000000 |
+ 1.044842 |
+ 1 |
+
+
+ 4 |
+ x1 |
+ 5.530556 |
+ 0.000000 |
+ 0.917242 |
+ 1 |
+
+
+ 5 |
+ x6 |
+ 2.390868 |
+ 0.016827 |
+ 1.497983 |
+ 1 |
+
+
+ 6 |
+ x7 |
+ 0.901098 |
+ 0.367558 |
+ 2.865508 |
+ 0 |
+
+
+ 7 |
+ x8 |
+ 0.563214 |
+ 0.573302 |
+ 1.933632 |
+ 0 |
+
+
+ 8 |
+ x9 |
+ -1.607814 |
+ 0.107908 |
+ -4.537098 |
+ -1 |
+
+
+
+
+
diff --git a/requirements.txt b/requirements.txt
index 1aca3b7..bbb015c 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,6 +1,4 @@
pandas
-scikit_learn
scipy
shap
statsmodels
-numpy
diff --git a/setup.py b/setup.py
new file mode 100644
index 0000000..af99f3d
--- /dev/null
+++ b/setup.py
@@ -0,0 +1,38 @@
+from setuptools import find_packages, setup
+
+with open("README.md") as f:
+ long_description = f.read()
+
+setup(
+ name="shap-select",
+ version="0.1.0",
+ description="Heuristic for quick feature selection for tabular regression/classification using shapley values",
+ long_description=long_description,
+ long_description_content_type="text/markdown",
+ author="Wise Plc",
+ url="https://github.com/transferwise/shap-select",
+ classifiers=[
+ "Programming Language :: Python :: 3 :: Only",
+ "Programming Language :: Python :: 3.7",
+ "Programming Language :: Python :: 3.8",
+ "Programming Language :: Python :: 3.9",
+ "Programming Language :: Python :: 3.10",
+ "Programming Language :: Python :: 3.11",
+ "Programming Language :: Python :: 3.12",
+ ],
+ install_requires=[
+ "pandas",
+ "scipy>=1.8.0",
+ "shap",
+ "statsmodels",
+ ],
+ extras_require={
+ "test": ["flake8", "pytest", "pytest-cov"],
+ },
+ packages=find_packages(
+ include=["shap_select", "shap_select.*"],
+ exclude=["tests*"],
+ ),
+ include_package_data=True,
+ keywords="shap-select",
+)
diff --git a/shap_select/select.py b/shap_select/select.py
index 2c2229c..be092dd 100644
--- a/shap_select/select.py
+++ b/shap_select/select.py
@@ -1,10 +1,7 @@
from typing import Any, Tuple, List, Dict
import pandas as pd
-import numpy as np
import statsmodels.api as sm
-from sklearn.linear_model import Lasso, LogisticRegression
-from sklearn.preprocessing import StandardScaler
import scipy.stats as stats
import shap
@@ -112,7 +109,9 @@ def multi_classifier_significance(
# Iterate through each class and perform binary classification (one-vs-all)
for cls, feature_df in shap_features.items():
binary_target = (target == cls).astype(int)
- significance_df = binary_classifier_significance(feature_df, binary_target, alpha)
+ significance_df = binary_classifier_significance(
+ feature_df, binary_target, alpha
+ )
significance_dfs.append(significance_df)
# Combine results into a single DataFrame with the max significance value for each feature
@@ -227,14 +226,16 @@ def iterative_shap_feature_reduction(
shap_features: pd.DataFrame | List[pd.DataFrame],
target: pd.Series,
task: str,
- alpha: float=1e-6,
+ alpha: float = 1e-6,
) -> pd.DataFrame:
collected_rows = [] # List to store the rows we collect during each iteration
features_left = True
while features_left:
# Call the original shap_features_to_significance function
- significance_df = shap_features_to_significance(shap_features, target, task, alpha)
+ significance_df = shap_features_to_significance(
+ shap_features, target, task, alpha
+ )
# Find the feature with the lowest t-value
min_t_value_row = significance_df.loc[significance_df["t-value"].idxmin()]
@@ -315,7 +316,9 @@ def shap_select(
shap_features = create_shap_features(tree_model, validation_df[feature_names])
# Compute statistical significance of each feature, recursively ablating
- significance_df = iterative_shap_feature_reduction(shap_features, target, task, alpha)
+ significance_df = iterative_shap_feature_reduction(
+ shap_features, target, task, alpha
+ )
# Add 'Selected' column based on the threshold
significance_df["selected"] = (