Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

275 remove shap as optional dependency #296

Merged
merged 24 commits into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
5da6304
bar_plot without shap
Advueu963 Jan 2, 2025
6de1175
bar_plot without shap
Advueu963 Jan 4, 2025
360a4f1
forceplot without shap
Advueu963 Jan 5, 2025
40e5334
Made waterfall plot without shap.
Advueu963 Jan 5, 2025
b5eea32
Introduces formal_value function from shap to have identical looking …
Advueu963 Jan 5, 2025
f7f96f5
Removed errors in plot functions.
Advueu963 Jan 5, 2025
7283bbc
Refactor force.
Advueu963 Jan 5, 2025
ca42da9
Refactored abbreviation and label creation
Advueu963 Jan 5, 2025
7526c7c
Refactor force.
Advueu963 Jan 5, 2025
f03416f
Refactor bar.
Advueu963 Jan 5, 2025
85ba7fa
Add concrete test for waterfall
Advueu963 Jan 5, 2025
7f01854
Removed conversion function from shapiq to shap as not necessary anymore
Advueu963 Jan 5, 2025
de60c64
Removed imports of conversion function `get_interaction_values_and_fe…
Advueu963 Jan 5, 2025
d1279fd
updated bar plot and added aggregation of InteractionValues object
mmschlk Jan 8, 2025
d09911c
Merge branch 'main' into 275-remove-shap-as-optional-dependency
mmschlk Jan 8, 2025
4370dc5
Merge branch 'refs/heads/main' into 275-remove-shap-as-optional-depen…
mmschlk Jan 10, 2025
56e3ef6
updates aggregation method and finishes work on bar plot
mmschlk Jan 10, 2025
28d2e83
updated force_plot
mmschlk Jan 10, 2025
4113d17
updated test for the bar plot
mmschlk Jan 10, 2025
0fb57df
updated force plot test
mmschlk Jan 10, 2025
31283a1
updated tests for plots
mmschlk Jan 10, 2025
2be8f3d
Merge branch 'main' into 275-remove-shap-as-optional-dependency
mmschlk Jan 10, 2025
ddcf5c7
removes shap from requirements
mmschlk Jan 10, 2025
7fcafa2
removed call to shap in test
mmschlk Jan 10, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ ruff==0.8.4
scikit-image==0.25.0
scikit-learn==1.6.0
scipy==1.14.1
shap==0.46.0
tqdm==4.67.1
torch==2.5.1
torchvision==0.20.1
Expand Down
132 changes: 121 additions & 11 deletions shapiq/interaction_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import copy
import os
import pickle
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Optional, Union
from warnings import warn
Expand Down Expand Up @@ -630,6 +631,25 @@ def to_dict(self) -> dict:
"baseline_value": self.baseline_value,
}

def aggregate(
self, others: Sequence["InteractionValues"], aggregation: str = "mean"
) -> "InteractionValues":
"""Aggregates InteractionValues objects using a specific aggregation method.

Args:
others: A list of InteractionValues objects to aggregate.
aggregation: The aggregation method to use. Defaults to ``"mean"``. Other options are
``"median"``, ``"sum"``, ``"max"``, and ``"min"``.

Returns:
The aggregated InteractionValues object.

Note:
For documentation on the aggregation methods, see the ``aggregate_interaction_values()``
function.
"""
return aggregate_interaction_values([self, *others], aggregation)

def plot_network(self, show: bool = True, **kwargs) -> Optional[tuple[plt.Figure, plt.Axes]]:
"""Visualize InteractionValues on a graph.

Expand Down Expand Up @@ -682,18 +702,13 @@ def plot_stacked_bar(
def plot_force(
self,
feature_names: Optional[np.ndarray] = None,
feature_values: Optional[np.ndarray] = None,
matplotlib=True,
show: bool = True,
abbreviate: bool = True,
**kwargs,
) -> Optional[plt.Figure]:
"""Visualize InteractionValues on a force plot.

For arguments, see shapiq.plots.force_plot().

Requires the ``shap`` Python package to be installed.

Args:
feature_names: The feature names used for plotting. If no feature names are provided, the
feature indices are used instead. Defaults to ``None``.
Expand All @@ -710,18 +725,14 @@ def plot_force(

return force_plot(
self,
feature_values=feature_values,
feature_names=feature_names,
matplotlib=matplotlib,
show=show,
abbreviate=abbreviate,
**kwargs,
)

def plot_waterfall(
self,
feature_names: Optional[np.ndarray] = None,
feature_values: Optional[np.ndarray] = None,
show: bool = True,
abbreviate: bool = True,
max_display: int = 10,
Expand All @@ -743,11 +754,10 @@ def plot_waterfall(

return waterfall_plot(
self,
feature_values=feature_values,
feature_names=feature_names,
show=show,
abbreviate=abbreviate,
max_display=max_display,
abbreviate=abbreviate,
)

def plot_sentence(
Expand Down Expand Up @@ -779,3 +789,103 @@ def plot_upset(self, show: bool = True, **kwargs) -> Optional[plt.Figure]:
from shapiq.plot.upset import upset_plot

return upset_plot(self, show=show, **kwargs)


def aggregate_interaction_values(
interaction_values: Sequence[InteractionValues],
aggregation: str = "mean",
) -> InteractionValues:
"""Aggregates InteractionValues objects using a specific aggregation method.

Args:
interaction_values: A list of InteractionValues objects to aggregate.
aggregation: The aggregation method to use. Defaults to ``"mean"``. Other options are
``"median"``, ``"sum"``, ``"max"``, and ``"min"``.

Returns:
The aggregated InteractionValues object.

Example:
>>> iv1 = InteractionValues(
... values=np.array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6]),
... interaction_lookup={(0,): 0, (1,): 1, (2,): 2, (0, 1): 3, (0, 2): 4, (1, 2): 5},
... index="SII",
... max_order=2,
... n_players=3,
... min_order=1,
... baseline_value=0.0,
... )
>>> iv2 = InteractionValues(
... values=np.array([0.2, 0.3, 0.4, 0.5, 0.6]), # this iv is missing the (1, 2) value
... interaction_lookup={(0,): 0, (1,): 1, (2,): 2, (0, 1): 3, (0, 2): 4}, # no (1, 2)
... index="SII",
... max_order=2,
... n_players=3,
... min_order=1,
... baseline_value=1.0,
... )
>>> aggregate_interaction_values([iv1, iv2], "mean")
InteractionValues(
index=SII, max_order=2, min_order=1, estimated=True, estimation_budget=None,
n_players=3, baseline_value=0.5,
Top 10 interactions:
(1, 2): 0.60
(0, 2): 0.35
(0, 1): 0.25
(0,): 0.15
(1,): 0.25
(2,): 0.35
)
Note:
The index of the aggregated InteractionValues object is set to the index of the first
InteractionValues object in the list.

Raises:
ValueError: If the aggregation method is not supported.
"""

def _aggregate(vals: list[float], method: str) -> float:
"""Does the actual aggregation of the values."""
if method == "mean":
return np.mean(vals)
elif method == "median":
return np.median(vals)
elif method == "sum":
return np.sum(vals)
elif method == "max":
return np.max(vals)
elif method == "min":
return np.min(vals)
else:
raise ValueError(f"Aggregation method {method} is not supported.")

# get all keys from all InteractionValues objects
all_keys = set()
for iv in interaction_values:
all_keys.update(iv.interaction_lookup.keys())
all_keys = sorted(all_keys)

# aggregate the values
new_values = np.zeros(len(all_keys), dtype=float)
new_lookup = {}
for i, key in enumerate(all_keys):
new_lookup[key] = i
values = [iv[key] for iv in interaction_values]
new_values[i] = _aggregate(values, aggregation)

max_order = max([iv.max_order for iv in interaction_values])
min_order = min([iv.min_order for iv in interaction_values])
n_players = max([iv.n_players for iv in interaction_values])
baseline_value = _aggregate([iv.baseline_value for iv in interaction_values], aggregation)

return InteractionValues(
values=new_values,
index=interaction_values[0].index,
max_order=max_order,
n_players=n_players,
min_order=min_order,
interaction_lookup=new_lookup,
estimated=True,
estimation_budget=None,
baseline_value=baseline_value,
)
3 changes: 1 addition & 2 deletions shapiq/plot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from .si_graph import si_graph_plot
from .stacked_bar import stacked_bar_plot
from .upset import upset_plot
from .utils import abbreviate_feature_names, get_interaction_values_and_feature_names
from .utils import abbreviate_feature_names
from .watefall import waterfall_plot

__all__ = [
Expand All @@ -21,5 +21,4 @@
"upset_plot",
# utils
"abbreviate_feature_names",
"get_interaction_values_and_feature_names",
]
Loading
Loading