diff --git a/docs/examples/Predict_Contributions.ipynb b/docs/examples/Predict_Contributions.ipynb
new file mode 100644
index 0000000..54cac6e
--- /dev/null
+++ b/docs/examples/Predict_Contributions.ipynb
@@ -0,0 +1,1153 @@
+{
+ "cells": [
+ {
+ "metadata": {},
+ "cell_type": "markdown",
+ "source": [
+ "# Show how to extract prediction contributions for each distribution parameter\n",
+ "\n",
+ "This example shows how to get the contribution of every feature for each distributional parameter for a given data set. This allows similar inferences as one might get from SHAP but comes from lightGBM's internal workings. We can use output for example to get for a given prediction which features are causing the most impact to a given distributional parameter.\n",
+ "\n",
+ "These contributions are created before the response function is applied. As such in the case of the identity function, for a given row of data the sum of the contributions should equal the parameter value.\n"
+ ],
+ "id": "bf95ab4267d5a34"
+ },
+ {
+ "metadata": {},
+ "cell_type": "markdown",
+ "source": [
+ "# Imports\n",
+ "\n",
+ "First, we import necessary functions. "
+ ],
+ "id": "bbea43740b87eb"
+ },
+ {
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-09-26T09:33:07.019505Z",
+ "start_time": "2024-09-26T09:33:01.235342Z"
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "import numpy as np\n",
+ "\n",
+ "from lightgbmlss.model import *\n",
+ "from lightgbmlss.distributions.Gaussian import *\n",
+ "from lightgbmlss.datasets.data_loader import load_simulated_gaussian_data\n",
+ "from scipy.stats import norm\n",
+ "\n",
+ "import plotnine\n",
+ "from plotnine import *\n",
+ "\n",
+ "plotnine.options.figure_size = (12, 8)"
+ ],
+ "id": "b5f2d07ce70bb24b",
+ "outputs": [],
+ "execution_count": 1
+ },
+ {
+ "metadata": {},
+ "cell_type": "markdown",
+ "source": "# Data",
+ "id": "bd7bba77a5e0fa2f"
+ },
+ {
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-09-26T09:33:07.067189Z",
+ "start_time": "2024-09-26T09:33:07.019505Z"
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "train, test = load_simulated_gaussian_data()\n",
+ "\n",
+ "X_train, y_train = train.filter(regex=\"x\"), train[\"y\"].values\n",
+ "X_test, y_test = test.filter(regex=\"x\"), test[\"y\"].values\n",
+ "\n",
+ "dtrain = lgb.Dataset(X_train, label=y_train)"
+ ],
+ "id": "1062b4b851a12bc9",
+ "outputs": [],
+ "execution_count": 2
+ },
+ {
+ "metadata": {},
+ "cell_type": "markdown",
+ "source": [
+ "# Get a Trained Model\n",
+ "\n",
+ "As this example is about th uses of a trained model, we wont do any hyper-parameter searching. We will also use a Gaussian distribution as the response function of the loc parameter is the identity function, this will allow us to more easily compare the outputs of a standard parameter prediction to a contributions prediction."
+ ],
+ "id": "170feafe1dccf85c"
+ },
+ {
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-09-26T09:58:18.921694Z",
+ "start_time": "2024-09-26T09:57:36.453028Z"
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "lgblss = LightGBMLSS(\n",
+ " Gaussian()\n",
+ ")\n",
+ "lgblss.train(\n",
+ " params=dict(),\n",
+ " train_set=dtrain\n",
+ ")\n",
+ "\n",
+ "param_dict = {\n",
+ " \"eta\": [\"float\", {\"low\": 1e-5, \"high\": 1, \"log\": True}],\n",
+ " \"max_depth\": [\"int\", {\"low\": 1, \"high\": 10, \"log\": False}],\n",
+ " \"num_leaves\": [\"int\", {\"low\": 255, \"high\": 255, \"log\": False}], # set to constant for this example\n",
+ " \"min_data_in_leaf\": [\"int\", {\"low\": 20, \"high\": 20, \"log\": False}], # set to constant for this example\n",
+ " \"min_gain_to_split\": [\"float\", {\"low\": 1e-8, \"high\": 40, \"log\": False}],\n",
+ " \"min_sum_hessian_in_leaf\": [\"float\", {\"low\": 1e-8, \"high\": 500, \"log\": True}],\n",
+ " \"subsample\": [\"float\", {\"low\": 0.2, \"high\": 1.0, \"log\": False}],\n",
+ " \"feature_fraction\": [\"float\", {\"low\": 0.2, \"high\": 1.0, \"log\": False}],\n",
+ " \"boosting\": [\"categorical\", [\"gbdt\"]],\n",
+ "}\n",
+ "\n",
+ "np.random.seed(123)\n",
+ "opt_param = lgblss.hyper_opt(param_dict,\n",
+ " dtrain,\n",
+ " num_boost_round=100, # Number of boosting iterations.\n",
+ " nfold=5, # Number of cv-folds.\n",
+ " early_stopping_rounds=20, # Number of early-stopping rounds\n",
+ " max_minutes=10, # Time budget in minutes, i.e., stop study after the given number of minutes.\n",
+ " n_trials=30 , # The number of trials. If this argument is set to None, there is no limitation on the number of trials.\n",
+ " silence=True, # Controls the verbosity of the trail, i.e., user can silence the outputs of the trail.\n",
+ " seed=123, # Seed used to generate cv-folds.\n",
+ " hp_seed=123 # Seed for random number generator used in the Bayesian hyperparameter search.\n",
+ " )\n",
+ "\n",
+ "np.random.seed(123)\n",
+ "\n",
+ "opt_params = opt_param.copy()\n",
+ "n_rounds = opt_params[\"opt_rounds\"]\n",
+ "del opt_params[\"opt_rounds\"]\n",
+ "\n",
+ "# Train Model with optimized hyperparameters\n",
+ "lgblss.train(opt_params,\n",
+ " dtrain,\n",
+ " num_boost_round=n_rounds\n",
+ " )\n"
+ ],
+ "id": "f45c868160f1f08b",
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ " 0%| | 0/30 [00:00, ?it/s]"
+ ],
+ "application/vnd.jupyter.widget-view+json": {
+ "version_major": 2,
+ "version_minor": 0,
+ "model_id": "fe4d614b202c4516931ba7cd69d2c733"
+ }
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "Hyper-Parameter Optimization successfully finished.\n",
+ " Number of finished trials: 30\n",
+ " Best trial:\n",
+ " Value: 2.0839056900194977\n",
+ " Params: \n",
+ " eta: 0.042322345196562056\n",
+ " max_depth: 3\n",
+ " num_leaves: 255\n",
+ " min_data_in_leaf: 20\n",
+ " min_gain_to_split: 10.495083287505906\n",
+ " min_sum_hessian_in_leaf: 4.025662198099785e-06\n",
+ " subsample: 0.41879883505881144\n",
+ " feature_fraction: 0.7628021535153005\n",
+ " boosting: gbdt\n",
+ " opt_rounds: 72\n"
+ ]
+ }
+ ],
+ "execution_count": 25
+ },
+ {
+ "metadata": {},
+ "cell_type": "markdown",
+ "source": "# Get parameter predictions and parameter contribution predictions",
+ "id": "3c8358d79ec85438"
+ },
+ {
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-09-26T09:58:28.354001Z",
+ "start_time": "2024-09-26T09:58:28.322325Z"
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "pred_params = lgblss.predict(X_test, pred_type=\"parameters\")\n",
+ "pred_param_contributions = lgblss.predict(X_test, pred_type=\"contributions\")\n"
+ ],
+ "id": "c0bab6ad5807cd8d",
+ "outputs": [],
+ "execution_count": 26
+ },
+ {
+ "metadata": {},
+ "cell_type": "markdown",
+ "source": "As location parameter uses identity function the sum of these predictions should equal the value in pred_params. However this is not true for the scale params, as response functions have not been applied when contributions are created.",
+ "id": "c39c97f68b1ba929"
+ },
+ {
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-09-26T09:58:30.937621Z",
+ "start_time": "2024-09-26T09:58:30.930606Z"
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "sum_of_contributions = pred_param_contributions.groupby(level=\"distribution_arg\", axis=1).sum()\n",
+ "location_values_are_all_close = np.allclose(pred_params[\"loc\"], sum_of_contributions[\"loc\"])\n",
+ "scale_values_are_all_close = np.allclose(pred_params[\"scale\"], sum_of_contributions[\"scale\"])\n",
+ "\n",
+ "\n",
+ "print(f\"{location_values_are_all_close=}\")\n",
+ "print(f\"{scale_values_are_all_close=}\")\n"
+ ],
+ "id": "87e216c88a4ff947",
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "location_values_are_all_close=True\n",
+ "scale_values_are_all_close=False\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "C:\\Users\\SimonRobertPike\\AppData\\Local\\Temp\\ipykernel_47316\\1135199838.py:1: FutureWarning: DataFrame.groupby with axis=1 is deprecated. Do `frame.T.groupby(...)` without axis instead.\n"
+ ]
+ }
+ ],
+ "execution_count": 27
+ },
+ {
+ "metadata": {},
+ "cell_type": "markdown",
+ "source": "### Show contributions for each feature for location parameter",
+ "id": "90e4b7d2544afd58"
+ },
+ {
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-09-26T09:58:40.765411Z",
+ "start_time": "2024-09-26T09:58:40.738780Z"
+ }
+ },
+ "cell_type": "code",
+ "source": "pred_param_contributions.xs(\"loc\", axis=1, level=\"distribution_arg\")",
+ "id": "1b6b7013a1f7e957",
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "FeatureContribution x_true x_noise1 x_noise2 x_noise3 x_noise4 x_noise5 \\\n",
+ "0 0.0 0.0 0.0 0.0 0.0 0.0 \n",
+ "1 0.0 0.0 0.0 0.0 0.0 0.0 \n",
+ "2 0.0 0.0 0.0 0.0 0.0 0.0 \n",
+ "3 0.0 0.0 0.0 0.0 0.0 0.0 \n",
+ "4 0.0 0.0 0.0 0.0 0.0 0.0 \n",
+ "... ... ... ... ... ... ... \n",
+ "2995 0.0 0.0 0.0 0.0 0.0 0.0 \n",
+ "2996 0.0 0.0 0.0 0.0 0.0 0.0 \n",
+ "2997 0.0 0.0 0.0 0.0 0.0 0.0 \n",
+ "2998 0.0 0.0 0.0 0.0 0.0 0.0 \n",
+ "2999 0.0 0.0 0.0 0.0 0.0 0.0 \n",
+ "\n",
+ "FeatureContribution x_noise6 x_noise7 x_noise8 x_noise9 x_noise10 \\\n",
+ "0 0.0 0.0 0.0 0.0 0.0 \n",
+ "1 0.0 0.0 0.0 0.0 0.0 \n",
+ "2 0.0 0.0 0.0 0.0 0.0 \n",
+ "3 0.0 0.0 0.0 0.0 0.0 \n",
+ "4 0.0 0.0 0.0 0.0 0.0 \n",
+ "... ... ... ... ... ... \n",
+ "2995 0.0 0.0 0.0 0.0 0.0 \n",
+ "2996 0.0 0.0 0.0 0.0 0.0 \n",
+ "2997 0.0 0.0 0.0 0.0 0.0 \n",
+ "2998 0.0 0.0 0.0 0.0 0.0 \n",
+ "2999 0.0 0.0 0.0 0.0 0.0 \n",
+ "\n",
+ "FeatureContribution Const \n",
+ "0 9.979979 \n",
+ "1 9.979979 \n",
+ "2 9.979979 \n",
+ "3 9.979979 \n",
+ "4 9.979979 \n",
+ "... ... \n",
+ "2995 9.979979 \n",
+ "2996 9.979979 \n",
+ "2997 9.979979 \n",
+ "2998 9.979979 \n",
+ "2999 9.979979 \n",
+ "\n",
+ "[3000 rows x 12 columns]"
+ ],
+ "text/html": [
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " FeatureContribution | \n",
+ " x_true | \n",
+ " x_noise1 | \n",
+ " x_noise2 | \n",
+ " x_noise3 | \n",
+ " x_noise4 | \n",
+ " x_noise5 | \n",
+ " x_noise6 | \n",
+ " x_noise7 | \n",
+ " x_noise8 | \n",
+ " x_noise9 | \n",
+ " x_noise10 | \n",
+ " Const | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 9.979979 | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 9.979979 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 9.979979 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 9.979979 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 9.979979 | \n",
+ "
\n",
+ " \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ "
\n",
+ " \n",
+ " 2995 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 9.979979 | \n",
+ "
\n",
+ " \n",
+ " 2996 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 9.979979 | \n",
+ "
\n",
+ " \n",
+ " 2997 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 9.979979 | \n",
+ "
\n",
+ " \n",
+ " 2998 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 9.979979 | \n",
+ "
\n",
+ " \n",
+ " 2999 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 9.979979 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
3000 rows × 12 columns
\n",
+ "
"
+ ]
+ },
+ "execution_count": 28,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "execution_count": 28
+ },
+ {
+ "metadata": {},
+ "cell_type": "markdown",
+ "source": "### Show contributions for each feature for scale parameter",
+ "id": "eaf2ad3ecc736152"
+ },
+ {
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-09-26T09:58:55.637299Z",
+ "start_time": "2024-09-26T09:58:55.621654Z"
+ }
+ },
+ "cell_type": "code",
+ "source": "pred_param_contributions.xs(\"scale\", axis=1, level=\"distribution_arg\")\n",
+ "id": "c5453f7e5a378096",
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "FeatureContribution x_true x_noise1 x_noise2 x_noise3 x_noise4 \\\n",
+ "0 0.410550 0.002106 0.0 0.0 0.000034 \n",
+ "1 0.411261 0.000684 0.0 0.0 -0.000340 \n",
+ "2 -0.597674 0.002106 0.0 0.0 -0.000340 \n",
+ "3 0.848748 0.002832 0.0 0.0 0.000034 \n",
+ "4 0.414522 0.001565 0.0 0.0 0.000866 \n",
+ "... ... ... ... ... ... \n",
+ "2995 0.411230 0.002832 0.0 0.0 -0.000340 \n",
+ "2996 0.380649 0.002106 0.0 0.0 -0.000340 \n",
+ "2997 -0.597582 0.001647 0.0 0.0 0.000034 \n",
+ "2998 -0.607346 -0.001425 0.0 0.0 -0.001143 \n",
+ "2999 0.410550 0.002106 0.0 0.0 0.000034 \n",
+ "\n",
+ "FeatureContribution x_noise5 x_noise6 x_noise7 x_noise8 x_noise9 \\\n",
+ "0 0.000197 0.004102 -0.000127 0.0 -0.000608 \n",
+ "1 0.000197 0.004813 -0.000127 0.0 -0.000608 \n",
+ "2 0.000197 0.004102 -0.000127 0.0 -0.000608 \n",
+ "3 0.000197 0.001399 -0.000127 0.0 -0.000608 \n",
+ "4 0.000123 0.002716 -0.004167 0.0 0.053916 \n",
+ "... ... ... ... ... ... \n",
+ "2995 0.000197 0.002135 -0.000127 0.0 -0.000608 \n",
+ "2996 0.000197 0.004400 -0.000127 0.0 -0.000608 \n",
+ "2997 0.000197 -0.004547 -0.000127 0.0 -0.000700 \n",
+ "2998 0.000887 0.002013 0.003194 0.0 -0.000029 \n",
+ "2999 0.000197 0.004102 -0.000127 0.0 -0.000608 \n",
+ "\n",
+ "FeatureContribution x_noise10 Const \n",
+ "0 -0.000503 0.653589 \n",
+ "1 -0.000129 0.653589 \n",
+ "2 -0.000129 0.653589 \n",
+ "3 0.001529 0.653589 \n",
+ "4 0.001894 0.653589 \n",
+ "... ... ... \n",
+ "2995 0.000432 0.653589 \n",
+ "2996 0.002376 0.653589 \n",
+ "2997 -0.000893 0.653589 \n",
+ "2998 -0.004395 0.653589 \n",
+ "2999 -0.000503 0.653589 \n",
+ "\n",
+ "[3000 rows x 12 columns]"
+ ],
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " FeatureContribution | \n",
+ " x_true | \n",
+ " x_noise1 | \n",
+ " x_noise2 | \n",
+ " x_noise3 | \n",
+ " x_noise4 | \n",
+ " x_noise5 | \n",
+ " x_noise6 | \n",
+ " x_noise7 | \n",
+ " x_noise8 | \n",
+ " x_noise9 | \n",
+ " x_noise10 | \n",
+ " Const | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " 0.410550 | \n",
+ " 0.002106 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.000034 | \n",
+ " 0.000197 | \n",
+ " 0.004102 | \n",
+ " -0.000127 | \n",
+ " 0.0 | \n",
+ " -0.000608 | \n",
+ " -0.000503 | \n",
+ " 0.653589 | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " 0.411261 | \n",
+ " 0.000684 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " -0.000340 | \n",
+ " 0.000197 | \n",
+ " 0.004813 | \n",
+ " -0.000127 | \n",
+ " 0.0 | \n",
+ " -0.000608 | \n",
+ " -0.000129 | \n",
+ " 0.653589 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " -0.597674 | \n",
+ " 0.002106 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " -0.000340 | \n",
+ " 0.000197 | \n",
+ " 0.004102 | \n",
+ " -0.000127 | \n",
+ " 0.0 | \n",
+ " -0.000608 | \n",
+ " -0.000129 | \n",
+ " 0.653589 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 0.848748 | \n",
+ " 0.002832 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.000034 | \n",
+ " 0.000197 | \n",
+ " 0.001399 | \n",
+ " -0.000127 | \n",
+ " 0.0 | \n",
+ " -0.000608 | \n",
+ " 0.001529 | \n",
+ " 0.653589 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 0.414522 | \n",
+ " 0.001565 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.000866 | \n",
+ " 0.000123 | \n",
+ " 0.002716 | \n",
+ " -0.004167 | \n",
+ " 0.0 | \n",
+ " 0.053916 | \n",
+ " 0.001894 | \n",
+ " 0.653589 | \n",
+ "
\n",
+ " \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ "
\n",
+ " \n",
+ " 2995 | \n",
+ " 0.411230 | \n",
+ " 0.002832 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " -0.000340 | \n",
+ " 0.000197 | \n",
+ " 0.002135 | \n",
+ " -0.000127 | \n",
+ " 0.0 | \n",
+ " -0.000608 | \n",
+ " 0.000432 | \n",
+ " 0.653589 | \n",
+ "
\n",
+ " \n",
+ " 2996 | \n",
+ " 0.380649 | \n",
+ " 0.002106 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " -0.000340 | \n",
+ " 0.000197 | \n",
+ " 0.004400 | \n",
+ " -0.000127 | \n",
+ " 0.0 | \n",
+ " -0.000608 | \n",
+ " 0.002376 | \n",
+ " 0.653589 | \n",
+ "
\n",
+ " \n",
+ " 2997 | \n",
+ " -0.597582 | \n",
+ " 0.001647 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.000034 | \n",
+ " 0.000197 | \n",
+ " -0.004547 | \n",
+ " -0.000127 | \n",
+ " 0.0 | \n",
+ " -0.000700 | \n",
+ " -0.000893 | \n",
+ " 0.653589 | \n",
+ "
\n",
+ " \n",
+ " 2998 | \n",
+ " -0.607346 | \n",
+ " -0.001425 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " -0.001143 | \n",
+ " 0.000887 | \n",
+ " 0.002013 | \n",
+ " 0.003194 | \n",
+ " 0.0 | \n",
+ " -0.000029 | \n",
+ " -0.004395 | \n",
+ " 0.653589 | \n",
+ "
\n",
+ " \n",
+ " 2999 | \n",
+ " 0.410550 | \n",
+ " 0.002106 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.000034 | \n",
+ " 0.000197 | \n",
+ " 0.004102 | \n",
+ " -0.000127 | \n",
+ " 0.0 | \n",
+ " -0.000608 | \n",
+ " -0.000503 | \n",
+ " 0.653589 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
3000 rows × 12 columns
\n",
+ "
"
+ ]
+ },
+ "execution_count": 29,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "execution_count": 29
+ },
+ {
+ "metadata": {},
+ "cell_type": "markdown",
+ "source": "# Show Mean Feature Impact for Data set",
+ "id": "394e64d247168fa0"
+ },
+ {
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-09-26T10:02:54.858851Z",
+ "start_time": "2024-09-26T10:02:54.838744Z"
+ }
+ },
+ "cell_type": "code",
+ "source": [
+ "sum_of_contributions_column = \"SumOfContributions\"\n",
+ "mean_parameter_contribution = pred_param_contributions.abs().mean().unstack(\"distribution_arg\")\n",
+ "mean_parameter_contribution[sum_of_contributions_column] = mean_parameter_contribution.sum(1)\n",
+ "\n",
+ "mean_parameter_contribution.sort_values(sum_of_contributions_column, ascending=False).drop(columns=sum_of_contributions_column)\n"
+ ],
+ "id": "54d4970cf1957735",
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "distribution_arg loc scale\n",
+ "FeatureContribution \n",
+ "Const 9.97998 0.653589\n",
+ "x_true 0.00000 0.591846\n",
+ "x_noise6 0.00000 0.004865\n",
+ "x_noise7 0.00000 0.004410\n",
+ "x_noise1 0.00000 0.003991\n",
+ "x_noise10 0.00000 0.002688\n",
+ "x_noise9 0.00000 0.002582\n",
+ "x_noise4 0.00000 0.001666\n",
+ "x_noise5 0.00000 0.000585\n",
+ "x_noise2 0.00000 0.000000\n",
+ "x_noise3 0.00000 0.000000\n",
+ "x_noise8 0.00000 0.000000"
+ ],
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " distribution_arg | \n",
+ " loc | \n",
+ " scale | \n",
+ "
\n",
+ " \n",
+ " FeatureContribution | \n",
+ " | \n",
+ " | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " Const | \n",
+ " 9.97998 | \n",
+ " 0.653589 | \n",
+ "
\n",
+ " \n",
+ " x_true | \n",
+ " 0.00000 | \n",
+ " 0.591846 | \n",
+ "
\n",
+ " \n",
+ " x_noise6 | \n",
+ " 0.00000 | \n",
+ " 0.004865 | \n",
+ "
\n",
+ " \n",
+ " x_noise7 | \n",
+ " 0.00000 | \n",
+ " 0.004410 | \n",
+ "
\n",
+ " \n",
+ " x_noise1 | \n",
+ " 0.00000 | \n",
+ " 0.003991 | \n",
+ "
\n",
+ " \n",
+ " x_noise10 | \n",
+ " 0.00000 | \n",
+ " 0.002688 | \n",
+ "
\n",
+ " \n",
+ " x_noise9 | \n",
+ " 0.00000 | \n",
+ " 0.002582 | \n",
+ "
\n",
+ " \n",
+ " x_noise4 | \n",
+ " 0.00000 | \n",
+ " 0.001666 | \n",
+ "
\n",
+ " \n",
+ " x_noise5 | \n",
+ " 0.00000 | \n",
+ " 0.000585 | \n",
+ "
\n",
+ " \n",
+ " x_noise2 | \n",
+ " 0.00000 | \n",
+ " 0.000000 | \n",
+ "
\n",
+ " \n",
+ " x_noise3 | \n",
+ " 0.00000 | \n",
+ " 0.000000 | \n",
+ "
\n",
+ " \n",
+ " x_noise8 | \n",
+ " 0.00000 | \n",
+ " 0.000000 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ]
+ },
+ "execution_count": 36,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "execution_count": 36
+ },
+ {
+ "metadata": {},
+ "cell_type": "markdown",
+ "source": "# Get correlation between contributions for the scale parameter ",
+ "id": "f7c73f303f04d4ff"
+ },
+ {
+ "metadata": {},
+ "cell_type": "markdown",
+ "source": "",
+ "id": "8d5dc9e448d5c322"
+ },
+ {
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-09-26T10:07:21.265976Z",
+ "start_time": "2024-09-26T10:07:21.249801Z"
+ }
+ },
+ "cell_type": "code",
+ "source": "pred_param_contributions.xs(\"scale\", axis=1, level=\"distribution_arg\").corr().dropna(how=\"all\").dropna(axis=1,how=\"all\")\n",
+ "id": "f331d8603042908",
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "FeatureContribution x_true x_noise1 x_noise4 x_noise5 x_noise6 \\\n",
+ "FeatureContribution \n",
+ "x_true 1.000000 0.007743 -0.001227 -0.047812 -0.021568 \n",
+ "x_noise1 0.007743 1.000000 -0.006627 -0.022206 0.136683 \n",
+ "x_noise4 -0.001227 -0.006627 1.000000 -0.015965 -0.030661 \n",
+ "x_noise5 -0.047812 -0.022206 -0.015965 1.000000 0.006217 \n",
+ "x_noise6 -0.021568 0.136683 -0.030661 0.006217 1.000000 \n",
+ "x_noise7 0.015344 0.002144 0.474505 0.021826 0.029863 \n",
+ "x_noise9 0.024361 -0.006972 0.013089 0.016001 0.009558 \n",
+ "x_noise10 0.035479 0.012114 -0.035713 -0.001433 0.028450 \n",
+ "\n",
+ "FeatureContribution x_noise7 x_noise9 x_noise10 \n",
+ "FeatureContribution \n",
+ "x_true 0.015344 0.024361 0.035479 \n",
+ "x_noise1 0.002144 -0.006972 0.012114 \n",
+ "x_noise4 0.474505 0.013089 -0.035713 \n",
+ "x_noise5 0.021826 0.016001 -0.001433 \n",
+ "x_noise6 0.029863 0.009558 0.028450 \n",
+ "x_noise7 1.000000 0.023556 -0.015318 \n",
+ "x_noise9 0.023556 1.000000 -0.030408 \n",
+ "x_noise10 -0.015318 -0.030408 1.000000 "
+ ],
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " FeatureContribution | \n",
+ " x_true | \n",
+ " x_noise1 | \n",
+ " x_noise4 | \n",
+ " x_noise5 | \n",
+ " x_noise6 | \n",
+ " x_noise7 | \n",
+ " x_noise9 | \n",
+ " x_noise10 | \n",
+ "
\n",
+ " \n",
+ " FeatureContribution | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " x_true | \n",
+ " 1.000000 | \n",
+ " 0.007743 | \n",
+ " -0.001227 | \n",
+ " -0.047812 | \n",
+ " -0.021568 | \n",
+ " 0.015344 | \n",
+ " 0.024361 | \n",
+ " 0.035479 | \n",
+ "
\n",
+ " \n",
+ " x_noise1 | \n",
+ " 0.007743 | \n",
+ " 1.000000 | \n",
+ " -0.006627 | \n",
+ " -0.022206 | \n",
+ " 0.136683 | \n",
+ " 0.002144 | \n",
+ " -0.006972 | \n",
+ " 0.012114 | \n",
+ "
\n",
+ " \n",
+ " x_noise4 | \n",
+ " -0.001227 | \n",
+ " -0.006627 | \n",
+ " 1.000000 | \n",
+ " -0.015965 | \n",
+ " -0.030661 | \n",
+ " 0.474505 | \n",
+ " 0.013089 | \n",
+ " -0.035713 | \n",
+ "
\n",
+ " \n",
+ " x_noise5 | \n",
+ " -0.047812 | \n",
+ " -0.022206 | \n",
+ " -0.015965 | \n",
+ " 1.000000 | \n",
+ " 0.006217 | \n",
+ " 0.021826 | \n",
+ " 0.016001 | \n",
+ " -0.001433 | \n",
+ "
\n",
+ " \n",
+ " x_noise6 | \n",
+ " -0.021568 | \n",
+ " 0.136683 | \n",
+ " -0.030661 | \n",
+ " 0.006217 | \n",
+ " 1.000000 | \n",
+ " 0.029863 | \n",
+ " 0.009558 | \n",
+ " 0.028450 | \n",
+ "
\n",
+ " \n",
+ " x_noise7 | \n",
+ " 0.015344 | \n",
+ " 0.002144 | \n",
+ " 0.474505 | \n",
+ " 0.021826 | \n",
+ " 0.029863 | \n",
+ " 1.000000 | \n",
+ " 0.023556 | \n",
+ " -0.015318 | \n",
+ "
\n",
+ " \n",
+ " x_noise9 | \n",
+ " 0.024361 | \n",
+ " -0.006972 | \n",
+ " 0.013089 | \n",
+ " 0.016001 | \n",
+ " 0.009558 | \n",
+ " 0.023556 | \n",
+ " 1.000000 | \n",
+ " -0.030408 | \n",
+ "
\n",
+ " \n",
+ " x_noise10 | \n",
+ " 0.035479 | \n",
+ " 0.012114 | \n",
+ " -0.035713 | \n",
+ " -0.001433 | \n",
+ " 0.028450 | \n",
+ " -0.015318 | \n",
+ " -0.030408 | \n",
+ " 1.000000 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ]
+ },
+ "execution_count": 44,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "execution_count": 44
+ },
+ {
+ "metadata": {},
+ "cell_type": "code",
+ "outputs": [],
+ "execution_count": null,
+ "source": "",
+ "id": "ae0a0247ad688b42"
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 2
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython2",
+ "version": "2.7.6"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}