diff --git a/docs/_static/tutorials/performance_calculation/multiclass/business_value.svg b/docs/_static/tutorials/performance_calculation/multiclass/business_value.svg new file mode 100644 index 000000000..f4ce3569b --- /dev/null +++ b/docs/_static/tutorials/performance_calculation/multiclass/business_value.svg @@ -0,0 +1 @@ +May 2020Jul 2020Sep 2020Nov 2020Jan 20211.31.41.51.61.71.81.92MetricAlertThresholdRealized performanceTimeBusiness ValueRealized Business ValueReferenceAnalysis \ No newline at end of file diff --git a/docs/_static/tutorials/performance_estimation/multiclass/business_value.svg b/docs/_static/tutorials/performance_estimation/multiclass/business_value.svg new file mode 100644 index 000000000..6a5283e1a --- /dev/null +++ b/docs/_static/tutorials/performance_estimation/multiclass/business_value.svg @@ -0,0 +1 @@ +May 2020Jul 2020Sep 2020Nov 2020Jan 20211.61.71.81.92MetricAlertThresholdConfidence bandEstimated performance (CBPE)TimeBusiness ValueEstimated Business ValueReferenceAnalysis \ No newline at end of file diff --git a/docs/example_notebooks/Tutorial - Calculating Business Value - Multiclass Classification.ipynb b/docs/example_notebooks/Tutorial - Calculating Business Value - Multiclass Classification.ipynb new file mode 100644 index 000000000..16a0524ad --- /dev/null +++ b/docs/example_notebooks/Tutorial - Calculating Business Value - Multiclass Classification.ipynb @@ -0,0 +1,876 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
idacq_channelapp_behavioral_scorerequested_credit_limitapp_channelcredit_bureau_scorestated_incomeis_customertimestampy_pred_proba_prepaid_cardy_pred_proba_highstreet_cardy_pred_proba_upmarket_cardy_predy_true
00Partner31.808232350web30915000True2020-05-02 02:01:300.970.030.00prepaid_cardprepaid_card
11Partner24.382568500mobile41823000True2020-05-02 02:03:330.870.130.00prepaid_cardprepaid_card
22Partner2-0.787575400web50724000False2020-05-02 02:04:490.470.350.18prepaid_cardupmarket_card
\n", + "
" + ], + "text/plain": [ + " id acq_channel app_behavioral_score requested_credit_limit app_channel \\\n", + "0 0 Partner3 1.808232 350 web \n", + "1 1 Partner2 4.382568 500 mobile \n", + "2 2 Partner2 -0.787575 400 web \n", + "\n", + " credit_bureau_score stated_income is_customer timestamp \\\n", + "0 309 15000 True 2020-05-02 02:01:30 \n", + "1 418 23000 True 2020-05-02 02:03:33 \n", + "2 507 24000 False 2020-05-02 02:04:49 \n", + "\n", + " y_pred_proba_prepaid_card y_pred_proba_highstreet_card \\\n", + "0 0.97 0.03 \n", + "1 0.87 0.13 \n", + "2 0.47 0.35 \n", + "\n", + " y_pred_proba_upmarket_card y_pred y_true \n", + "0 0.00 prepaid_card prepaid_card \n", + "1 0.00 prepaid_card prepaid_card \n", + "2 0.18 prepaid_card upmarket_card " + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import nannyml as nml\n", + "from IPython.display import display\n", + "\n", + "reference_df, analysis_df, analysis_target_df = nml.load_synthetic_multiclass_classification_dataset()\n", + "\n", + "analysis_df = analysis_df.merge(analysis_target_df, on='id', how='left')\n", + "\n", + "display(reference_df.head(3))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "+----+------+---------------+------------------------+--------------------------+---------------+-----------------------+-----------------+---------------+---------------------+-----------------------------+--------------------------------+------------------------------+--------------+---------------+\n", + "| | id | acq_channel | app_behavioral_score | requested_credit_limit | app_channel | credit_bureau_score | stated_income | is_customer | timestamp | y_pred_proba_prepaid_card | y_pred_proba_highstreet_card | y_pred_proba_upmarket_card | y_pred | y_true |\n", + "+====+======+===============+========================+==========================+===============+=======================+=================+===============+=====================+=============================+================================+==============================+==============+===============+\n", + "| 0 | 0 | Partner3 | 1.80823 | 350 | web | 309 | 15000 | True | 2020-05-02 02:01:30 | 0.97 | 0.03 | 0 | prepaid_card | prepaid_card |\n", + "+----+------+---------------+------------------------+--------------------------+---------------+-----------------------+-----------------+---------------+---------------------+-----------------------------+--------------------------------+------------------------------+--------------+---------------+\n", + "| 1 | 1 | Partner2 | 4.38257 | 500 | mobile | 418 | 23000 | True | 2020-05-02 02:03:33 | 0.87 | 0.13 | 0 | prepaid_card | prepaid_card |\n", + "+----+------+---------------+------------------------+--------------------------+---------------+-----------------------+-----------------+---------------+---------------------+-----------------------------+--------------------------------+------------------------------+--------------+---------------+\n", + "| 2 | 2 | Partner2 | -0.787575 | 400 | web | 507 | 24000 | False | 2020-05-02 02:04:49 | 0.47 | 0.35 | 0.18 | prepaid_card | upmarket_card |\n", + "+----+------+---------------+------------------------+--------------------------+---------------+-----------------------+-----------------+---------------+---------------------+-----------------------------+--------------------------------+------------------------------+--------------+---------------+\n" + ] + } + ], + "source": [ + "print(reference_df.head(3).to_markdown(tablefmt=\"grid\"))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# matrix can be provided as a list of lists or a numpy array\n", + "business_value_matrix = [\n", + " [1, 0, -1],\n", + " [0, 1, 0],\n", + " [-1, 0, 1]\n", + "]\n", + "calc = nml.PerformanceCalculator(\n", + " y_pred_proba={\n", + " 'prepaid_card': 'y_pred_proba_prepaid_card',\n", + " 'highstreet_card': 'y_pred_proba_highstreet_card',\n", + " 'upmarket_card': 'y_pred_proba_upmarket_card'\n", + " },\n", + " y_pred='y_pred',\n", + " y_true='y_true',\n", + " timestamp_column_name='timestamp',\n", + " problem_type='classification_multiclass',\n", + " metrics=['business_value'],\n", + " business_value_matrix = business_value_matrix,\n", + " normalize_business_value='per_prediction',\n", + " chunk_size=6000\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "calc.fit(reference_df)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
chunkbusiness_value
keychunk_indexstart_indexend_indexstart_dateend_dateperiodtargets_missing_ratesampling_errorvalueupper_thresholdlower_thresholdalert
0[0:5999]0059992020-09-01 03:10:012020-09-13 16:15:10analysis0.00.0080472.0012202.0503161.963201False
1[6000:11999]16000119992020-09-13 16:15:322020-09-25 19:48:42analysis0.00.0080472.0441362.0503161.963201False
2[12000:17999]212000179992020-09-25 19:50:042020-10-08 02:53:47analysis0.00.0080472.0185322.0503161.963201False
3[18000:23999]318000239992020-10-08 02:57:342020-10-20 15:48:19analysis0.00.0080472.0185422.0503161.963201False
4[24000:29999]424000299992020-10-20 15:49:062020-11-01 22:04:40analysis0.00.0080472.0169322.0503161.963201False
5[30000:35999]530000359992020-11-01 22:04:592020-11-14 03:55:33analysis0.00.0080471.2892142.0503161.963201True
6[36000:41999]636000419992020-11-14 03:55:492020-11-26 09:19:06analysis0.00.0080471.3100692.0503161.963201True
7[42000:47999]742000479992020-11-26 09:19:222020-12-08 14:33:56analysis0.00.0080471.3297242.0503161.963201True
8[48000:53999]848000539992020-12-08 14:34:252020-12-20 18:30:30analysis0.00.0080471.3240452.0503161.963201True
9[54000:59999]954000599992020-12-20 18:31:092021-01-01 22:57:55analysis0.00.0080471.3162352.0503161.963201True
\n", + "
" + ], + "text/plain": [ + " chunk \\\n", + " key chunk_index start_index end_index start_date \n", + "0 [0:5999] 0 0 5999 2020-09-01 03:10:01 \n", + "1 [6000:11999] 1 6000 11999 2020-09-13 16:15:32 \n", + "2 [12000:17999] 2 12000 17999 2020-09-25 19:50:04 \n", + "3 [18000:23999] 3 18000 23999 2020-10-08 02:57:34 \n", + "4 [24000:29999] 4 24000 29999 2020-10-20 15:49:06 \n", + "5 [30000:35999] 5 30000 35999 2020-11-01 22:04:59 \n", + "6 [36000:41999] 6 36000 41999 2020-11-14 03:55:49 \n", + "7 [42000:47999] 7 42000 47999 2020-11-26 09:19:22 \n", + "8 [48000:53999] 8 48000 53999 2020-12-08 14:34:25 \n", + "9 [54000:59999] 9 54000 59999 2020-12-20 18:31:09 \n", + "\n", + " business_value \\\n", + " end_date period targets_missing_rate sampling_error value \n", + "0 2020-09-13 16:15:10 analysis 0.0 0.008047 2.001220 \n", + "1 2020-09-25 19:48:42 analysis 0.0 0.008047 2.044136 \n", + "2 2020-10-08 02:53:47 analysis 0.0 0.008047 2.018532 \n", + "3 2020-10-20 15:48:19 analysis 0.0 0.008047 2.018542 \n", + "4 2020-11-01 22:04:40 analysis 0.0 0.008047 2.016932 \n", + "5 2020-11-14 03:55:33 analysis 0.0 0.008047 1.289214 \n", + "6 2020-11-26 09:19:06 analysis 0.0 0.008047 1.310069 \n", + "7 2020-12-08 14:33:56 analysis 0.0 0.008047 1.329724 \n", + "8 2020-12-20 18:30:30 analysis 0.0 0.008047 1.324045 \n", + "9 2021-01-01 22:57:55 analysis 0.0 0.008047 1.316235 \n", + "\n", + " \n", + " upper_threshold lower_threshold alert \n", + "0 2.050316 1.963201 False \n", + "1 2.050316 1.963201 False \n", + "2 2.050316 1.963201 False \n", + "3 2.050316 1.963201 False \n", + "4 2.050316 1.963201 False \n", + "5 2.050316 1.963201 True \n", + "6 2.050316 1.963201 True \n", + "7 2.050316 1.963201 True \n", + "8 2.050316 1.963201 True \n", + "9 2.050316 1.963201 True " + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "results = calc.calculate(analysis_df)\n", + "display(results.filter(period='analysis').to_df())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "jupyter": { + "outputs_hidden": true + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "+----+---------------+-----------------+-----------------+---------------+---------------------+---------------------+------------+--------------------------+--------------------+-----------+---------------------+---------------------+-----------+\n", + "| | | chunk | | | | | | | | | business_value | | | | |\n", + "| | | key | | chunk_index | | start_index | | end_index | | start_date | | end_date | | period | | targets_missing_rate | | sampling_error | | value | | upper_threshold | | lower_threshold | | alert |\n", + "+====+===============+=================+=================+===============+=====================+=====================+============+==========================+====================+===========+=====================+=====================+===========+\n", + "| 0 | [0:5999] | 0 | 0 | 5999 | 2020-09-01 03:10:01 | 2020-09-13 16:15:10 | analysis | 0 | 0.00804747 | 2.00122 | 2.05032 | 1.9632 | False |\n", + "+----+---------------+-----------------+-----------------+---------------+---------------------+---------------------+------------+--------------------------+--------------------+-----------+---------------------+---------------------+-----------+\n", + "| 1 | [6000:11999] | 1 | 6000 | 11999 | 2020-09-13 16:15:32 | 2020-09-25 19:48:42 | analysis | 0 | 0.00804747 | 2.04414 | 2.05032 | 1.9632 | False |\n", + "+----+---------------+-----------------+-----------------+---------------+---------------------+---------------------+------------+--------------------------+--------------------+-----------+---------------------+---------------------+-----------+\n", + "| 2 | [12000:17999] | 2 | 12000 | 17999 | 2020-09-25 19:50:04 | 2020-10-08 02:53:47 | analysis | 0 | 0.00804747 | 2.01853 | 2.05032 | 1.9632 | False |\n", + "+----+---------------+-----------------+-----------------+---------------+---------------------+---------------------+------------+--------------------------+--------------------+-----------+---------------------+---------------------+-----------+\n", + "| 3 | [18000:23999] | 3 | 18000 | 23999 | 2020-10-08 02:57:34 | 2020-10-20 15:48:19 | analysis | 0 | 0.00804747 | 2.01854 | 2.05032 | 1.9632 | False |\n", + "+----+---------------+-----------------+-----------------+---------------+---------------------+---------------------+------------+--------------------------+--------------------+-----------+---------------------+---------------------+-----------+\n", + "| 4 | [24000:29999] | 4 | 24000 | 29999 | 2020-10-20 15:49:06 | 2020-11-01 22:04:40 | analysis | 0 | 0.00804747 | 2.01693 | 2.05032 | 1.9632 | False |\n", + "+----+---------------+-----------------+-----------------+---------------+---------------------+---------------------+------------+--------------------------+--------------------+-----------+---------------------+---------------------+-----------+\n", + "| 5 | [30000:35999] | 5 | 30000 | 35999 | 2020-11-01 22:04:59 | 2020-11-14 03:55:33 | analysis | 0 | 0.00804747 | 1.28921 | 2.05032 | 1.9632 | True |\n", + "+----+---------------+-----------------+-----------------+---------------+---------------------+---------------------+------------+--------------------------+--------------------+-----------+---------------------+---------------------+-----------+\n", + "| 6 | [36000:41999] | 6 | 36000 | 41999 | 2020-11-14 03:55:49 | 2020-11-26 09:19:06 | analysis | 0 | 0.00804747 | 1.31007 | 2.05032 | 1.9632 | True |\n", + "+----+---------------+-----------------+-----------------+---------------+---------------------+---------------------+------------+--------------------------+--------------------+-----------+---------------------+---------------------+-----------+\n", + "| 7 | [42000:47999] | 7 | 42000 | 47999 | 2020-11-26 09:19:22 | 2020-12-08 14:33:56 | analysis | 0 | 0.00804747 | 1.32972 | 2.05032 | 1.9632 | True |\n", + "+----+---------------+-----------------+-----------------+---------------+---------------------+---------------------+------------+--------------------------+--------------------+-----------+---------------------+---------------------+-----------+\n", + "| 8 | [48000:53999] | 8 | 48000 | 53999 | 2020-12-08 14:34:25 | 2020-12-20 18:30:30 | analysis | 0 | 0.00804747 | 1.32404 | 2.05032 | 1.9632 | True |\n", + "+----+---------------+-----------------+-----------------+---------------+---------------------+---------------------+------------+--------------------------+--------------------+-----------+---------------------+---------------------+-----------+\n", + "| 9 | [54000:59999] | 9 | 54000 | 59999 | 2020-12-20 18:31:09 | 2021-01-01 22:57:55 | analysis | 0 | 0.00804747 | 1.31623 | 2.05032 | 1.9632 | True |\n", + "+----+---------------+-----------------+-----------------+---------------+---------------------+---------------------+------------+--------------------------+--------------------+-----------+---------------------+---------------------+-----------+\n" + ] + } + ], + "source": [ + "from docs.utils import print_multi_index_markdown\n", + "print_multi_index_markdown(results.filter(period='analysis').to_df())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
chunkbusiness_value
keychunk_indexstart_indexend_indexstart_dateend_dateperiodtargets_missing_ratesampling_errorvalueupper_thresholdlower_thresholdalert
0[0:5999]0059992020-05-02 02:01:302020-05-14 12:25:35reference0.00.0080472.0092582.0503161.963201False
1[6000:11999]16000119992020-05-14 12:29:252020-05-26 18:27:42reference0.00.0080472.0049992.0503161.963201False
2[12000:17999]212000179992020-05-26 18:31:062020-06-07 19:55:45reference0.00.0080472.0147652.0503161.963201False
3[18000:23999]318000239992020-06-07 19:58:392020-06-19 19:42:20reference0.00.0080471.9891842.0503161.963201False
4[24000:29999]424000299992020-06-19 19:44:142020-07-02 01:58:05reference0.00.0080472.0243722.0503161.963201False
5[30000:35999]530000359992020-07-02 02:06:562020-07-14 08:14:04reference0.00.0080471.9909782.0503161.963201False
6[36000:41999]636000419992020-07-14 08:14:082020-07-26 12:55:42reference0.00.0080471.9922582.0503161.963201False
7[42000:47999]742000479992020-07-26 12:57:372020-08-07 16:32:15reference0.00.0080472.0245382.0503161.963201False
8[48000:53999]848000539992020-08-07 16:33:442020-08-20 00:06:08reference0.00.0080471.9908222.0503161.963201False
9[54000:59999]954000599992020-08-20 00:07:582020-09-01 03:03:23reference0.00.0080472.0264092.0503161.963201False
\n", + "
" + ], + "text/plain": [ + " chunk \\\n", + " key chunk_index start_index end_index start_date \n", + "0 [0:5999] 0 0 5999 2020-05-02 02:01:30 \n", + "1 [6000:11999] 1 6000 11999 2020-05-14 12:29:25 \n", + "2 [12000:17999] 2 12000 17999 2020-05-26 18:31:06 \n", + "3 [18000:23999] 3 18000 23999 2020-06-07 19:58:39 \n", + "4 [24000:29999] 4 24000 29999 2020-06-19 19:44:14 \n", + "5 [30000:35999] 5 30000 35999 2020-07-02 02:06:56 \n", + "6 [36000:41999] 6 36000 41999 2020-07-14 08:14:08 \n", + "7 [42000:47999] 7 42000 47999 2020-07-26 12:57:37 \n", + "8 [48000:53999] 8 48000 53999 2020-08-07 16:33:44 \n", + "9 [54000:59999] 9 54000 59999 2020-08-20 00:07:58 \n", + "\n", + " business_value \\\n", + " end_date period targets_missing_rate sampling_error \n", + "0 2020-05-14 12:25:35 reference 0.0 0.008047 \n", + "1 2020-05-26 18:27:42 reference 0.0 0.008047 \n", + "2 2020-06-07 19:55:45 reference 0.0 0.008047 \n", + "3 2020-06-19 19:42:20 reference 0.0 0.008047 \n", + "4 2020-07-02 01:58:05 reference 0.0 0.008047 \n", + "5 2020-07-14 08:14:04 reference 0.0 0.008047 \n", + "6 2020-07-26 12:55:42 reference 0.0 0.008047 \n", + "7 2020-08-07 16:32:15 reference 0.0 0.008047 \n", + "8 2020-08-20 00:06:08 reference 0.0 0.008047 \n", + "9 2020-09-01 03:03:23 reference 0.0 0.008047 \n", + "\n", + " \n", + " value upper_threshold lower_threshold alert \n", + "0 2.009258 2.050316 1.963201 False \n", + "1 2.004999 2.050316 1.963201 False \n", + "2 2.014765 2.050316 1.963201 False \n", + "3 1.989184 2.050316 1.963201 False \n", + "4 2.024372 2.050316 1.963201 False \n", + "5 1.990978 2.050316 1.963201 False \n", + "6 1.992258 2.050316 1.963201 False \n", + "7 2.024538 2.050316 1.963201 False \n", + "8 1.990822 2.050316 1.963201 False \n", + "9 2.026409 2.050316 1.963201 False " + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "display(results.filter(period='reference').to_df())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "jupyter": { + "outputs_hidden": true + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "+----+---------------+-----------------+-----------------+---------------+---------------------+---------------------+------------+--------------------------+--------------------+-----------+---------------------+---------------------+-----------+\n", + "| | | chunk | | | | | | | | | business_value | | | | |\n", + "| | | key | | chunk_index | | start_index | | end_index | | start_date | | end_date | | period | | targets_missing_rate | | sampling_error | | value | | upper_threshold | | lower_threshold | | alert |\n", + "+====+===============+=================+=================+===============+=====================+=====================+============+==========================+====================+===========+=====================+=====================+===========+\n", + "| 0 | [0:5999] | 0 | 0 | 5999 | 2020-05-02 02:01:30 | 2020-05-14 12:25:35 | reference | 0 | 0.00804747 | 2.00926 | 2.05032 | 1.9632 | False |\n", + "+----+---------------+-----------------+-----------------+---------------+---------------------+---------------------+------------+--------------------------+--------------------+-----------+---------------------+---------------------+-----------+\n", + "| 1 | [6000:11999] | 1 | 6000 | 11999 | 2020-05-14 12:29:25 | 2020-05-26 18:27:42 | reference | 0 | 0.00804747 | 2.005 | 2.05032 | 1.9632 | False |\n", + "+----+---------------+-----------------+-----------------+---------------+---------------------+---------------------+------------+--------------------------+--------------------+-----------+---------------------+---------------------+-----------+\n", + "| 2 | [12000:17999] | 2 | 12000 | 17999 | 2020-05-26 18:31:06 | 2020-06-07 19:55:45 | reference | 0 | 0.00804747 | 2.01476 | 2.05032 | 1.9632 | False |\n", + "+----+---------------+-----------------+-----------------+---------------+---------------------+---------------------+------------+--------------------------+--------------------+-----------+---------------------+---------------------+-----------+\n", + "| 3 | [18000:23999] | 3 | 18000 | 23999 | 2020-06-07 19:58:39 | 2020-06-19 19:42:20 | reference | 0 | 0.00804747 | 1.98918 | 2.05032 | 1.9632 | False |\n", + "+----+---------------+-----------------+-----------------+---------------+---------------------+---------------------+------------+--------------------------+--------------------+-----------+---------------------+---------------------+-----------+\n", + "| 4 | [24000:29999] | 4 | 24000 | 29999 | 2020-06-19 19:44:14 | 2020-07-02 01:58:05 | reference | 0 | 0.00804747 | 2.02437 | 2.05032 | 1.9632 | False |\n", + "+----+---------------+-----------------+-----------------+---------------+---------------------+---------------------+------------+--------------------------+--------------------+-----------+---------------------+---------------------+-----------+\n", + "| 5 | [30000:35999] | 5 | 30000 | 35999 | 2020-07-02 02:06:56 | 2020-07-14 08:14:04 | reference | 0 | 0.00804747 | 1.99098 | 2.05032 | 1.9632 | False |\n", + "+----+---------------+-----------------+-----------------+---------------+---------------------+---------------------+------------+--------------------------+--------------------+-----------+---------------------+---------------------+-----------+\n", + "| 6 | [36000:41999] | 6 | 36000 | 41999 | 2020-07-14 08:14:08 | 2020-07-26 12:55:42 | reference | 0 | 0.00804747 | 1.99226 | 2.05032 | 1.9632 | False |\n", + "+----+---------------+-----------------+-----------------+---------------+---------------------+---------------------+------------+--------------------------+--------------------+-----------+---------------------+---------------------+-----------+\n", + "| 7 | [42000:47999] | 7 | 42000 | 47999 | 2020-07-26 12:57:37 | 2020-08-07 16:32:15 | reference | 0 | 0.00804747 | 2.02454 | 2.05032 | 1.9632 | False |\n", + "+----+---------------+-----------------+-----------------+---------------+---------------------+---------------------+------------+--------------------------+--------------------+-----------+---------------------+---------------------+-----------+\n", + "| 8 | [48000:53999] | 8 | 48000 | 53999 | 2020-08-07 16:33:44 | 2020-08-20 00:06:08 | reference | 0 | 0.00804747 | 1.99082 | 2.05032 | 1.9632 | False |\n", + "+----+---------------+-----------------+-----------------+---------------+---------------------+---------------------+------------+--------------------------+--------------------+-----------+---------------------+---------------------+-----------+\n", + "| 9 | [54000:59999] | 9 | 54000 | 59999 | 2020-08-20 00:07:58 | 2020-09-01 03:03:23 | reference | 0 | 0.00804747 | 2.02641 | 2.05032 | 1.9632 | False |\n", + "+----+---------------+-----------------+-----------------+---------------+---------------------+---------------------+------------+--------------------------+--------------------+-----------+---------------------+---------------------+-----------+\n" + ] + } + ], + "source": [ + "print_multi_index_markdown(results.filter(period='reference').to_df())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "figure = results.plot()\n", + "figure.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "figure = results.plot()\n", + "figure.write_image(f'../_static/tutorials/performance_calculation/multiclass/business_value.svg')\n", + "\n", + "# tutorial-perf-est-guide-binary-class-car-loan-analysis" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/docs/example_notebooks/Tutorial - Estimating Business Value - Binary Classification.ipynb b/docs/example_notebooks/Tutorial - Estimating Business Value - Binary Classification.ipynb index 244d2529a..f4912c9f4 100644 --- a/docs/example_notebooks/Tutorial - Estimating Business Value - Binary Classification.ipynb +++ b/docs/example_notebooks/Tutorial - Estimating Business Value - Binary Classification.ipynb @@ -127,6 +127,7 @@ "execution_count": null, "id": "16b16c8b", "metadata": { + "collapsed": false, "jupyter": { "outputs_hidden": false } @@ -541,6 +542,7 @@ "execution_count": null, "id": "0c5e9902", "metadata": { + "collapsed": false, "jupyter": { "outputs_hidden": false } @@ -570,7 +572,7 @@ ], "metadata": { "kernelspec": { - "display_name": "EMD", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -584,7 +586,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.8" + "version": "3.11.9" }, "vscode": { "interpreter": { diff --git a/docs/example_notebooks/Tutorial - Estimating Business Value - Multiclass Classification.ipynb b/docs/example_notebooks/Tutorial - Estimating Business Value - Multiclass Classification.ipynb new file mode 100644 index 000000000..fe54e8103 --- /dev/null +++ b/docs/example_notebooks/Tutorial - Estimating Business Value - Multiclass Classification.ipynb @@ -0,0 +1,587 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
idacq_channelapp_behavioral_scorerequested_credit_limitapp_channelcredit_bureau_scorestated_incomeis_customertimestampy_pred_proba_prepaid_cardy_pred_proba_highstreet_cardy_pred_proba_upmarket_cardy_predy_true
00Partner31.808232350web30915000True2020-05-02 02:01:300.970.030.00prepaid_cardprepaid_card
11Partner24.382568500mobile41823000True2020-05-02 02:03:330.870.130.00prepaid_cardprepaid_card
22Partner2-0.787575400web50724000False2020-05-02 02:04:490.470.350.18prepaid_cardupmarket_card
\n", + "
" + ], + "text/plain": [ + " id acq_channel app_behavioral_score requested_credit_limit app_channel \\\n", + "0 0 Partner3 1.808232 350 web \n", + "1 1 Partner2 4.382568 500 mobile \n", + "2 2 Partner2 -0.787575 400 web \n", + "\n", + " credit_bureau_score stated_income is_customer timestamp \\\n", + "0 309 15000 True 2020-05-02 02:01:30 \n", + "1 418 23000 True 2020-05-02 02:03:33 \n", + "2 507 24000 False 2020-05-02 02:04:49 \n", + "\n", + " y_pred_proba_prepaid_card y_pred_proba_highstreet_card \\\n", + "0 0.97 0.03 \n", + "1 0.87 0.13 \n", + "2 0.47 0.35 \n", + "\n", + " y_pred_proba_upmarket_card y_pred y_true \n", + "0 0.00 prepaid_card prepaid_card \n", + "1 0.00 prepaid_card prepaid_card \n", + "2 0.18 prepaid_card upmarket_card " + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import nannyml as nml\n", + "from IPython.display import display\n", + "\n", + "reference_df, analysis_df, _ = nml.load_synthetic_multiclass_classification_dataset()\n", + "\n", + "display(reference_df.head(3))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "jupyter": { + "outputs_hidden": true + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "+----+------+---------------+------------------------+--------------------------+---------------+-----------------------+-----------------+---------------+---------------------+-----------------------------+--------------------------------+------------------------------+--------------+---------------+\n", + "| | id | acq_channel | app_behavioral_score | requested_credit_limit | app_channel | credit_bureau_score | stated_income | is_customer | timestamp | y_pred_proba_prepaid_card | y_pred_proba_highstreet_card | y_pred_proba_upmarket_card | y_pred | y_true |\n", + "+====+======+===============+========================+==========================+===============+=======================+=================+===============+=====================+=============================+================================+==============================+==============+===============+\n", + "| 0 | 0 | Partner3 | 1.80823 | 350 | web | 309 | 15000 | True | 2020-05-02 02:01:30 | 0.97 | 0.03 | 0 | prepaid_card | prepaid_card |\n", + "+----+------+---------------+------------------------+--------------------------+---------------+-----------------------+-----------------+---------------+---------------------+-----------------------------+--------------------------------+------------------------------+--------------+---------------+\n", + "| 1 | 1 | Partner2 | 4.38257 | 500 | mobile | 418 | 23000 | True | 2020-05-02 02:03:33 | 0.87 | 0.13 | 0 | prepaid_card | prepaid_card |\n", + "+----+------+---------------+------------------------+--------------------------+---------------+-----------------------+-----------------+---------------+---------------------+-----------------------------+--------------------------------+------------------------------+--------------+---------------+\n", + "| 2 | 2 | Partner2 | -0.787575 | 400 | web | 507 | 24000 | False | 2020-05-02 02:04:49 | 0.47 | 0.35 | 0.18 | prepaid_card | upmarket_card |\n", + "+----+------+---------------+------------------------+--------------------------+---------------+-----------------------+-----------------+---------------+---------------------+-----------------------------+--------------------------------+------------------------------+--------------+---------------+\n" + ] + } + ], + "source": [ + "print(reference_df.head(3).to_markdown(tablefmt=\"grid\"))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# matrix can be provided as a list of lists or a numpy array\n", + "business_value_matrix = [\n", + " [1, 0, -1],\n", + " [0, 1, 0],\n", + " [-1, 0, 1]\n", + "]\n", + "estimator = nml.CBPE(\n", + " y_pred_proba={\n", + " 'prepaid_card': 'y_pred_proba_prepaid_card',\n", + " 'highstreet_card': 'y_pred_proba_highstreet_card',\n", + " 'upmarket_card': 'y_pred_proba_upmarket_card'},\n", + " y_pred='y_pred',\n", + " y_true='y_true',\n", + " timestamp_column_name='timestamp',\n", + " problem_type='classification_multiclass',\n", + " metrics=['business_value'],\n", + " business_value_matrix=business_value_matrix,\n", + " normalize_business_value=\"per_prediction\",\n", + " chunk_size=6000,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "estimator.fit(reference_df)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
chunkbusiness_value
keychunk_indexstart_indexend_indexstart_dateend_dateperiodvaluesampling_errorrealizedupper_confidence_boundarylower_confidence_boundaryupper_thresholdlower_thresholdalert
0[0:5999]0059992020-09-01 03:10:012020-09-13 16:15:10analysis2.0086170.008047NaN2.0327601.9844752.0503161.963201False
1[6000:11999]16000119992020-09-13 16:15:322020-09-25 19:48:42analysis2.0167090.008047NaN2.0408511.9925662.0503161.963201False
2[12000:17999]212000179992020-09-25 19:50:042020-10-08 02:53:47analysis2.0251520.008047NaN2.0492942.0010102.0503161.963201False
3[18000:23999]318000239992020-10-08 02:57:342020-10-20 15:48:19analysis2.0189280.008047NaN2.0430701.9947862.0503161.963201False
4[24000:29999]424000299992020-10-20 15:49:062020-11-01 22:04:40analysis2.0065210.008047NaN2.0306641.9823792.0503161.963201False
5[30000:35999]530000359992020-11-01 22:04:592020-11-14 03:55:33analysis1.5644430.008047NaN1.5885851.5403002.0503161.963201True
6[36000:41999]636000419992020-11-14 03:55:492020-11-26 09:19:06analysis1.5684600.008047NaN1.5926031.5443182.0503161.963201True
7[42000:47999]742000479992020-11-26 09:19:222020-12-08 14:33:56analysis1.5620410.008047NaN1.5861831.5378982.0503161.963201True
8[48000:53999]848000539992020-12-08 14:34:252020-12-20 18:30:30analysis1.5668660.008047NaN1.5910091.5427242.0503161.963201True
9[54000:59999]954000599992020-12-20 18:31:092021-01-01 22:57:55analysis1.5742500.008047NaN1.5983921.5501072.0503161.963201True
\n", + "
" + ], + "text/plain": [ + " chunk \\\n", + " key chunk_index start_index end_index start_date \n", + "0 [0:5999] 0 0 5999 2020-09-01 03:10:01 \n", + "1 [6000:11999] 1 6000 11999 2020-09-13 16:15:32 \n", + "2 [12000:17999] 2 12000 17999 2020-09-25 19:50:04 \n", + "3 [18000:23999] 3 18000 23999 2020-10-08 02:57:34 \n", + "4 [24000:29999] 4 24000 29999 2020-10-20 15:49:06 \n", + "5 [30000:35999] 5 30000 35999 2020-11-01 22:04:59 \n", + "6 [36000:41999] 6 36000 41999 2020-11-14 03:55:49 \n", + "7 [42000:47999] 7 42000 47999 2020-11-26 09:19:22 \n", + "8 [48000:53999] 8 48000 53999 2020-12-08 14:34:25 \n", + "9 [54000:59999] 9 54000 59999 2020-12-20 18:31:09 \n", + "\n", + " business_value \\\n", + " end_date period value sampling_error realized \n", + "0 2020-09-13 16:15:10 analysis 2.008617 0.008047 NaN \n", + "1 2020-09-25 19:48:42 analysis 2.016709 0.008047 NaN \n", + "2 2020-10-08 02:53:47 analysis 2.025152 0.008047 NaN \n", + "3 2020-10-20 15:48:19 analysis 2.018928 0.008047 NaN \n", + "4 2020-11-01 22:04:40 analysis 2.006521 0.008047 NaN \n", + "5 2020-11-14 03:55:33 analysis 1.564443 0.008047 NaN \n", + "6 2020-11-26 09:19:06 analysis 1.568460 0.008047 NaN \n", + "7 2020-12-08 14:33:56 analysis 1.562041 0.008047 NaN \n", + "8 2020-12-20 18:30:30 analysis 1.566866 0.008047 NaN \n", + "9 2021-01-01 22:57:55 analysis 1.574250 0.008047 NaN \n", + "\n", + " \\\n", + " upper_confidence_boundary lower_confidence_boundary upper_threshold \n", + "0 2.032760 1.984475 2.050316 \n", + "1 2.040851 1.992566 2.050316 \n", + "2 2.049294 2.001010 2.050316 \n", + "3 2.043070 1.994786 2.050316 \n", + "4 2.030664 1.982379 2.050316 \n", + "5 1.588585 1.540300 2.050316 \n", + "6 1.592603 1.544318 2.050316 \n", + "7 1.586183 1.537898 2.050316 \n", + "8 1.591009 1.542724 2.050316 \n", + "9 1.598392 1.550107 2.050316 \n", + "\n", + " \n", + " lower_threshold alert \n", + "0 1.963201 False \n", + "1 1.963201 False \n", + "2 1.963201 False \n", + "3 1.963201 False \n", + "4 1.963201 False \n", + "5 1.963201 True \n", + "6 1.963201 True \n", + "7 1.963201 True \n", + "8 1.963201 True \n", + "9 1.963201 True " + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "results = estimator.estimate(analysis_df)\n", + "display(results.filter(period='analysis').to_df())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "jupyter": { + "outputs_hidden": true + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "+----+---------------+-----------------+-----------------+---------------+---------------------+---------------------+------------+--------------------+--------------------+--------------+-------------------------------+-------------------------------+---------------------+---------------------+-----------+\n", + "| | | chunk | | | | | | | | business_value | | | | | | | |\n", + "| | | key | | chunk_index | | start_index | | end_index | | start_date | | end_date | | period | | value | | sampling_error | | realized | | upper_confidence_boundary | | lower_confidence_boundary | | upper_threshold | | lower_threshold | | alert |\n", + "+====+===============+=================+=================+===============+=====================+=====================+============+====================+====================+==============+===============================+===============================+=====================+=====================+===========+\n", + "| 0 | [0:5999] | 0 | 0 | 5999 | 2020-09-01 03:10:01 | 2020-09-13 16:15:10 | analysis | 2.00862 | 0.00804747 | nan | 2.03276 | 1.98448 | 2.05032 | 1.9632 | False |\n", + "+----+---------------+-----------------+-----------------+---------------+---------------------+---------------------+------------+--------------------+--------------------+--------------+-------------------------------+-------------------------------+---------------------+---------------------+-----------+\n", + "| 1 | [6000:11999] | 1 | 6000 | 11999 | 2020-09-13 16:15:32 | 2020-09-25 19:48:42 | analysis | 2.01671 | 0.00804747 | nan | 2.04085 | 1.99257 | 2.05032 | 1.9632 | False |\n", + "+----+---------------+-----------------+-----------------+---------------+---------------------+---------------------+------------+--------------------+--------------------+--------------+-------------------------------+-------------------------------+---------------------+---------------------+-----------+\n", + "| 2 | [12000:17999] | 2 | 12000 | 17999 | 2020-09-25 19:50:04 | 2020-10-08 02:53:47 | analysis | 2.02515 | 0.00804747 | nan | 2.04929 | 2.00101 | 2.05032 | 1.9632 | False |\n", + "+----+---------------+-----------------+-----------------+---------------+---------------------+---------------------+------------+--------------------+--------------------+--------------+-------------------------------+-------------------------------+---------------------+---------------------+-----------+\n", + "| 3 | [18000:23999] | 3 | 18000 | 23999 | 2020-10-08 02:57:34 | 2020-10-20 15:48:19 | analysis | 2.01893 | 0.00804747 | nan | 2.04307 | 1.99479 | 2.05032 | 1.9632 | False |\n", + "+----+---------------+-----------------+-----------------+---------------+---------------------+---------------------+------------+--------------------+--------------------+--------------+-------------------------------+-------------------------------+---------------------+---------------------+-----------+\n", + "| 4 | [24000:29999] | 4 | 24000 | 29999 | 2020-10-20 15:49:06 | 2020-11-01 22:04:40 | analysis | 2.00652 | 0.00804747 | nan | 2.03066 | 1.98238 | 2.05032 | 1.9632 | False |\n", + "+----+---------------+-----------------+-----------------+---------------+---------------------+---------------------+------------+--------------------+--------------------+--------------+-------------------------------+-------------------------------+---------------------+---------------------+-----------+\n", + "| 5 | [30000:35999] | 5 | 30000 | 35999 | 2020-11-01 22:04:59 | 2020-11-14 03:55:33 | analysis | 1.56444 | 0.00804747 | nan | 1.58858 | 1.5403 | 2.05032 | 1.9632 | True |\n", + "+----+---------------+-----------------+-----------------+---------------+---------------------+---------------------+------------+--------------------+--------------------+--------------+-------------------------------+-------------------------------+---------------------+---------------------+-----------+\n", + "| 6 | [36000:41999] | 6 | 36000 | 41999 | 2020-11-14 03:55:49 | 2020-11-26 09:19:06 | analysis | 1.56846 | 0.00804747 | nan | 1.5926 | 1.54432 | 2.05032 | 1.9632 | True |\n", + "+----+---------------+-----------------+-----------------+---------------+---------------------+---------------------+------------+--------------------+--------------------+--------------+-------------------------------+-------------------------------+---------------------+---------------------+-----------+\n", + "| 7 | [42000:47999] | 7 | 42000 | 47999 | 2020-11-26 09:19:22 | 2020-12-08 14:33:56 | analysis | 1.56204 | 0.00804747 | nan | 1.58618 | 1.5379 | 2.05032 | 1.9632 | True |\n", + "+----+---------------+-----------------+-----------------+---------------+---------------------+---------------------+------------+--------------------+--------------------+--------------+-------------------------------+-------------------------------+---------------------+---------------------+-----------+\n", + "| 8 | [48000:53999] | 8 | 48000 | 53999 | 2020-12-08 14:34:25 | 2020-12-20 18:30:30 | analysis | 1.56687 | 0.00804747 | nan | 1.59101 | 1.54272 | 2.05032 | 1.9632 | True |\n", + "+----+---------------+-----------------+-----------------+---------------+---------------------+---------------------+------------+--------------------+--------------------+--------------+-------------------------------+-------------------------------+---------------------+---------------------+-----------+\n", + "| 9 | [54000:59999] | 9 | 54000 | 59999 | 2020-12-20 18:31:09 | 2021-01-01 22:57:55 | analysis | 1.57425 | 0.00804747 | nan | 1.59839 | 1.55011 | 2.05032 | 1.9632 | True |\n", + "+----+---------------+-----------------+-----------------+---------------+---------------------+---------------------+------------+--------------------+--------------------+--------------+-------------------------------+-------------------------------+---------------------+---------------------+-----------+\n" + ] + } + ], + "source": [ + "from docs.utils import print_multi_index_markdown\n", + "print_multi_index_markdown(results.filter(period='analysis').to_df())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "metric_fig = results.plot()\n", + "metric_fig.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "metric_fig = results.plot()\n", + "metric_fig.write_image(file=f\"../_static/tutorials/performance_estimation/multiclass/business_value.svg\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/docs/glossary.rst b/docs/glossary.rst index 6f664acfb..03b12845e 100644 --- a/docs/glossary.rst +++ b/docs/glossary.rst @@ -16,12 +16,17 @@ Glossary Note that alerts are not raised during the reference :term:`Data Period`. Business Value Matrix - A matrix that is used to calculate the business value of a model. For binary classification, - the matrix is a 2x2 matrix with the following cells: true positive cost, true negative cost, - false positive cost, false negative cost. The business value of a model is calculated as the + A matrix that is used to calculate the business value of a model. The format of the + business value matrix must be specified so that each element represents the business + value of it's respective confusion matrix element. Hence the element on the i-th row and j-column of the + business value matrix tells us the value of the i-th target when we have predicted the j-th value. + It can be provided as a list of lists or a numpy array. The business value of a model is calculated as the sum of the products of the values in the matrix and the corresponding cells in the confusion matrix. + For more information about the business value matrix, + check out the :ref:`Business Value "How it Works" page`. + Butterfly dataset A dataset used in :ref:`how-multiv-drift` to give an example where univariate drift statistics are insufficient in detecting complex data drifts in multidimensional @@ -96,7 +101,7 @@ Glossary periods - they contain all the observations and predictions from a single hour, day, month etc. depending on the selected interval. They can also be size-based so that each chunk contains *n* observations or number-based so the whole data is split into *k* chunks. In each case chronology of data between chunks is - maintained. + maintained. To better understand how to create chunks with NannyML check out the :ref:`chunking tutorial`. Data Period A data period is a subset of the data used to monitor a model. NannyML expects the provided data to be in one of two data periods. diff --git a/docs/how_it_works/business_value.rst b/docs/how_it_works/business_value.rst index 2f590b2f5..35d4d9eae 100644 --- a/docs/how_it_works/business_value.rst +++ b/docs/how_it_works/business_value.rst @@ -9,15 +9,14 @@ monetary or business oriented outcomes. In this page, we will discuss how the **business_value** metric works under the hood. Introduction to Business Value --------------------------------------- +------------------------------ -The **business_value** metric offers a way to quantify -the value of a model in terms of the +The **business_value** metric offers a way to quantify the value of a model in terms of the business's own metrics. At the core, if the business value (or cost) of each outcome in the :term:`confusion matrix` is known, then the business value of a -model can either be *calculated* using the realized :term:`confusion matrix` if -the ground truth labels are available or *estimated* using the -estimated :term:`confusion matrix` if the ground truth labels are not available. +model can either be *calculated* using the :ref:`realized Performance Calculator` if +the ground truth labels are available or *estimated* using :ref:`Performance Estimation` +if the ground truth labels are not available. More specifically, we know that each prediction made by a binary classification models can be one of four outcomes: @@ -50,10 +49,16 @@ We can formalize the intuition above as follows: \text{business value} = \sum_{i=1}^{n} \sum_{j=1}^{n} \text{business_value}_{i,j} \times \text{confusion_matrix}_{i,j} -where :math:`\text{business_value}_{i,j}` is the business value of a cell in the :term:`confusion matrix`, and :math:`\text{confusion_matrix}_{i,j}` is the count of observations -in that cell of the :term:`confusion matrix`. We use the `sklearn confusion matrix representation`_ that assuming label 0 is negative and label 1 is positive. +where :math:`\text{business_value}_{i,j}` is the business value of a cell in the +:term:`confusion matrix`, and :math:`\text{confusion_matrix}_{i,j}` is the count of +observations in that cell of the :term:`confusion matrix`. Using the confusion +matrix notation the element on the i-th row and j-column of the business value matrix tells us the value +of the i-th target when we have predicted the j-th value. -Since we are in the binary classification case, :math:`n=2`, and the :term:`confusion matrix` is: +For binary classification this formula is easier to manage hence we will use it as an example. Classificatio problems +with more classes follow the same pattern. +Using the `sklearn confusion matrix convention`_ we designate label 0 as negative and label 1 as positive. +Hence we can write the :term:`confusion matrix` as: .. math:: @@ -62,7 +67,10 @@ Since we are in the binary classification case, :math:`n=2`, and the :term:`conf \text{# of false negatives} & \text{# of true positives} \end{bmatrix} -And the :term:`business value matrix` is: +Note that target values are represented by rows and predicted values are represented by columns. +This means that the first row contains values that have resulted in the negative outcome +while the first column contains values that were predicted with negative label. +The correspondings :term:`business value matrix` is: .. math:: @@ -80,22 +88,27 @@ The business value of a binary classification model can thus be generally expres + (\text{value of a false negative}) \cdot (\text{# of false negatives}) \\ + (\text{value of a true positive}) \cdot (\text{# of true positives}) -Calculation of Business Value For Binary Classification -------------------------------------------------------- +Calculation of Business Value For Classification +------------------------------------------------ When the ground truth labels are available, the business value of a model can be calculated by using the -values from the realized :term:`confusion matrix`, and then using the business value formula above to calculate -the business value. +values from the realized :term:`confusion matrix`, +and then using the business value formula above to calculate the business value. -For a tutorial on how to calculate the business value of a model, see our :ref:`business-value-calculation` tutorial. +For a tutorial on how to calculate the business value of a model, +see our :ref:`business-value-calculation` and :ref:`multiclass-business-value-calculation` tutorials. -Estimation of Business Value For Binary Classification ------------------------------------------------------- -In cases where ground truth labels of the data are unavailable, we can still estimate the business value of a model. This is done by using the -:term:`CBPE (Confidence-Based Performance Estimation)` algorithm to estimate the :term:`confusion matrix`, and then using the business value formula above to obtain a business value estimate. -To read more about the :term:`CBPE (Confidence-Based Performance Estimation)` algorithm, see our :ref:`performance estimation deep dive`. +Estimation of Business Value For Classification +----------------------------------------------- -For a tutorial on how to estimate the business value of a model, see our :ref:`business-value-estimation` tutorial. +In cases where ground truth labels of the data are unavailable, we can still estimate the business value of a model. +This is done by using the :term:`CBPE (Confidence-Based Performance Estimation)` algorithm to estimate the +:term:`confusion matrix`, and then using the business value formula above to obtain a business value estimate. +To read more about the :term:`CBPE (Confidence-Based Performance Estimation)` algorithm, +see our :ref:`performance estimation deep dive`. + +For a tutorial on how to estimate the business value of a model, see our :ref:`business-value-estimation` +and :ref:`multiclasss-business-value-estimation` tutorials. Normalization ------------- @@ -113,4 +126,4 @@ Check out the :ref:`business-value-calculation` tutorial and the :ref:`business- for examples of how to normalize the business value metric. -.. _`sklearn confusion matrix representation`: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.confusion_matrix.html +.. _`sklearn confusion matrix convention`: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.confusion_matrix.html diff --git a/docs/tutorials/performance_calculation/binary_performance_calculation/business_value_calculation.rst b/docs/tutorials/performance_calculation/binary_performance_calculation/business_value_calculation.rst index 3d989554f..856a08922 100644 --- a/docs/tutorials/performance_calculation/binary_performance_calculation/business_value_calculation.rst +++ b/docs/tutorials/performance_calculation/binary_performance_calculation/business_value_calculation.rst @@ -1,8 +1,9 @@ .. _business-value-calculation: -======================================================================================== +==================================================== Calculating Business Value for Binary Classification -======================================================================================== +==================================================== + This tutorial explains how to use NannyML to calculate business value for binary classification models. @@ -14,7 +15,7 @@ models. .. _business-value-calculation-binary-just-the-code: Just The Code ----------------- +------------- .. nbimport:: :path: ./example_notebooks/Tutorial - Calculating Business Value - Binary Classification.ipynb @@ -22,7 +23,7 @@ Just The Code Walkthrough --------------- +----------- For simplicity this guide is based on a synthetic dataset included in the library, where the monitored model predicts whether a customer will repay a loan to buy a car. @@ -132,17 +133,16 @@ calculated metric. The results can be plotted for visual inspection. Our plot contains several key elements. -* *The purple step plot* shows the performance in each chunk of the analysis period. Thick squared point - markers indicate the middle of these chunks. - -* *The blue step plot* shows the performance in each chunk of the reference period. Thick squared point markers indicate +* *The blue step plot* shows the performance in each chunk of the provided data. Thick squared point markers indicate the middle of these chunks. -* *The gray vertical line* splits the reference and analysis periods. +* *The gray vertical line* splits the reference and analysis data periods. -* *The red horizontal dashed lines* show upper and lower thresholds for alerting purposes. +* *The red horizontal dashed lines* show upper and lower thresholds that indicate the range of + expected performance values. -* *The red diamond-shaped point markers* in the middle of a chunk indicate that an alert has been raised. Alerts are caused by the performance crossing the upper or lower threshold. +* *The red diamond-shaped point markers* in the middle of a chunk indicate that an alert has been raised. + Alerts are caused by the performance crossing the upper or lower threshold. .. nbimport:: :path: ./example_notebooks/Tutorial - Calculating Business Value - Binary Classification.ipynb diff --git a/docs/tutorials/performance_calculation/binary_performance_calculation/confusion_matrix_calculation.rst b/docs/tutorials/performance_calculation/binary_performance_calculation/confusion_matrix_calculation.rst index 59976015f..e729b8d44 100644 --- a/docs/tutorials/performance_calculation/binary_performance_calculation/confusion_matrix_calculation.rst +++ b/docs/tutorials/performance_calculation/binary_performance_calculation/confusion_matrix_calculation.rst @@ -1,8 +1,8 @@ .. _confusion-matrix-calculation: -======================================================================================== +=============================================================== Calculating Confusion Matrix Elements for Binary Classification -======================================================================================== +=============================================================== This tutorial explains how to use NannyML to calculate the :term:`confusion matrix` for binary classification models. @@ -15,7 +15,7 @@ models. .. _confusion-matrix-calculation-binary-just-the-code: Just The Code ----------------- +------------- .. nbimport:: :path: ./example_notebooks/Tutorial - Calculating Confusion Matrix - Binary Classification.ipynb @@ -23,7 +23,7 @@ Just The Code Walkthrough --------------- +----------- For simplicity this guide is based on a synthetic dataset included in the library, where the monitored model predicts whether a customer will repay a loan to buy a car. @@ -125,17 +125,16 @@ calculated metric. The results can be plotted for visual inspection. Our plot contains several key elements. -* *The purple step plot* shows the performance in each chunk of the analysis period. Thick squared point - markers indicate the middle of these chunks. - -* *The blue step plot* shows the performance in each chunk of the reference period. Thick squared point markers indicate +* *The blue step plot* shows the performance in each chunk of the provided data. Thick squared point markers indicate the middle of these chunks. -* *The gray vertical line* splits the reference and analysis periods. +* *The gray vertical line* splits the reference and analysis data periods. -* *The red horizontal dashed lines* show upper and lower thresholds for alerting purposes. +* *The red horizontal dashed lines* show upper and lower thresholds that indicate the range of + expected performance values. -* *The red diamond-shaped point markers* in the middle of a chunk indicate that an alert has been raised. Alerts are caused by the performance crossing the upper or lower threshold. +* *The red diamond-shaped point markers* in the middle of a chunk indicate that an alert has been raised. + Alerts are caused by the performance crossing the upper or lower threshold. .. nbimport:: :path: ./example_notebooks/Tutorial - Calculating Confusion Matrix - Binary Classification.ipynb diff --git a/docs/tutorials/performance_calculation/binary_performance_calculation/standard_metric_calculation.rst b/docs/tutorials/performance_calculation/binary_performance_calculation/standard_metric_calculation.rst index 9b2ea90c4..d994b86fb 100644 --- a/docs/tutorials/performance_calculation/binary_performance_calculation/standard_metric_calculation.rst +++ b/docs/tutorials/performance_calculation/binary_performance_calculation/standard_metric_calculation.rst @@ -1,8 +1,8 @@ .. _standard-metric-calculation: -======================================================================================== +================================================================== Calculating Standard Performance Metrics for Binary Classification -======================================================================================== +================================================================== This tutorial explains how to use NannyML to calculate standard performance metrics for binary classification models. @@ -15,7 +15,7 @@ models. .. _standard-metric-calculation-binary-just-the-code: Just The Code ----------------- +------------- .. nbimport:: :path: ./example_notebooks/Tutorial - Calculating Standard Metrics - Binary Classification.ipynb @@ -28,7 +28,7 @@ Just The Code - To learn how :class:`~nannyml.thresholds.ConstantThreshold` works and to set up custom threshold check out the :ref:`thresholds tutorial ` Walkthrough --------------- +----------- For simplicity this guide is based on a synthetic dataset included in the library, where the monitored model predicts whether a customer will repay a loan to buy a car. @@ -123,17 +123,16 @@ Apart from chunk-related data, the results data have a set of columns for each c The results can be plotted for visual inspection. Our plot contains several key elements. -* *The purple step plot* shows the performance in each chunk of the analysis period. Thick squared point - markers indicate the middle of these chunks. - -* *The blue step plot* shows the performance in each chunk of the reference period. Thick squared point markers indicate +* *The blue step plot* shows the performance in each chunk of the provided data. Thick squared point markers indicate the middle of these chunks. -* *The gray vertical line* splits the reference and analysis periods. +* *The gray vertical line* splits the reference and analysis data periods. -* *The red horizontal dashed lines* show upper and lower thresholds for alerting purposes. +* *The red horizontal dashed lines* show upper and lower thresholds that indicate the range of + expected performance values. -* *The red diamond-shaped point markers* in the middle of a chunk indicate that an alert has been raised. Alerts are caused by the performance crossing the upper or lower threshold. +* *The red diamond-shaped point markers* in the middle of a chunk indicate that an alert has been raised. + Alerts are caused by the performance crossing the upper or lower threshold. .. nbimport:: :path: ./example_notebooks/Tutorial - Calculating Standard Metrics - Binary Classification.ipynb diff --git a/docs/tutorials/performance_calculation/multiclass_performance_calculation.rst b/docs/tutorials/performance_calculation/multiclass_performance_calculation.rst index 708924365..7b469f0c4 100644 --- a/docs/tutorials/performance_calculation/multiclass_performance_calculation.rst +++ b/docs/tutorials/performance_calculation/multiclass_performance_calculation.rst @@ -12,17 +12,23 @@ We currently support the following **standard** metrics for multiclass classific * **recall** * **specificity** * **accuracy** + * **average_precision** For more information about estimating these metrics, refer to the :ref:`multiclass-standard-metric-calculation` section. We also support the following *complex* metric for multiclass classification performance calculation: * **confusion_matrix** + * **business_value:** a metric that combines the components of the confusion matrix using + user-specified weights for each element, allowing for a connection between model performance and + business results. -For more information about estimating this metrics, refer to the :ref:`multiclass-confusion-matrix-estimation` section. +For more information about calculating these metrics, refer to the :ref:`multiclass-confusion-matrix-calculation` +and :ref:`multiclass-business-value-calculation` sections. .. toctree:: :maxdepth: 2 multiclass_performance_calculation/standard_metric_calculation multiclass_performance_calculation/confusion_matrix_calculation + multiclass_performance_calculation/business_value_calculation diff --git a/docs/tutorials/performance_calculation/multiclass_performance_calculation/business_value_calculation.rst b/docs/tutorials/performance_calculation/multiclass_performance_calculation/business_value_calculation.rst new file mode 100644 index 000000000..4e085661d --- /dev/null +++ b/docs/tutorials/performance_calculation/multiclass_performance_calculation/business_value_calculation.rst @@ -0,0 +1,172 @@ +.. _multiclass-business-value-calculation: + +======================================================== +Calculating Business Value for Multiclass Classification +======================================================== + +This tutorial explains how to use NannyML to calculate business value for multiclass classification +models. + +.. note:: + The following example uses :term:`timestamps`. + These are optional but have an impact on the way data is chunked and results are plotted. + You can read more about them in the :ref:`data requirements`. + +.. _business-value-calculation-multiclass-just-the-code: + +Just The Code +------------- + +.. nbimport:: + :path: ./example_notebooks/Tutorial - Calculating Business Value - Multiclass Classification.ipynb + :cells: 1 3 4 5 7 9 + + +Walkthrough +----------- + +For simplicity this guide is based on a synthetic dataset where the monitored model predicts +which type of credit card product new customers should be assigned to. +Check out :ref:`Credit Card Dataset` to learn more about this dataset. + +In order to monitor a model, NannyML needs to learn about it from a reference dataset. +Then it can monitor the data that is subject to actual analysis, provided as the analysis dataset. +You can read more about this in our section on :ref:`data periods`. + +The ``analysis_targets`` dataframe contains the target results of the analysis period. +This is kept separate in the synthetic data because it is +not used during :ref:`performance estimation`. But it is required to calculate performance, +so the first thing we need to in this case is set up the right data in the right dataframes. + +The analysis target values are joined on the analysis frame by their index. +Your dataset may already contain the **target** column, so you may skip this join. + +.. nbimport:: + :path: ./example_notebooks/Tutorial - Calculating Business Value - Multiclass Classification.ipynb + :cells: 1 + +.. nbtable:: + :path: ./example_notebooks/Tutorial - Calculating Business Value - Multiclass Classification.ipynb + :cell: 2 + +Next a :class:`~nannyml.performance_calculation.calculator.PerformanceCalculator` is created with +the following parameter specifications: + + - **y_pred_proba:** a dictionary that maps the class names to the name of the column in the reference data + that contains the predicted probabilities for that class. + - **y_pred:** the name of the column in the reference data that + contains the predicted classes. + - **y_true:** the name of the column in the reference data that + contains the true classes. + - **timestamp_column_name (Optional):** the name of the column in the reference data that + contains timestamps. + - **problem_type:** the type of problem being monitored. In this example we + will monitor a binary classification problem. + - **metrics:** a list of metrics to calculate. In this example we + will calculate the ``business_value`` metric. + - **business_value_matrix:** A matrix that specifies the value of each corresponding cell in the confusion matrix. + - **normalize_business_value (Optional):** how to normalize the business value. + The normalization options are: + + * **None** : returns the total value per chunk + * **"per_prediction"** : returns the total value for the chunk divided by the number of observations + in a given chunk. + + - **chunk_size (Optional):** the number of observations in each chunk of data + used to calculate performance. For more information about + :term:`chunking` other chunking options check out the :ref:`chunking tutorial`. + - **thresholds (Optional):** the thresholds used to calculate the alert flag. For more information about + thresholds, check out the :ref:`thresholds tutorial`. + +.. nbimport:: + :path: ./example_notebooks/Tutorial - Calculating Business Value - Multiclass Classification.ipynb + :cells: 3 + +.. note:: + When calculating **business_value**, the ``business_value_matrix`` parameter is required. + A :term:`business value matrix` is a nxn matrix that specifies the value of each cell in the confusion matrix. + The format of the business value matrix must be specified so that each element represents the business + value of it's respective confusion matrix element. Hence the element on the i-th row and j-column of the + business value matrix tells us the value of the i-th target when we have predicted the j-th value. + It can be provided as a list of lists or a numpy array. + For more information about the business value matrix, + check out the :ref:`Business Value "How it Works" page`. + +The new :class:`~nannyml.performance_calculation.calculator.PerformanceCalculator` is fitted using the +:meth:`~nannyml.performance_calculation.calculator.PerformanceCalculator.fit` method on the **reference** data. + +.. nbimport:: + :path: ./example_notebooks/Tutorial - Calculating Business Value - Multiclass Classification.ipynb + :cells: 4 + +The fitted :class:`~nannyml.performance_calculation.calculator.PerformanceCalculator` can then be used to calculate +realized performance metrics on all data which has target values available with the +:meth:`~nannyml.performance_calculation.calculator.PerformanceCalculator.calculate` method. +NannyML can output a dataframe that contains all the results of the analysis data. + +.. nbimport:: + :path: ./example_notebooks/Tutorial - Calculating Business Value - Multiclass Classification.ipynb + :cells: 5 + +.. nbtable:: + :path: ./example_notebooks/Tutorial - Calculating Business Value - Multiclass Classification.ipynb + :cell: 6 + +The results from the reference data are also available. + +.. nbimport:: + :path: ./example_notebooks/Tutorial - Calculating Business Value - Multiclass Classification.ipynb + :cells: 7 + +.. nbtable:: + :path: ./example_notebooks/Tutorial - Calculating Business Value - Multiclass Classification.ipynb + :cell: 8 + +Apart from chunk and period-related columns, the results data have a set of columns for each +calculated metric. + + - **targets_missing_rate** - the fraction of missing target data. + - **value** - the realized metric value for a specific chunk. + - **sampling_error** - the estimate of the :term:`Sampling Error`. + - **upper_threshold** and **lower_threshold** - crossing these thresholds will raise an alert on significant + performance change. The thresholds are calculated based on the actual performance of the monitored model on chunks in + the **reference** partition. The thresholds are 3 standard deviations away from the mean performance calculated on + chunks. They are calculated during **fit** phase. + - **alert** - flag indicating potentially significant performance change. ``True`` if estimated performance crosses + upper or lower threshold. + +The results can be plotted for visual inspection. Our plot contains several key elements. + +* *The blue step plot* shows the performance in each chunk of the provided data. Thick squared point markers indicate + the middle of these chunks. + +* *The gray vertical line* splits the reference and analysis data periods. + +* *The red horizontal dashed lines* show upper and lower thresholds that indicate the range of + expected performance values. + +* *The red diamond-shaped point markers* in the middle of a chunk indicate that an alert has been raised. + Alerts are caused by the performance crossing the upper or lower threshold. + +.. nbimport:: + :path: ./example_notebooks/Tutorial - Calculating Business Value - Multiclass Classification.ipynb + :cells: 9 + +.. image:: /_static/tutorials/performance_calculation/multiclass/business_value.svg + +Additional information such as the chunk index range and chunk date range (if timestamps were provided) is shown in the hover for each chunk (these are +interactive plots, though only static views are included here). + +Insights +-------- + +After reviewing the performance calculation results, we should be able to clearly see how the business value +provided by the model while it is in use. Depending on the results we may report them or need to investigate +further. + + +What's Next +----------- + +If we decide further investigation is needed, the :ref:`Data Drift` functionality can help us to see +what feature changes may be contributing to any performance changes. diff --git a/docs/tutorials/performance_calculation/multiclass_performance_calculation/confusion_matrix_calculation.rst b/docs/tutorials/performance_calculation/multiclass_performance_calculation/confusion_matrix_calculation.rst index ea4e6787b..4a11d17f6 100644 --- a/docs/tutorials/performance_calculation/multiclass_performance_calculation/confusion_matrix_calculation.rst +++ b/docs/tutorials/performance_calculation/multiclass_performance_calculation/confusion_matrix_calculation.rst @@ -125,17 +125,16 @@ calculated metric. The results can be plotted for visual inspection. Our plot contains several key elements. -* *The purple step plot* shows the performance in each chunk of the analysis period. Thick squared point - markers indicate the middle of these chunks. - -* *The blue step plot* shows the performance in each chunk of the reference period. Thick squared point markers indicate +* *The blue step plot* shows the performance in each chunk of the provided data. Thick squared point markers indicate the middle of these chunks. -* *The gray vertical line* splits the reference and analysis periods. +* *The gray vertical line* splits the reference and analysis data periods. -* *The red horizontal dashed lines* show upper and lower thresholds for alerting purposes. +* *The red horizontal dashed lines* show upper and lower thresholds that indicate the range of + expected performance values. -* *The red diamond-shaped point markers* in the middle of a chunk indicate that an alert has been raised. Alerts are caused by the performance crossing the upper or lower threshold. +* *The red diamond-shaped point markers* in the middle of a chunk indicate that an alert has been raised. + Alerts are caused by the performance crossing the upper or lower threshold. .. nbimport:: :path: ./example_notebooks/Tutorial - Calculating Confusion Matrix - Multiclass Classification.ipynb diff --git a/docs/tutorials/performance_calculation/multiclass_performance_calculation/standard_metric_calculation.rst b/docs/tutorials/performance_calculation/multiclass_performance_calculation/standard_metric_calculation.rst index d0694874a..eecc411d6 100644 --- a/docs/tutorials/performance_calculation/multiclass_performance_calculation/standard_metric_calculation.rst +++ b/docs/tutorials/performance_calculation/multiclass_performance_calculation/standard_metric_calculation.rst @@ -61,6 +61,7 @@ The following metrics are currently supported: - ``recall`` - macro-averaged - ``specificity`` - macro-averaged - ``accuracy`` +- ``average_precision`` - macro-averaged For more information on metrics, check the :mod:`~nannyml.performance_calculation.metrics` module. @@ -109,7 +110,18 @@ Apart from chunk-related data, the results data have a set of columns for each c - **alert** - flag indicating potentially significant performance change. ``True`` if estimated performance crosses upper or lower threshold. -The results can be plotted for visual inspection: +The results can be plotted for visual inspection. Our plot contains several key elements. + +* *The blue step plot* shows the performance in each chunk of the provided data. Thick squared point markers indicate + the middle of these chunks. + +* *The gray vertical line* splits the reference and analysis data periods. + +* *The red horizontal dashed lines* show upper and lower thresholds that indicate the range of + expected performance values. + +* *The red diamond-shaped point markers* in the middle of a chunk indicate that an alert has been raised. + Alerts are caused by the performance crossing the upper or lower threshold. .. nbimport:: :path: ./example_notebooks/Tutorial - Realized Performance - Multiclass Classification.ipynb diff --git a/docs/tutorials/performance_estimation/binary_performance_estimation/business_value_estimation.rst b/docs/tutorials/performance_estimation/binary_performance_estimation/business_value_estimation.rst index beff3271d..25cb8f34d 100644 --- a/docs/tutorials/performance_estimation/binary_performance_estimation/business_value_estimation.rst +++ b/docs/tutorials/performance_estimation/binary_performance_estimation/business_value_estimation.rst @@ -128,17 +128,20 @@ that was estimated: These results can be also plotted. Our plots contains several key elements. -* *The purple step plot* shows the estimated performance in each chunk of the analysis period. Thick squared point +* The purple dashed step plot shows the estimated performance in each chunk of the provided data. Thick squared point markers indicate the middle of these chunks. -* *The low-saturated purple area* around the estimated performance in the analysis period corresponds to the :term:`confidence band` which is - calculated as the estimated performance +/- 3 times the estimated :term:`Sampling Error`. +* The black vertical line splits the reference and analysis periods. -* *The gray vertical line* splits the reference and analysis periods. +* *The low-saturated purple area* around the estimated performance in the analysis period corresponds to the + :term:`confidence band` which is calculated as the estimated performance +/- 3 times the + estimated :term:`Sampling Error`. -* *The red horizontal dashed lines* show upper and lower thresholds for alerting purposes. +* *The red horizontal dashed lines* show upper and lower thresholds that indicate the range of + expected performance values. -* *The red diamond-shaped point markers* in the middle of a chunk indicate that an alert has been raised. Alerts are caused by the estimated performance crossing the upper or lower threshold. +* *The red diamond-shaped point markers* in the middle of a chunk indicate that an alert has been raised. + Alerts are caused by the estimated performance crossing the upper or lower threshold. .. nbimport:: :path: ./example_notebooks/Tutorial - Estimating Business Value - Binary Classification.ipynb diff --git a/docs/tutorials/performance_estimation/binary_performance_estimation/confusion_matrix_estimation.rst b/docs/tutorials/performance_estimation/binary_performance_estimation/confusion_matrix_estimation.rst index e1d957106..14ffa386d 100644 --- a/docs/tutorials/performance_estimation/binary_performance_estimation/confusion_matrix_estimation.rst +++ b/docs/tutorials/performance_estimation/binary_performance_estimation/confusion_matrix_estimation.rst @@ -129,17 +129,20 @@ that was estimated: These results can be also plotted. Our plot contains several key elements. -* *The purple step plot* shows the estimated performance in each chunk of the analysis period. Thick squared point +* The purple dashed step plot shows the estimated performance in each chunk of the provided data. Thick squared point markers indicate the middle of these chunks. -* *The low-saturated purple area* around the estimated performance in the analysis period corresponds to the :term:`confidence band` which is - calculated as the estimated performance +/- 3 times the estimated :term:`Sampling Error`. +* The black vertical line splits the reference and analysis periods. -* *The gray vertical line* splits the reference and analysis periods. +* *The low-saturated purple area* around the estimated performance in the analysis period corresponds to the + :term:`confidence band` which is calculated as the estimated performance +/- 3 times the + estimated :term:`Sampling Error`. -* *The red horizontal dashed lines* show upper and lower thresholds for alerting purposes. +* *The red horizontal dashed lines* show upper and lower thresholds that indicate the range of + expected performance values. -* *The red diamond-shaped point markers* in the middle of a chunk indicate that an alert has been raised. Alerts are caused by the estimated performance crossing the upper or lower threshold. +* *The red diamond-shaped point markers* in the middle of a chunk indicate that an alert has been raised. + Alerts are caused by the estimated performance crossing the upper or lower threshold. .. nbimport:: :path: ./example_notebooks/Tutorial - Estimating Confusion Matrix - Binary Classification.ipynb diff --git a/docs/tutorials/performance_estimation/binary_performance_estimation/standard_metric_estimation.rst b/docs/tutorials/performance_estimation/binary_performance_estimation/standard_metric_estimation.rst index f6acf359a..031279a1f 100644 --- a/docs/tutorials/performance_estimation/binary_performance_estimation/standard_metric_estimation.rst +++ b/docs/tutorials/performance_estimation/binary_performance_estimation/standard_metric_estimation.rst @@ -125,17 +125,20 @@ that was estimated: These results can be also plotted. Our plot contains several key elements. -* *The purple step plot* shows the estimated performance in each chunk of the analysis period. Thick squared point +* The purple dashed step plot shows the estimated performance in each chunk of the provided data. Thick squared point markers indicate the middle of these chunks. -* *The low-saturated purple area* around the estimated performance in the analysis period corresponds to the :term:`confidence band` which is - calculated as the estimated performance +/- 3 times the estimated :term:`Sampling Error`. +* The black vertical line splits the reference and analysis periods. -* *The gray vertical line* splits the reference and analysis periods. +* *The low-saturated purple area* around the estimated performance in the analysis period corresponds to the + :term:`confidence band` which is calculated as the estimated performance +/- 3 times the + estimated :term:`Sampling Error`. -* *The red horizontal dashed lines* show upper and lower thresholds for alerting purposes. +* *The red horizontal dashed lines* show upper and lower thresholds that indicate the range of + expected performance values. -* *The red diamond-shaped point markers* in the middle of a chunk indicate that an alert has been raised. Alerts are caused by the estimated performance crossing the upper or lower threshold. +* *The red diamond-shaped point markers* in the middle of a chunk indicate that an alert has been raised. + Alerts are caused by the estimated performance crossing the upper or lower threshold. .. nbimport:: :path: ./example_notebooks/Tutorial - Estimating Standard Performance Metrics - Binary Classification.ipynb diff --git a/docs/tutorials/performance_estimation/multiclass_performance_estimation.rst b/docs/tutorials/performance_estimation/multiclass_performance_estimation.rst index 9640c1d40..7c4d28e69 100644 --- a/docs/tutorials/performance_estimation/multiclass_performance_estimation.rst +++ b/docs/tutorials/performance_estimation/multiclass_performance_estimation.rst @@ -27,3 +27,4 @@ refer to the :ref:`multiclass-confusion-matrix-estimation` section. multiclass_performance_estimation/standard_metric_estimation multiclass_performance_estimation/confusion_matrix_estimation + multiclass_performance_estimation/business_value_estimation diff --git a/docs/tutorials/performance_estimation/multiclass_performance_estimation/business_value_estimation.rst b/docs/tutorials/performance_estimation/multiclass_performance_estimation/business_value_estimation.rst new file mode 100644 index 000000000..dae8d3f76 --- /dev/null +++ b/docs/tutorials/performance_estimation/multiclass_performance_estimation/business_value_estimation.rst @@ -0,0 +1,169 @@ +.. _multiclasss-business-value-estimation: + +======================================================= +Estimating Business Value for Multiclass Classification +======================================================= + +This tutorial explains how to use NannyML to estimate business value for multiclass classification +models in the absence of target data. To find out how CBPE estimates performance metrics, +read the :ref:`explanation of Confidence-based Performance Estimation`. + +.. note:: + The following example uses :term:`timestamps`. + These are optional but have an impact on the way data is chunked and results are plotted. + You can read more about them in the :ref:`data requirements`. + +.. _business-value-estimation-multiclass-just-the-code: + +Just The Code +------------- + +.. nbimport:: + :path: ./example_notebooks/Tutorial - Estimating Business Value - Multiclass Classification.ipynb + :cells: 1 3 4 5 7 + + +Walkthrough +----------- + +For simplicity this guide is based on a synthetic dataset where the monitored model predicts +which type of credit card product new customers should be assigned to. +Check out :ref:`Credit Card Dataset` to learn more about this dataset. + +In order to monitor a model, NannyML needs to learn about it from a reference dataset. Then it can monitor the data that is subject to actual analysis, provided as the analysis dataset. +You can read more about this in our section on :ref:`data periods`. + +We start by loading the dataset we'll be using: + +.. nbimport:: + :path: ./example_notebooks/Tutorial - Estimating Business Value - Multiclass Classification.ipynb + :cells: 1 + +.. nbtable:: + :path: ./example_notebooks/Tutorial - Estimating Business Value - Multiclass Classification.ipynb + :cell: 2 + +Next we create the Confidence-based Performance Estimation +(:class:`~nannyml.performance_estimation.confidence_based.cbpe.CBPE`) +estimator. To initialize an estimator that estimates **business_value**, we specify the following +parameters: + + - **y_pred_proba:** the name of the column in the reference data that + contains the predicted probabilities. + - **y_pred:** the name of the column in the reference data that + contains the predicted classes. + - **y_true:** the name of the column in the reference data that + contains the true classes. + - **timestamp_column_name (Optional):** the name of the column in the reference data that + contains timestamps. + - **metrics:** a list of metrics to estimate. In this example we + will estimate the ``business_value`` metric. + - **chunk_size (Optional):** the number of observations in each chunk of data + used to estimate performance. For more information about + :term:`chunking` configurations check out the :ref:`chunking tutorial`. + - **problem_type:** the type of problem being monitored. In this example we + will monitor a multiclass classification problem. + - **business_value_matrix:** A matrix that specifies the value of each corresponding cell in the confusion matrix. + - **normalize_business_value (Optional):** how to normalize the business value. + The normalization options are: + + * **None** : returns the total value per chunk + * **"per_prediction"** : returns the total value for the chunk divided by the number of observations + in a given chunk. + + - **thresholds (Optional):** the thresholds used to calculate the alert flag. For more information about + thresholds, check out the :ref:`thresholds tutorial`. + +.. note:: + When calculating **business_value**, the ``business_value_matrix`` parameter is required. + A :term:`business value matrix` is a nxn matrix that specifies the value of each cell in the confusion matrix. + The format of the business value matrix must be specified so that each element represents the business + value of it's respective confusion matrix element. Hence the element on the i-th row and j-column of the + business value matrix tells us the value of the i-th target when we have predicted the j-th value. + It can be provided as a list of lists or a numpy array. + For more information about the business value matrix, + check out the :ref:`Business Value "How it Works" page`. + +.. nbimport:: + :path: ./example_notebooks/Tutorial - Estimating Business Value - Multiclass Classification.ipynb + :cells: 3 + +The :class:`~nannyml.performance_estimation.confidence_based.cbpe.CBPE` +estimator is then fitted using the +:meth:`~nannyml.performance_estimation.confidence_based.cbpe.CBPE.fit` method on the ``reference`` data. + +.. nbimport:: + :path: ./example_notebooks/Tutorial - Estimating Business Value - Multiclass Classification.ipynb + :cells: 4 + +The fitted ``estimator`` can be used to estimate performance on other data, for which performance cannot be calculated. +Typically, this would be used on the latest production data where target is missing. In our example this is +the ``analysis_df`` data. + +NannyML can then output a dataframe that contains all the results. Let's have a look at the results for analysis period +only. + +.. nbimport:: + :path: ./example_notebooks/Tutorial - Estimating Business Value - Multiclass Classification.ipynb + :cells: 5 + +.. nbtable:: + :path: ./example_notebooks/Tutorial - Estimating Business Value - Multiclass Classification.ipynb + :cell: 6 + +Apart from chunk-related data, the results data have the following columns for each metric +that was estimated: + + - **value** - the estimate of a metric for a specific chunk. + - **sampling_error** - the estimate of the :term:`sampling error`. + - **realized** - when **target** values are available for a chunk, the realized performance metric will also + be calculated and included within the results. + - **upper_confidence_boundary** and **lower_confidence_boundary** - These values show the :term:`confidence band` of the relevant metric + and are equal to estimated value +/- 3 times the estimated :term:`sampling error`. + - **upper_threshold** and **lower_threshold** - crossing these thresholds will raise an alert on significant + performance change. The thresholds are calculated based on the actual performance of the monitored model on chunks in + the **reference** partition. The thresholds are 3 standard deviations away from the mean performance calculated on + the reference chunks. + The thresholds are calculated during **fit** phase. + - **alert** - flag indicating potentially significant performance change. ``True`` if estimated performance crosses + upper or lower threshold. + +These results can be also plotted. Our plots contains several key elements. + +* The purple dashed step plot shows the estimated performance in each chunk of the provided data. Thick squared point + markers indicate the middle of these chunks. + +* The black vertical line splits the reference and analysis periods. + +* *The low-saturated purple area* around the estimated performance in the analysis period corresponds to the + :term:`confidence band` which is calculated as the estimated performance +/- 3 times the + estimated :term:`Sampling Error`. + +* *The red horizontal dashed lines* show upper and lower thresholds that indicate the range of + expected performance values. + +* *The red diamond-shaped point markers* in the middle of a chunk indicate that an alert has been raised. + Alerts are caused by the estimated performance crossing the upper or lower threshold. + +.. nbimport:: + :path: ./example_notebooks/Tutorial - Estimating Business Value - Multiclass Classification.ipynb + :cells: 7 + +.. image:: ../../../_static/tutorials/performance_estimation/multiclass/business_value.svg + +Additional information such as the chunk index range and chunk date range (if timestamps were provided) is shown in the hover for each chunk (these are +interactive plots, though only static views are included here). + +Insights +-------- + +After reviewing the performance estimation results, we should be able to see any indications of performance change that +NannyML has detected based upon the model's inputs and outputs alone. + + +What's next +----------- + +The :ref:`Data Drift` functionality can help us to understand whether data drift is causing the performance problem. +When the target values become available we can +:ref:`compared realized and estimated business value results`. diff --git a/docs/tutorials/performance_estimation/multiclass_performance_estimation/confusion_matrix_estimation.rst b/docs/tutorials/performance_estimation/multiclass_performance_estimation/confusion_matrix_estimation.rst index 571bb1672..faa9b7767 100644 --- a/docs/tutorials/performance_estimation/multiclass_performance_estimation/confusion_matrix_estimation.rst +++ b/docs/tutorials/performance_estimation/multiclass_performance_estimation/confusion_matrix_estimation.rst @@ -131,17 +131,20 @@ that was estimated: These results can be also plotted. Our plot contains several key elements. -* *The purple step plot* shows the estimated performance in each chunk of the analysis period. Thick squared point +* The purple dashed step plot shows the estimated performance in each chunk of the provided data. Thick squared point markers indicate the middle of these chunks. -* *The low-saturated purple area* around the estimated performance in the analysis period corresponds to the :term:`confidence band` which is - calculated as the estimated performance +/- 3 times the estimated :term:`Sampling Error`. +* The black vertical line splits the reference and analysis periods. -* *The gray vertical line* splits the reference and analysis periods. +* *The low-saturated purple area* around the estimated performance in the analysis period corresponds to the + :term:`confidence band` which is calculated as the estimated performance +/- 3 times the + estimated :term:`Sampling Error`. -* *The red horizontal dashed lines* show upper and lower thresholds for alerting purposes. +* *The red horizontal dashed lines* show upper and lower thresholds that indicate the range of + expected performance values. -* *The red diamond-shaped point markers* in the middle of a chunk indicate that an alert has been raised. Alerts are caused by the estimated performance crossing the upper or lower threshold. +* *The red diamond-shaped point markers* in the middle of a chunk indicate that an alert has been raised. + Alerts are caused by the estimated performance crossing the upper or lower threshold. .. nbimport:: :path: ./example_notebooks/Tutorial - Estimating Confusion Matrix - Multiclass Classification.ipynb diff --git a/docs/tutorials/performance_estimation/multiclass_performance_estimation/standard_metric_estimation.rst b/docs/tutorials/performance_estimation/multiclass_performance_estimation/standard_metric_estimation.rst index 75ae505fd..0b2556666 100644 --- a/docs/tutorials/performance_estimation/multiclass_performance_estimation/standard_metric_estimation.rst +++ b/docs/tutorials/performance_estimation/multiclass_performance_estimation/standard_metric_estimation.rst @@ -61,6 +61,7 @@ chunking check out the :ref:`chunking tutorial` and it's :ref:`advance - ``recall`` - macro-averaged - ``specificity`` - macro-averaged - ``accuracy`` + - ``average_precision`` - macro-averaged .. nbimport:: @@ -106,17 +107,20 @@ that was estimated: These results can be also plotted. Our plot contains several key elements. -* The purple dashed step plot shows the estimated performance in each chunk of the analysis period. Thick squared point +* The purple dashed step plot shows the estimated performance in each chunk of the provided data. Thick squared point markers indicate the middle of these chunks. * The black vertical line splits the reference and analysis periods. -* The low-saturated colored area around the estimated performance indicates the :ref:`sampling error`. +* *The low-saturated purple area* around the estimated performance in the analysis period corresponds to the + :term:`confidence band` which is calculated as the estimated performance +/- 3 times the + estimated :term:`Sampling Error`. -* The red horizontal dashed lines show upper and lower thresholds for alerting purposes. +* *The red horizontal dashed lines* show upper and lower thresholds that indicate the range of + expected performance values. -* If the estimated performance crosses the upper or lower threshold an alert is raised which is indicated with a red - diamond-shaped point marker in the middle of the chunk. +* *The red diamond-shaped point markers* in the middle of a chunk indicate that an alert has been raised. + Alerts are caused by the estimated performance crossing the upper or lower threshold. Description of tabular results above explains how the :term:`confidence bands` and thresholds are calculated. Additional information is shown in the hover (these are diff --git a/nannyml/performance_calculation/calculator.py b/nannyml/performance_calculation/calculator.py index fb35a784e..e0a390352 100644 --- a/nannyml/performance_calculation/calculator.py +++ b/nannyml/performance_calculation/calculator.py @@ -165,10 +165,11 @@ def __init__( observations for each true class. If 'predicted', the confusion matrix will be normalized by the total number of observations for each predicted class. business_value_matrix: Optional[Union[List, np.ndarray]], default=None - A matrix containing the business costs for each combination of true and predicted class. - The i-th row and j-th column entry of the matrix contains the business cost for predicting the - i-th class as the j-th class. The matrix must have the same number of rows and columns as the number - of classes in the problem. + A nxn matrix that specifies the value of each cell in the confusion matrix. + The format of the business value matrix must be specified so that each element represents the business + value of it's respective confusion matrix element. Hence the element on the i-th row and j-column of the + business value matrix tells us the value of the i-th target while we predicted the j-th value. + It can be provided as a list of lists or a numpy array. normalize_business_value: str, default=None Determines how the business value will be normalized. Allowed values are None and 'per_prediction'. If None, the business value will not be normalized and the value diff --git a/nannyml/performance_calculation/metrics/multiclass_classification.py b/nannyml/performance_calculation/metrics/multiclass_classification.py index 76d21ebfd..9d1ee098b 100644 --- a/nannyml/performance_calculation/metrics/multiclass_classification.py +++ b/nannyml/performance_calculation/metrics/multiclass_classification.py @@ -16,6 +16,7 @@ precision_score, recall_score, roc_auc_score, + average_precision_score, ) from sklearn.preprocessing import LabelBinarizer, label_binarize @@ -39,6 +40,10 @@ recall_sampling_error_components, specificity_sampling_error, specificity_sampling_error_components, + average_precision_sampling_error_components, + average_precision_sampling_error, + business_value_sampling_error_components, + business_value_sampling_error, ) from nannyml.thresholds import Threshold, calculate_threshold_values @@ -101,7 +106,7 @@ def _fit(self, reference_data: pd.DataFrame): _list_missing([self.y_true] + self.class_probability_columns, list(reference_data.columns)) reference_data, empty = common_nan_removal( reference_data[[self.y_true] + self.class_probability_columns], - [self.y_true] + self.class_probability_columns + [self.y_true] + self.class_probability_columns, ) if empty: self._sampling_error_components = [(np.NaN, 0) for clasz in self.classes] @@ -115,7 +120,8 @@ def _fit(self, reference_data: pd.DataFrame): "targets." ) raise InvalidArgumentsException( - "y_pred_proba class and class probabilities dictionary does not match reference data.") + "y_pred_proba class and class probabilities dictionary does not match reference data." + ) # sampling error binarized_y_true = list(label_binarize(reference_data[self.y_true], classes=self.classes).T) @@ -909,3 +915,269 @@ def get_chunk_record(self, chunk_data: pd.DataFrame) -> Dict[str, Union[float, b ) or (self.alert_thresholds is not None and (chunk_record[f"{column_name}"] > upper_threshold)) return chunk_record + + +@MetricFactory.register(metric='average_precision', use_case=ProblemType.CLASSIFICATION_MULTICLASS) +class MulticlassClassificationAP(Metric): + """Average Precision metric.""" + + y_pred_proba: Dict[str, str] + + def __init__( + self, + y_true: str, + y_pred: str, + threshold: Threshold, + y_pred_proba: Dict[str, str], + **kwargs, + ): + """Creates a new AP instance. + + Parameters + ---------- + y_true: str + The name of the column containing target values. + y_pred: str + The name of the column containing your model predictions. + threshold: Threshold + The Threshold instance that determines how the lower and upper threshold values will be calculated. + y_pred_proba: Union[str, Dict[str, str]] + Name(s) of the column(s) containing your model output. + + - For binary classification, pass a single string refering to the model output column. + - For multiclass classification, pass a dictionary that maps a class string to the column name \ + containing model outputs for that class. + """ + super().__init__( + name='average_precision', + y_true=y_true, + y_pred=y_pred, + threshold=threshold, + y_pred_proba=y_pred_proba, + lower_threshold_limit=0, + upper_threshold_limit=1, + components=[("Average Precision", "average_precision")], + ) + # FIXME: Should we check the y_pred_proba argument here to ensure it's a dict? + self.y_pred_proba: Dict[str, str] + + # sampling error + self._sampling_error_components: List[Tuple] = [] + + # classes and class probability columns + self.classes: List[str] = [""] + self.class_probability_columns: List[str] + + def __str__(self): + """Get string representation of metric.""" + return "average_precision" + + def _fit(self, reference_data: pd.DataFrame): + # set up sorted classes and prob_column_names to use across metric class + self.classes = class_labels(self.y_pred_proba) + self.class_probability_columns = [self.y_pred_proba[clazz] for clazz in self.classes] + _list_missing([self.y_true] + self.class_probability_columns, list(reference_data.columns)) + reference_data, empty = common_nan_removal( + reference_data[[self.y_true] + self.class_probability_columns], + [self.y_true] + self.class_probability_columns, + ) + if empty: + self._sampling_error_components = [(np.NaN, 0) for class_col in self.class_probability_columns] + else: + # sampling error + binarized_y_true = list(label_binarize(reference_data[self.y_true], classes=self.classes).T) + y_pred_proba = [reference_data[self.y_pred_proba[clazz]].T for clazz in self.classes] + self._sampling_error_components = average_precision_sampling_error_components( + y_true_reference=binarized_y_true, y_pred_proba_reference=y_pred_proba + ) + + def _calculate(self, data: pd.DataFrame): + if not isinstance(self.y_pred_proba, Dict): + raise InvalidArgumentsException( + f"'y_pred_proba' is of type {type(self.y_pred_proba)}\n" + f"multiclass use cases require 'y_pred_proba' to " + "be a dictionary mapping classes to columns." + ) + + # class_y_pred_proba_columns = model_output_column_names(self.y_pred_proba) + _list_missing([self.y_true] + self.class_probability_columns, data) + data, empty = common_nan_removal( + data[[self.y_true] + self.class_probability_columns], [self.y_true] + self.class_probability_columns + ) + if empty: + warnings.warn(f"Too many missing values, cannot calculate {self.display_name}. " f"Returning NaN.") + return np.NaN + + y_true = data[self.y_true] + y_pred_proba = data[self.class_probability_columns] + + if y_true.nunique() <= 1: + warnings.warn( + f"'{self.y_true}' only contains a single class for chunk, cannot calculate {self.display_name}. " + "Returning NaN." + ) + return np.NaN + else: + # https://scikit-learn.org/stable/modules/model_evaluation.html#precision-recall-f-measure-metrics + # average_precision_score always performs OVR averaging + return average_precision_score(y_true, y_pred_proba, average='macro') + + def _sampling_error(self, data: pd.DataFrame) -> float: + _list_missing([self.y_true] + self.class_probability_columns, data) + data, empty = common_nan_removal( + data[[self.y_true] + self.class_probability_columns], [self.y_true] + self.class_probability_columns + ) + if empty: + warnings.warn( + f"Too many missing values, cannot calculate {self.display_name} sampling error. " f"Returning NaN." + ) + return np.NaN + else: + return average_precision_sampling_error(self._sampling_error_components, data) + + +@MetricFactory.register(metric='business_value', use_case=ProblemType.CLASSIFICATION_MULTICLASS) +class MulticlassClassificationBusinessValue(Metric): + """Business Value metric.""" + + y_pred: str + y_pred_proba: Dict[str, str] + + def __init__( + self, + y_true: str, + y_pred: str, + threshold: Threshold, + business_value_matrix: Union[List, np.ndarray], + normalize_business_value: Optional[str] = None, + y_pred_proba: Optional[Dict[str, str]] = None, + **kwargs, + ): + """Creates a new Business Value instance. + + Parameters + ---------- + y_true: str + The name of the column containing target values. + y_pred: str + The name of the column containing your model predictions. + threshold: Threshold + The Threshold instance that determines how the lower and upper threshold values will be calculated. + business_value_matrix: Union[List, np.ndarray] + A nxn matrix that specifies the value of each cell in the confusion matrix. + The format of the business value matrix must be specified so that each element represents the business + value of it's respective confusion matrix element. Hence the element on the i-th row and j-column of the + business value matrix tells us the value of the i-th target while we predicted the j-th value. + It can be provided as a list of lists or a numpy array. + normalize_business_value: Optional[str], default=None + Determines how the business value will be normalized. Allowed values are None and 'per_prediction'. + y_pred_proba: Optional[str], default=None + Name(s) of the column(s) containing your model output. For binary classification, pass a single string + refering to the model output column. + """ + if normalize_business_value not in [None, "per_prediction"]: + raise InvalidArgumentsException( + f"normalize_business_value must be None or 'per_prediction', but got {normalize_business_value}" + ) + + super().__init__( + name='business_value', + y_true=y_true, + y_pred=y_pred, + y_pred_proba=y_pred_proba, + threshold=threshold, + components=[('Business Value', 'business_value')], + ) + + if business_value_matrix is None: + raise ValueError("business_value_matrix must be provided for 'business_value' metric") + + if not (isinstance(business_value_matrix, np.ndarray) or isinstance(business_value_matrix, list)): + raise ValueError( + f"business_value_matrix must be a numpy array or a list, but got {type(business_value_matrix)}" + ) + + if isinstance(business_value_matrix, list): + business_value_matrix = np.array(business_value_matrix) + _rows, _columns = business_value_matrix.shape + if _rows != _columns: + raise InvalidArgumentsException( + f"business_value_matrix is not a square matrix but has shape: {(_rows, _columns)}" + ) + + self.business_value_matrix = business_value_matrix + self.normalize_business_value: Optional[str] = normalize_business_value + + # sampling error + self._sampling_error_components: Tuple = () + + # if y_pred_proba is provided uses this to get information about number of classes in the problem. + if y_pred_proba: + if not isinstance(self.y_pred_proba, Dict): + raise InvalidArgumentsException( + f"'y_pred_proba' is of type {type(self.y_pred_proba)}\n" + f"multiclass use cases require 'y_pred_proba' to " + "be a dictionary mapping classes to columns." + ) + self.y_pred_proba: Dict[str, str] = y_pred_proba + self.classes: List[str] = class_labels(self.y_pred_proba) + + def __str__(self): + """Get string representation of metric.""" + return "business_value" + + def _fit(self, reference_data: pd.DataFrame): + _list_missing([self.y_true, self.y_pred], list(reference_data.columns)) + data, empty = common_nan_removal(reference_data[[self.y_true, self.y_pred]], [self.y_true, self.y_pred]) + if empty: + self._sampling_error_components = np.NaN, self.normalize_business_value + else: + # get class number from y_pred_proba if provided otherwise from reference y_true + # this way the code will work even if some classes are missing from reference + # provided the business value matrix is constructed correctly. + if self.classes: + num_classes = len(self.classes) + _classes = self.classes + else: + num_classes = reference_data[self.y_true].nunique() + _classes = sorted(list(reference_data[self.y_true].unique)) + if num_classes != self.business_value_matrix.shape[0]: + raise InvalidArgumentsException( + f"business_value_matrix has shape {self.business_value_matrix.shape} " + "but we have {num_classes} classes!" + ) + self._sampling_error_components = business_value_sampling_error_components( + y_true_reference=data[self.y_true], + y_pred_reference=data[self.y_pred], + business_value_matrix=self.business_value_matrix, + classes=_classes, + normalize_business_value=self.normalize_business_value, + ) + + def _calculate(self, data: pd.DataFrame): + _list_missing([self.y_true, self.y_pred], list(data.columns)) + data, empty = common_nan_removal(data[[self.y_true, self.y_pred]], [self.y_true, self.y_pred]) + if empty: + warnings.warn(f"'{self.y_true}' contains no data, cannot calculate business value. Returning NaN.") + return np.NaN + + y_true = data[self.y_true] + y_pred = data[self.y_pred] + + cm = confusion_matrix(y_true, y_pred, labels=self.classes) + if self.normalize_business_value == 'per_prediction': + with np.errstate(all="ignore"): + cm = cm / cm.sum(axis=0, keepdims=True) + cm = np.nan_to_num(cm) + + return (self.business_value_matrix * cm).sum() + + def _sampling_error(self, data: pd.DataFrame) -> float: + data, empty = common_nan_removal(data[[self.y_true, self.y_pred]], [self.y_true, self.y_pred]) + if empty: + warnings.warn( + f"Too many missing values, cannot calculate {self.display_name} sampling error. " "Returning NaN." + ) + return np.NaN + else: + return business_value_sampling_error(self._sampling_error_components, data) diff --git a/nannyml/performance_estimation/confidence_based/cbpe.py b/nannyml/performance_estimation/confidence_based/cbpe.py index 9739841c2..5a74d2477 100644 --- a/nannyml/performance_estimation/confidence_based/cbpe.py +++ b/nannyml/performance_estimation/confidence_based/cbpe.py @@ -148,8 +148,8 @@ def __init__( 'recall': StandardDeviationThreshold(), 'specificity': StandardDeviationThreshold(), 'accuracy': StandardDeviationThreshold(), - 'confusion_matrix': StandardDeviationThreshold(), # only for binary classification - 'business_value': StandardDeviationThreshold(), # only for binary classification + 'confusion_matrix': StandardDeviationThreshold(), + 'business_value': StandardDeviationThreshold(), } A dictionary allowing users to set a custom threshold for each method. It links a `Threshold` subclass @@ -171,9 +171,11 @@ def __init__( - 'predicted' - the confusion matrix will be normalized by the total number of observations for each \ predicted class. business_value_matrix: Optional[Union[List, np.ndarray]], default=None - A 2x2 matrix that specifies the value of each cell in the confusion matrix. - The format of the business value matrix must be specified as [[value_of_TN, value_of_FP], \ - [value_of_FN, value_of_TP]]. Required when estimating the 'business_value' metric. + A nxn matrix that specifies the value of each cell in the confusion matrix. + The format of the business value matrix must be specified so that each element represents the business + value of it's respective confusion matrix element. Hence the element on the i-th row and j-column of the + business value matrix tells us the value of the i-th target while we predicted the j-th value. + It can be provided as a list of lists or a numpy array. normalize_business_value: str, default=None Determines how the business value will be normalized. Allowed values are None and 'per_prediction'. @@ -357,6 +359,7 @@ def _estimate(self, data: pd.DataFrame, *args, **kwargs) -> Result: data = data.copy(deep=True) if self.problem_type == ProblemType.CLASSIFICATION_BINARY: + assert isinstance(self.y_pred_proba, str) required_cols = [self.y_pred_proba] if self.y_pred is not None: required_cols.append(self.y_pred) @@ -366,10 +369,10 @@ def _estimate(self, data: pd.DataFrame, *args, **kwargs) -> Result: # https://github.com/NannyML/nannyml/issues/98 data[f'uncalibrated_{self.y_pred_proba}'] = data[self.y_pred_proba] - assert isinstance(self.y_pred_proba, str) if self.needs_calibration: data[self.y_pred_proba] = self.calibrator.calibrate(data[self.y_pred_proba]) else: + assert isinstance(self.y_pred_proba, Dict) _list_missing([self.y_pred] + model_output_column_names(self.y_pred_proba), data) # We need uncalibrated data to calculate the realized performance on. @@ -377,7 +380,6 @@ def _estimate(self, data: pd.DataFrame, *args, **kwargs) -> Result: for class_proba in model_output_column_names(self.y_pred_proba): data[f'uncalibrated_{class_proba}'] = data[class_proba] - assert isinstance(self.y_pred_proba, Dict) data = _calibrate_predicted_probabilities(data, self.y_true, self.y_pred_proba, self._calibrators) chunks = self.chunker.split(data) diff --git a/nannyml/performance_estimation/confidence_based/metrics.py b/nannyml/performance_estimation/confidence_based/metrics.py index e7e0fde61..b5242a344 100644 --- a/nannyml/performance_estimation/confidence_based/metrics.py +++ b/nannyml/performance_estimation/confidence_based/metrics.py @@ -2354,7 +2354,8 @@ def _fit(self, reference_data: pd.DataFrame): "targets." ) raise InvalidArgumentsException( - "y_pred_proba class and class probabilities dictionary does not match reference data.") + "y_pred_proba class and class probabilities dictionary does not match reference data." + ) # sampling error binarized_y_true = list(label_binarize(reference_data[self.y_true], classes=self.classes).T) y_pred_proba = [reference_data['uncalibrated_' + self.y_pred_proba[clazz]].T for clazz in self.classes] @@ -2402,7 +2403,7 @@ def _sampling_error(self, data: pd.DataFrame) -> float: data, empty = common_nan_removal(data[needed_columns], needed_columns) if empty: warnings.warn( - f"Too many missing values, cannot calculate {self.display_name} sampling error. " f"Returning NaN." + f"Too many missing values, cannot calculate {self.display_name} sampling error. Returning NaN." ) return np.NaN else: @@ -3329,3 +3330,286 @@ def _sampling_error(self, data: pd.DataFrame) -> float: def _realized_performance(self, data: pd.DataFrame) -> float: return 0.0 + + +@MetricFactory.register('average_precision', ProblemType.CLASSIFICATION_MULTICLASS) +class MulticlassClassificationAP(Metric): + """CBPE multiclass classification AP Metric Class.""" + + def __init__( + self, + y_pred_proba: ModelOutputsType, + y_pred: str, + y_true: str, + chunker: Chunker, + threshold: Threshold, + timestamp_column_name: Optional[str] = None, + **kwargs, + ): + """Initialize CBPE multiclass classification AP Metric Class.""" + super().__init__( + name='average_precision', + y_pred_proba=y_pred_proba, + y_pred=y_pred, + y_true=y_true, + timestamp_column_name=timestamp_column_name, + chunker=chunker, + threshold=threshold, + components=[('Average Precision', 'average_precision')], + ) + # FIXME: Should we check the y_pred_proba argument here to ensure it's a dict? + self.y_pred_proba: Dict[str, str] + + # sampling error + self._sampling_error_components: List[Tuple] = [] + + # classes and class probability columns + self.classes: List[str] + self.class_probability_columns: List[str] + self.class_uncalibrated_y_pred_proba_columns: List[str] + + def _fit(self, reference_data: pd.DataFrame): + # set up sorted classes and prob_column_names to use across metric class + self.classes = class_labels(self.y_pred_proba) + self.class_probability_columns = [self.y_pred_proba[clazz] for clazz in self.classes] + self.class_uncalibrated_y_pred_proba_columns = ['uncalibrated_' + el for el in self.class_probability_columns] + + _list_missing([self.y_true] + self.class_uncalibrated_y_pred_proba_columns, list(reference_data.columns)) + # filter nans here + reference_data, empty = common_nan_removal( + reference_data[[self.y_true] + self.class_uncalibrated_y_pred_proba_columns], + [self.y_true] + self.class_uncalibrated_y_pred_proba_columns, + ) + if empty: + self._sampling_error_components = [(np.NaN, 0) for clazz in self.classes] + else: + # sampling error + binarized_y_true = list(label_binarize(reference_data[self.y_true], classes=self.classes).T) + y_pred_proba = [reference_data['uncalibrated_' + self.y_pred_proba[clazz]].T for clazz in self.classes] + self._sampling_error_components = mse.average_precision_sampling_error_components( + y_true_reference=binarized_y_true, y_pred_proba_reference=y_pred_proba + ) + + def _estimate(self, data: pd.DataFrame): + needed_columns = self.class_probability_columns + self.class_uncalibrated_y_pred_proba_columns + try: + data, empty = common_nan_removal(data, needed_columns) + except InvalidArgumentsException as ex: + if "not all present in provided data columns" in str(ex): + self._logger.debug(str(ex)) + return np.NaN + else: + raise ex + if empty: + self._logger.debug(f"Not enough data to compute estimated {self.display_name}.") + warnings.warn(f"Not enough data to compute estimated {self.display_name}.") + return np.NaN + + _, y_pred_probas, _ = _get_binarized_multiclass_predictions(data, self.y_pred, self.y_pred_proba) + _, y_pred_probas_uncalibrated, _ = _get_multiclass_uncalibrated_predictions( + data, self.y_pred, self.y_pred_proba + ) + ovr_estimates = [] + for el in range(len(y_pred_probas)): + ovr_estimates.append( + estimate_ap( + # sorting according to classes is/should_be the same across + # _get_binarized_multiclass_predictions and _get_multiclass_uncalibrated_predictions + y_pred_probas[el], + y_pred_probas_uncalibrated.iloc[:, el], + ) + ) + multiclass_ap = np.mean(ovr_estimates) + return multiclass_ap + + def _sampling_error(self, data: pd.DataFrame) -> float: + needed_columns = self.class_probability_columns + self.class_uncalibrated_y_pred_proba_columns + _list_missing(needed_columns, data) + data, empty = common_nan_removal(data[needed_columns], needed_columns) + if empty: + warnings.warn( + f"Too many missing values, cannot calculate {self.display_name} sampling error. " f"Returning NaN." + ) + return np.NaN + else: + return mse.average_precision_sampling_error(self._sampling_error_components, data) + + def _realized_performance(self, data: pd.DataFrame) -> float: + try: + data, empty = common_nan_removal(data, [self.y_true] + self.class_uncalibrated_y_pred_proba_columns) + except InvalidArgumentsException as ex: + if "not all present in provided data columns" in str(ex): + self._logger.debug(str(ex)) + return np.NaN + else: + raise ex + if empty: + warnings.warn(f"Too many missing values, cannot calculate {self.display_name}. " f"Returning NaN.") + return np.NaN + + y_true = data[self.y_true] + if y_true.nunique() <= 1: + warnings.warn("Too few unique values present in 'y_true', returning NaN as realized AP.") + return np.NaN + + _, y_pred_probas, _ = _get_multiclass_uncalibrated_predictions(data, self.y_pred, self.y_pred_proba) + + # https://scikit-learn.org/stable/modules/model_evaluation.html#precision-recall-f-measure-metrics + # average_precision_score always performs OVR averaging + return average_precision_score(y_true, y_pred_probas, average='macro') + + +@MetricFactory.register('business_value', ProblemType.CLASSIFICATION_MULTICLASS) +class MulticlassClassificationBusinessValue(Metric): + """CBPE multiclass classification Business Value Metric Class.""" + + y_pred_proba: Dict[str, str] + + def __init__( + self, + y_pred_proba: Dict[str, str], + y_pred: str, + y_true: str, + chunker: Chunker, + threshold: Threshold, + business_value_matrix: Union[List, np.ndarray], + normalize_business_value: Optional[str] = None, + timestamp_column_name: Optional[str] = None, + **kwargs, + ): + """Initialize CBPE multiclass classification Business Value Metric Class.""" + super().__init__( + name='business_value', + y_pred_proba=y_pred_proba, + y_pred=y_pred, + y_true=y_true, + timestamp_column_name=timestamp_column_name, + chunker=chunker, + threshold=threshold, + components=[('Business Value', 'business_value')], + ) + + if business_value_matrix is None: + raise ValueError("business_value_matrix must be provided for 'business_value' metric") + + if not (isinstance(business_value_matrix, np.ndarray) or isinstance(business_value_matrix, list)): + raise ValueError( + f"business_value_matrix must be a numpy array or a list, but got {type(business_value_matrix)}" + ) + + if isinstance(business_value_matrix, list): + business_value_matrix = np.array(business_value_matrix) + _rows, _columns = business_value_matrix.shape + if _rows != _columns: + raise InvalidArgumentsException( + f"business_value_matrix is not a square matrix but has shape: {(_rows, _columns)}" + ) + + self.business_value_matrix = business_value_matrix + self.normalize_business_value: Optional[str] = normalize_business_value + + self.classes: List[str] = class_labels(self.y_pred_proba) + self.class_probability_columns: List[str] + + # sampling error + self._sampling_error_components: Tuple = () + + def _fit(self, reference_data: pd.DataFrame): + _list_missing([self.y_true, self.y_pred], list(reference_data.columns)) + data, empty = common_nan_removal(reference_data[[self.y_true, self.y_pred]], [self.y_true, self.y_pred]) + if empty: + self._sampling_error_components = np.NaN, self.normalize_business_value + else: + num_classes = len(self.classes) + if num_classes != self.business_value_matrix.shape[0]: + raise InvalidArgumentsException( + f"business_value_matrix has shape {self.business_value_matrix.shape} " + f"but we have {num_classes} classes!" + ) + self._sampling_error_components = mse.business_value_sampling_error_components( + y_true_reference=data[self.y_true], + y_pred_reference=data[self.y_pred], + business_value_matrix=self.business_value_matrix, + classes=self.classes, + normalize_business_value=self.normalize_business_value, + ) + + self.class_probability_columns = [self.y_pred_proba[clazz] for clazz in self.classes] + + def _estimate(self, data: pd.DataFrame): + needed_columns = self.class_probability_columns + [self.y_pred] + try: + data, empty = common_nan_removal(data, needed_columns) + except InvalidArgumentsException as ex: + if "not all present in provided data columns" in str(ex): + self._logger.warning(str(ex)) + return np.NaN + else: + raise ex + + if empty: + self._logger.warning(f"Not enough data to compute estimated {self.display_name}.") + warnings.warn(f"Not enough data to compute estimated {self.display_name}.") + return np.NaN + + # TODO: put in a function? Also for MC CM. + y_pred_proba = {key: data[value] for key, value in self.y_pred_proba.items()} + y_pred = data[self.y_pred] + num_classes = len(self.classes) + est_confusion_matrix = np.zeros((num_classes, num_classes)) + # CM elements are properly ordered because y_pred_proba items are selected from self.classes[index] + for i in range(num_classes): + for j in range(num_classes): + est_confusion_matrix[i, j] = np.sum( + np.where( + (y_pred == self.classes[j]), + y_pred_proba[self.classes[i]], + 0, + ) + ) + + if self.normalize_business_value == 'per_prediction': + with np.errstate(all="ignore"): + est_confusion_matrix = est_confusion_matrix / est_confusion_matrix.sum(axis=0, keepdims=True) + est_confusion_matrix = np.nan_to_num(est_confusion_matrix) + + return (self.business_value_matrix * est_confusion_matrix).sum() + + def _sampling_error(self, data: pd.DataFrame) -> float: + needed_columns = self.class_probability_columns + [self.y_pred] + _list_missing(needed_columns, data) + data, empty = common_nan_removal(data[needed_columns], needed_columns) + if empty: + _message = f"Too many missing values, cannot calculate {self.display_name} sampling error. Returning NaN." + self._logger.warning(_message) + warnings.warn(_message) + return np.NaN + else: + return mse.business_value_sampling_error(self._sampling_error_components, data) + + def _realized_performance(self, data: pd.DataFrame) -> float: + try: + _list_missing([self.y_true, self.y_pred], data) + except InvalidArgumentsException as ex: + if "missing required columns" in str(ex): + self._logger.info(str(ex)) + return np.NaN + else: + raise ex + data, empty = common_nan_removal(data[[self.y_true, self.y_pred]], [self.y_true, self.y_pred]) + if empty: + _message = f"'{self.y_true}' contains no data, cannot calculate business value. Returning NaN." + self._logger.info(_message) + warnings.warn(_message) + return np.NaN + + y_true = data[self.y_true] + y_pred = data[self.y_pred] + + cm = confusion_matrix(y_true, y_pred, labels=self.classes) + if self.normalize_business_value == 'per_prediction': + with np.errstate(all="ignore"): + cm = cm / cm.sum(axis=0, keepdims=True) + cm = np.nan_to_num(cm) + + return (self.business_value_matrix * cm).sum() diff --git a/nannyml/sampling_error/binary_classification.py b/nannyml/sampling_error/binary_classification.py index c23231773..e5ccc2f3a 100644 --- a/nannyml/sampling_error/binary_classification.py +++ b/nannyml/sampling_error/binary_classification.py @@ -816,6 +816,8 @@ def business_value_sampling_error_components( Predictions for the reference dataset. business_value_matrix: np.ndarray A 2x2 matrix of values for the business problem. + normalize_business_value: Optional[str], default=None + Determines how the business value will be normalized. Allowed values are None and 'per_prediction'. Returns ------- components: tuple diff --git a/nannyml/sampling_error/multiclass_classification.py b/nannyml/sampling_error/multiclass_classification.py index 7b0bba596..9466b4a21 100644 --- a/nannyml/sampling_error/multiclass_classification.py +++ b/nannyml/sampling_error/multiclass_classification.py @@ -2,11 +2,20 @@ # Jakub Bialek # # License: Apache Software License 2.0 -from typing import List, Tuple, Union + +"""Module containing functions to estimate sampling error for multiclass classification metrics.""" + +from typing import List, Tuple, Union, Optional import numpy as np import pandas as pd -from sklearn.metrics import confusion_matrix +from sklearn.metrics import confusion_matrix, average_precision_score + + +# How many experiments to perform when doing resampling to approximate sampling error. +N_EXPERIMENTS = 50 +# Max resample size - we don't need full reference if it is too big. +MAX_RESAMPLE_SIZE = 50_000 def _standard_deviation_of_variances(components: List[Tuple], data) -> float: @@ -16,8 +25,7 @@ def _standard_deviation_of_variances(components: List[Tuple], data) -> float: def auroc_sampling_error_components(y_true_reference: List[pd.Series], y_pred_proba_reference: List[pd.Series]): - """ - Calculate sampling error components for AUROC using reference data. + """Calculate sampling error components for AUROC using reference data. The ``y_true_reference`` and ``y_pred_proba_reference`` lists represent the binarized target values and model probabilities. The order of the Series in both lists should both match the list of class labels present. @@ -68,13 +76,14 @@ def _get_class_components(y_true, y_pred_proba): def auroc_sampling_error(sampling_error_components, data) -> float: - """ - Calculate the AUROC sampling error for a chunk of data. + """Calculate the AUROC sampling error for a chunk of data. Parameters ---------- - sampling_error_components : a set of parameters that were derived from reference data. - data : the (analysis) data you want to calculate or estimate a metric for. + sampling_error_components: + a set of parameters that were derived from reference data. + data: + the (analysis) data you want to calculate or estimate a metric for. Returns ------- @@ -82,13 +91,13 @@ def auroc_sampling_error(sampling_error_components, data) -> float: """ class_variances = [c[0] / (len(data) * c[1]) for c in sampling_error_components] + # Experiments showed that std of class variances underestimated sampling error by 20% so we manually adjust result multiclass_std = np.sqrt(np.sum(class_variances)) / len(class_variances) * 1.2 return multiclass_std def f1_sampling_error_components(y_true_reference: List[pd.Series], y_pred_reference: List[pd.Series]): - """ - Calculate sampling error components for F1 using reference data. + """Calculate sampling error components for F1 using reference data. The ``y_true_reference`` and ``y_pred_proba_reference`` lists represent the binarized target values and model probabilities. The order of the Series in both lists should both match the list of class labels present. @@ -127,13 +136,14 @@ def _get_class_components(y_true, y_pred): def f1_sampling_error(sampling_error_components: List[Tuple], data) -> float: - """ - Calculate the F1 sampling error for a chunk of data. + """Calculate the F1 sampling error for a chunk of data. Parameters ---------- - sampling_error_components : a set of parameters that were derived from reference data. - data : the (analysis) data you want to calculate or estimate a metric for. + sampling_error_components: + a set of parameters that were derived from reference data. + data: + the (analysis) data you want to calculate or estimate a metric for. Returns ------- @@ -144,8 +154,7 @@ def f1_sampling_error(sampling_error_components: List[Tuple], data) -> float: def precision_sampling_error_components(y_true_reference: List[pd.Series], y_pred_reference: List[pd.Series]): - """ - Calculate sampling error components for precision using reference data. + """Calculate sampling error components for precision using reference data. The ``y_true_reference`` and ``y_pred_proba_reference`` lists represent the binarized target values and model probabilities. The order of the Series in both lists should both match the list of class labels present. @@ -182,13 +191,14 @@ def _get_class_components(y_true, y_pred): def precision_sampling_error(sampling_error_components: List[Tuple], data) -> float: - """ - Calculate the precision sampling error for a chunk of data. + """Calculate the precision sampling error for a chunk of data. Parameters ---------- - sampling_error_components : a set of parameters that were derived from reference data. - data : the (analysis) data you want to calculate or estimate a metric for. + sampling_error_components: + a set of parameters that were derived from reference data. + data: + the (analysis) data you want to calculate or estimate a metric for. Returns ------- @@ -199,8 +209,7 @@ def precision_sampling_error(sampling_error_components: List[Tuple], data) -> fl def recall_sampling_error_components(y_true_reference: List[pd.Series], y_pred_reference: List[pd.Series]): - """ - Calculate sampling error components for recall using reference data. + """Calculate sampling error components for recall using reference data. The ``y_true_reference`` and ``y_pred_proba_reference`` lists represent the binarized target values and model probabilities. The order of the Series in both lists should both match the list of class labels present. @@ -236,13 +245,14 @@ def _get_class_components(y_true, y_pred): def recall_sampling_error(sampling_error_components: List[Tuple], data) -> float: - """ - Calculate the recall sampling error for a chunk of data. + """Calculate the recall sampling error for a chunk of data. Parameters ---------- - sampling_error_components : a set of parameters that were derived from reference data. - data : the (analysis) data you want to calculate or estimate a metric for. + sampling_error_components: + a set of parameters that were derived from reference data. + data: + the (analysis) data you want to calculate or estimate a metric for. Returns ------- @@ -253,8 +263,7 @@ def recall_sampling_error(sampling_error_components: List[Tuple], data) -> float def specificity_sampling_error_components(y_true_reference: List[pd.Series], y_pred_reference: List[pd.Series]): - """ - Calculate sampling error components for specificity using reference data. + """Calculate sampling error components for specificity using reference data. The ``y_true_reference`` and ``y_pred_proba_reference`` lists represent the binarized target values and model probabilities. The order of the Series in both lists should both match the list of class labels present. @@ -290,13 +299,14 @@ def _get_class_components(y_true, y_pred): def specificity_sampling_error(sampling_error_components: List[Tuple], data) -> float: - """ - Calculate the specificity sampling error for a chunk of data. + """Calculate the specificity sampling error for a chunk of data. Parameters ---------- - sampling_error_components : a set of parameters that were derived from reference data. - data : the (analysis) data you want to calculate or estimate a metric for. + sampling_error_components: + a set of parameters that were derived from reference data. + data: + the (analysis) data you want to calculate or estimate a metric for. Returns ------- @@ -307,8 +317,7 @@ def specificity_sampling_error(sampling_error_components: List[Tuple], data) -> def accuracy_sampling_error_components(y_true_reference: List[pd.Series], y_pred_reference: List[pd.Series]): - """ - Calculate sampling error components for accuracy using reference data. + """Calculate sampling error components for accuracy using reference data. The ``y_true_reference`` and ``y_pred_proba_reference`` lists represent the binarized target values and model probabilities. The order of the Series in both lists should both match the list of class labels present. @@ -332,13 +341,14 @@ def accuracy_sampling_error_components(y_true_reference: List[pd.Series], y_pred def accuracy_sampling_error(sampling_error_components: Tuple, data) -> float: - """ - Calculate the accuracy sampling error for a chunk of data. + """Calculate the accuracy sampling error for a chunk of data. Parameters ---------- - sampling_error_components : a set of parameters that were derived from reference data. - data : the (analysis) data you want to calculate or estimate a metric for. + sampling_error_components: + a set of parameters that were derived from reference data. + data: + the (analysis) data you want to calculate or estimate a metric for. Returns ------- @@ -351,6 +361,7 @@ def accuracy_sampling_error(sampling_error_components: Tuple, data) -> float: def multiclass_confusion_matrix_sampling_error_components( y_true_reference: List[pd.Series], y_pred_reference: List[pd.Series], normalize_confusion_matrix: Union[str, None] ): + """Calculate sampling error components for CM using reference data.""" cm = confusion_matrix(y_true_reference, y_pred_reference) true_marginal = cm.sum(axis=1)[:, None] @@ -391,6 +402,7 @@ def multiclass_confusion_matrix_sampling_error_components( def multiclass_confusion_matrix_sampling_error(sampling_error_components: Tuple, data): + """Calculate the CM sampling error for a chunk of data.""" reference_stds, relevant_proportions = sampling_error_components if relevant_proportions is None: @@ -399,3 +411,138 @@ def multiclass_confusion_matrix_sampling_error(sampling_error_components: Tuple, standard_errors = reference_stds / np.sqrt(len(data) * relevant_proportions) return standard_errors + + +def average_precision_sampling_error_components( + y_true_reference: List[np.ndarray], y_pred_proba_reference: List[pd.Series] +): + """Calculate sampling error components for AP using reference data. + + The ``y_true_reference`` and ``y_pred_proba_reference`` lists represent the binarized target values and model + probabilities. The order of the Series in both lists should both match the list of class labels present. + + Parameters + ---------- + y_true_reference: List[np.ndarray] + Target values for the reference dataset. + y_pred_proba_reference: List[pd.Series] + Prediction probability values for the reference dataset. + + Returns + ------- + sampling_error_components: List[Tuple] + """ + + def _get_class_components(y_true_reference: np.ndarray, y_pred_proba_reference: pd.Series): + sample_size = np.minimum(y_true_reference.shape[0] // 2, MAX_RESAMPLE_SIZE) + + y_pred_proba_reference = y_pred_proba_reference.to_numpy() + + ap_results = [] + for _ in range(N_EXPERIMENTS): + _indexes_for_sample = np.random.choice(y_true_reference.shape[0], sample_size, replace=True) + sample_y_true_reference = y_true_reference[_indexes_for_sample] + sample_y_pred_proba_reference = y_pred_proba_reference[_indexes_for_sample] + ap_results.append(average_precision_score(sample_y_true_reference, sample_y_pred_proba_reference)) + return np.var(ap_results), sample_size + + class_components = [] + for y_true_class, y_pred_proba_class in zip(y_true_reference, y_pred_proba_reference): + class_components.append(_get_class_components(y_true_class, y_pred_proba_class)) + + return class_components + + +def average_precision_sampling_error(sampling_error_components, data) -> float: + """Calculate the AUROC sampling error for a chunk of data. + + Parameters + ---------- + sampling_error_components: + a set of parameters that were derived from reference data. + data: + the (chunk) data you want to calculate or estimate a metric for. + + Returns + ------- + sampling_error: float + """ + class_variances = [c[0] * c[1] / len(data) for c in sampling_error_components] + multiclass_std = np.sqrt(np.mean(class_variances)) + return multiclass_std + + +def _calculate_business_value_per_row( + row, + business_value_matrix: np.ndarray, + classes: List[str], +): + """Helper function that calculates business value per row in a dataframe. + + Intended to be used within a pandas apply function. + """ + cm = confusion_matrix(y_true=np.array([row.y_true]), y_pred=np.array([row.y_pred]), labels=classes) + bv = (cm * business_value_matrix).sum() + return bv + + +def business_value_sampling_error_components( + y_true_reference: pd.Series, + y_pred_reference: pd.Series, + business_value_matrix: np.ndarray, + classes: List[str], + normalize_business_value: Optional[str], +) -> Tuple[float, Union[str, None]]: + """Estimate sampling error for the false negative rate. + + Parameters + ---------- + y_true_reference: pd.Series + Target values for the reference dataset. + y_pred_reference: pd.Series + Predictions for the reference dataset. + business_value_matrix: np.ndarray + A nxn matrix of values for the business problem. + classes: List[str] + An alphanumerically sorted list of the unique classes in the multiclass problem + normalize_business_value: Optional[str], default=None + Determines how the business value will be normalized. Allowed values are None and 'per_prediction'. + + Returns + ------- + components: tuple + """ + data = pd.DataFrame( + { + 'y_true': y_true_reference, + 'y_pred': y_pred_reference, + } + ) + bvs = data.apply(lambda x: _calculate_business_value_per_row(x, business_value_matrix, classes), axis=1) + return (bvs.std(), normalize_business_value) + + +def business_value_sampling_error(sampling_error_components: Tuple, data) -> float: + """Calculate the false positive rate sampling error for a chunk of data. + + Parameters + ---------- + sampling_error_components: + a set of parameters that were derived from reference data. + data: + the (chunk) data you want to calculate or estimate a metric for. + + Returns + ------- + sampling_error: float + """ + (reference_std, norm_type) = sampling_error_components + _size = len(data) + + if norm_type is None: + analysis_std = reference_std * _size + else: # norm_type must be 'per_prediciton' + analysis_std = reference_std + + total_value_standard_error = analysis_std / np.sqrt(_size) + return total_value_standard_error diff --git a/tests/performance_calculation/metrics/test_multiclass_classification.py b/tests/performance_calculation/metrics/test_multiclass_classification.py index 1b3a96499..7f610dae3 100644 --- a/tests/performance_calculation/metrics/test_multiclass_classification.py +++ b/tests/performance_calculation/metrics/test_multiclass_classification.py @@ -5,6 +5,7 @@ """Unit tests for performance metrics.""" from typing import Tuple +import numpy as np import pandas as pd import pytest from logging import getLogger @@ -22,6 +23,8 @@ MulticlassClassificationPrecision, MulticlassClassificationRecall, MulticlassClassificationSpecificity, + MulticlassClassificationAP, + MulticlassClassificationBusinessValue ) from nannyml.thresholds import ConstantThreshold, StandardDeviationThreshold @@ -37,6 +40,11 @@ def multiclass_data() -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: # noq @pytest.fixture(scope='module') def performance_calculator() -> PerformanceCalculator: # noqa: D103 + business_value_matrix = np.array([ + [1, 0, -1], + [0, 1, 0], + [-1, 0, 1] + ]) return PerformanceCalculator( timestamp_column_name='timestamp', y_pred_proba={ @@ -46,13 +54,30 @@ def performance_calculator() -> PerformanceCalculator: # noqa: D103 }, y_pred='y_pred', y_true='y_true', - metrics=['roc_auc', 'f1', 'precision', 'recall', 'specificity', 'accuracy', 'confusion_matrix'], + metrics=[ + 'roc_auc', + 'f1', + 'precision', + 'recall', + 'specificity', + 'accuracy', + 'confusion_matrix', + 'average_precision', + 'business_value' + ], problem_type='classification_multiclass', + business_value_matrix=business_value_matrix, + normalize_business_value='per_prediction' ) @pytest.fixture(scope='module') def realized_performance_metrics(multiclass_data) -> pd.DataFrame: # noqa: D103 + business_value_matrix = np.array([ + [1, 0, -1], + [0, 1, 0], + [-1, 0, 1] + ]) performance_calculator = PerformanceCalculator( y_pred_proba={ 'prepaid_card': 'y_pred_proba_prepaid_card', @@ -61,11 +86,23 @@ def realized_performance_metrics(multiclass_data) -> pd.DataFrame: # noqa: D103 }, y_pred='y_pred', y_true='y_true', - metrics=['roc_auc', 'f1', 'precision', 'recall', 'specificity', 'accuracy', 'confusion_matrix'], + metrics=[ + 'roc_auc', + 'f1', + 'precision', + 'recall', + 'specificity', + 'accuracy', + 'confusion_matrix', + 'average_precision', + 'business_value' + ], problem_type='classification_multiclass', + business_value_matrix=business_value_matrix, + normalize_business_value='per_prediction' ).fit(multiclass_data[0]) results = performance_calculator.calculate( - multiclass_data[1].merge(multiclass_data[2], left_index=True, right_index=True) + multiclass_data[1].merge(multiclass_data[2], on='id', how='left') ).filter(period='analysis') return results.to_df() @@ -89,27 +126,36 @@ def no_timestamp_metrics(performance_calculator, multiclass_data) -> pd.DataFram ('specificity', ProblemType.CLASSIFICATION_MULTICLASS, MulticlassClassificationSpecificity), ('accuracy', ProblemType.CLASSIFICATION_MULTICLASS, MulticlassClassificationAccuracy), ('confusion_matrix', ProblemType.CLASSIFICATION_MULTICLASS, MulticlassClassificationConfusionMatrix), + ('average_precision', ProblemType.CLASSIFICATION_MULTICLASS, MulticlassClassificationAP), + ('business_value', ProblemType.CLASSIFICATION_MULTICLASS, MulticlassClassificationBusinessValue), ], ) def test_metric_factory_returns_correct_metric_given_key_and_problem_type(key, problem_type, metric): # noqa: D103 - calc = PerformanceCalculator( - timestamp_column_name='timestamp', - y_pred_proba={'class1': 'y_pred_proba1', 'class2': 'y_pred_proba2', 'class3': 'y_pred_proba3'}, - y_pred='y_pred', - y_true='y_true', - metrics=['roc_auc', 'f1'], - problem_type='classification_multiclass', - ) + y_pred_proba = { + 'prepaid_card': 'y_pred_proba_prepaid_card', + 'highstreet_card': 'y_pred_proba_highstreet_card', + 'upmarket_card': 'y_pred_proba_upmarket_card', + } + business_value_matrix = np.array([ + [1, 0, -1], + [0, 1, 0], + [-1, 0, 1] + ]) sut = MetricFactory.create( key, problem_type, - y_true=calc.y_true, - y_pred=calc.y_pred, - y_pred_proba=calc.y_pred_proba, + y_true='y_true', + y_pred='y_pred', + y_pred_proba=y_pred_proba, threshold=StandardDeviationThreshold(), + business_value_matrix=business_value_matrix ) assert sut == metric( - y_true=calc.y_true, y_pred=calc.y_pred, y_pred_proba=calc.y_pred_proba, threshold=StandardDeviationThreshold + y_true='y_true', + y_pred='y_pred', + y_pred_proba=y_pred_proba, + threshold=StandardDeviationThreshold, + business_value_matrix=business_value_matrix ) @@ -131,6 +177,8 @@ def test_metric_factory_returns_correct_metric_given_key_and_problem_type(key, p ('true_highstreet_card_pred_upmarket_card', [250, 237, 259, 251, 277, 330, 318, 302, 312, 326]), ('true_highstreet_card_pred_prepaid_card', [275, 261, 250, 248, 240, 421, 404, 396, 412, 390]), ('true_highstreet_card_pred_highstreet_card', [1457, 1536, 1451, 1450, 1488, 1322, 1346, 1397, 1353, 1354]), + ('average_precision', [0.83891, 0.8424, 0.84207, 0.844, 0.8364, 0.59673, 0.60133, 0.60421, 0.60751, 0.6052]), + ('business_value', [2.00122, 2.04414, 2.01853, 2.01854, 2.01693, 1.28921, 1.31007, 1.32972, 1.32404, 1.31623]) ], ) def test_metric_values_are_calculated_correctly(realized_performance_metrics, metric, expected): # noqa: D103 @@ -156,6 +204,8 @@ def test_metric_values_are_calculated_correctly(realized_performance_metrics, me ('true_highstreet_card_pred_upmarket_card', [250, 237, 259, 251, 277, 330, 318, 302, 312, 326]), ('true_highstreet_card_pred_prepaid_card', [275, 261, 250, 248, 240, 421, 404, 396, 412, 390]), ('true_highstreet_card_pred_highstreet_card', [1457, 1536, 1451, 1450, 1488, 1322, 1346, 1397, 1353, 1354]), + ('average_precision', [0.83891, 0.8424, 0.84207, 0.844, 0.8364, 0.59673, 0.60133, 0.60421, 0.60751, 0.6052]), + ('business_value', [2.00122, 2.04414, 2.01853, 2.01854, 2.01693, 1.28921, 1.31007, 1.32972, 1.32404, 1.31623]) ], ) def test_metric_values_without_timestamps_are_calculated_correctly( # noqa: D103 @@ -275,3 +325,58 @@ def test_auroc_errors_out_when_not_all_classes_are_represented_chunk(multiclass_ _ = performance_calculator.calculate(monitored) expected_exc_test = "does not contain all reported classes, cannot calculate" assert expected_exc_test in caplog.text + + +def test_business_value_getting_classes_from_y_pred_proba(multiclass_data): + reference, monitored, targets = multiclass_data + reference['y_true'] = 'prepaid_card' + monitored = monitored.merge(targets, on='id', how='left') + business_value_matrix = np.array([ + [1, 0, -1], + [0, 1, 0], + [-1, 0, 1] + ]) + calc = PerformanceCalculator( + y_pred_proba={ + 'prepaid_card': 'y_pred_proba_prepaid_card', + 'highstreet_card': 'y_pred_proba_highstreet_card', + 'upmarket_card': 'y_pred_proba_upmarket_card', + }, + y_pred='y_pred', + y_true='y_true', + metrics=['business_value'], + problem_type='classification_multiclass', + business_value_matrix=business_value_matrix, + normalize_business_value='per_prediction' + ).fit(reference) + results = calc.calculate(monitored) + assert [ + 2.00122, 2.04414, 2.01853, 2.01854, 2.01693, 1.28921, 1.31007, 1.32972, 1.32404, 1.31623 + ] == list( + results.filter(period='analysis').to_df().round(5).loc[:, ('business_value', 'value')] + ) + + +# TODO: At the moment the test below is invalid because y_pred_proba is mandatory. Uncomment when it is not. +# def test_business_value_getting_classes_without_y_pred_proba(multiclass_data): +# reference, monitored, targets = multiclass_data +# monitored = monitored.merge(targets, on='id', how='left') +# business_value_matrix = np.array([ +# [1, 0, -1], +# [0, 1, 0], +# [-1, 0, 1] +# ]) +# calc = PerformanceCalculator( +# y_pred='y_pred', +# y_true='y_true', +# metrics=['business_value'], +# problem_type='classification_multiclass', +# business_value_matrix=business_value_matrix, +# normalize_business_value='per_prediction' +# ).fit(reference) +# results = calc.calculate(monitored) +# assert [ +# 2.00122, 2.04414, 2.01853, 2.01854, 2.01693, 1.28921, 1.31007, 1.32972, 1.32404, 1.31623 +# ] == list( +# results.filter(period='analysis').to_df().round(5).loc[:, ('business_value', 'value')] +# ) diff --git a/tests/performance_estimation/CBPE/test_cbpe_metrics.py b/tests/performance_estimation/CBPE/test_cbpe_metrics.py index 5335c1013..23f2f8bee 100644 --- a/tests/performance_estimation/CBPE/test_cbpe_metrics.py +++ b/tests/performance_estimation/CBPE/test_cbpe_metrics.py @@ -1,4 +1,5 @@ """Tests.""" +import re import pandas as pd import numpy as np @@ -2657,6 +2658,8 @@ def test_cbpe_for_binary_classification_with_timestamps(calculator_opts, expecte 'estimated_recall': [0.7564129287764665, 0.6934788458355289, 0.6319310599943714], 'estimated_specificity': [0.8782068281303994, 0.8469556750949159, 0.8172644220189141], 'estimated_accuracy': [0.7564451493123628, 0.6946947603445697, 0.6378557309960986], + 'estimated_average_precision': [0.8418535417603635, 0.7785618577588246, 0.6985785036188713], + 'estimated_business_value': [2.0193901626043056, 1.7875283323693987, 1.570045452479401], 'estimated_true_highstreet_card_pred_highstreet_card': [ 4976.829215997277, 5148.649186425118, @@ -2716,6 +2719,8 @@ def test_cbpe_for_binary_classification_with_timestamps(calculator_opts, expecte 'estimated_recall': [0.7564129287764665, 0.6934788458355289, 0.6319310599943714], 'estimated_specificity': [0.8782068281303994, 0.8469556750949159, 0.8172644220189141], 'estimated_accuracy': [0.7564451493123628, 0.6946947603445697, 0.6378557309960986], + 'estimated_average_precision': [0.8418535417603635, 0.7785618577588246, 0.6985785036188713], + 'estimated_business_value': [2.0193901626043056, 1.7875283323693987, 1.570045452479401], 'estimated_true_highstreet_card_pred_highstreet_card': [ 0.7442780881812128, 0.7170050012869645, @@ -2800,6 +2805,18 @@ def test_cbpe_for_binary_classification_with_timestamps(calculator_opts, expecte 0.6364205304514962, 0.6375753072973162, ], + 'estimated_average_precision': [ + 0.8406535565924922, + 0.8410572134298334, + 0.697327636452664, + 0.6984330753389926, + ], + 'estimated_business_value': [ + 2.0134445826512186, + 2.0170794978486395, + 1.5673705142973104, + 1.5671595942359196, + ], 'estimated_true_highstreet_card_pred_highstreet_card': [ 0.7546260682147157, 0.7511343683695074, @@ -2893,6 +2910,18 @@ def test_cbpe_for_binary_classification_with_timestamps(calculator_opts, expecte 0.6364205304514962, 0.6375753072973162, ], + 'estimated_average_precision': [ + 0.8406535565924922, + 0.8410572134298334, + 0.697327636452664, + 0.6984330753389926, + ], + 'estimated_business_value': [ + 2.0134445826512186, + 2.0170794978486395, + 1.5673705142973104, + 1.5671595942359196, + ], 'estimated_true_highstreet_card_pred_highstreet_card': [ 0.24922783612904678, 0.24847524905663304, @@ -2961,6 +2990,8 @@ def test_cbpe_for_binary_classification_with_timestamps(calculator_opts, expecte 'estimated_recall': [0.6957620347508907, 0.6272720458900231], 'estimated_specificity': [0.8480220572478717, 0.8145095377877009], 'estimated_accuracy': [0.6967957612985849, 0.6305270354546132], + 'estimated_average_precision': [0.7812291182204878, 0.6907845497417768], + 'estimated_business_value': [1.7964098918968543, 1.5447162372665988], 'estimated_true_highstreet_card_pred_highstreet_card': [15431.207920621628, 106.61852759787631], 'estimated_true_highstreet_card_pred_prepaid_card': [3140.1950482057946, 27.27202363566655], 'estimated_true_highstreet_card_pred_upmarket_card': [2911.0243109194275, 24.485771034437157], @@ -3061,6 +3092,30 @@ def test_cbpe_for_binary_classification_with_timestamps(calculator_opts, expecte 0.6365172577468735, 0.6393273094601863, ], + 'estimated_average_precision': [ + 0.838071, + 0.843094, + 0.842962, + 0.841563, + 0.838078, + 0.696295, + 0.699327, + 0.695691, + 0.696305, + 0.701142, + ], + 'estimated_business_value': [ + 2.0086174744097525, + 2.0167085528014574, + 2.025151984316981, + 2.018928025883902, + 2.006521418618063, + 1.5644425523502847, + 1.5684601001268144, + 1.5620405529135275, + 1.5668663365944273, + 1.574249644290713, + ], 'estimated_true_highstreet_card_pred_highstreet_card': [ 1483.745037516118, 1536.2546154566053, @@ -3260,6 +3315,30 @@ def test_cbpe_for_binary_classification_with_timestamps(calculator_opts, expecte 0.6365172577468735, 0.6393273094601863, ], + 'estimated_average_precision': [ + 0.838071, + 0.843094, + 0.842962, + 0.841563, + 0.838078, + 0.696295, + 0.699327, + 0.695691, + 0.696305, + 0.701142, + ], + 'estimated_business_value': [ + 2.0086174744097525, + 2.0167085528014574, + 2.025151984316981, + 2.018928025883902, + 2.006521418618063, + 1.5644425523502847, + 1.5684601001268144, + 1.5620405529135275, + 1.5668663365944273, + 1.574249644290713, + ], 'estimated_true_highstreet_card_pred_highstreet_card': [ 1483.745037516118, 1536.2546154566053, @@ -3384,6 +3463,7 @@ def test_cbpe_for_binary_classification_with_timestamps(calculator_opts, expecte ) def test_cbpe_for_multiclass_classification_with_timestamps(calculator_opts, expected): # noqa: D103 ref_df, ana_df, _ = load_synthetic_multiclass_classification_dataset() + business_value_matrix = np.array([[1, 0, -1], [0, 1, 0], [-1, 0, 1]]) cbpe = CBPE( y_pred_proba={ 'upmarket_card': 'y_pred_proba_upmarket_card', @@ -3393,7 +3473,19 @@ def test_cbpe_for_multiclass_classification_with_timestamps(calculator_opts, exp y_pred='y_pred', y_true='y_true', problem_type='classification_multiclass', - metrics=['roc_auc', 'f1', 'precision', 'recall', 'specificity', 'accuracy', 'confusion_matrix'], + metrics=[ + 'roc_auc', + 'f1', + 'precision', + 'recall', + 'specificity', + 'accuracy', + 'average_precision', + 'confusion_matrix', + 'business_value', + ], + business_value_matrix=business_value_matrix, + normalize_business_value='per_prediction', **calculator_opts, ).fit(ref_df) result = cbpe.estimate(ana_df) @@ -3419,6 +3511,8 @@ def test_cbpe_for_multiclass_classification_with_timestamps(calculator_opts, exp 'estimated_recall', 'estimated_specificity', 'estimated_accuracy', + 'estimated_average_precision', + 'estimated_business_value', 'estimated_true_highstreet_card_pred_highstreet_card', 'estimated_true_highstreet_card_pred_prepaid_card', 'estimated_true_highstreet_card_pred_upmarket_card', @@ -3446,7 +3540,9 @@ def test_cbpe_for_multiclass_classification_with_timestamps(calculator_opts, exp BinaryClassificationConfusionMatrix, ], ) -def test_method_logs_warning_when_lower_threshold_is_overridden_by_metric_limits(caplog, metric_cls): # noqa: D103, E501 +def test_method_logs_warning_when_lower_threshold_is_overridden_by_metric_limits( + caplog, metric_cls +): # noqa: D103, E501 reference, _, _ = load_synthetic_binary_classification_dataset() # TODO: move this from CBPE to metrics @@ -3483,6 +3579,8 @@ def test_method_logs_warning_when_lower_threshold_is_overridden_by_metric_limits 'realized_recall': [0.759149, 0.658760, np.nan], 'realized_specificity': [0.879632, 0.829581, np.nan], 'realized_accuracy': [0.75925, 0.65950, np.nan], + 'realized_average_precision': [0.841830, 0.738332, np.nan], + 'realized_business_value': [2.029064521843538, 1.6533562273847497, np.nan], 'realized_true_highstreet_card_pred_highstreet_card': [ 4912.0, 4702.0, @@ -3531,13 +3629,14 @@ def test_method_logs_warning_when_lower_threshold_is_overridden_by_metric_limits } ), ), - ] + ], ) def test_cbpe_for_multiclass_classification_cm_with_nans(calculator_opts, realized): # noqa: D103 """Test Nan Handling of CM MC metric.""" reference, analysis, targets = load_synthetic_multiclass_classification_dataset() analysis = analysis.merge(targets, left_index=True, right_index=True) analysis.y_true[-20_000:] = np.nan + business_value_matrix = np.array([[1, 0, -1], [0, 1, 0], [-1, 0, 1]]) cbpe = CBPE( y_pred_proba={ 'upmarket_card': 'y_pred_proba_upmarket_card', @@ -3547,7 +3646,19 @@ def test_cbpe_for_multiclass_classification_cm_with_nans(calculator_opts, realiz y_pred='y_pred', y_true='y_true', problem_type='classification_multiclass', - metrics=['roc_auc', 'f1', 'precision', 'recall', 'specificity', 'accuracy', 'confusion_matrix'], + metrics=[ + 'roc_auc', + 'f1', + 'precision', + 'recall', + 'specificity', + 'accuracy', + 'average_precision', + 'confusion_matrix', + 'business_value', + ], + business_value_matrix=business_value_matrix, + normalize_business_value='per_prediction', **calculator_opts, ).fit(reference) result = cbpe.estimate(analysis) @@ -3573,6 +3684,8 @@ def test_cbpe_for_multiclass_classification_cm_with_nans(calculator_opts, realiz 'realized_recall', 'realized_specificity', 'realized_accuracy', + 'realized_average_precision', + 'realized_business_value', 'realized_true_highstreet_card_pred_highstreet_card', 'realized_true_highstreet_card_pred_prepaid_card', 'realized_true_highstreet_card_pred_upmarket_card', @@ -3594,7 +3707,7 @@ def test_auroc_errors_out_when_not_all_classes_are_represented_reference(): 'prepaid_card': 'y_pred_proba_prepaid_card', 'highstreet_card': 'y_pred_proba_highstreet_card', 'upmarket_card': 'y_pred_proba_upmarket_card', - 'clazz': 'y_pred_proba_clazz' + 'clazz': 'y_pred_proba_clazz', }, y_pred='y_pred', y_true='y_true', @@ -3625,7 +3738,7 @@ def test_auroc_errors_out_when_not_all_classes_are_represented_chunk(caplog): 'prepaid_card': 'y_pred_proba_prepaid_card', 'highstreet_card': 'y_pred_proba_highstreet_card', 'upmarket_card': 'y_pred_proba_upmarket_card', - 'clazz': 'y_pred_proba_clazz' + 'clazz': 'y_pred_proba_clazz', }, y_pred='y_pred', y_true='y_true', @@ -3636,3 +3749,60 @@ def test_auroc_errors_out_when_not_all_classes_are_represented_chunk(caplog): _ = calc.estimate(monitored) expected_exc_test = "does not contain all reported classes, cannot calculate" assert expected_exc_test in caplog.text + + +def test_cbpe_multiclass_business_value_matrix_square_requirement(): # noqa: D103 + """Test business value matrix.""" + reference, analysis, targets = load_synthetic_multiclass_classification_dataset() + analysis = analysis.merge(targets, left_index=True, right_index=True) + business_value_matrix = np.array( + [ + [1, 0, -1], + [0, 1, 0], + ] + ) + with pytest.raises(InvalidArgumentsException, match="business_value_matrix is not a square matrix but has shape:"): + _ = CBPE( + y_pred_proba={ + 'upmarket_card': 'y_pred_proba_upmarket_card', + 'highstreet_card': 'y_pred_proba_highstreet_card', + 'prepaid_card': 'y_pred_proba_prepaid_card', + }, + y_pred='y_pred', + y_true='y_true', + problem_type='classification_multiclass', + metrics=['business_value'], + business_value_matrix=business_value_matrix, + normalize_business_value='per_prediction', + chunk_number=1, + ) + + +def test_cbpe_multiclass_business_value_matrix_classes_and_bvm_shape(): # noqa: D103 + """Test business value matrix.""" + reference, _, _ = load_synthetic_multiclass_classification_dataset() + business_value_matrix = np.array( + [ + [1, 0, -1, 0], + [0, 1, 0, 0], + [0, 1, 0, 0], + [0, 1, 0, 0], + ] + ) + with pytest.raises( + InvalidArgumentsException, match=re.escape("business_value_matrix has shape (4, 4) but we have 3 classes!") + ): + _ = CBPE( + y_pred_proba={ + 'upmarket_card': 'y_pred_proba_upmarket_card', + 'highstreet_card': 'y_pred_proba_highstreet_card', + 'prepaid_card': 'y_pred_proba_prepaid_card', + }, + y_pred='y_pred', + y_true='y_true', + problem_type='classification_multiclass', + metrics=['business_value'], + business_value_matrix=business_value_matrix, + normalize_business_value='per_prediction', + chunk_number=1, + ).fit(reference)