From b9baa1ef9c6a971aa55beb9c1340b416affee901 Mon Sep 17 00:00:00 2001 From: Theo <49311372+Advueu963@users.noreply.github.com> Date: Fri, 25 Oct 2024 08:40:39 +0200 Subject: [PATCH] 190 add a tutorial notebook for data valuation (#243) --- docs/source/index.rst | 2 + docs/source/notebooks/data_valuation.ipynb | 792 +++++++++++++++++++++ docs/source/notebooks/sv_calculation.ipynb | 84 +++ 3 files changed, 878 insertions(+) create mode 100644 docs/source/notebooks/data_valuation.ipynb create mode 100644 docs/source/notebooks/sv_calculation.ipynb diff --git a/docs/source/index.rst b/docs/source/index.rst index 4aeb4953..f437ef7f 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -30,6 +30,7 @@ Contents :maxdepth: 1 :caption: TUTORIALS + notebooks/sv_calculation notebooks/shapiq_scikit_learn notebooks/treeshapiq_lightgbm notebooks/visualizing_shapley_interactions @@ -38,6 +39,7 @@ Contents notebooks/conditional_imputer notebooks/parallel_computation notebooks/benchmark_approximators + notebooks/data_valuation notebooks/core .. toctree:: diff --git a/docs/source/notebooks/data_valuation.ipynb b/docs/source/notebooks/data_valuation.ipynb new file mode 100644 index 00000000..e7b41a59 --- /dev/null +++ b/docs/source/notebooks/data_valuation.ipynb @@ -0,0 +1,792 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "source": [ + "# Shapiq for Data Valuation\n", + "On this page we demonstrate two examples for using Shapiq for Data valuation.\n", + "The first example demonstrates this for a synthetic dataset, and the second for a real dataset.\n", + "In data valuation we are interested given a training and testing dataset to evaluate the contribution of each training point to the model's performance on the test data." + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 1, + "outputs": [], + "source": [ + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "import shapiq\n", + "from sklearn.inspection import DecisionBoundaryDisplay\n", + "\n", + "# Vector Graphics\n", + "%matplotlib inline\n", + "import matplotlib_inline\n", + "from shapiq.plot._config import COLORS_K_SII, RED\n", + "\n", + "matplotlib_inline.backend_inline.set_matplotlib_formats(\"svg\")" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-10-22T16:04:37.434604Z", + "start_time": "2024-10-22T16:04:35.731310Z" + } + } + }, + { + "cell_type": "markdown", + "source": [ + "## Synthetic Data\n", + "In this example we generate a synthetic classification dataset with 2 features, and 22 samples.\n", + "The dataset consists of two classes, each with 11 samples.\n", + "The data is generated from two multivariate normal distributions with different means and covariances.\n", + "This is done in such a way that the two classes are linearly separable.\n" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "collapsed": true, + "ExecuteTime": { + "end_time": "2024-10-22T16:04:37.540707Z", + "start_time": "2024-10-22T16:04:37.438398Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": "
", + "image/svg+xml": "\n\n\n \n \n \n \n 2024-10-22T18:04:37.516569\n image/svg+xml\n \n \n Matplotlib v3.8.0, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n" + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "def plot_synthetic_data(ax, X_train, y_train, X_test, y_test, title):\n", + " ax.set_title(title)\n", + " ax.scatter(\n", + " X_train[:, 0],\n", + " X_train[:, 1],\n", + " c=[COLORS_K_SII[i] for i in y_train],\n", + " label=\"Training Points\",\n", + " marker=\"o\",\n", + " )\n", + " ax.scatter(\n", + " X_test[:, 0],\n", + " X_test[:, 1],\n", + " c=[COLORS_K_SII[i] for i in y_test],\n", + " label=\"Test Points\",\n", + " marker=\"x\",\n", + " )\n", + " # Manually create legend entries\n", + " handles = [\n", + " plt.Line2D(\n", + " [0],\n", + " [0],\n", + " marker=\"o\",\n", + " color=\"w\",\n", + " markerfacecolor=COLORS_K_SII[i],\n", + " markersize=10,\n", + " label=f\"Class {i} (Train)\",\n", + " )\n", + " for i in [1, 2]\n", + " ]\n", + " handles += [\n", + " plt.Line2D(\n", + " [0],\n", + " [0],\n", + " marker=\"x\",\n", + " linewidth=0,\n", + " color=COLORS_K_SII[i],\n", + " markerfacecolor=COLORS_K_SII[i],\n", + " markersize=10,\n", + " label=f\"Class {i} (Test)\",\n", + " )\n", + " for i in [1, 2]\n", + " ]\n", + "\n", + " ax.legend(handles=handles, loc=\"upper right\", title=\"Data Points\")\n", + "\n", + " ax.set_xlabel(\"Feature 1\")\n", + " ax.set_ylabel(\"Feature 2\")\n", + "\n", + "\n", + "# Meta information\n", + "n_samples = 11\n", + "n_classes = 2\n", + "classes = list(range(1, n_classes + 1))\n", + "random_state = 1337\n", + "np.random.seed(random_state)\n", + "\n", + "# parameters for toy data\n", + "means = [(3, 0), (-3, 0)]\n", + "covs = [np.diag([3, 2]), np.diag([3, 3.5])]\n", + "\n", + "# Construct the dataset\n", + "X = np.vstack(\n", + " [np.random.multivariate_normal(mean, cov, n_samples) for mean, cov in zip(means, covs)]\n", + ")\n", + "y = np.hstack([np.full(n_samples, i) for i in classes])\n", + "\n", + "# Build training and test set\n", + "n_samples_to_select = 10\n", + "random_indices = np.random.choice(X.shape[0], n_samples_to_select, replace=False)\n", + "X_test, y_test = X[random_indices], y[random_indices]\n", + "X_train, y_train = np.delete(X, random_indices, axis=0), np.delete(y, random_indices, axis=0)\n", + "fig, ax = plt.subplots()\n", + "\n", + "plot_synthetic_data(ax, X_train, y_train, X_test, y_test, \"Synthetic Classification Data\")" + ] + }, + { + "cell_type": "markdown", + "source": [ + "To apply `shapiq` approximators we need to reformulate the task of data valuation into a cooperative game $(N,\\nu)$.\n", + "We define $N$ as the set of training points $N = \\{1, \\ldots, n\\}$ and the characteristic function $$\\nu: 2^N \\rightarrow \\mathbb{R}$$ is then the accuracy the model achieves on the test points (cross) given the training points in $S$.\n" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 3, + "outputs": [], + "source": [ + "class SyntheticDataValuation(shapiq.Game):\n", + " \"\"\"The synthetic data valuation tasked modeled as a cooperative game.\n", + " Args:\n", + " classifier: A classifier object that has the methods fit and score.\n", + " n_players: The number of players in the game.\n", + " X_test: The test data.\n", + " y_test: The test labels.\n", + " \"\"\"\n", + "\n", + " def __init__(self, classifier, n_players, X_train, y_train, X_test, y_test):\n", + " self.classifier = classifier\n", + " self.X_train = X_train\n", + " self.y_train = y_train\n", + " self.X_test = X_test\n", + " self.y_test = y_test\n", + "\n", + " empty_coalition_value = np.zeros((1, n_players), dtype=bool)\n", + " self.normalization_value = float(self.value_function(empty_coalition_value)[0])\n", + " super().__init__(n_players, normalization_value=self.normalization_value)\n", + "\n", + " def value_function(self, coalitions: np.ndarray) -> np.ndarray:\n", + " \"\"\"Compute the value of the coalitions.\n", + " Args:\n", + " coalitions: A numpy matrix of shape (n_coalitions, n_players)\n", + "\n", + " Returns:\n", + " A vector of the value of the coalition\n", + " \"\"\"\n", + " values = []\n", + " for coalition in coalitions:\n", + " tmp_X_train = self.X_train[coalition]\n", + " tmp_y_train = self.y_train[coalition]\n", + " if len(tmp_X_train) == 0:\n", + " # If the coalition is empty, the value is zero\n", + " value = 0\n", + " else:\n", + " unique_targets = np.unique(tmp_y_train)\n", + " if len(unique_targets) == 1:\n", + " # If we only have one class present in training data, we predict this class\n", + " value = np.mean((self.y_test == unique_targets[0]))\n", + " else:\n", + " # We have at least two classes, we fit the classifier\n", + " self.classifier.fit(tmp_X_train, tmp_y_train)\n", + " value = self.classifier.score(self.X_test, self.y_test)\n", + "\n", + " values.append(value)\n", + "\n", + " return np.array(values, dtype=float)" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-10-22T16:04:37.551682Z", + "start_time": "2024-10-22T16:04:37.549258Z" + } + } + }, + { + "cell_type": "markdown", + "source": [ + "As our model we choose the `LinearSVC()`from `sklearn`.\n", + "To get first insights into the data valuation we can compute the value of the full and empty coalition." + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 4, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Full coalition value: 1.0\n", + "Empty coalition value: 0.0\n" + ] + } + ], + "source": [ + "from sklearn.svm import LinearSVC\n", + "\n", + "classifier = LinearSVC()\n", + "n_players = X_train.shape[0]\n", + "data_valuation_game = SyntheticDataValuation(\n", + " classifier=classifier,\n", + " n_players=n_players,\n", + " X_train=X_train,\n", + " y_train=y_train,\n", + " X_test=X_test,\n", + " y_test=y_test,\n", + ")\n", + "\n", + "full_coalition = np.ones((1, n_players), dtype=bool)\n", + "empty_coalition = np.zeros((1, n_players), dtype=bool)\n", + "print(\"Full coalition value: \", data_valuation_game(full_coalition)[0])\n", + "print(\"Empty coalition value: \", data_valuation_game(empty_coalition)[0])" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-10-22T16:04:37.577065Z", + "start_time": "2024-10-22T16:04:37.554955Z" + } + } + }, + { + "cell_type": "markdown", + "source": [ + "The empty coalition value is $0.0$ as the model has no information about the data.\n", + "The full coalition value is $1.0$ as the model is trained on all data points, and they are linearly seperable.\n", + "For this we plot the decision boundary of the `LinearSVM` classifier for the training data and the corresponding test data." + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 5, + "outputs": [ + { + "data": { + "text/plain": "
", + "image/svg+xml": "\n\n\n \n \n \n \n 2024-10-22T18:04:37.637652\n image/svg+xml\n \n \n Matplotlib v3.8.0, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n" + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "fig, ax = plt.subplots()\n", + "\n", + "classifier.fit(X_train, y_train)\n", + "plot_synthetic_data(ax, X_train, y_train, X_test, y_test, \"Synthetic Classification Data\")\n", + "\n", + "DecisionBoundaryDisplay.from_estimator(\n", + " classifier,\n", + " X_train,\n", + " plot_method=\"contour\",\n", + " ax=ax,\n", + " levels=[-1, 0, 1],\n", + " linestyles=[\"--\", \"-\", \"--\"],\n", + " colors=[COLORS_K_SII[1], RED.hex, COLORS_K_SII[2]],\n", + " alpha=0.5,\n", + ")\n", + "ax.set_xlabel(\"Feature 1\")\n", + "ax.set_ylabel(\"Feature 2\")\n", + "ax.set_title(\"Decision Boundary\")\n", + "plt.show()" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-10-22T16:04:37.657117Z", + "start_time": "2024-10-22T16:04:37.567430Z" + } + } + }, + { + "cell_type": "markdown", + "source": [ + "### Computing Shapley Values\n", + "Now we can compute the Shapley values for the data valuation game.\n", + "Intuitively, the Shapley values should all be positive as each training point makes the model more aware of the natural boundarie between the two classes." + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 6, + "outputs": [ + { + "data": { + "text/plain": "
", + "image/svg+xml": "\n\n\n \n \n \n \n 2024-10-22T18:04:38.986904\n image/svg+xml\n \n \n Matplotlib v3.8.0, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n" + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Compute Shapley values with the ShapIQ approximator for the game function\n", + "exactComputer = shapiq.ExactComputer(n_players=n_players, game_fun=data_valuation_game)\n", + "sv_values = exactComputer(\"SV\")\n", + "sv_values.plot_stacked_bar(\n", + " title=\"Shapley Values for Synthetic (Training) Data\",\n", + " xlabel=\"Data Point\",\n", + " ylabel=\"Shapley Value\",\n", + ")\n", + "plt.show()" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-10-22T16:04:39.005886Z", + "start_time": "2024-10-22T16:04:37.657523Z" + } + } + }, + { + "cell_type": "markdown", + "source": [ + "The Shapley values are all positive indicating that all data points have a positive impact on the model's performance.\n", + "Interestingly the Shapley values indicate the first five data points to be more important.\n", + "\n", + "To understand this better notice that we have *four* blue test points and *six* orange test points.\n", + "If we are provided with a training set that contains only orange points, the model will have an accuracy of $0.6$.\n", + "On the other side, if we are provided with a training set that contains only blue points, the model will have an accuracy of $0.4$.\n", + "Thus having orange points in the training set is more important for the model's performance.\n", + "Meaning that the orange points are more important regarding accuracy.\n", + "These are exactly the first five data points, which are all orange.\n", + "\n" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 7, + "outputs": [ + { + "data": { + "text/plain": "
", + "image/svg+xml": "\n\n\n \n \n \n \n 2024-10-22T18:04:39.045802\n image/svg+xml\n \n \n Matplotlib v3.8.0, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n" + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "fig, ax = plt.subplots()\n", + "\n", + "plot_synthetic_data(ax, X_train, y_train, X_test, y_test, \"Synthetic Classification Data\")\n", + "for i in range(5):\n", + " ax.annotate(\n", + " f\"Point {i+1}\",\n", + " (X_train[i, 0], X_train[i, 1]),\n", + " textcoords=\"offset points\",\n", + " xytext=(0, -10),\n", + " ha=\"center\",\n", + " )\n", + "\n", + "plt.show()" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-10-22T16:04:39.066692Z", + "start_time": "2024-10-22T16:04:39.013827Z" + } + } + }, + { + "cell_type": "markdown", + "source": [ + "### Corrupting the Data\n", + "We can now investigate the impact of corrupting the data on the Shapley values.\n", + "Currently the Shapley values are less interesting as we have a clear boundary between the two classes.\n", + "If we now corrupt the data by adding noise to the labels, the Shapley values should change and identify the corrupted samples." + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 8, + "outputs": [ + { + "data": { + "text/plain": "
", + "image/svg+xml": "\n\n\n \n \n \n \n 2024-10-22T18:04:39.183419\n image/svg+xml\n \n \n Matplotlib v3.8.0, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n" + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from matplotlib import patches\n", + "\n", + "fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(12, 8))\n", + "\n", + "plot_synthetic_data(ax[0], X_train, y_train, X_test, y_test, \"Synthetic Classification Data\")\n", + "\n", + "corrupted_X_train = X_train.copy()\n", + "corruped_y_train = y_train.copy()\n", + "\n", + "corruped_y_train[5] = 1\n", + "corruped_y_train[2] = 2\n", + "plot_synthetic_data(\n", + " ax[1],\n", + " corrupted_X_train,\n", + " corruped_y_train,\n", + " X_test,\n", + " y_test,\n", + " \"Corrupted Synthetic Classification Data\",\n", + ")" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-10-22T16:04:39.216073Z", + "start_time": "2024-10-22T16:04:39.071637Z" + } + } + }, + { + "cell_type": "markdown", + "source": [ + "Let us now look at the Shapley values for the corrupted data.\n", + "The Shapley values should now identify the corrupted samples." + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 9, + "outputs": [ + { + "data": { + "text/plain": "
", + "image/svg+xml": "\n\n\n \n \n \n \n 2024-10-22T18:04:40.617300\n image/svg+xml\n \n \n Matplotlib v3.8.0, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n" + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "data_valuation_game = SyntheticDataValuation(\n", + " classifier=classifier,\n", + " n_players=n_players,\n", + " X_train=corrupted_X_train,\n", + " y_train=corruped_y_train,\n", + " X_test=X_test,\n", + " y_test=y_test,\n", + ")\n", + "\n", + "# Compute Shapley values with the shapiq ExactComputer for the game function\n", + "exactComputer = shapiq.ExactComputer(n_players=n_players, game_fun=data_valuation_game)\n", + "sv_values = exactComputer(\"SV\")\n", + "sv_values.plot_stacked_bar()\n", + "plt.show()" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-10-22T16:04:40.635663Z", + "start_time": "2024-10-22T16:04:39.216847Z" + } + } + }, + { + "cell_type": "markdown", + "source": [ + "With both corrupted samples identified by the Shapley values, we can now remove them from the training data and our model should perform better on the test data." + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 10, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Accuracy on test data before removing corrupted samples: 0.5\n", + "Accuracy on test data after removing corrupted samples: 1.0\n" + ] + } + ], + "source": [ + "classifier.fit(corrupted_X_train, corruped_y_train)\n", + "print(\"Accuracy on test data before removing corrupted samples: \", classifier.score(X_test, y_test))\n", + "\n", + "cleaned_X_train = np.delete(corrupted_X_train, [5, 2], axis=0)\n", + "cleaned_y_train = np.delete(corruped_y_train, [5, 2], axis=0)\n", + "classifier.fit(cleaned_X_train, cleaned_y_train)\n", + "print(\"Accuracy on test data after removing corrupted samples: \", classifier.score(X_test, y_test))" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-10-22T16:04:40.640022Z", + "start_time": "2024-10-22T16:04:40.636571Z" + } + } + }, + { + "cell_type": "markdown", + "source": [ + "To verify this as sensible we plot the decision boundary of the `LinearSVM` classifier for the corrupted and cleaned training data and the corresponding test data." + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 11, + "outputs": [ + { + "data": { + "text/plain": "
", + "image/svg+xml": "\n\n\n \n \n \n \n 2024-10-22T18:04:40.745038\n image/svg+xml\n \n \n Matplotlib v3.8.0, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n" + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "def plot_decision_boundary(ax, classifier, X_train, y_train, X_test, y_test):\n", + " classifier.fit(X_train, y_train)\n", + " plot_synthetic_data(ax, X_train, y_train, X_test, y_test, \"Synthetic Classification Data\")\n", + " DecisionBoundaryDisplay.from_estimator(\n", + " classifier,\n", + " X_train,\n", + " plot_method=\"contour\",\n", + " ax=ax,\n", + " levels=[-1, 0, 1],\n", + " linestyles=[\"--\", \"-\", \"--\"],\n", + " alpha=0.5,\n", + " )\n", + " ax.set_xlabel(\"Feature 1\")\n", + " ax.set_ylabel(\"Feature 2\")\n", + " ax.set_title(\"Decision Boundary\")\n", + "\n", + "\n", + "fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(12, 7))\n", + "fig.suptitle(\"Decision Boundary of Linear SVM\")\n", + "# Plot the decision boundary of the model with corrupted samples\n", + "plot_decision_boundary(ax[0], classifier, corrupted_X_train, corruped_y_train, X_test, y_test)\n", + "\n", + "# Plot the decision boundary of the model with removed corrupted samples\n", + "plot_decision_boundary(ax[1], classifier, cleaned_X_train, cleaned_y_train, X_test, y_test)\n", + "plt.show()" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-10-22T16:04:40.881221Z", + "start_time": "2024-10-22T16:04:40.641575Z" + } + } + }, + { + "cell_type": "markdown", + "source": [ + "## Real Data\n", + "We now demonstrate the data valuation for the [AdultCensus](../api/shapiq.datasets.rst) dataset.\n", + "Due to increasing runtime we choose a subset of the data, consisting of 200 samples.\n", + "Then we divide the data into training and test data at an 80/20 ratio." + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 12, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Players: 160\n" + ] + } + ], + "source": [ + "from sklearn.tree import DecisionTreeClassifier\n", + "from sklearn.model_selection import train_test_split\n", + "\n", + "X, y = shapiq.load_adult_census(to_numpy=True)\n", + "\n", + "X, y = X[:200], y[:200]\n", + "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=random_state)\n", + "classifier = DecisionTreeClassifier(random_state=random_state)\n", + "n_players = X_train.shape[0]\n", + "print(\"Players: \", n_players)" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-10-22T16:04:41.055176Z", + "start_time": "2024-10-22T16:04:40.879435Z" + } + } + }, + { + "cell_type": "code", + "execution_count": 13, + "outputs": [], + "source": [ + "data_valuation_game = SyntheticDataValuation(\n", + " classifier=classifier,\n", + " n_players=n_players,\n", + " X_train=X_train,\n", + " y_train=y_train,\n", + " X_test=X_test,\n", + " y_test=y_test,\n", + ")" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-10-22T16:04:41.056814Z", + "start_time": "2024-10-22T16:04:41.055761Z" + } + } + }, + { + "cell_type": "markdown", + "source": [ + "In the next step we show how different budgets influence the quality of approximation and the corresponding accuracy tradeoff." + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 14, + "outputs": [ + { + "data": { + "text/plain": "
", + "image/svg+xml": "\n\n\n \n \n \n \n 2024-10-22T18:05:15.036777\n image/svg+xml\n \n \n Matplotlib v3.8.0, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n" + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "budgets = [10, 100, 1000, 5000]\n", + "erg = {}\n", + "for budget in budgets:\n", + " # Compute Shapley interactions with the SVARM approximator for the game function\n", + " approximator = shapiq.SVARM(n=n_players, random_state=random_state)\n", + " shapley_approx = approximator.approximate(budget=budget, game=data_valuation_game)\n", + "\n", + " # Sort the approximated values of each player and get the keys\n", + " players = np.array(range(0, n_players))\n", + " sv_values = shapley_approx.values[1:]\n", + " idx = np.argsort(sv_values)\n", + " sorted_players = players[idx]\n", + "\n", + " # Compute the accuracy of the model for different amount of removed samples\n", + " percent_removal = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]\n", + " accuracies = []\n", + " for p in percent_removal:\n", + " n_samples_to_remove = int(p * n_players)\n", + " removed_players = sorted_players[:n_samples_to_remove]\n", + " cleaned_X_train = np.delete(X_train, removed_players, axis=0)\n", + " cleaned_y_train = np.delete(y_train, removed_players, axis=0)\n", + " classifier.fit(cleaned_X_train, cleaned_y_train)\n", + " accuracies.append(classifier.score(X_test, y_test))\n", + " erg[budget] = (percent_removal, accuracies)\n", + "\n", + "# plot the results\n", + "\n", + "fig, ax = plt.subplots()\n", + "fig.suptitle(\"Accuracy of the model on the test data after removing samples\")\n", + "for i, (budget, (percent_removal, accuracies)) in enumerate(erg.items()):\n", + " ax.plot(\n", + " percent_removal,\n", + " accuracies,\n", + " label=f\"Budget: {budget}\",\n", + " marker=\"o\",\n", + " linestyle=\"-\",\n", + " color=COLORS_K_SII[i],\n", + " )\n", + "plt.xlabel(\"Percentage of removed samples\")\n", + "plt.ylabel(\"Accuracy\")\n", + "plt.legend()\n", + "plt.show()" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-10-22T16:05:15.069797Z", + "start_time": "2024-10-22T16:04:41.062860Z" + } + } + }, + { + "cell_type": "markdown", + "source": [ + "Intuitively increasing amount of removed data samples with low Shapley values should yield a better model performance.\n", + "Increasing the budget yields a more clear effect for lower percentages.\n", + "For very high percentages the effect is less pronounced as the model is already trained on the most important samples." + ], + "metadata": { + "collapsed": false + } + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/docs/source/notebooks/sv_calculation.ipynb b/docs/source/notebooks/sv_calculation.ipynb new file mode 100644 index 00000000..a7969431 --- /dev/null +++ b/docs/source/notebooks/sv_calculation.ipynb @@ -0,0 +1,84 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "source": [ + "# Shapley Value Calculation\n", + "A popular approach to tackle the problem of XAI is to use concepts from game theory in particular cooperative game theory.\n", + "The most popular method is to use the **Shapley Values** named after Lloyd Shapley, who introduced it in 1951 with his work *\"II: The Value of an n-Person Game\"*.\n", + "\n", + "## Cooperative Game Theory\n", + "Cooperative game theory deals with the study of games in which players/participants can form groups to achieve a collective payoff. More formally a cooperative game is defined as a tuple $(N,\\nu)$ where:\n", + "- $N$ is a finite set of players\n", + "- $\\nu$ is a characteristic function that maps every coalition of players to a real number, i.e. $\\nu:2^N \\rightarrow \\mathbb{R}$\n", + "\n", + "Of particular interest is to find a concept that distributes the payoff of $\\nu(N)$ among the players, as it is assumed that the *grand coalition* $N$ is formed.\n", + "The distribution of the payoff among the players is called a *solution concept*.\n", + "\n", + "## Shapley Values: A Unique Solution Concept\n", + "Given a cooperative game $(N,\\nu)$, the Shapley value is a payoff vector dividing the total payoff $\\nu(N)$ among the players. The Shapley value of player $i$ is denoted by $\\phi_i(\\nu)$ and is defined as:\n", + "$$\n", + "\\phi_i(\\nu) := \\sum_{S \\subseteq N \\setminus \\{i\\}} \\frac{|S|!(|N|-|S|-1)!}{|N|!} [\\nu(S \\cup \\{i\\}) - \\nu(S)]\n", + "$$\n", + "and can be interpreted as the average marginal contribution of player $i$ across all possible permutations of the players.\n", + "Its popularity arises from uniquely satisfies the following properties:\n", + "- **Efficiency**: The sum of the Shapley values equals the total payoff, i.e. $\\sum_{i \\in N} \\phi_i(\\nu) = \\nu(N)$\n", + "- **Symmetry**: If two players $i$ and $j$ are such that for all coalitions $S \\subseteq N \\setminus \\{i,j\\}$, $\\nu(S \\cup \\{i\\}) = \\nu(S \\cup \\{j\\})$, then $\\phi_i(\\nu) = \\phi_j(\\nu)$\n", + "- **Additivity**: For a game $(N,\\nu + \\mu)$ based on two games $(N,\\nu)$ and $(N,\\mu)$, the Shapley value of the sum of the games is the sum of the Shapley values, i.e. $\\phi_i(\\nu + \\mu) = \\phi_i(\\nu) + \\phi_i(\\mu)$\n", + "- **Dummy Player**: If for a player $i$ is holds for all coalitions $S \\subseteq N \\setminus \\{i\\}$, $\\nu(S \\cup \\{i\\}) - \\nu(S) = \\nu(\\{i\\})$ then $\\phi_i(\\nu) = \\nu(\\{i\\})$\n", + "\n", + "## Shapley Values: Cooking Game\n", + "To illustrate the concept of Shapley values, we consider a simple example of a cooking game.\n", + " The game consists of three players(cooks), Alice, Bob, and Charlie, who are cooking a meal together.\n", + " The characteristic function $\\nu$ maps each coalition of players to the quality of the meal." + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "markdown", + "source": [], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "markdown", + "source": [], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +}