Skip to content

Commit

Permalink
Merge pull request #253 from mmschlk/visualization_notebook
Browse files Browse the repository at this point in the history
Visualization notebook
  • Loading branch information
pwhofman authored Oct 24, 2024
2 parents 1d62801 + e17a450 commit d64ab0d
Show file tree
Hide file tree
Showing 22 changed files with 790 additions and 73 deletions.
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ Contents

notebooks/shapiq_scikit_learn
notebooks/treeshapiq_lightgbm
notebooks/visualizing_shapley_interactions
notebooks/language_model_game
notebooks/vision_transformer
notebooks/conditional_imputer
Expand Down
Binary file added docs/source/notebooks/2-SII_network.pdf
Binary file not shown.
Binary file added docs/source/notebooks/2-SII_si_graph.pdf
Binary file not shown.
Binary file added docs/source/notebooks/Moebius_network.pdf
Binary file not shown.
Binary file added docs/source/notebooks/Moebius_si_graph.pdf
Binary file not shown.
Binary file added docs/source/notebooks/SV_si_graph.pdf
Binary file not shown.
593 changes: 593 additions & 0 deletions docs/source/notebooks/visualizing_shapley_interactions.ipynb

Large diffs are not rendered by default.

38 changes: 26 additions & 12 deletions shapiq/interaction_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,20 +579,21 @@ def to_dict(self) -> dict:
"baseline_value": self.baseline_value,
}

def plot_network(self, **kwargs) -> tuple[plt.Figure, plt.Axes]:
def plot_network(self, show: bool = True, **kwargs) -> Optional[tuple[plt.Figure, plt.Axes]]:
"""Visualize InteractionValues on a graph.
For arguments, see shapiq.plots.network_plot().
Returns:
matplotlib.pyplot.Figure, matplotlib.pyplot.Axes
"""
from shapiq import network_plot
from shapiq.plot.network import network_plot

if self.max_order > 1:
return network_plot(
first_order_values=self.get_n_order_values(1),
second_order_values=self.get_n_order_values(2),
show=show,
**kwargs,
)
else:
Expand All @@ -601,22 +602,32 @@ def plot_network(self, **kwargs) -> tuple[plt.Figure, plt.Axes]:
"but requires also 2-order values for the network plot."
)

def plot_stacked_bar(self, **kwargs) -> tuple[plt.Figure, plt.Axes]:
def plot_si_graph(self, show: bool = True, **kwargs) -> Optional[tuple[plt.Figure, plt.Axes]]:
"""Visualize InteractionValues as a SI graph.
For arguments, see shapiq.plots.si_graph_plot().
Returns:
The SI graph as a tuple containing the figure and the axes.
"""

from shapiq.plot.si_graph import si_graph_plot

return si_graph_plot(self, show=show, **kwargs)

def plot_stacked_bar(
self, show: bool = True, **kwargs
) -> Optional[tuple[plt.Figure, plt.Axes]]:
"""Visualize InteractionValues on a graph.
For arguments, see shapiq.plots.stacked_bar_plot().
Returns:
matplotlib.pyplot.Figure, matplotlib.pyplot.Axes
The stacked bar plot as a tuple containing the figure and the axes.
"""
from shapiq import stacked_bar_plot

ret = stacked_bar_plot(
self,
**kwargs,
)

return ret
return stacked_bar_plot(self, show=show, **kwargs)

