Skip to content

Commit

Permalink
integrates intercept of xgb models directly into values of the TreeModel
Browse files Browse the repository at this point in the history
  • Loading branch information
mmschlk committed Nov 7, 2024
1 parent cb25de1 commit 42fb3da
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 43 deletions.
5 changes: 3 additions & 2 deletions shapiq/explainer/tree/conversion/xgboost.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def _convert_xgboost_tree_as_df(

# pandas can't chill https://stackoverflow.com/q/77900971
with pd.option_context("future.no_silent_downcasting", True):
values = tree_df["Gain"].values * scaling + intercept # add intercept to all values
tree_model = TreeModel(
children_left=tree_df["Yes"]
.replace(convert_node_str_to_int)
Expand All @@ -111,9 +112,9 @@ def _convert_xgboost_tree_as_df(
.values,
features=tree_df["Feature"].values,
thresholds=tree_df["Split"].values,
values=tree_df["Gain"].values * scaling, # values in non-leaf nodes are not used
values=values, # values in non-leaf nodes are not used
node_sample_weight=tree_df["Cover"].values,
empty_prediction=intercept,
empty_prediction=None,
original_output_type=output_type,
)

Expand Down
42 changes: 15 additions & 27 deletions shapiq/explainer/tree/explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,31 +14,6 @@
from .validation import validate_tree_model


def set_baseline_value(model: Any, treeshapiq_explainers: list[TreeSHAPIQ]) -> float:
"""Sets the baseline value for the interaction values.
Tries to set the baseline value for the interaction values from a model.
Args:
model: The model to explain.
treeshapiq_explainers: The treeSHAP-IQ explainers.
Returns:
The baseline value for the interaction values.
"""
# default value for the baseline provided by ensembles
# works for sklearn decision trees and random forests
baseline_value = sum([treeshapiq.empty_prediction for treeshapiq in treeshapiq_explainers])
try: # xgboost models have base_score/intercept_
base_score = model.base_score
if base_score is None:
base_score = float(model.intercept_[0])
baseline_value = base_score if base_score is not None else baseline_value
except AttributeError:
pass
return baseline_value


class TreeExplainer(Explainer):
"""
The explainer for tree-based models using the TreeSHAP-IQ algorithm.
Expand Down Expand Up @@ -96,8 +71,7 @@ def __init__(
self._treeshapiq_explainers: list[TreeSHAPIQ] = [
TreeSHAPIQ(model=_tree, max_order=self._max_order, index=index) for _tree in self._trees
]

self.baseline_value = set_baseline_value(self.model, self._treeshapiq_explainers)
self.baseline_value = self._compute_baseline_value()

def explain(self, x: np.ndarray) -> InteractionValues:
# run treeshapiq for all trees
Expand All @@ -112,3 +86,17 @@ def explain(self, x: np.ndarray) -> InteractionValues:
for i in range(1, len(interaction_values)):
final_explanation += interaction_values[i]
return final_explanation

def _compute_baseline_value(self) -> float:
"""Computes the baseline value for the explainer.
The baseline value is the sum of the empty predictions of all trees in the ensemble.
Returns:
The baseline value for the explainer.
"""

baseline_value = sum(
[treeshapiq.empty_prediction for treeshapiq in self._treeshapiq_explainers]
)
return baseline_value
22 changes: 8 additions & 14 deletions tests/tests_explainer/tests_tree_explainer/test_tree_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,10 +135,8 @@ def test_xgboost_reg(xgb_reg_model, background_reg_data):
# explainer_shap = shap.TreeExplainer(model=xgb_reg_model)
# x_explain_shap = background_reg_data[explanation_instance].reshape(1, -1)
# sv_shap = explainer_shap.shap_values(x_explain_shap)[0]
# baseline_shap = explainer_shap.expected_value
sv_shap = [-2.555832, 28.50987, 1.7708225, -7.8653603, 10.7955885, -0.1877861, 4.549199]
sv_shap = np.asarray(sv_shap)
baseline_shap = -2.5668228

# compute with shapiq
explainer_shapiq = TreeExplainer(model=xgb_reg_model, max_order=1, index="SV")
Expand All @@ -147,12 +145,11 @@ def test_xgboost_reg(xgb_reg_model, background_reg_data):
sv_shapiq_values = sv_shapiq.get_n_order_values(1)
baseline_shapiq = sv_shapiq.baseline_value

assert baseline_shap == pytest.approx(baseline_shapiq, rel=1e-4)
assert np.allclose(sv_shap, sv_shapiq_values, rtol=1e-4)
assert np.allclose(sv_shap, sv_shapiq_values, rtol=1e-5)

# get prediction of the model
prediction = xgb_reg_model.predict(x_explain_shapiq.reshape(1, -1))
assert prediction == pytest.approx(baseline_shapiq + np.sum(sv_shapiq_values), rel=1e-2)
assert prediction == pytest.approx(baseline_shapiq + np.sum(sv_shapiq_values), rel=1e-5)


def test_xgboost_clf(xgb_clf_model, background_clf_data):
Expand All @@ -173,7 +170,6 @@ def test_xgboost_clf(xgb_clf_model, background_clf_data):
# print(sv_shap)
sv = [-0.00545454, -0.15837783, -0.17675081, -0.24213657, 0.00247543, 0.00988865, -0.01564346]
sv_shap = np.array(sv)
baseline_shap = 0.5

# compute with shapiq
explainer_shapiq = TreeExplainer(
Expand All @@ -184,14 +180,14 @@ def test_xgboost_clf(xgb_clf_model, background_clf_data):
sv_shapiq_values = sv_shapiq.get_n_order_values(1)
baseline_shapiq = sv_shapiq.baseline_value

assert baseline_shap == pytest.approx(baseline_shapiq, rel=1e-4)
assert np.allclose(sv_shap, sv_shapiq_values, rtol=1e-4)
# assert baseline_shap == pytest.approx(baseline_shapiq, rel=1e-4)
assert np.allclose(sv_shap, sv_shapiq_values, rtol=1e-5)

# get prediction of the model (as the log odds)
prediction = xgb_clf_model.predict(x_explain_shapiq.reshape(1, -1), output_margin=True)[0][
class_label
]
assert prediction == pytest.approx(baseline_shapiq + np.sum(sv_shapiq_values), rel=2e-2)
assert prediction == pytest.approx(baseline_shapiq + np.sum(sv_shapiq_values), rel=1e-5)


def test_xgboost_shap_error(xgb_clf_model, background_clf_data):
Expand Down Expand Up @@ -220,7 +216,6 @@ def test_xgboost_shap_error(xgb_clf_model, background_clf_data):
# print(baseline_shap)
sv = [-0.00163636, 0.05099502, -0.13182959, -0.44538185, 0.00428653, -0.04872373, -0.01370917]
sv_shap = np.array(sv)
baseline_shap = 0.5

# setup shapiq TreeSHAP
explainer_shapiq = TreeExplainer(
Expand All @@ -230,9 +225,8 @@ def test_xgboost_shap_error(xgb_clf_model, background_clf_data):
sv_shapiq = explainer_shapiq.explain(x=x_explain_shapiq)
sv_shapiq_values = sv_shapiq.get_n_order_values(1)

# the baseline scores should be the same as with SHAP but the values should be different
assert baseline_shap == pytest.approx(sv_shapiq.baseline_value, rel=1e-4)
assert not np.allclose(sv_shap, sv_shapiq_values, rtol=1e-4)
# the SHAP sv values should be different from the shapiq values
assert not np.allclose(sv_shap, sv_shapiq_values, rtol=1e-5)

# when we round the model thresholds of the xgb model (thresholds decide weather a feature is
# used or not) -> then suddenly the shap and shapiq values are the same, which points to the
Expand All @@ -247,4 +241,4 @@ def test_xgboost_shap_error(xgb_clf_model, background_clf_data):
sv_shapiq_rounded_values = sv_shapiq_rounded.get_n_order_values(1)

# now the values surprisingly are the same
assert np.allclose(sv_shap, sv_shapiq_rounded_values, rtol=1e-4)
assert np.allclose(sv_shap, sv_shapiq_rounded_values, rtol=1e-5)

0 comments on commit 42fb3da

Please sign in to comment.