diff --git a/docs/examples/Predict_Contributions.ipynb b/docs/examples/Predict_Contributions.ipynb new file mode 100644 index 0000000..54cac6e --- /dev/null +++ b/docs/examples/Predict_Contributions.ipynb @@ -0,0 +1,1153 @@ +{ + "cells": [ + { + "metadata": {}, + "cell_type": "markdown", + "source": [ + "# Show how to extract prediction contributions for each distribution parameter\n", + "\n", + "This example shows how to get the contribution of every feature for each distributional parameter for a given data set. This allows similar inferences as one might get from SHAP but comes from lightGBM's internal workings. We can use output for example to get for a given prediction which features are causing the most impact to a given distributional parameter.\n", + "\n", + "These contributions are created before the response function is applied. As such in the case of the identity function, for a given row of data the sum of the contributions should equal the parameter value.\n" + ], + "id": "bf95ab4267d5a34" + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": [ + "# Imports\n", + "\n", + "First, we import necessary functions. " + ], + "id": "bbea43740b87eb" + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-09-26T09:33:07.019505Z", + "start_time": "2024-09-26T09:33:01.235342Z" + } + }, + "cell_type": "code", + "source": [ + "import numpy as np\n", + "\n", + "from lightgbmlss.model import *\n", + "from lightgbmlss.distributions.Gaussian import *\n", + "from lightgbmlss.datasets.data_loader import load_simulated_gaussian_data\n", + "from scipy.stats import norm\n", + "\n", + "import plotnine\n", + "from plotnine import *\n", + "\n", + "plotnine.options.figure_size = (12, 8)" + ], + "id": "b5f2d07ce70bb24b", + "outputs": [], + "execution_count": 1 + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "# Data", + "id": "bd7bba77a5e0fa2f" + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-09-26T09:33:07.067189Z", + "start_time": "2024-09-26T09:33:07.019505Z" + } + }, + "cell_type": "code", + "source": [ + "train, test = load_simulated_gaussian_data()\n", + "\n", + "X_train, y_train = train.filter(regex=\"x\"), train[\"y\"].values\n", + "X_test, y_test = test.filter(regex=\"x\"), test[\"y\"].values\n", + "\n", + "dtrain = lgb.Dataset(X_train, label=y_train)" + ], + "id": "1062b4b851a12bc9", + "outputs": [], + "execution_count": 2 + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": [ + "# Get a Trained Model\n", + "\n", + "As this example is about th uses of a trained model, we wont do any hyper-parameter searching. We will also use a Gaussian distribution as the response function of the loc parameter is the identity function, this will allow us to more easily compare the outputs of a standard parameter prediction to a contributions prediction." + ], + "id": "170feafe1dccf85c" + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-09-26T09:58:18.921694Z", + "start_time": "2024-09-26T09:57:36.453028Z" + } + }, + "cell_type": "code", + "source": [ + "lgblss = LightGBMLSS(\n", + " Gaussian()\n", + ")\n", + "lgblss.train(\n", + " params=dict(),\n", + " train_set=dtrain\n", + ")\n", + "\n", + "param_dict = {\n", + " \"eta\": [\"float\", {\"low\": 1e-5, \"high\": 1, \"log\": True}],\n", + " \"max_depth\": [\"int\", {\"low\": 1, \"high\": 10, \"log\": False}],\n", + " \"num_leaves\": [\"int\", {\"low\": 255, \"high\": 255, \"log\": False}], # set to constant for this example\n", + " \"min_data_in_leaf\": [\"int\", {\"low\": 20, \"high\": 20, \"log\": False}], # set to constant for this example\n", + " \"min_gain_to_split\": [\"float\", {\"low\": 1e-8, \"high\": 40, \"log\": False}],\n", + " \"min_sum_hessian_in_leaf\": [\"float\", {\"low\": 1e-8, \"high\": 500, \"log\": True}],\n", + " \"subsample\": [\"float\", {\"low\": 0.2, \"high\": 1.0, \"log\": False}],\n", + " \"feature_fraction\": [\"float\", {\"low\": 0.2, \"high\": 1.0, \"log\": False}],\n", + " \"boosting\": [\"categorical\", [\"gbdt\"]],\n", + "}\n", + "\n", + "np.random.seed(123)\n", + "opt_param = lgblss.hyper_opt(param_dict,\n", + " dtrain,\n", + " num_boost_round=100, # Number of boosting iterations.\n", + " nfold=5, # Number of cv-folds.\n", + " early_stopping_rounds=20, # Number of early-stopping rounds\n", + " max_minutes=10, # Time budget in minutes, i.e., stop study after the given number of minutes.\n", + " n_trials=30 , # The number of trials. If this argument is set to None, there is no limitation on the number of trials.\n", + " silence=True, # Controls the verbosity of the trail, i.e., user can silence the outputs of the trail.\n", + " seed=123, # Seed used to generate cv-folds.\n", + " hp_seed=123 # Seed for random number generator used in the Bayesian hyperparameter search.\n", + " )\n", + "\n", + "np.random.seed(123)\n", + "\n", + "opt_params = opt_param.copy()\n", + "n_rounds = opt_params[\"opt_rounds\"]\n", + "del opt_params[\"opt_rounds\"]\n", + "\n", + "# Train Model with optimized hyperparameters\n", + "lgblss.train(opt_params,\n", + " dtrain,\n", + " num_boost_round=n_rounds\n", + " )\n" + ], + "id": "f45c868160f1f08b", + "outputs": [ + { + "data": { + "text/plain": [ + " 0%| | 0/30 [00:00\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
FeatureContributionx_truex_noise1x_noise2x_noise3x_noise4x_noise5x_noise6x_noise7x_noise8x_noise9x_noise10Const
00.00.00.00.00.00.00.00.00.00.00.09.979979
10.00.00.00.00.00.00.00.00.00.00.09.979979
20.00.00.00.00.00.00.00.00.00.00.09.979979
30.00.00.00.00.00.00.00.00.00.00.09.979979
40.00.00.00.00.00.00.00.00.00.00.09.979979
.......................................
29950.00.00.00.00.00.00.00.00.00.00.09.979979
29960.00.00.00.00.00.00.00.00.00.00.09.979979
29970.00.00.00.00.00.00.00.00.00.00.09.979979
29980.00.00.00.00.00.00.00.00.00.00.09.979979
29990.00.00.00.00.00.00.00.00.00.00.09.979979
\n", + "

3000 rows × 12 columns

\n", + "" + ] + }, + "execution_count": 28, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 28 + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "### Show contributions for each feature for scale parameter", + "id": "eaf2ad3ecc736152" + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-09-26T09:58:55.637299Z", + "start_time": "2024-09-26T09:58:55.621654Z" + } + }, + "cell_type": "code", + "source": "pred_param_contributions.xs(\"scale\", axis=1, level=\"distribution_arg\")\n", + "id": "c5453f7e5a378096", + "outputs": [ + { + "data": { + "text/plain": [ + "FeatureContribution x_true x_noise1 x_noise2 x_noise3 x_noise4 \\\n", + "0 0.410550 0.002106 0.0 0.0 0.000034 \n", + "1 0.411261 0.000684 0.0 0.0 -0.000340 \n", + "2 -0.597674 0.002106 0.0 0.0 -0.000340 \n", + "3 0.848748 0.002832 0.0 0.0 0.000034 \n", + "4 0.414522 0.001565 0.0 0.0 0.000866 \n", + "... ... ... ... ... ... \n", + "2995 0.411230 0.002832 0.0 0.0 -0.000340 \n", + "2996 0.380649 0.002106 0.0 0.0 -0.000340 \n", + "2997 -0.597582 0.001647 0.0 0.0 0.000034 \n", + "2998 -0.607346 -0.001425 0.0 0.0 -0.001143 \n", + "2999 0.410550 0.002106 0.0 0.0 0.000034 \n", + "\n", + "FeatureContribution x_noise5 x_noise6 x_noise7 x_noise8 x_noise9 \\\n", + "0 0.000197 0.004102 -0.000127 0.0 -0.000608 \n", + "1 0.000197 0.004813 -0.000127 0.0 -0.000608 \n", + "2 0.000197 0.004102 -0.000127 0.0 -0.000608 \n", + "3 0.000197 0.001399 -0.000127 0.0 -0.000608 \n", + "4 0.000123 0.002716 -0.004167 0.0 0.053916 \n", + "... ... ... ... ... ... \n", + "2995 0.000197 0.002135 -0.000127 0.0 -0.000608 \n", + "2996 0.000197 0.004400 -0.000127 0.0 -0.000608 \n", + "2997 0.000197 -0.004547 -0.000127 0.0 -0.000700 \n", + "2998 0.000887 0.002013 0.003194 0.0 -0.000029 \n", + "2999 0.000197 0.004102 -0.000127 0.0 -0.000608 \n", + "\n", + "FeatureContribution x_noise10 Const \n", + "0 -0.000503 0.653589 \n", + "1 -0.000129 0.653589 \n", + "2 -0.000129 0.653589 \n", + "3 0.001529 0.653589 \n", + "4 0.001894 0.653589 \n", + "... ... ... \n", + "2995 0.000432 0.653589 \n", + "2996 0.002376 0.653589 \n", + "2997 -0.000893 0.653589 \n", + "2998 -0.004395 0.653589 \n", + "2999 -0.000503 0.653589 \n", + "\n", + "[3000 rows x 12 columns]" + ], + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
FeatureContributionx_truex_noise1x_noise2x_noise3x_noise4x_noise5x_noise6x_noise7x_noise8x_noise9x_noise10Const
00.4105500.0021060.00.00.0000340.0001970.004102-0.0001270.0-0.000608-0.0005030.653589
10.4112610.0006840.00.0-0.0003400.0001970.004813-0.0001270.0-0.000608-0.0001290.653589
2-0.5976740.0021060.00.0-0.0003400.0001970.004102-0.0001270.0-0.000608-0.0001290.653589
30.8487480.0028320.00.00.0000340.0001970.001399-0.0001270.0-0.0006080.0015290.653589
40.4145220.0015650.00.00.0008660.0001230.002716-0.0041670.00.0539160.0018940.653589
.......................................
29950.4112300.0028320.00.0-0.0003400.0001970.002135-0.0001270.0-0.0006080.0004320.653589
29960.3806490.0021060.00.0-0.0003400.0001970.004400-0.0001270.0-0.0006080.0023760.653589
2997-0.5975820.0016470.00.00.0000340.000197-0.004547-0.0001270.0-0.000700-0.0008930.653589
2998-0.607346-0.0014250.00.0-0.0011430.0008870.0020130.0031940.0-0.000029-0.0043950.653589
29990.4105500.0021060.00.00.0000340.0001970.004102-0.0001270.0-0.000608-0.0005030.653589
\n", + "

3000 rows × 12 columns

\n", + "
" + ] + }, + "execution_count": 29, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 29 + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "# Show Mean Feature Impact for Data set", + "id": "394e64d247168fa0" + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-09-26T10:02:54.858851Z", + "start_time": "2024-09-26T10:02:54.838744Z" + } + }, + "cell_type": "code", + "source": [ + "sum_of_contributions_column = \"SumOfContributions\"\n", + "mean_parameter_contribution = pred_param_contributions.abs().mean().unstack(\"distribution_arg\")\n", + "mean_parameter_contribution[sum_of_contributions_column] = mean_parameter_contribution.sum(1)\n", + "\n", + "mean_parameter_contribution.sort_values(sum_of_contributions_column, ascending=False).drop(columns=sum_of_contributions_column)\n" + ], + "id": "54d4970cf1957735", + "outputs": [ + { + "data": { + "text/plain": [ + "distribution_arg loc scale\n", + "FeatureContribution \n", + "Const 9.97998 0.653589\n", + "x_true 0.00000 0.591846\n", + "x_noise6 0.00000 0.004865\n", + "x_noise7 0.00000 0.004410\n", + "x_noise1 0.00000 0.003991\n", + "x_noise10 0.00000 0.002688\n", + "x_noise9 0.00000 0.002582\n", + "x_noise4 0.00000 0.001666\n", + "x_noise5 0.00000 0.000585\n", + "x_noise2 0.00000 0.000000\n", + "x_noise3 0.00000 0.000000\n", + "x_noise8 0.00000 0.000000" + ], + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
distribution_arglocscale
FeatureContribution
Const9.979980.653589
x_true0.000000.591846
x_noise60.000000.004865
x_noise70.000000.004410
x_noise10.000000.003991
x_noise100.000000.002688
x_noise90.000000.002582
x_noise40.000000.001666
x_noise50.000000.000585
x_noise20.000000.000000
x_noise30.000000.000000
x_noise80.000000.000000
\n", + "
" + ] + }, + "execution_count": 36, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 36 + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "# Get correlation between contributions for the scale parameter ", + "id": "f7c73f303f04d4ff" + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "", + "id": "8d5dc9e448d5c322" + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-09-26T10:07:21.265976Z", + "start_time": "2024-09-26T10:07:21.249801Z" + } + }, + "cell_type": "code", + "source": "pred_param_contributions.xs(\"scale\", axis=1, level=\"distribution_arg\").corr().dropna(how=\"all\").dropna(axis=1,how=\"all\")\n", + "id": "f331d8603042908", + "outputs": [ + { + "data": { + "text/plain": [ + "FeatureContribution x_true x_noise1 x_noise4 x_noise5 x_noise6 \\\n", + "FeatureContribution \n", + "x_true 1.000000 0.007743 -0.001227 -0.047812 -0.021568 \n", + "x_noise1 0.007743 1.000000 -0.006627 -0.022206 0.136683 \n", + "x_noise4 -0.001227 -0.006627 1.000000 -0.015965 -0.030661 \n", + "x_noise5 -0.047812 -0.022206 -0.015965 1.000000 0.006217 \n", + "x_noise6 -0.021568 0.136683 -0.030661 0.006217 1.000000 \n", + "x_noise7 0.015344 0.002144 0.474505 0.021826 0.029863 \n", + "x_noise9 0.024361 -0.006972 0.013089 0.016001 0.009558 \n", + "x_noise10 0.035479 0.012114 -0.035713 -0.001433 0.028450 \n", + "\n", + "FeatureContribution x_noise7 x_noise9 x_noise10 \n", + "FeatureContribution \n", + "x_true 0.015344 0.024361 0.035479 \n", + "x_noise1 0.002144 -0.006972 0.012114 \n", + "x_noise4 0.474505 0.013089 -0.035713 \n", + "x_noise5 0.021826 0.016001 -0.001433 \n", + "x_noise6 0.029863 0.009558 0.028450 \n", + "x_noise7 1.000000 0.023556 -0.015318 \n", + "x_noise9 0.023556 1.000000 -0.030408 \n", + "x_noise10 -0.015318 -0.030408 1.000000 " + ], + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
FeatureContributionx_truex_noise1x_noise4x_noise5x_noise6x_noise7x_noise9x_noise10
FeatureContribution
x_true1.0000000.007743-0.001227-0.047812-0.0215680.0153440.0243610.035479
x_noise10.0077431.000000-0.006627-0.0222060.1366830.002144-0.0069720.012114
x_noise4-0.001227-0.0066271.000000-0.015965-0.0306610.4745050.013089-0.035713
x_noise5-0.047812-0.022206-0.0159651.0000000.0062170.0218260.016001-0.001433
x_noise6-0.0215680.136683-0.0306610.0062171.0000000.0298630.0095580.028450
x_noise70.0153440.0021440.4745050.0218260.0298631.0000000.023556-0.015318
x_noise90.024361-0.0069720.0130890.0160010.0095580.0235561.000000-0.030408
x_noise100.0354790.012114-0.035713-0.0014330.028450-0.015318-0.0304081.000000
\n", + "
" + ] + }, + "execution_count": 44, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 44 + }, + { + "metadata": {}, + "cell_type": "code", + "outputs": [], + "execution_count": null, + "source": "", + "id": "ae0a0247ad688b42" + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}