def plot_force(
self,
Expand All @@ -639,6 +650,9 @@ def plot_force(
matplotlib: Whether to return a ``matplotlib`` figure. Defaults to ``True``.
show: Whether to show the plot. Defaults to ``False``.
**kwargs: Keyword arguments passed to ``shap.plots.force()``.
Returns:
The force plot as a matplotlib figure (if show is ``False``).
"""
from shapiq import force_plot

Expand All @@ -655,7 +669,7 @@ def plot_waterfall(
self,
feature_names: Optional[np.ndarray] = None,
feature_values: Optional[np.ndarray] = None,
show: bool = False,
show: bool = True,
max_display: int = 10,
) -> Optional[plt.Axes]:
"""Draws interaction values on a waterfall plot.
Expand Down
4 changes: 4 additions & 0 deletions shapiq/plot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .network import network_plot
from .si_graph import si_graph_plot
from .stacked_bar import stacked_bar_plot
from .utils import abbreviate_feature_names, get_interaction_values_and_feature_names
from .watefall import waterfall_plot

__all__ = [
Expand All @@ -14,4 +15,7 @@
"force_plot",
"bar_plot",
"waterfall_plot",
# utils
"abbreviate_feature_names",
"get_interaction_values_and_feature_names",
]
15 changes: 6 additions & 9 deletions shapiq/plot/bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
def bar_plot(
list_of_interaction_values: list[InteractionValues],
feature_names: Optional[np.ndarray] = None,
show: bool = True,
show: bool = False,
**kwargs,
) -> plt.Axes:
) -> Optional[plt.Axes]:
"""Draws interaction values on a bar plot.
Requires the ``shap`` Python package to be installed.
Expand Down Expand Up @@ -55,11 +55,8 @@ def bar_plot(
feature_names=_labels,
)

if show:
ax = shap.plots.bar(explanation, **kwargs, show=False)
ax.set_xlabel("mean(|Shapley Interaction value|)")
plt.show()
else:
ax = shap.plots.bar(explanation, **kwargs, show=False)
ax.set_xlabel("mean(|Shapley Interaction value|)")
ax = shap.plots.bar(explanation, **kwargs, show=False)
ax.set_xlabel("mean(|Shapley Interaction value|)")
if not show:
return ax
plt.show()
37 changes: 13 additions & 24 deletions shapiq/plot/force.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def force_plot(
feature_names: Optional[np.ndarray] = None,
feature_values: Optional[np.ndarray] = None,
matplotlib: bool = True,
show: bool = True,
show: bool = False,
**kwargs,
) -> Optional[plt.Figure]:
"""Draws interaction values on a force plot.
Expand All @@ -37,26 +37,15 @@ def force_plot(
check_import_module("shap")
import shap

if interaction_values.max_order == 1:
return shap.plots.force(
base_value=np.array([interaction_values.baseline_value], dtype=float), # must be array
shap_values=interaction_values.get_n_order_values(1),
features=feature_values,
feature_names=feature_names,
matplotlib=matplotlib,
show=show,
**kwargs,
)
else:
_shap_values, _labels = get_interaction_values_and_feature_names(
interaction_values, feature_names, feature_values
)

return shap.plots.force(
base_value=np.array([interaction_values.baseline_value], dtype=float), # must be array
shap_values=np.array(_shap_values),
feature_names=_labels,
matplotlib=matplotlib,
show=show,
**kwargs,
)
_shap_values, _labels = get_interaction_values_and_feature_names(
interaction_values, feature_names, feature_values
)

return shap.plots.force(
base_value=np.array([interaction_values.baseline_value], dtype=float), # must be array
shap_values=np.array(_shap_values),
feature_names=_labels,
matplotlib=matplotlib,
show=show,
**kwargs,
)
11 changes: 8 additions & 3 deletions shapiq/plot/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ def network_plot(
center_image_size: Optional[float] = 0.6,
draw_legend: bool = True,
center_text: Optional[str] = None,
) -> tuple[plt.Figure, plt.Axes]:
show: bool = False,
) -> Optional[tuple[plt.Figure, plt.Axes]]:
"""Draws the interaction network.
An interaction network is a graph where the nodes represent the features and the edges represent
Expand Down Expand Up @@ -59,9 +60,11 @@ def network_plot(
center_image_size: The size of the center image. Defaults to ``0.6``.
draw_legend: Whether to draw the legend. Defaults to ``True``.
center_text: The text to be displayed in the center of the network. Defaults to ``None``.
show: Whether to show the plot. Defaults to ``False``. If ``False``, the figure and the axis
containing the plot are returned, otherwise ``None``.
Returns:
The figure and the axis containing the plot.
The figure and the axis containing the plot if ``show=False``.
"""
fig, axis = plt.subplots(figsize=(6, 6))
axis.axis("off")
Expand Down Expand Up @@ -175,7 +178,9 @@ def network_plot(
if draw_legend:
_add_legend_to_axis(axis)

return fig, axis
if not show:
return fig, axis
plt.show()


def _add_weight_to_edges_in_graph(
Expand Down
38 changes: 28 additions & 10 deletions shapiq/plot/si_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
ADJUST_NODE_ALPHA = True


__all__ = ["si_graph_plot"]


def _normalize_value(
value: float, max_value: float, base_value: float, cubic_scaling: bool = False
) -> float:
Expand Down Expand Up @@ -310,14 +313,14 @@ def _adjust_position(

def si_graph_plot(
interaction_values: InteractionValues,
graph: Union[list[tuple], nx.Graph],
graph: Optional[Union[list[tuple], nx.Graph]] = None,
n_interactions: Optional[int] = None,
draw_threshold: float = 0.0,
random_seed: int = 42,
size_factor: float = 1.0,
plot_explanation: bool = True,
compactness: float = 1.0,
label_mapping: Optional[dict] = None,
compactness: float = 1e10,
feature_names: Optional[list] = None,
cubic_scaling: bool = False,
pos: Optional[dict] = None,
node_size_scaling: float = 1.0,
Expand All @@ -326,7 +329,8 @@ def si_graph_plot(
spring_k: Optional[float] = None,
interaction_direction: Optional[str] = None,
node_area_scaling: bool = False,
) -> tuple[plt.figure, plt.axis]:
show: bool = False,
) -> Optional[tuple[plt.figure, plt.axis]]:
"""Plots the interaction values as an explanation graph.
An explanation graph is an undirected graph where the nodes represent players and the edges
Expand All @@ -338,7 +342,8 @@ def si_graph_plot(
interaction_values: The interaction values to plot.
graph: The underlying graph structure as a list of edge tuples or a networkx graph. If a
networkx graph is provided, the nodes are used as the players and the edges are used as
the connections between the players.
the connections between the players. Defaults to ``None``, which creates a graph with
all nodes from the interaction values without any edges between them.
n_interactions: The number of interactions to plot. If ``None``, all interactions are plotted
according to the draw_threshold.
draw_threshold: The threshold to draw an edge (i.e. only draw explanations with an
Expand All @@ -351,8 +356,8 @@ def si_graph_plot(
compactness: A scaling factor for the underlying spring layout. A higher compactness value
will move the interactions closer to the graph nodes. If your graph looks weird, try
adjusting this value, e.g. ``[0.1, 1.0, 10.0, 100.0, 1000.0]``. Defaults to ``1.0``.
label_mapping: A mapping from the player/node indices to the player label. If ``None``, the
player indices are used as labels. Defaults to ``None``.
feature_names: A list of feature names to use for the nodes in the graph. If ``None``,
the feature indices are used instead. Defaults to ``None``.
cubic_scaling: Whether to scale the size of explanations cubically (``True``) or linearly
(``False``, default). Cubic scaling puts more emphasis on larger interactions in the plot.
Defaults to ``False``.
Expand All @@ -372,14 +377,19 @@ def si_graph_plot(
interactions are plotted. Possible values are ``"positive"`` and
``"negative"``. Defaults to ``None``.
node_area_scaling: TODO add docstring.
show: Whether to show or return the plot. Defaults to ``False``.
Returns:
The figure and axis of the plot.
The figure and axis of the plot if ``show`` is ``True``. Otherwise, ``None``.
"""

normal_node_size = NORMAL_NODE_SIZE * node_size_scaling
base_size = BASE_SIZE * node_size_scaling

label_mapping = None
if feature_names is not None:
label_mapping = {i: feature_names[i] for i in range(len(feature_names))}

# fill the original graph with the edges and nodes
if isinstance(graph, nx.Graph):
original_graph = graph
Expand All @@ -389,7 +399,7 @@ def si_graph_plot(
for node in graph_nodes:
node_label = label_mapping.get(node, node) if label_mapping is not None else node
original_graph.nodes[node]["label"] = node_label
else:
elif isinstance(graph, list):
original_graph, graph_nodes = nx.Graph(), []
for edge in graph:
original_graph.add_edge(*edge)
Expand All @@ -399,6 +409,12 @@ def si_graph_plot(
original_graph.add_node(edge[0], label=nodel_labels[0])
original_graph.add_node(edge[1], label=nodel_labels[1])
graph_nodes.extend([edge[0], edge[1]])
else: # graph is considered None
original_graph = nx.Graph()
graph_nodes = list(range(interaction_values.n_players))
for node in graph_nodes:
node_label = label_mapping.get(node, node) if label_mapping is not None else node
original_graph.add_node(node, label=node_label)

if n_interactions is not None:
# get the top n interactions
Expand Down Expand Up @@ -500,4 +516,6 @@ def si_graph_plot(
ax.set_aspect("equal", adjustable="datalim") # make y- and x-axis scales equal
ax.axis("off") # remove axis

return fig, ax
if not show:
return fig, ax
plt.show()
6 changes: 5 additions & 1 deletion shapiq/plot/stacked_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def stacked_bar_plot(
title: Optional[str] = None,
xlabel: Optional[str] = None,
ylabel: Optional[str] = None,
show: bool = False,
):
"""Plot the n-SII values for a given instance.
Expand All @@ -43,6 +44,7 @@ def stacked_bar_plot(
title (str): The title of the plot.
xlabel (str): The label of the x-axis.
ylabel (str): The label of the y-axis.
show (bool): Whether to show the plot. Defaults to ``False``.
Returns:
tuple[matplotlib.figure.Figure, matplotlib.axes.Axes]: A tuple containing the figure and
Expand Down Expand Up @@ -147,4 +149,6 @@ def stacked_bar_plot(

plt.tight_layout()

return fig, axis
if not show:
return fig, axis
plt.show()
Loading

0 comments on commit d64ab0d

Please sign in to comment.