From 2ef024e0e4bdf0152db347609e201b20f9819a12 Mon Sep 17 00:00:00 2001 From: Niels <94110348+nnansters@users.noreply.github.com> Date: Fri, 16 Sep 2022 19:03:09 +0200 Subject: [PATCH] Feature/optional timestamp (#121) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Adding chunk index to calculation / estimation results * Adapted step plots * Adapted joy plots * Adapted stacked bar plots * Added tests * Update quickstart docs to not use timestamps * Add tests for runner.py * Clean up temp dirs after testing * Move chunk date assignment to inheriting classes (each chunker should decide to support dates or not) * Display "unified" chunk index (taking reference data into account) in hover * Updated docs warning about using / not using timestamps in the corresponding code samples * Updated CHANGELOG.md * Bump version: 0.6.1 → 0.6.2 * Hopefully fix weird build error due to random rounding?? * Fix linting issue Co-authored-by: Niels Nuyttens --- .bumpversion.cfg | 2 +- CHANGELOG.md | 18 + README.md | 4 +- ...quick-start-drift-distance_from_office.svg | 2 +- .../quick-start-drift-gas_price_per_litre.svg | 2 +- .../quick-start-drift-multivariate.svg | 2 +- ...start-drift-public_transportation_cost.svg | 2 +- .../quick-start-drift-salary_range.svg | 2 +- docs/_static/quick-start-drift-tenure.svg | 2 +- .../quick-start-drift-wfh_prev_workday.svg | 2 +- docs/_static/quick-start-drift-workday.svg | 2 +- docs/_static/quick-start-perf-est.svg | 2 +- docs/_static/quick-start-score-drift.svg | 2 +- docs/example_notebooks/Quickstart.ipynb | 8 +- docs/quick.rst | 6 + docs/tutorials/data_requirements.rst | 23 +- ...or_binary_classification_model_outputs.rst | 6 + ...ulticlass_classification_model_outputs.rst | 6 + ...detection_for_regression_model_outputs.rst | 6 + ...or_binary_classification_model_targets.rst | 6 + ...ulticlass_classification_model_targets.rst | 6 + ...detection_for_regression_model_targets.rst | 6 + .../binary_performance_calculation.rst | 6 + .../multiclass_performance_calculation.rst | 6 + .../regression_performance_calculation.rst | 6 + .../binary_performance_estimation.rst | 6 + .../multiclass_performance_estimation.rst | 6 + .../regression_performance_estimation.rst | 6 + nannyml/__init__.py | 2 +- nannyml/base.py | 13 +- nannyml/chunk.py | 107 +-- .../data_reconstruction/calculator.py | 16 +- .../data_reconstruction/results.py | 4 + .../univariate/statistical/calculator.py | 13 +- .../univariate/statistical/results.py | 24 +- .../univariate/statistical/calculator.py | 15 +- .../univariate/statistical/results.py | 39 +- .../target/target_distribution/calculator.py | 9 +- .../target/target_distribution/result.py | 26 +- nannyml/performance_calculation/calculator.py | 9 +- .../performance_calculation/metrics/base.py | 1 - nannyml/performance_calculation/result.py | 4 + .../_cbpe_binary_classification.py | 7 +- .../_cbpe_multiclass_classification.py | 7 +- .../confidence_based/cbpe.py | 16 +- .../confidence_based/metrics.py | 1 - .../confidence_based/results.py | 4 + .../direct_loss_estimation/dle.py | 8 +- .../direct_loss_estimation/metrics.py | 1 - .../direct_loss_estimation/result.py | 4 + nannyml/plots/_joy_plot.py | 102 ++- nannyml/plots/_stacked_bar_plot.py | 86 +- nannyml/plots/_step_plot.py | 210 ++++- nannyml/runner.py | 22 +- pyproject.toml | 2 +- tests/drift/test_data_reconstruction_drift.py | 118 ++- tests/drift/test_drift.py | 167 +++- tests/drift/test_output_drift.py | 860 +++++++++++++++++- tests/drift/test_target_distribution.py | 697 +++++++++++++- ... implementation visualisation update.ipynb | 198 +++- .../metrics/test_binary_classification.py | 28 + .../metrics/test_multiclass_classification.py | 35 +- .../metrics/test_regression.py | 64 ++ .../test_performance_calculator.py | 105 ++- .../performance_estimation/CBPE/test_cbpe.py | 70 ++ .../CBPE/test_cbpe_metrics.py | 726 +++++++++++++++ tests/performance_estimation/DLE/test_dle.py | 32 + .../DLE/test_dle_metrics.py | 325 +++++++ tests/performance_estimation/test_base.py | 2 +- tests/test_chunk.py | 99 +- tests/test_runner.py | 124 +++ 71 files changed, 4190 insertions(+), 365 deletions(-) create mode 100644 tests/test_runner.py diff --git a/.bumpversion.cfg b/.bumpversion.cfg index 83e2dc14..0ec714fe 100644 --- a/.bumpversion.cfg +++ b/.bumpversion.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 0.6.1 +current_version = 0.6.2 commit = True tag = True diff --git a/CHANGELOG.md b/CHANGELOG.md index 773cfc5c..41a60f2b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,24 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [0.6.2] - 2022-09-16 + +### Changed + +- Made the `timestamp_column_name` required by all calculators and estimators optional. The main consequences of this + are plots have a chunk-index based x-axis now when no timestamp column name was given. You can also not chunk by + period when the timestamp column name is not specified. + +### Fixed + +- Added missing `s3fs` dependency +- Fixed outdated plotting kind constants in the runner (used by CLI) +- Fixed some missing images and incorrect version numbers in the README, thanks [@NeoKish](https://github.com/NeoKish)! + +### Added + +- Added a lot of additional tests, mainly concerning plotting and the [`Runner`](nannyml/runner.py) class + ## [0.6.1] - 2022-09-09 ### Changed diff --git a/README.md b/README.md index 7d7cadc1..4ea56082 100644 --- a/README.md +++ b/README.md @@ -69,7 +69,7 @@ Allowing you to have the following benefits: | 🔬 **[Technical reference]** | Monitor the performance of your ML models. | | 🔎 **[Blog]** | Thoughts on post-deployment data science from the NannyML team. | | 📬 **[Newsletter]** | All things post-deployment data science. Subscribe to see the latest papers and blogs. | -| 💎 **[New in v0.6.1]** | New features, bug fixes. | +| 💎 **[New in v0.6.2]** | New features, bug fixes. | | 🧑‍💻 **[Contribute]** | How to contribute to the NannyML project and codebase. | | **[Join slack]** | Need help with your specific use case? Say hi on slack! | @@ -77,7 +77,7 @@ Allowing you to have the following benefits: [Performance Estimation]: https://nannyml.readthedocs.io/en/stable/how_it_works/performance_estimation.html [Key Concepts]: https://nannyml.readthedocs.io/en/stable/glossary.html [Technical Reference]:https://nannyml.readthedocs.io/en/stable/nannyml/modules.html -[New in v0.6.1]: https://github.com/NannyML/nannyml/releases/latest/ +[New in v0.6.2]: https://github.com/NannyML/nannyml/releases/latest/ [Real World Example]: https://nannyml.readthedocs.io/en/stable/examples/california_housing.html [Blog]: https://www.nannyml.com/blog [Newsletter]: https://mailchi.mp/022c62281d13/postdeploymentnewsletter diff --git a/docs/_static/quick-start-drift-distance_from_office.svg b/docs/_static/quick-start-drift-distance_from_office.svg index 108379f9..eaf9f25a 100644 --- a/docs/_static/quick-start-drift-distance_from_office.svg +++ b/docs/_static/quick-start-drift-distance_from_office.svg @@ -1 +1 @@ -201520162017201820192020202100.10.20.30.4Reference periodAnalysis periodP-value is signficantData driftKS statistic for distance_from_officeTimeKS statistic \ No newline at end of file +0510152000.10.20.30.4Reference periodAnalysis periodP-value is signficantData driftKS statistic for distance_from_officeChunkKS statistic \ No newline at end of file diff --git a/docs/_static/quick-start-drift-gas_price_per_litre.svg b/docs/_static/quick-start-drift-gas_price_per_litre.svg index 4943a050..7deea8c7 100644 --- a/docs/_static/quick-start-drift-gas_price_per_litre.svg +++ b/docs/_static/quick-start-drift-gas_price_per_litre.svg @@ -1 +1 @@ -20152016201720182019202020210.0080.010.0120.0140.0160.0180.02Reference periodAnalysis periodP-value is signficantData driftKS statistic for gas_price_per_litreTimeKS statistic \ No newline at end of file +051015200.0080.010.0120.0140.0160.0180.02Reference periodAnalysis periodP-value is signficantData driftKS statistic for gas_price_per_litreChunkKS statistic \ No newline at end of file diff --git a/docs/_static/quick-start-drift-multivariate.svg b/docs/_static/quick-start-drift-multivariate.svg index d78cdf89..b10d59b4 100644 --- a/docs/_static/quick-start-drift-multivariate.svg +++ b/docs/_static/quick-start-drift-multivariate.svg @@ -1 +1 @@ -20152016201720182019202020211.11.151.21.25Reference periodAnalysis periodConfidence bandData drift thresholdData driftData Reconstruction DriftTimeReconstruction Error1.141.10 \ No newline at end of file +051015201.11.151.21.25Reference periodAnalysis periodConfidence bandData drift thresholdData driftData Reconstruction DriftChunkReconstruction Error1.141.10 \ No newline at end of file diff --git a/docs/_static/quick-start-drift-public_transportation_cost.svg b/docs/_static/quick-start-drift-public_transportation_cost.svg index 6862146a..4d580746 100644 --- a/docs/_static/quick-start-drift-public_transportation_cost.svg +++ b/docs/_static/quick-start-drift-public_transportation_cost.svg @@ -1 +1 @@ -201520162017201820192020202100.050.10.150.2Reference periodAnalysis periodP-value is signficantData driftKS statistic for public_transportation_costTimeKS statistic \ No newline at end of file +0510152000.050.10.150.2Reference periodAnalysis periodP-value is signficantData driftKS statistic for public_transportation_costChunkKS statistic \ No newline at end of file diff --git a/docs/_static/quick-start-drift-salary_range.svg b/docs/_static/quick-start-drift-salary_range.svg index 7f8b88b2..86795118 100644 --- a/docs/_static/quick-start-drift-salary_range.svg +++ b/docs/_static/quick-start-drift-salary_range.svg @@ -1 +1 @@ -20152016201720182019202020210100200300400500Reference periodAnalysis periodP-value is signficantData driftChi-square statistic for salary_rangeTimeChi-square statistic \ No newline at end of file +051015200100200300400500Reference periodAnalysis periodP-value is signficantData driftChi-square statistic for salary_rangeChunkChi-square statistic \ No newline at end of file diff --git a/docs/_static/quick-start-drift-tenure.svg b/docs/_static/quick-start-drift-tenure.svg index f7b34c5b..ae77389c 100644 --- a/docs/_static/quick-start-drift-tenure.svg +++ b/docs/_static/quick-start-drift-tenure.svg @@ -1 +1 @@ -20152016201720182019202020210.010.0150.02Reference periodAnalysis periodP-value is signficantData driftKS statistic for tenureTimeKS statistic \ No newline at end of file +051015200.010.0150.02Reference periodAnalysis periodP-value is signficantData driftKS statistic for tenureChunkKS statistic \ No newline at end of file diff --git a/docs/_static/quick-start-drift-wfh_prev_workday.svg b/docs/_static/quick-start-drift-wfh_prev_workday.svg index db2cefeb..2607a43c 100644 --- a/docs/_static/quick-start-drift-wfh_prev_workday.svg +++ b/docs/_static/quick-start-drift-wfh_prev_workday.svg @@ -1 +1 @@ -2015201620172018201920202021020040060080010001200Reference periodAnalysis periodP-value is signficantData driftChi-square statistic for wfh_prev_workdayTimeChi-square statistic \ No newline at end of file +05101520020040060080010001200Reference periodAnalysis periodP-value is signficantData driftChi-square statistic for wfh_prev_workdayChunkChi-square statistic \ No newline at end of file diff --git a/docs/_static/quick-start-drift-workday.svg b/docs/_static/quick-start-drift-workday.svg index f15f4778..2cbc0bf0 100644 --- a/docs/_static/quick-start-drift-workday.svg +++ b/docs/_static/quick-start-drift-workday.svg @@ -1 +1 @@ -201520162017201820192020202101234567Reference periodAnalysis periodP-value is signficantData driftChi-square statistic for workdayTimeChi-square statistic \ No newline at end of file +0510152001234567Reference periodAnalysis periodP-value is signficantData driftChi-square statistic for workdayChunkChi-square statistic \ No newline at end of file diff --git a/docs/_static/quick-start-perf-est.svg b/docs/_static/quick-start-perf-est.svg index 4950fa2e..71b49ace 100644 --- a/docs/_static/quick-start-perf-est.svg +++ b/docs/_static/quick-start-perf-est.svg @@ -1 +1 @@ -20152016201720182019202020210.9550.960.9650.970.9750.98Reference period (realized ROC AUC)Analysis period (estimated ROC AUC)Confidence bandPerformance thresholdDegraded performanceCBPE - Estimated ROC AUCTimeROC AUC0.980.96 \ No newline at end of file +051015200.9550.960.9650.970.9750.98Reference period (realized ROC AUC)Analysis period (estimated ROC AUC)Confidence bandPerformance thresholdDegraded performanceCBPE - Estimated ROC AUCChunkROC AUC0.980.96 \ No newline at end of file diff --git a/docs/_static/quick-start-score-drift.svg b/docs/_static/quick-start-score-drift.svg index 9f17ba5d..e80aca29 100644 --- a/docs/_static/quick-start-score-drift.svg +++ b/docs/_static/quick-start-score-drift.svg @@ -1 +1 @@ -201520162017201820192020202100.010.020.030.04Reference periodAnalysis periodP-value is signficantData driftKS statistic for y_predTimeKS statistic \ No newline at end of file +05101520051015202530Reference periodAnalysis periodP-value is signficantData driftChi-square statistic for y_predChunkChi-square statistic \ No newline at end of file diff --git a/docs/example_notebooks/Quickstart.ipynb b/docs/example_notebooks/Quickstart.ipynb index a5a303fc..f08e0ec8 100644 --- a/docs/example_notebooks/Quickstart.ipynb +++ b/docs/example_notebooks/Quickstart.ipynb @@ -391,7 +391,6 @@ " y_pred_proba='y_pred_proba',\n", " y_pred='y_pred',\n", " y_true='work_home_actual',\n", - " timestamp_column_name='timestamp',\n", " metrics=['roc_auc'],\n", " chunk_size=chunk_size,\n", " problem_type='classification_binary',\n", @@ -427,7 +426,6 @@ "# Let's initialize the object that will perform the Univariate Drift calculations\n", "univariate_calculator = nml.UnivariateStatisticalDriftCalculator(\n", " feature_column_names=feature_column_names,\n", - " timestamp_column_name='timestamp',\n", " chunk_size=chunk_size\n", ")\n", "univariate_calculator = univariate_calculator.fit(reference)\n", @@ -600,7 +598,6 @@ "calc = nml.StatisticalOutputDriftCalculator(\n", " y_pred='y_pred',\n", " y_pred_proba='y_pred_proba',\n", - " timestamp_column_name='timestamp',\n", " problem_type='classification_binary'\n", ")\n", "calc.fit(reference)\n", @@ -626,7 +623,10 @@ "outputs": [], "source": [ "# Let's initialize the object that will perform Data Reconstruction with PCA\n", - "rcerror_calculator = nml.DataReconstructionDriftCalculator(feature_column_names=feature_column_names, timestamp_column_name='timestamp', chunk_size=chunk_size).fit(reference_data=reference)\n", + "rcerror_calculator = nml.DataReconstructionDriftCalculator(\n", + " feature_column_names=feature_column_names,\n", + " chunk_size=chunk_size\n", + ").fit(reference_data=reference)\n", "# let's see Reconstruction error statistics for all available data\n", "rcerror_results = rcerror_calculator.calculate(analysis)\n", "figure = rcerror_results.plot(kind='drift', plot_reference=True)\n", diff --git a/docs/quick.rst b/docs/quick.rst index 5e097b79..e0c286c5 100644 --- a/docs/quick.rst +++ b/docs/quick.rst @@ -42,6 +42,12 @@ concepts and functionalities. If you want to know what is implemented under the visit :ref:`how it works`. Finally, if you just look for examples on other datasets or ML problems look through our :ref:`examples`. +.. note:: + The following example does not use any :term:`timestamps`. + These are optional but have an impact on the way data is chunked and results are plotted. + You can read more about them in the :ref:`data requirements`. + + ------------- Just the code diff --git a/docs/tutorials/data_requirements.rst b/docs/tutorials/data_requirements.rst index 7f5853e9..b3cf6601 100644 --- a/docs/tutorials/data_requirements.rst +++ b/docs/tutorials/data_requirements.rst @@ -109,6 +109,8 @@ Below we see the columns our dataset contains and explain their purpose. +----+------------------------+----------------+-----------------------+------------------------------+--------------------+-----------+----------+ +.. _data_requirements_columns_timestamp: + Timestamp ^^^^^^^^^ @@ -124,7 +126,24 @@ In the sample data this is the ``timestamp`` column. - *ISO 8601*, e.g. ``2021-10-13T08:47:23Z`` - *Unix-epoch* in units of seconds, e.g. ``1513393355`` -Currently required for all features of NannyML, though we are looking to drop this requirement in a future release. + +.. warning:: + This column is optional. When a timestamp column is not provided, plots will no longer make use of a time based x-axis + but will use the index of the chunks instead. The following plots illustrate this: + + .. figure:: /_static/drift-guide-salary_range.svg + + Plot using a time based X-axis + + + .. figure:: /_static/quick-start-drift-salary_range.svg + + Plot using an index based X-axis + + + Some :class:`~nannyml.chunk.Chunker` classes might require the presence of a timestamp, such as the + :class:`~nannyml.chunk.PeriodBasedChunker`. + Target ^^^^^^ @@ -183,7 +202,7 @@ You can see those requirements in the table below: +--------------+-------------------------------------+-------------------------------------+-----------------------------------+-----------------------------------+-----------------------------------+-----------------------------------+ | Data | Performance Estimation | Realized Performance | Univariate Feature Drift | Multivariate Feature Drift | Target Drift | Output Drift | +==============+=====================================+=====================================+===================================+===================================+===================================+===================================+ -| timestamp | Required (reference and analysis) | Required (reference and analysis) | Required (reference and analysis) | Required (reference and analysis) | Required (reference and analysis) | Required (reference and analysis) | +| timestamp | | | | | | | +--------------+-------------------------------------+-------------------------------------+-----------------------------------+-----------------------------------+-----------------------------------+-----------------------------------+ | features | | | Required (reference and analysis) | Required (reference and analysis) | | | +--------------+-------------------------------------+-------------------------------------+-----------------------------------+-----------------------------------+-----------------------------------+-----------------------------------+ diff --git a/docs/tutorials/detecting_data_drift/model_outputs/drift_detection_for_binary_classification_model_outputs.rst b/docs/tutorials/detecting_data_drift/model_outputs/drift_detection_for_binary_classification_model_outputs.rst index 0590e6d3..19e3ad7a 100644 --- a/docs/tutorials/detecting_data_drift/model_outputs/drift_detection_for_binary_classification_model_outputs.rst +++ b/docs/tutorials/detecting_data_drift/model_outputs/drift_detection_for_binary_classification_model_outputs.rst @@ -13,6 +13,12 @@ If the model's population changes, then its actions will be different. The difference in actions is very important to know as soon as possible because they directly affect the business results from operating a machine learning model. +.. note:: + The following example uses :term:`timestamps`. + These are optional but have an impact on the way data is chunked and results are plotted. + You can read more about them in the :ref:`data requirements`. + + Just The Code ------------------------------------ diff --git a/docs/tutorials/detecting_data_drift/model_outputs/drift_detection_for_multiclass_classification_model_outputs.rst b/docs/tutorials/detecting_data_drift/model_outputs/drift_detection_for_multiclass_classification_model_outputs.rst index 5cb63456..40c452e9 100644 --- a/docs/tutorials/detecting_data_drift/model_outputs/drift_detection_for_multiclass_classification_model_outputs.rst +++ b/docs/tutorials/detecting_data_drift/model_outputs/drift_detection_for_multiclass_classification_model_outputs.rst @@ -13,6 +13,12 @@ If the model's population changes, then our populations' actions will be differe The difference in actions is very important to know as soon as possible because they directly affect the business results from operating a machine learning model. +.. note:: + The following example uses :term:`timestamps`. + These are optional but have an impact on the way data is chunked and results are plotted. + You can read more about them in the :ref:`data requirements`. + + Just The Code ------------------------------------ diff --git a/docs/tutorials/detecting_data_drift/model_outputs/drift_detection_for_regression_model_outputs.rst b/docs/tutorials/detecting_data_drift/model_outputs/drift_detection_for_regression_model_outputs.rst index 3d959a2c..e379e7f7 100644 --- a/docs/tutorials/detecting_data_drift/model_outputs/drift_detection_for_regression_model_outputs.rst +++ b/docs/tutorials/detecting_data_drift/model_outputs/drift_detection_for_regression_model_outputs.rst @@ -13,6 +13,12 @@ If the model's population changes, then the outcome will be different. The difference in actions is very important to know as soon as possible because they directly affect the business results from operating a machine learning model. +.. note:: + The following example uses :term:`timestamps`. + These are optional but have an impact on the way data is chunked and results are plotted. + You can read more about them in the :ref:`data requirements`. + + Just The Code ------------- diff --git a/docs/tutorials/detecting_data_drift/model_targets/drift_detection_for_binary_classification_model_targets.rst b/docs/tutorials/detecting_data_drift/model_targets/drift_detection_for_binary_classification_model_targets.rst index feba3f27..abc4904c 100644 --- a/docs/tutorials/detecting_data_drift/model_targets/drift_detection_for_binary_classification_model_targets.rst +++ b/docs/tutorials/detecting_data_drift/model_targets/drift_detection_for_binary_classification_model_targets.rst @@ -23,6 +23,12 @@ of the available target values for each chunk, for both binary and multiclass cl .. note:: The Target Drift detection process can handle missing target values across all :term:`data periods`. +.. note:: + The following example uses :term:`timestamps`. + These are optional but have an impact on the way data is chunked and results are plotted. + You can read more about them in the :ref:`data requirements`. + + Just The Code ------------------------------------ diff --git a/docs/tutorials/detecting_data_drift/model_targets/drift_detection_for_multiclass_classification_model_targets.rst b/docs/tutorials/detecting_data_drift/model_targets/drift_detection_for_multiclass_classification_model_targets.rst index 59698be5..ddc6fd34 100644 --- a/docs/tutorials/detecting_data_drift/model_targets/drift_detection_for_multiclass_classification_model_targets.rst +++ b/docs/tutorials/detecting_data_drift/model_targets/drift_detection_for_multiclass_classification_model_targets.rst @@ -23,6 +23,12 @@ of the available target values for each chunk, for both binary and multiclass cl .. note:: The Target Drift detection process can handle missing target values across all :term:`data periods`. +.. note:: + The following example uses :term:`timestamps`. + These are optional but have an impact on the way data is chunked and results are plotted. + You can read more about them in the :ref:`data requirements`. + + Just The Code ------------------------------------ diff --git a/docs/tutorials/detecting_data_drift/model_targets/drift_detection_for_regression_model_targets.rst b/docs/tutorials/detecting_data_drift/model_targets/drift_detection_for_regression_model_targets.rst index e0f4b8ab..53f196fc 100644 --- a/docs/tutorials/detecting_data_drift/model_targets/drift_detection_for_regression_model_targets.rst +++ b/docs/tutorials/detecting_data_drift/model_targets/drift_detection_for_regression_model_targets.rst @@ -21,6 +21,12 @@ but also show the target distribution results per chunk with joyploys. .. note:: The Target Drift detection process can handle missing target values across all :term:`data periods`. +.. note:: + The following example uses :term:`timestamps`. + These are optional but have an impact on the way data is chunked and results are plotted. + You can read more about them in the :ref:`data requirements`. + + Just The Code ------------- diff --git a/docs/tutorials/performance_calculation/binary_performance_calculation.rst b/docs/tutorials/performance_calculation/binary_performance_calculation.rst index 6cc37ec1..d3ffeb8a 100644 --- a/docs/tutorials/performance_calculation/binary_performance_calculation.rst +++ b/docs/tutorials/performance_calculation/binary_performance_calculation.rst @@ -4,6 +4,12 @@ Monitoring Realized Performance for Binary Classification ================================================================ +.. note:: + The following example uses :term:`timestamps`. + These are optional but have an impact on the way data is chunked and results are plotted. + You can read more about them in the :ref:`data requirements`. + + Just The Code ============== diff --git a/docs/tutorials/performance_calculation/multiclass_performance_calculation.rst b/docs/tutorials/performance_calculation/multiclass_performance_calculation.rst index 866956cb..c66248a5 100644 --- a/docs/tutorials/performance_calculation/multiclass_performance_calculation.rst +++ b/docs/tutorials/performance_calculation/multiclass_performance_calculation.rst @@ -4,6 +4,12 @@ Monitoring Realized Performance for Multiclass Classification ================================================================ +.. note:: + The following example uses :term:`timestamps`. + These are optional but have an impact on the way data is chunked and results are plotted. + You can read more about them in the :ref:`data requirements`. + + Just The Code ============== diff --git a/docs/tutorials/performance_calculation/regression_performance_calculation.rst b/docs/tutorials/performance_calculation/regression_performance_calculation.rst index 178cecef..692d404e 100644 --- a/docs/tutorials/performance_calculation/regression_performance_calculation.rst +++ b/docs/tutorials/performance_calculation/regression_performance_calculation.rst @@ -4,6 +4,12 @@ Monitoring Realized Performance for Regression ============================================== +.. note:: + The following example uses :term:`timestamps`. + These are optional but have an impact on the way data is chunked and results are plotted. + You can read more about them in the :ref:`data requirements`. + + Just The Code ============= diff --git a/docs/tutorials/performance_estimation/binary_performance_estimation.rst b/docs/tutorials/performance_estimation/binary_performance_estimation.rst index 706d94ec..465ecedf 100644 --- a/docs/tutorials/performance_estimation/binary_performance_estimation.rst +++ b/docs/tutorials/performance_estimation/binary_performance_estimation.rst @@ -8,6 +8,12 @@ This tutorial explains how to use NannyML to estimate the performance of binary models in the absence of target data. To find out how CBPE estimates performance, read the :ref:`explanation of Confidence-based Performance Estimation`. +.. note:: + The following example uses :term:`timestamps`. + These are optional but have an impact on the way data is chunked and results are plotted. + You can read more about them in the :ref:`data requirements`. + + .. _performance-estimation-binary-just-the-code: diff --git a/docs/tutorials/performance_estimation/multiclass_performance_estimation.rst b/docs/tutorials/performance_estimation/multiclass_performance_estimation.rst index 526a8590..b2e1ba51 100644 --- a/docs/tutorials/performance_estimation/multiclass_performance_estimation.rst +++ b/docs/tutorials/performance_estimation/multiclass_performance_estimation.rst @@ -8,6 +8,12 @@ This tutorial explains how to use NannyML to estimate the performance of multicl models in the absence of target data. To find out how CBPE estimates performance, read the :ref:`explanation of Confidence-based Performance Estimation`. +.. note:: + The following example uses :term:`timestamps`. + These are optional but have an impact on the way data is chunked and results are plotted. + You can read more about them in the :ref:`data requirements`. + + Just The Code ------------- diff --git a/docs/tutorials/performance_estimation/regression_performance_estimation.rst b/docs/tutorials/performance_estimation/regression_performance_estimation.rst index a8838a5e..a7ca2348 100644 --- a/docs/tutorials/performance_estimation/regression_performance_estimation.rst +++ b/docs/tutorials/performance_estimation/regression_performance_estimation.rst @@ -8,6 +8,12 @@ This tutorial explains how to use NannyML to estimate the performance of regress models in the absence of target data. To find out how DLE estimates performance, read the :ref:`explanation of how Direct Loss Estimation works`. +.. note:: + The following example uses :term:`timestamps`. + These are optional but have an impact on the way data is chunked and results are plotted. + You can read more about them in the :ref:`data requirements`. + + .. _performance-estimation-regression-just-the-code: Just The Code diff --git a/nannyml/__init__.py b/nannyml/__init__.py index a9791ee2..f4afd5a6 100644 --- a/nannyml/__init__.py +++ b/nannyml/__init__.py @@ -32,7 +32,7 @@ # Dev branch marker is: 'X.Y.dev' or 'X.Y.devN' where N is an integer. # 'X.Y.dev0' is the canonical version of 'X.Y.dev' # -__version__ = '0.6.1' +__version__ = '0.6.2' import logging diff --git a/nannyml/base.py b/nannyml/base.py index 0712306a..475502e4 100644 --- a/nannyml/base.py +++ b/nannyml/base.py @@ -66,6 +66,7 @@ def __init__( chunk_number: int = None, chunk_period: str = None, chunker: Chunker = None, + timestamp_column_name: Optional[str] = None, ): """Creates a new instance of an abstract DriftCalculator. @@ -83,7 +84,11 @@ def __init__( chunker : Chunker The `Chunker` used to split the data sets into a lists of chunks. """ - self.chunker = ChunkerFactory.get_chunker(chunk_size, chunk_number, chunk_period, chunker) + self.chunker = ChunkerFactory.get_chunker( + chunk_size, chunk_number, chunk_period, chunker, timestamp_column_name + ) + + self.timestamp_column_name = timestamp_column_name @property def _logger(self) -> logging.Logger: @@ -167,6 +172,7 @@ def __init__( chunk_number: int = None, chunk_period: str = None, chunker: Chunker = None, + timestamp_column_name: str = None, ): """Creates a new instance of an abstract DriftCalculator. @@ -184,7 +190,10 @@ def __init__( chunker : Chunker The `Chunker` used to split the data sets into a lists of chunks. """ - self.chunker = ChunkerFactory.get_chunker(chunk_size, chunk_number, chunk_period, chunker) + self.chunker = ChunkerFactory.get_chunker( + chunk_size, chunk_number, chunk_period, chunker, timestamp_column_name + ) + self.timestamp_column_name = timestamp_column_name @property def _logger(self) -> logging.Logger: diff --git a/nannyml/chunk.py b/nannyml/chunk.py index 71140336..06ca5127 100644 --- a/nannyml/chunk.py +++ b/nannyml/chunk.py @@ -9,7 +9,7 @@ import logging import warnings from datetime import datetime -from typing import List +from typing import List, Optional import numpy as np import pandas as pd @@ -28,8 +28,8 @@ def __init__( self, key: str, data: pd.DataFrame, - start_datetime: datetime = datetime.max, - end_datetime: datetime = datetime.max, + start_datetime: Optional[datetime] = None, + end_datetime: Optional[datetime] = None, period: str = None, ): """Creates a new chunk. @@ -55,8 +55,9 @@ def __init__( self.start_datetime = start_datetime self.end_datetime = end_datetime - self.start_index: int = 0 - self.end_index: int = 0 + self.start_index: int = -1 + self.end_index: int = -1 + self.chunk_index: int = -1 def __repr__(self): """Returns textual summary of a chunk. @@ -97,14 +98,13 @@ class Chunker(abc.ABC): or a preferred number of Chunks. """ - def __init__(self): + def __init__(self, timestamp_column_name: Optional[str] = None): """Creates a new Chunker.""" - pass + self.timestamp_column_name = timestamp_column_name def split( self, data: pd.DataFrame, - timestamp_column_name: str, columns=None, ) -> List[Chunk]: """Splits a given data frame into a list of chunks. @@ -123,8 +123,6 @@ def split( ---------- data: DataFrame The data to be split into chunks - timestamp_column_name: str - Name of the column containing the timestamp of an observation. columns: List[str], default=None A list of columns to be included in the resulting chunk data. Unlisted columns will be dropped. @@ -134,23 +132,25 @@ def split( The list of chunks """ - if timestamp_column_name not in data.columns: - raise InvalidArgumentsException( - f"timestamp column '{timestamp_column_name}' not in columns: {list(data.columns)}." - ) + if self.timestamp_column_name: + if self.timestamp_column_name not in data.columns: + raise InvalidArgumentsException( + f"timestamp column '{self.timestamp_column_name}' not in columns: {list(data.columns)}." + ) - data = data.sort_values(by=[timestamp_column_name]).reset_index(drop=True) + data = data.sort_values(by=[self.timestamp_column_name]).reset_index(drop=True) try: - chunks = self._split(data, timestamp_column_name) + chunks = self._split(data) except Exception as exc: raise ChunkerException(f"could not split data into chunks: {exc}") - for c in chunks: - c.start_index, c.end_index = _get_boundary_indices(c) + for chunk_index, chunk in enumerate(chunks): + chunk.start_index, chunk.end_index = _get_boundary_indices(chunk) + chunk.chunk_index = chunk_index if columns is not None: - c.data = c.data[columns] + chunk.data = chunk.data[columns] if len(chunks) < 6: # TODO wording @@ -163,7 +163,7 @@ def split( # TODO wording @abc.abstractmethod - def _split(self, data: pd.DataFrame, timestamp_column_name: str) -> List[Chunk]: + def _split(self, data: pd.DataFrame) -> List[Chunk]: """Splits the DataFrame into chunks. Abstract method, to be implemented within inheriting classes. @@ -202,17 +202,20 @@ def get_chunker( chunk_number: int = None, chunk_period: str = None, chunker: Chunker = None, + timestamp_column_name: str = None, ) -> Chunker: if chunker is not None: return chunker if chunk_size: - return SizeBasedChunker(chunk_size=chunk_size) # type: ignore + return SizeBasedChunker(chunk_size=chunk_size, timestamp_column_name=timestamp_column_name) # type: ignore elif chunk_number: - return CountBasedChunker(chunk_count=chunk_number) # type: ignore + return CountBasedChunker( + chunk_count=chunk_number, timestamp_column_name=timestamp_column_name # type: ignore + ) elif chunk_period: - return PeriodBasedChunker(offset=chunk_period) # type: ignore + return PeriodBasedChunker(offset=chunk_period, timestamp_column_name=timestamp_column_name) # type: ignore else: - return DefaultChunker() # type: ignore + return DefaultChunker(timestamp_column_name=timestamp_column_name) # type: ignore class PeriodBasedChunker(Chunker): @@ -224,20 +227,21 @@ class PeriodBasedChunker(Chunker): >>> from nannyml.chunk import PeriodBasedChunker >>> df = pd.read_parquet('/path/to/my/data.pq') - >>> chunker = PeriodBasedChunker(date_column_name='observation_date', offset='M') + >>> chunker = PeriodBasedChunker(timestamp_column_name='observation_date', offset='M') >>> chunks = chunker.split(data=df) Or chunk using weekly periods >>> from nannyml.chunk import PeriodBasedChunker >>> df = pd.read_parquet('/path/to/my/data.pq') - >>> chunker = PeriodBasedChunker(date_column=df['observation_date'], offset='W', minimum_chunk_size=50) + >>> chunker = PeriodBasedChunker(timestamp_column_name=df['observation_date'], offset='W', minimum_chunk_size=50) >>> chunks = chunker.split(data=df) """ def __init__( self, + timestamp_column_name: str, offset: str = 'W', ): """Creates a new PeriodBasedChunker. @@ -252,14 +256,14 @@ def __init__( ------- chunker: a PeriodBasedChunker instance used to split data into time-based Chunks. """ - super().__init__() + super().__init__(timestamp_column_name) self.offset = offset - def _split(self, data: pd.DataFrame, timestamp_column_name: str) -> List[Chunk]: + def _split(self, data: pd.DataFrame) -> List[Chunk]: chunks = [] try: - grouped_data = data.groupby(pd.to_datetime(data[timestamp_column_name]).dt.to_period(self.offset)) + grouped_data = data.groupby(pd.to_datetime(data[self.timestamp_column_name]).dt.to_period(self.offset)) k: Period for k in grouped_data.groups.keys(): @@ -268,11 +272,11 @@ def _split(self, data: pd.DataFrame, timestamp_column_name: str) -> List[Chunk]: ) chunks.append(chunk) except KeyError: - raise ChunkerException(f"could not find date_column '{timestamp_column_name}' in given data") + raise ChunkerException(f"could not find date_column '{self.timestamp_column_name}' in given data") except ParserError: raise ChunkerException( - f"could not parse date_column '{timestamp_column_name}' values as dates." + f"could not parse date_column '{self.timestamp_column_name}' values as dates." f"Please verify if you've specified the correct date column." ) @@ -299,7 +303,7 @@ class SizeBasedChunker(Chunker): """ - def __init__(self, chunk_size: int, drop_incomplete: bool = False): + def __init__(self, chunk_size: int, drop_incomplete: bool = False, timestamp_column_name: Optional[str] = None): """Create a new SizeBasedChunker. Parameters @@ -315,7 +319,7 @@ def __init__(self, chunk_size: int, drop_incomplete: bool = False): chunker: a size-based instance used to split data into Chunks of a constant size. """ - super().__init__() + super().__init__(timestamp_column_name) # TODO wording if not isinstance(chunk_size, (int, np.int64)): @@ -334,17 +338,14 @@ def __init__(self, chunk_size: int, drop_incomplete: bool = False): self.chunk_size = chunk_size self.drop_incomplete = drop_incomplete - def _split(self, data: pd.DataFrame, timestamp_column_name: str) -> List[Chunk]: + def _split(self, data: pd.DataFrame) -> List[Chunk]: def _create_chunk(index: int, data: pd.DataFrame, chunk_size: int) -> Chunk: chunk_data = data.loc[index : index + chunk_size - 1, :] - min_date = pd.to_datetime(chunk_data[timestamp_column_name].min()) - max_date = pd.to_datetime(chunk_data[timestamp_column_name].max()) - return Chunk( - key=f'[{index}:{index + chunk_size - 1}]', - data=chunk_data, - start_datetime=min_date, - end_datetime=max_date, - ) + chunk = Chunk(key=f'[{index}:{index + chunk_size - 1}]', data=chunk_data) + if self.timestamp_column_name: + chunk.start_datetime = pd.to_datetime(chunk.data[self.timestamp_column_name].min()) + chunk.end_datetime = pd.to_datetime(chunk.data[self.timestamp_column_name].max()) + return chunk data = data.copy().reset_index(drop=True) chunks = [ @@ -377,7 +378,7 @@ class CountBasedChunker(Chunker): """ - def __init__(self, chunk_count: int): + def __init__(self, chunk_count: int, timestamp_column_name: Optional[str] = None): """Creates a new CountBasedChunker. It will calculate the amount of observations per chunk based on the given chunk count. @@ -393,7 +394,7 @@ def __init__(self, chunk_count: int): chunker: CountBasedChunker """ - super().__init__() + super().__init__(timestamp_column_name) # TODO wording if not isinstance(chunk_count, int): @@ -411,14 +412,17 @@ def __init__(self, chunk_count: int): self.chunk_count = chunk_count - def _split(self, data: pd.DataFrame, timestamp_column_name: str) -> List[Chunk]: + def _split(self, data: pd.DataFrame) -> List[Chunk]: if data.shape[0] == 0: return [] data = data.copy().reset_index() chunk_size = data.shape[0] // self.chunk_count - chunks = SizeBasedChunker(chunk_size=chunk_size).split(data=data, timestamp_column_name=timestamp_column_name) + chunks = SizeBasedChunker(chunk_size=chunk_size, timestamp_column_name=self.timestamp_column_name).split( + data=data + ) + return chunks @@ -435,19 +439,18 @@ class DefaultChunker(Chunker): DEFAULT_CHUNK_COUNT = 10 - def __init__(self): + def __init__(self, timestamp_column_name: Optional[str] = None): """Creates a new DefaultChunker.""" - super(DefaultChunker, self).__init__() + super(DefaultChunker, self).__init__(timestamp_column_name) - def _split(self, data: pd.DataFrame, timestamp_column_name: str) -> List[Chunk]: + def _split(self, data: pd.DataFrame) -> List[Chunk]: if data.shape[0] == 0: return [] data = data.copy().reset_index(drop=True) chunk_size = data.shape[0] // self.DEFAULT_CHUNK_COUNT - chunks = SizeBasedChunker(chunk_size=chunk_size).split( - data=data, - timestamp_column_name=timestamp_column_name, + chunks = SizeBasedChunker(chunk_size=chunk_size, timestamp_column_name=self.timestamp_column_name).split( + data=data ) return chunks diff --git a/nannyml/drift/model_inputs/multivariate/data_reconstruction/calculator.py b/nannyml/drift/model_inputs/multivariate/data_reconstruction/calculator.py index 288dc86d..f47a9112 100644 --- a/nannyml/drift/model_inputs/multivariate/data_reconstruction/calculator.py +++ b/nannyml/drift/model_inputs/multivariate/data_reconstruction/calculator.py @@ -26,7 +26,7 @@ class DataReconstructionDriftCalculator(AbstractCalculator): def __init__( self, feature_column_names: List[str], - timestamp_column_name: str, + timestamp_column_name: str = None, n_components: Union[int, float, str] = 0.65, chunk_size: int = None, chunk_number: int = None, @@ -42,7 +42,7 @@ def __init__( feature_column_names: List[str] A list containing the names of features in the provided data set. All of these features will be used by the multivariate data reconstruction drift calculator to calculate an aggregate drift score. - timestamp_column_name: str + timestamp_column_name: str, default=None The name of the column containing the timestamp of the model prediction. n_components: Union[int, float, str], default=0.65 The n_components parameter as passed to the sklearn.decomposition.PCA constructor. @@ -91,12 +91,13 @@ def __init__( >>> fig = results.plot(kind='drift', plot_reference=True) >>> fig.show() """ - super(DataReconstructionDriftCalculator, self).__init__(chunk_size, chunk_number, chunk_period, chunker) + super(DataReconstructionDriftCalculator, self).__init__( + chunk_size, chunk_number, chunk_period, chunker, timestamp_column_name + ) self.feature_column_names = feature_column_names self.continuous_feature_column_names: List[str] = [] self.categorical_feature_column_names: List[str] = [] - self.timestamp_column_name = timestamp_column_name self._n_components = n_components self._scaler = None @@ -200,14 +201,13 @@ def _calculate(self, data: pd.DataFrame, *args, **kwargs) -> DataReconstructionD data, self.feature_column_names ) - chunks = self.chunker.split( - data, columns=self.feature_column_names, timestamp_column_name=self.timestamp_column_name - ) + chunks = self.chunker.split(data, columns=self.feature_column_names) res = pd.DataFrame.from_records( [ { 'key': chunk.key, + 'chunk_index': chunk.chunk_index, 'start_index': chunk.start_index, 'end_index': chunk.end_index, 'start_date': chunk.start_datetime, @@ -237,7 +237,7 @@ def _calculate(self, data: pd.DataFrame, *args, **kwargs) -> DataReconstructionD return DataReconstructionDriftCalculatorResult(results_data=res, calculator=self) def _calculate_alert_thresholds(self, reference_data) -> Tuple[float, float]: - reference_chunks = self.chunker.split(reference_data, self.timestamp_column_name) # type: ignore + reference_chunks = self.chunker.split(reference_data) # type: ignore reference_reconstruction_error = pd.Series( [ _calculate_reconstruction_error_for_data( diff --git a/nannyml/drift/model_inputs/multivariate/data_reconstruction/results.py b/nannyml/drift/model_inputs/multivariate/data_reconstruction/results.py index 6efb01f6..e713e348 100644 --- a/nannyml/drift/model_inputs/multivariate/data_reconstruction/results.py +++ b/nannyml/drift/model_inputs/multivariate/data_reconstruction/results.py @@ -103,6 +103,8 @@ def _plot_drift(data: pd.DataFrame, calculator, plot_reference: bool) -> go.Figu reference_results['period'] = 'reference' data = pd.concat([reference_results, data], ignore_index=True) + is_time_based_x_axis = calculator.timestamp_column_name is not None + fig = _step_plot( table=data, metric_column_name='reconstruction_error', @@ -118,6 +120,8 @@ def _plot_drift(data: pd.DataFrame, calculator, plot_reference: bool) -> go.Figu lower_confidence_column_name='lower_confidence_bound', upper_confidence_column_name='upper_confidence_bound', plot_confidence_for_reference=True, + start_date_column_name='start_date' if is_time_based_x_axis else None, + end_date_column_name='end_date' if is_time_based_x_axis else None, ) return fig diff --git a/nannyml/drift/model_inputs/univariate/statistical/calculator.py b/nannyml/drift/model_inputs/univariate/statistical/calculator.py index 725795bc..40189b78 100644 --- a/nannyml/drift/model_inputs/univariate/statistical/calculator.py +++ b/nannyml/drift/model_inputs/univariate/statistical/calculator.py @@ -26,7 +26,7 @@ class UnivariateStatisticalDriftCalculator(AbstractCalculator): def __init__( self, feature_column_names: List[str], - timestamp_column_name: str, + timestamp_column_name: str = None, chunk_size: int = None, chunk_number: int = None, chunk_period: str = None, @@ -39,7 +39,7 @@ def __init__( feature_column_names: List[str] A list containing the names of features in the provided data set. A drift score will be calculated for each entry in this list. - timestamp_column_name: str + timestamp_column_name: str, default=None The name of the column containing the timestamp of the model prediction. chunk_size: int Splits the data into chunks containing `chunks_size` observations. @@ -82,14 +82,14 @@ def __init__( >>> fig = results.plot(kind='feature_drift', plot_reference=True, feature_column_name='distance_from_office') >>> fig.show() """ - super(UnivariateStatisticalDriftCalculator, self).__init__(chunk_size, chunk_number, chunk_period, chunker) + super(UnivariateStatisticalDriftCalculator, self).__init__( + chunk_size, chunk_number, chunk_period, chunker, timestamp_column_name + ) self.feature_column_names = feature_column_names self.continuous_column_names: List[str] = [] self.categorical_column_names: List[str] = [] - self.timestamp_column_name = timestamp_column_name - # required for distribution plots self.previous_reference_data: Optional[pd.DataFrame] = None self.previous_reference_results: Optional[pd.DataFrame] = None @@ -121,7 +121,7 @@ def _calculate(self, data: pd.DataFrame, *args, **kwargs) -> UnivariateStatistic data, self.feature_column_names ) - chunks = self.chunker.split(data, self.timestamp_column_name) + chunks = self.chunker.split(data) chunk_drifts = [] # Calculate chunk-wise drift statistics. @@ -129,6 +129,7 @@ def _calculate(self, data: pd.DataFrame, *args, **kwargs) -> UnivariateStatistic for chunk in chunks: chunk_drift: Dict[str, Any] = { 'key': chunk.key, + 'chunk_index': chunk.chunk_index, 'start_index': chunk.start_index, 'end_index': chunk.end_index, 'start_date': chunk.start_datetime, diff --git a/nannyml/drift/model_inputs/univariate/statistical/results.py b/nannyml/drift/model_inputs/univariate/statistical/results.py index 537868fc..bc32d00d 100644 --- a/nannyml/drift/model_inputs/univariate/statistical/results.py +++ b/nannyml/drift/model_inputs/univariate/statistical/results.py @@ -174,6 +174,8 @@ def _feature_drift( reference_results['period'] = 'reference' data = pd.concat([reference_results, data], ignore_index=True) + is_time_based_x_axis = calculator.timestamp_column_name is not None + fig = _step_plot( table=data, metric_column_name=metric_column_name, @@ -185,6 +187,8 @@ def _feature_drift( y_axis_title=metric_label, v_line_separating_analysis_period=plot_reference, statistically_significant_column_name=drift_column_name, + start_date_column_name='start_date' if is_time_based_x_axis else None, + end_date_column_name='end_date' if is_time_based_x_axis else None, ) return fig @@ -233,7 +237,7 @@ def _plot_continuous_feature_distribution( drift_data['period'] = 'analysis' data['period'] = 'analysis' - feature_table = _create_feature_table(calculator.chunker.split(data, calculator.timestamp_column_name)) + feature_table = _create_feature_table(calculator.chunker.split(data)) if plot_reference: if calculator.previous_reference_results is None: @@ -245,11 +249,11 @@ def _plot_continuous_feature_distribution( reference_drift['period'] = 'reference' drift_data = pd.concat([reference_drift, drift_data], ignore_index=True) - reference_feature_table = _create_feature_table( - calculator.chunker.split(calculator.previous_reference_data, calculator.timestamp_column_name) - ) + reference_feature_table = _create_feature_table(calculator.chunker.split(calculator.previous_reference_data)) feature_table = pd.concat([reference_feature_table, feature_table], ignore_index=True) + is_time_based_x_axis = calculator.timestamp_column_name is not None + fig = _joy_plot( feature_table=feature_table, drift_table=drift_data, @@ -259,6 +263,8 @@ def _plot_continuous_feature_distribution( x_axis_title=x_axis_title, title=title, style='vertical', + start_date_column_name='start_date' if is_time_based_x_axis else None, + end_date_column_name='end_date' if is_time_based_x_axis else None, ) return fig @@ -281,7 +287,7 @@ def _plot_categorical_feature_distribution( drift_data['period'] = 'analysis' data['period'] = 'analysis' - feature_table = _create_feature_table(calculator.chunker.split(data, calculator.timestamp_column_name)) + feature_table = _create_feature_table(calculator.chunker.split(data)) if plot_reference: if calculator.previous_reference_results is None: @@ -293,11 +299,11 @@ def _plot_categorical_feature_distribution( reference_drift['period'] = 'reference' drift_data = pd.concat([reference_drift, drift_data], ignore_index=True) - reference_feature_table = _create_feature_table( - calculator.chunker.split(calculator.previous_reference_data, calculator.timestamp_column_name) - ) + reference_feature_table = _create_feature_table(calculator.chunker.split(calculator.previous_reference_data)) feature_table = pd.concat([reference_feature_table, feature_table], ignore_index=True) + is_time_based_x_axis = calculator.timestamp_column_name is not None + fig = _stacked_bar_plot( feature_table=feature_table, drift_table=drift_data, @@ -306,6 +312,8 @@ def _plot_categorical_feature_distribution( feature_column_name=feature_column_name, yaxis_title=yaxis_title, title=title, + start_date_column_name='start_date' if is_time_based_x_axis else None, + end_date_column_name='end_date' if is_time_based_x_axis else None, ) return fig diff --git a/nannyml/drift/model_outputs/univariate/statistical/calculator.py b/nannyml/drift/model_outputs/univariate/statistical/calculator.py index 84af3dff..d45a24c5 100644 --- a/nannyml/drift/model_outputs/univariate/statistical/calculator.py +++ b/nannyml/drift/model_outputs/univariate/statistical/calculator.py @@ -25,9 +25,9 @@ class StatisticalOutputDriftCalculator(AbstractCalculator): def __init__( self, y_pred: str, - timestamp_column_name: str, problem_type: Union[str, ProblemType], y_pred_proba: ModelOutputsType = None, + timestamp_column_name: str = None, chunk_size: int = None, chunk_number: int = None, chunk_period: str = None, @@ -44,7 +44,7 @@ def __init__( The dictionary maps a class/label string to the column name containing model outputs for that class/label. y_pred: str The name of the column containing your model predictions. - timestamp_column_name: str + timestamp_column_name: str, default=None The name of the column containing the timestamp of the model prediction. chunk_size: int, default=None Splits the data into chunks containing `chunks_size` observations. @@ -90,11 +90,12 @@ def __init__( >>> results.plot(kind='prediction_drift', plot_reference=True).show() >>> results.plot(kind='prediction_distribution', plot_reference=True).show() """ - super(StatisticalOutputDriftCalculator, self).__init__(chunk_size, chunk_number, chunk_period, chunker) + super(StatisticalOutputDriftCalculator, self).__init__( + chunk_size, chunk_number, chunk_period, chunker, timestamp_column_name + ) self.y_pred_proba = y_pred_proba self.y_pred = y_pred - self.timestamp_column_name = timestamp_column_name if isinstance(problem_type, str): problem_type = ProblemType.parse(problem_type) @@ -156,16 +157,14 @@ def _calculate(self, data: pd.DataFrame, *args, **kwargs) -> UnivariateDriftResu elif self.problem_type == ProblemType.REGRESSION: continuous_columns += [self.y_pred] - chunks = self.chunker.split( - data, columns=continuous_columns + categorical_columns, timestamp_column_name=self.timestamp_column_name - ) - + chunks = self.chunker.split(data, columns=continuous_columns + categorical_columns) chunk_drifts = [] # Calculate chunk-wise drift statistics. # Append all into resulting DataFrame indexed by chunk key. for chunk in chunks: chunk_drift: Dict[str, Any] = { 'key': chunk.key, + 'chunk_index': chunk.chunk_index, 'start_index': chunk.start_index, 'end_index': chunk.end_index, 'start_date': chunk.start_datetime, diff --git a/nannyml/drift/model_outputs/univariate/statistical/results.py b/nannyml/drift/model_outputs/univariate/statistical/results.py index 28044d7a..cb69affa 100644 --- a/nannyml/drift/model_outputs/univariate/statistical/results.py +++ b/nannyml/drift/model_outputs/univariate/statistical/results.py @@ -245,6 +245,8 @@ def _plot_prediction_drift( reference_results['period'] = 'reference' data = pd.concat([reference_results, data], ignore_index=True) + is_time_based_x_axis = calculator.timestamp_column_name is not None + fig = _step_plot( table=data, metric_column_name=metric_column_name, @@ -256,6 +258,8 @@ def _plot_prediction_drift( y_axis_title=metric_label, v_line_separating_analysis_period=plot_period_separator, statistically_significant_column_name=drift_column_name, + start_date_column_name='start_date' if is_time_based_x_axis else None, + end_date_column_name='end_date' if is_time_based_x_axis else None, ) return fig @@ -288,12 +292,11 @@ def _plot_prediction_distribution( prediction_column_name = calculator.y_pred axis_title = f'{prediction_column_name}' drift_column_name = f'{prediction_column_name}_alert' - title = f'Distribution over time for {prediction_column_name}' drift_data['period'] = 'analysis' data['period'] = 'analysis' - feature_table = _create_feature_table(calculator.chunker.split(data, calculator.timestamp_column_name)) + feature_table = _create_feature_table(calculator.chunker.split(data)) if plot_reference: reference_drift = calculator.previous_reference_results.copy() @@ -305,11 +308,11 @@ def _plot_prediction_distribution( reference_drift['period'] = 'reference' drift_data = pd.concat([reference_drift, drift_data], ignore_index=True) - reference_feature_table = _create_feature_table( - calculator.chunker.split(calculator.previous_reference_data, calculator.timestamp_column_name) - ) + reference_feature_table = _create_feature_table(calculator.chunker.split(calculator.previous_reference_data)) feature_table = pd.concat([reference_feature_table, feature_table], ignore_index=True) + is_time_based_x_axis = calculator.timestamp_column_name is not None + if calculator.problem_type in [ProblemType.CLASSIFICATION_BINARY, ProblemType.CLASSIFICATION_MULTICLASS]: fig = _stacked_bar_plot( feature_table=feature_table, @@ -318,7 +321,9 @@ def _plot_prediction_distribution( drift_column_name=drift_column_name, feature_column_name=prediction_column_name, yaxis_title=axis_title, - title=title, + title=f'Distribution over time for {prediction_column_name}', + start_date_column_name='start_date' if is_time_based_x_axis else None, + end_date_column_name='end_date' if is_time_based_x_axis else None, ) elif calculator.problem_type == ProblemType.REGRESSION: fig = _joy_plot( @@ -329,8 +334,12 @@ def _plot_prediction_distribution( feature_column_name=prediction_column_name, x_axis_title=axis_title, post_kde_clip=clip, - title=title, + title=f'Distribution over time for {prediction_column_name}' + if is_time_based_x_axis + else f'Distribution over chunks for {prediction_column_name}', style='vertical', + start_date_column_name='start_date' if is_time_based_x_axis else None, + end_date_column_name='end_date' if is_time_based_x_axis else None, ) else: raise RuntimeError( @@ -399,6 +408,8 @@ def _plot_score_drift( reference_results['period'] = 'reference' data = pd.concat([reference_results, data], ignore_index=True) + is_time_based_x_axis = calculator.timestamp_column_name is not None + fig = _step_plot( table=data, metric_column_name=metric_column_name, @@ -410,6 +421,8 @@ def _plot_score_drift( y_axis_title=metric_label, v_line_separating_analysis_period=plot_period_separator, statistically_significant_column_name=drift_column_name, + start_date_column_name='start_date' if is_time_based_x_axis else None, + end_date_column_name='end_date' if is_time_based_x_axis else None, ) return fig @@ -477,7 +490,7 @@ def _plot_score_distribution( drift_data['period'] = 'analysis' data['period'] = 'analysis' - feature_table = _create_feature_table(calculator.chunker.split(data, calculator.timestamp_column_name)) + feature_table = _create_feature_table(calculator.chunker.split(data)) if plot_reference: reference_drift = calculator.previous_reference_results.copy() @@ -489,11 +502,11 @@ def _plot_score_distribution( reference_drift['period'] = 'reference' drift_data = pd.concat([reference_drift, drift_data], ignore_index=True) - reference_feature_table = _create_feature_table( - calculator.chunker.split(calculator.previous_reference_data, calculator.timestamp_column_name) - ) + reference_feature_table = _create_feature_table(calculator.chunker.split(calculator.previous_reference_data)) feature_table = pd.concat([reference_feature_table, feature_table], ignore_index=True) + is_time_based_x_axis = calculator.timestamp_column_name is not None + if _column_is_categorical(data[output_column_name]): fig = _stacked_bar_plot( feature_table=feature_table, @@ -503,6 +516,8 @@ def _plot_score_distribution( feature_column_name=output_column_name, yaxis_title=axis_title, title=title, + start_date_column_name='start_date' if is_time_based_x_axis else None, + end_date_column_name='end_date' if is_time_based_x_axis else None, ) elif _column_is_continuous(data[output_column_name]): fig = _joy_plot( @@ -515,6 +530,8 @@ def _plot_score_distribution( post_kde_clip=clip, title=title, style='vertical', + start_date_column_name='start_date' if is_time_based_x_axis else None, + end_date_column_name='end_date' if is_time_based_x_axis else None, ) else: raise RuntimeError( diff --git a/nannyml/drift/target/target_distribution/calculator.py b/nannyml/drift/target/target_distribution/calculator.py index 4ab573fd..381d3e14 100644 --- a/nannyml/drift/target/target_distribution/calculator.py +++ b/nannyml/drift/target/target_distribution/calculator.py @@ -28,8 +28,8 @@ class TargetDistributionCalculator(AbstractCalculator): def __init__( self, y_true: str, - timestamp_column_name: str, problem_type: Union[str, ProblemType], + timestamp_column_name: str = None, chunk_size: int = None, chunk_number: int = None, chunk_period: str = None, @@ -41,7 +41,7 @@ def __init__( ---------- y_true: str The name of the column containing your model target values. - timestamp_column_name: str + timestamp_column_name: str, default=None The name of the column containing the timestamp of the model prediction. chunk_size: int, default=None Splits the data into chunks containing `chunks_size` observations. @@ -83,10 +83,9 @@ def __init__( >>> results.plot(kind='target_drift', plot_reference=True).show() >>> results.plot(kind='target_distribution', plot_reference=True).show() """ - super().__init__(chunk_size, chunk_number, chunk_period, chunker) + super().__init__(chunk_size, chunk_number, chunk_period, chunker, timestamp_column_name) self.y_true = y_true - self.timestamp_column_name = timestamp_column_name if isinstance(problem_type, str): problem_type = ProblemType.parse(problem_type) @@ -135,7 +134,6 @@ def _calculate(self, data: pd.DataFrame, *args, **kwargs): chunks = self.chunker.split( data, columns=[self.y_true, 'NML_TARGET_INCOMPLETE'], - timestamp_column_name=self.timestamp_column_name, ) # Construct result frame @@ -145,6 +143,7 @@ def _calculate(self, data: pd.DataFrame, *args, **kwargs): [ { 'key': chunk.key, + 'chunk_index': chunk.chunk_index, 'start_index': chunk.start_index, 'end_index': chunk.end_index, 'start_date': chunk.start_datetime, diff --git a/nannyml/drift/target/target_distribution/result.py b/nannyml/drift/target/target_distribution/result.py index 8a2b3269..9032ba3f 100644 --- a/nannyml/drift/target/target_distribution/result.py +++ b/nannyml/drift/target/target_distribution/result.py @@ -114,6 +114,8 @@ def _plot_target_drift( reference_results['period'] = 'reference' data = pd.concat([reference_results, data], ignore_index=True) + is_time_based_x_axis = self.calculator.timestamp_column_name is not None + if self.calculator.problem_type == ProblemType.REGRESSION: return _step_plot( table=data, @@ -124,6 +126,8 @@ def _plot_target_drift( title=f'KS statistic over time for {self.calculator.y_true}', y_axis_title='KS statistic', v_line_separating_analysis_period=plot_period_separator, + start_date_column_name='start_date' if is_time_based_x_axis else None, + end_date_column_name='end_date' if is_time_based_x_axis else None, ) elif self.calculator.problem_type in [ProblemType.CLASSIFICATION_BINARY, ProblemType.CLASSIFICATION_MULTICLASS]: return _step_plot( @@ -137,6 +141,8 @@ def _plot_target_drift( v_line_separating_analysis_period=plot_period_separator, partial_target_column_name='targets_missing_rate', statistically_significant_column_name='significant', + start_date_column_name='start_date' if is_time_based_x_axis else None, + end_date_column_name='end_date' if is_time_based_x_axis else None, ) else: raise RuntimeError( @@ -160,6 +166,8 @@ def _plot_target_distribution(self, plot_reference: bool) -> go.Figure: reference_results['period'] = 'reference' results_data = pd.concat([reference_results, results_data.copy()], ignore_index=True) + is_time_based_x_axis = self.calculator.timestamp_column_name is not None + if self.calculator.problem_type in [ProblemType.CLASSIFICATION_BINARY, ProblemType.CLASSIFICATION_MULTICLASS]: return _step_plot( table=results_data, @@ -172,14 +180,14 @@ def _plot_target_distribution(self, plot_reference: bool) -> go.Figure: v_line_separating_analysis_period=plot_period_separator, partial_target_column_name='targets_missing_rate', statistically_significant_column_name='significant', + start_date_column_name='start_date' if is_time_based_x_axis else None, + end_date_column_name='end_date' if is_time_based_x_axis else None, ) if self.calculator.problem_type == ProblemType.REGRESSION: feature_table = pd.concat( [ chunk.data.assign(key=chunk.key) - for chunk in self.calculator.chunker.split( - self.calculator.previous_analysis_data, self.calculator.timestamp_column_name - ) + for chunk in self.calculator.chunker.split(self.calculator.previous_analysis_data) ] ) @@ -196,13 +204,13 @@ def _plot_target_distribution(self, plot_reference: bool) -> go.Figure: reference_feature_table = pd.concat( [ chunk.data.assign(key=chunk.key) - for chunk in self.calculator.chunker.split( - self.calculator.previous_reference_data, self.calculator.timestamp_column_name - ) + for chunk in self.calculator.chunker.split(self.calculator.previous_reference_data) ] ) feature_table = pd.concat([reference_feature_table, feature_table], ignore_index=True) + is_time_based_x_axis = self.calculator.timestamp_column_name is not None + return _joy_plot( feature_table=feature_table, drift_table=results_data, @@ -211,8 +219,12 @@ def _plot_target_distribution(self, plot_reference: bool) -> go.Figure: feature_column_name=self.calculator.y_true, x_axis_title=f'{self.calculator.y_true}', post_kde_clip=None, - title=f'Distribution over time for {self.calculator.y_true}', + title=f'Distribution over time for {self.calculator.y_true}' + if is_time_based_x_axis + else f'Distribution over chunks for {self.calculator.y_true}', style='vertical', + start_date_column_name='start_date' if is_time_based_x_axis else None, + end_date_column_name='end_date' if is_time_based_x_axis else None, ) else: raise RuntimeError( diff --git a/nannyml/performance_calculation/calculator.py b/nannyml/performance_calculation/calculator.py index 38914284..7d61bf82 100644 --- a/nannyml/performance_calculation/calculator.py +++ b/nannyml/performance_calculation/calculator.py @@ -28,12 +28,12 @@ class PerformanceCalculator(AbstractCalculator): def __init__( self, - timestamp_column_name: str, metrics: List[str], y_true: str, y_pred: str, problem_type: Union[str, ProblemType], y_pred_proba: ModelOutputsType = None, + timestamp_column_name: str = None, chunk_size: int = None, chunk_number: int = None, chunk_period: str = None, @@ -52,7 +52,7 @@ def __init__( The dictionary maps a class/label string to the column name containing model outputs for that class/label. y_pred: str The name of the column containing your model predictions. - timestamp_column_name: str + timestamp_column_name: str, default=None The name of the column containing the timestamp of the model prediction. metrics: List[str] A list of metrics to calculate. @@ -102,8 +102,6 @@ def __init__( self.y_pred_proba = y_pred_proba - self.timestamp_column_name = timestamp_column_name - if isinstance(problem_type, str): problem_type = ProblemType.parse(problem_type) self.problem_type = problem_type @@ -163,13 +161,14 @@ def _calculate(self, data: pd.DataFrame, *args, **kwargs) -> PerformanceCalculat 'Please ensure you run ``calculator.fit()`` ' 'before running ``calculator.calculate()``' ) - chunks = self.chunker.split(data, timestamp_column_name=self.timestamp_column_name) + chunks = self.chunker.split(data) # Construct result frame res = pd.DataFrame.from_records( [ { 'key': chunk.key, + 'chunk_index': chunk.chunk_index, 'start_index': chunk.start_index, 'end_index': chunk.end_index, 'start_date': chunk.start_datetime, diff --git a/nannyml/performance_calculation/metrics/base.py b/nannyml/performance_calculation/metrics/base.py index 6e5f50e6..b6ff5051 100644 --- a/nannyml/performance_calculation/metrics/base.py +++ b/nannyml/performance_calculation/metrics/base.py @@ -74,7 +74,6 @@ def fit(self, reference_data: pd.DataFrame, chunker: Chunker): # Calculate alert thresholds reference_chunks = chunker.split( reference_data, - timestamp_column_name=self.calculator.timestamp_column_name, ) self.lower_threshold, self.upper_threshold = self._calculate_alert_thresholds( reference_chunks=reference_chunks, diff --git a/nannyml/performance_calculation/result.py b/nannyml/performance_calculation/result.py index 30e7279f..c2f34504 100644 --- a/nannyml/performance_calculation/result.py +++ b/nannyml/performance_calculation/result.py @@ -158,6 +158,8 @@ def _plot_performance_metric( reference_results['period'] = 'reference' results_data = pd.concat([reference_results, results_data], ignore_index=True) + is_time_based_x_axis = calculator.timestamp_column_name is not None + # Plot metric performance fig = _step_plot( table=results_data, @@ -175,6 +177,8 @@ def _plot_performance_metric( y_axis_title='Realized performance', v_line_separating_analysis_period=plot_period_separator, sampling_error_column_name=f'{metric.column_name}_sampling_error', + start_date_column_name='start_date' if is_time_based_x_axis else None, + end_date_column_name='end_date' if is_time_based_x_axis else None, ) return fig diff --git a/nannyml/performance_estimation/confidence_based/_cbpe_binary_classification.py b/nannyml/performance_estimation/confidence_based/_cbpe_binary_classification.py index e53db605..a3deffa9 100644 --- a/nannyml/performance_estimation/confidence_based/_cbpe_binary_classification.py +++ b/nannyml/performance_estimation/confidence_based/_cbpe_binary_classification.py @@ -23,8 +23,8 @@ def __init__( y_pred: str, y_pred_proba: ModelOutputsType, y_true: str, - timestamp_column_name: str, problem_type: Union[str, ProblemType], + timestamp_column_name: str = None, chunk_size: int = None, chunk_number: int = None, chunk_period: str = None, @@ -98,12 +98,13 @@ def _estimate(self, data: pd.DataFrame, *args, **kwargs) -> CBPEPerformanceEstim if self.needs_calibration: data[self.y_pred_proba] = self.calibrator.calibrate(data[self.y_pred_proba]) - chunks = self.chunker.split(data, timestamp_column_name=self.timestamp_column_name) + chunks = self.chunker.split(data) res = pd.DataFrame.from_records( [ { 'key': chunk.key, + 'chunk_index': chunk.chunk_index, 'start_index': chunk.start_index, 'end_index': chunk.end_index, 'start_date': chunk.start_datetime, @@ -137,5 +138,5 @@ def _estimate_chunk(self, chunk: Chunk) -> Dict: estimated_metric > metric.upper_threshold or estimated_metric < metric.lower_threshold ) estimates['period'] = 'analysis' - estimates['estimated'] = True + estimates['estimated'] = True return estimates diff --git a/nannyml/performance_estimation/confidence_based/_cbpe_multiclass_classification.py b/nannyml/performance_estimation/confidence_based/_cbpe_multiclass_classification.py index 5ed60d60..f24b80ad 100644 --- a/nannyml/performance_estimation/confidence_based/_cbpe_multiclass_classification.py +++ b/nannyml/performance_estimation/confidence_based/_cbpe_multiclass_classification.py @@ -25,8 +25,8 @@ def __init__( y_pred: str, y_pred_proba: ModelOutputsType, y_true: str, - timestamp_column_name: str, problem_type: Union[str, ProblemType], + timestamp_column_name: str = None, chunk_size: int = None, chunk_number: int = None, chunk_period: str = None, @@ -86,12 +86,13 @@ def _estimate(self, data: pd.DataFrame, *args, **kwargs) -> CBPEPerformanceEstim data = _calibrate_predicted_probabilities(data, self.y_true, self.y_pred_proba, self._calibrators) - chunks = self.chunker.split(data, timestamp_column_name=self.timestamp_column_name) + chunks = self.chunker.split(data) res = pd.DataFrame.from_records( [ { 'key': chunk.key, + 'chunk_index': chunk.chunk_index, 'start_index': chunk.start_index, 'end_index': chunk.end_index, 'start_date': chunk.start_datetime, @@ -125,7 +126,7 @@ def _estimate_for_chunk(self, chunk: Chunk) -> Dict: estimated_metric > metric.upper_threshold or estimated_metric < metric.lower_threshold ) estimates['period'] = 'analysis' - estimates['estimated'] = True + estimates['estimated'] = True return estimates diff --git a/nannyml/performance_estimation/confidence_based/cbpe.py b/nannyml/performance_estimation/confidence_based/cbpe.py index 94af070f..2df6f0ef 100644 --- a/nannyml/performance_estimation/confidence_based/cbpe.py +++ b/nannyml/performance_estimation/confidence_based/cbpe.py @@ -44,8 +44,8 @@ def __init__( y_pred: str, y_pred_proba: ModelOutputsType, y_true: str, - timestamp_column_name: str, problem_type: Union[str, ProblemType], + timestamp_column_name: str = None, chunk_size: int = None, chunk_number: int = None, chunk_period: str = None, @@ -57,6 +57,17 @@ def __init__( Parameters ---------- + y_true: str + The name of the column containing target values (that are provided in reference data during fitting). + y_pred_proba: ModelOutputsType + Name(s) of the column(s) containing your model output. + Pass a single string when there is only a single model output column, e.g. in binary classification cases. + Pass a dictionary when working with multiple output columns, e.g. in multiclass classification cases. + The dictionary maps a class/label string to the column name containing model outputs for that class/label. + y_pred: str + The name of the column containing your model predictions. + timestamp_column_name: str, default=None + The name of the column containing the timestamp of the model prediction. metrics: List[str] A list of metrics to calculate. chunk_size: int, default=None @@ -114,12 +125,11 @@ def __init__( >>> results.plot(metric=metric, plot_reference=True).show() """ - super().__init__(chunk_size, chunk_number, chunk_period, chunker) + super().__init__(chunk_size, chunk_number, chunk_period, chunker, timestamp_column_name) self.y_true = y_true self.y_pred = y_pred self.y_pred_proba = y_pred_proba - self.timestamp_column_name = timestamp_column_name if metrics is None or len(metrics) == 0: raise InvalidArgumentsException( diff --git a/nannyml/performance_estimation/confidence_based/metrics.py b/nannyml/performance_estimation/confidence_based/metrics.py index fd99d6f7..63d40109 100644 --- a/nannyml/performance_estimation/confidence_based/metrics.py +++ b/nannyml/performance_estimation/confidence_based/metrics.py @@ -72,7 +72,6 @@ def fit(self, reference_data: pd.DataFrame): # Calculate alert thresholds reference_chunks = self.estimator.chunker.split( reference_data, - timestamp_column_name=self.estimator.timestamp_column_name, ) self.lower_threshold, self.upper_threshold = self._alert_thresholds(reference_chunks) diff --git a/nannyml/performance_estimation/confidence_based/results.py b/nannyml/performance_estimation/confidence_based/results.py index 485b2297..94e3c348 100644 --- a/nannyml/performance_estimation/confidence_based/results.py +++ b/nannyml/performance_estimation/confidence_based/results.py @@ -170,6 +170,8 @@ def _plot_cbpe_performance_estimation( lambda r: r[f'alert_{metric.column_name}'] if r['period'] == 'analysis' else False, axis=1 ) + is_time_based_x_axis = estimator.timestamp_column_name is not None + # Plot estimated performance fig = _step_plot( table=estimation_results, @@ -193,6 +195,8 @@ def _plot_cbpe_performance_estimation( lower_confidence_column_name=f'lower_confidence_{metric.column_name}', upper_confidence_column_name=f'upper_confidence_{metric.column_name}', sampling_error_column_name=f'sampling_error_{metric.column_name}', + start_date_column_name='start_date' if is_time_based_x_axis else None, + end_date_column_name='end_date' if is_time_based_x_axis else None, ) return fig diff --git a/nannyml/performance_estimation/direct_loss_estimation/dle.py b/nannyml/performance_estimation/direct_loss_estimation/dle.py index 54b97b05..6f915816 100644 --- a/nannyml/performance_estimation/direct_loss_estimation/dle.py +++ b/nannyml/performance_estimation/direct_loss_estimation/dle.py @@ -41,7 +41,7 @@ def __init__( feature_column_names: List[str], y_pred: str, y_true: str, - timestamp_column_name: str, + timestamp_column_name: str = None, chunk_size: int = None, chunk_number: int = None, chunk_period: str = None, @@ -159,12 +159,11 @@ def __init__( >>> results = estimator.estimate(analysis_df) """ - super().__init__(chunk_size, chunk_number, chunk_period, chunker) + super().__init__(chunk_size, chunk_number, chunk_period, chunker, timestamp_column_name) self.feature_column_names = feature_column_names self.y_pred = y_pred self.y_true = y_true - self.timestamp_column_name = timestamp_column_name if metrics is None: metrics = DEFAULT_METRICS @@ -233,12 +232,13 @@ def _estimate(self, data: pd.DataFrame, *args, **kwargs) -> AbstractEstimatorRes lambda x: self._categorical_encoders[x.name].transform(x) ) - chunks = self.chunker.split(data, timestamp_column_name=self.timestamp_column_name) + chunks = self.chunker.split(data) res = pd.DataFrame.from_records( [ { 'key': chunk.key, + 'chunk_index': chunk.chunk_index, 'start_index': chunk.start_index, 'end_index': chunk.end_index, 'start_date': chunk.start_datetime, diff --git a/nannyml/performance_estimation/direct_loss_estimation/metrics.py b/nannyml/performance_estimation/direct_loss_estimation/metrics.py index bf44530e..f5674f09 100644 --- a/nannyml/performance_estimation/direct_loss_estimation/metrics.py +++ b/nannyml/performance_estimation/direct_loss_estimation/metrics.py @@ -98,7 +98,6 @@ def fit(self, reference_data: pd.DataFrame): # Calculate alert thresholds reference_chunks = self.estimator.chunker.split( reference_data, - timestamp_column_name=self.estimator.timestamp_column_name, ) self.lower_threshold, self.upper_threshold = self._alert_thresholds(reference_chunks) diff --git a/nannyml/performance_estimation/direct_loss_estimation/result.py b/nannyml/performance_estimation/direct_loss_estimation/result.py index c83e7007..7d21537a 100644 --- a/nannyml/performance_estimation/direct_loss_estimation/result.py +++ b/nannyml/performance_estimation/direct_loss_estimation/result.py @@ -75,11 +75,15 @@ def _plot_direct_error_estimation_performance( lambda r: r[f'alert_{metric.column_name}'] if r['period'] == 'analysis' else False, axis=1 ) + is_time_based_x_axis = estimator.timestamp_column_name is not None + # Plot estimated performance fig = _step_plot( table=estimation_results, metric_column_name='plottable', chunk_column_name=CHUNK_KEY_COLUMN_NAME, + start_date_column_name='start_date' if is_time_based_x_axis else None, + end_date_column_name='end_date' if is_time_based_x_axis else None, chunk_legend_labels=[ f'Reference period (realized {metric.display_name})', f'Analysis period (estimated {metric.display_name})', diff --git a/nannyml/plots/_joy_plot.py b/nannyml/plots/_joy_plot.py index 80713558..0f6b12c3 100644 --- a/nannyml/plots/_joy_plot.py +++ b/nannyml/plots/_joy_plot.py @@ -113,14 +113,13 @@ def _create_joy_table( drift_table, kde_table, feature_column_name, - chunk_column_name='chunk', - chunk_type_column_name='chunk_type', - end_date_column_name='end_date', - drift_column_name='drift', - chunk_types=None, + chunk_column_name, + chunk_type_column_name, + end_date_column_name, + chunk_index_column_name, + drift_column_name, + chunk_types, ): - if chunk_types is None: - chunk_types = ['reference', 'analysis'] joy_table = pd.merge(drift_table, kde_table) @@ -132,7 +131,10 @@ def _create_joy_table( joy_table.loc[joy_table[drift_column_name], 'hue'] = -1 # Sort to make sure most current chunks are plotted in front of the others - joy_table = joy_table.sort_values(end_date_column_name, ascending=True).reset_index(drop=True) + if end_date_column_name: + joy_table = joy_table.sort_values(end_date_column_name, ascending=True).reset_index(drop=True) + else: + joy_table = joy_table.sort_values(chunk_index_column_name, ascending=True).reset_index(drop=True) return joy_table @@ -142,6 +144,7 @@ def _create_joy_plot( chunk_column_name, start_date_column_name, end_date_column_name, + chunk_index_column_name, chunk_type_column_name, drift_column_name, chunk_types, @@ -174,10 +177,20 @@ def _create_joy_plot( if colors is None: colors = [Colors.BLUE_SKY_CRAYOLA, Colors.INDIGO_PERSIAN, Colors.GRAY_DARK, Colors.RED_IMPERIAL] + is_time_based_x_axis = start_date_column_name and end_date_column_name + offset = ( + joy_table.loc[joy_table['period'] == 'reference', 'chunk_index'].max() + 1 + if len(joy_table.loc[joy_table['period'] == 'reference']) > 0 + else 0 + ) + joy_table['chunk_index_unified'] = [ + idx + offset if period == 'analysis' else idx + for idx, period in zip(joy_table['chunk_index'], joy_table['period']) + ] + colors_transparent = [ 'rgba{}'.format(matplotlib.colors.to_rgba(matplotlib.colors.to_rgb(color), alpha)) for color in colors ] - hover_template = chunk_hover_label + ' %{customdata[0]}: %{customdata[1]} - %{customdata[2]}, %{customdata[3]}' layout = go.Layout( title=title, @@ -210,8 +223,16 @@ def _create_joy_plot( fig = go.Figure(layout=layout) for i, row in joy_table.iterrows(): - y_date_position = row[start_date_column_name] - y_date_height_scaler = row[start_date_column_name] - row[end_date_column_name] + start_date_label_hover, end_date_label_hover = '', '' + if is_time_based_x_axis: + y_date_position = row[start_date_column_name] + y_date_height_scaler = row[start_date_column_name] - row[end_date_column_name] + start_date_label_hover = row[start_date_column_name].strftime(date_label_hover_format) + end_date_label_hover = row[end_date_column_name].strftime(date_label_hover_format) + else: + y_date_position = row['chunk_index_unified'] + y_date_height_scaler = -1 + kde_support = row['kde_support'] kde_density_scaled = row['kde_density_scaled'] * joy_overlap kde_quartiles = [(q[0], q[1] * joy_overlap) for q in row['kde_quartiles_scaled']] @@ -220,9 +241,6 @@ def _create_joy_plot( color_fill = colors_transparent[row['hue']] trace_name = hue_joy_hover_labels[row['hue']] - start_date_label_hover = row[start_date_column_name].strftime(date_label_hover_format) - end_date_label_hover = row[end_date_column_name].strftime(date_label_hover_format) - # ____Plot elements___# fig.add_trace( go.Scatter( @@ -252,12 +270,27 @@ def _create_joy_plot( if quartiles_legend_label: for kde_quartile in kde_quartiles: - hover_content = ( - row[chunk_column_name], - start_date_label_hover, - end_date_label_hover, - np.round(kde_quartile[0], 3), - ) + if is_time_based_x_axis: + hover_content = ( + row[chunk_column_name], + start_date_label_hover, + end_date_label_hover, + np.round(kde_quartile[0], 3), + ) + hover_template = ( + chunk_hover_label + + ' %{customdata[0]}: %{customdata[1]} - %{customdata[2]}, %{customdata[3]}' + ) + else: + hover_content = ( + row[chunk_column_name], + row['chunk_index_unified'], + np.round(kde_quartile[0], 3), + ) + hover_template = ( + chunk_hover_label + + ' %{customdata[0]}: chunk index %{customdata[1]}, %{customdata[2]}' + ) hover_data = np.asarray([hover_content, hover_content]) @@ -279,8 +312,12 @@ def _create_joy_plot( ) # ____Add elements to legend___# - x = [np.nan] * len(joy_table) if style == 'horizontal' else joy_table[end_date_column_name] - y = joy_table[end_date_column_name] if style == 'horizontal' else [np.nan] * len(joy_table) + if is_time_based_x_axis: + x = [np.nan] * len(joy_table) if style == 'horizontal' else joy_table[end_date_column_name] + y = joy_table[end_date_column_name] if style == 'horizontal' else [np.nan] * len(joy_table) + else: + x = [np.nan] * len(joy_table) if style == 'horizontal' else joy_table['chunk_index_unified'] + y = joy_table['chunk_index_unified'] if style == 'horizontal' else [np.nan] * len(joy_table) # Add joy coloring for i, hue_label in enumerate(hue_legend_labels): @@ -324,8 +361,9 @@ def _joy_plot( feature_table, feature_column_name, chunk_column_name='chunk', - start_date_column_name='start_date', - end_date_column_name='end_date', + start_date_column_name: str = None, + end_date_column_name: str = None, + chunk_index_column_name='chunk_index', chunk_type_column_name='period', drift_column_name='drift', chunk_types=None, @@ -337,10 +375,10 @@ def _joy_plot( joy_hover_format='{0:.2f}', joy_overlap=1, figure=None, - title='Feature: distribution over time', + title=None, x_axis_title='Feature', x_axis_lim=None, - y_axis_title='Time', + y_axis_title=None, alpha=0.2, colors=None, kde_cut=3, @@ -364,6 +402,16 @@ def _joy_plot( if colors is None: colors = [Colors.BLUE_SKY_CRAYOLA, Colors.INDIGO_PERSIAN, Colors.GRAY_DARK, Colors.RED_IMPERIAL] + is_time_based_x_axis = start_date_column_name and end_date_column_name + + if not x_axis_title: + x_axis_title = ( + 'Feature distribution over time' if is_time_based_x_axis else 'Feature distribution across chunks' + ) + + if not y_axis_title: + y_axis_title = 'Time' if is_time_based_x_axis else 'Chunk index' + kde_table = _create_kde_table( feature_table, feature_column_name, chunk_column_name, kde_cut, kde_clip, post_kde_clip ) @@ -375,6 +423,7 @@ def _joy_plot( chunk_column_name, chunk_type_column_name, end_date_column_name, + chunk_index_column_name, drift_column_name, chunk_types, ) @@ -384,6 +433,7 @@ def _joy_plot( chunk_column_name, start_date_column_name, end_date_column_name, + chunk_index_column_name, chunk_type_column_name, drift_column_name, chunk_types, diff --git a/nannyml/plots/_stacked_bar_plot.py b/nannyml/plots/_stacked_bar_plot.py index ea801cee..be4b525e 100644 --- a/nannyml/plots/_stacked_bar_plot.py +++ b/nannyml/plots/_stacked_bar_plot.py @@ -67,6 +67,7 @@ def _create_stacked_bar_table( value_counts_table, start_date_column_name, end_date_column_name, + chunk_index_column_name, chunk_type_column_name, chunk_column_name, drift_column_name, @@ -82,15 +83,26 @@ def _create_stacked_bar_table( if drift_column_name and drift_column_name in stacked_bar_table.columns: stacked_bar_table.loc[stacked_bar_table[drift_column_name], 'hue'] = -1 - stacked_bar_table = stacked_bar_table.sort_values(end_date_column_name, ascending=True).reset_index(drop=True) - stacked_bar_table['next_end_date'] = stacked_bar_table[end_date_column_name].shift(-1) + if start_date_column_name and end_date_column_name: + stacked_bar_table = stacked_bar_table.sort_values(end_date_column_name, ascending=True).reset_index(drop=True) + stacked_bar_table['next_end_date'] = stacked_bar_table[end_date_column_name].shift(-1) - stacked_bar_table['start_date_label_hover'] = stacked_bar_table[start_date_column_name].dt.strftime( - date_label_hover_format - ) - stacked_bar_table['end_date_label_hover'] = stacked_bar_table[end_date_column_name].dt.strftime( - date_label_hover_format + stacked_bar_table['start_date_label_hover'] = stacked_bar_table[start_date_column_name].dt.strftime( + date_label_hover_format + ) + stacked_bar_table['end_date_label_hover'] = stacked_bar_table[end_date_column_name].dt.strftime( + date_label_hover_format + ) + + offset = ( + stacked_bar_table.loc[stacked_bar_table['period'] == 'reference', 'chunk_index'].max() + 1 + if len(stacked_bar_table.loc[stacked_bar_table['period'] == 'reference']) > 0 + else 0 ) + stacked_bar_table['chunk_index_unified'] = [ + idx + offset if period == 'analysis' else idx + for idx, period in zip(stacked_bar_table['chunk_index'], stacked_bar_table['period']) + ] return stacked_bar_table @@ -100,6 +112,7 @@ def _create_stacked_bar_plot( feature_column_name, start_date_column_name, end_date_column_name, + chunk_index_column_name, chunk_type_column_name, chunk_column_name, chunk_types, @@ -129,10 +142,25 @@ def _create_stacked_bar_plot( for color in colors ] - hover_template = ( - chunk_hover_label - + ' %{customdata[0]}: %{customdata[1]} - %{customdata[2]}; (%{customdata[3]}, %{customdata[4]})' - ) + is_time_based_x_axis = start_date_column_name and end_date_column_name + if is_time_based_x_axis: + hover_template = ( + chunk_hover_label + + ' %{customdata[0]}: %{customdata[1]} - %{customdata[2]}; (%{customdata[3]}, %{customdata[4]})' + ) + custom_data = [ + chunk_column_name, + 'start_date_label_hover', + 'end_date_label_hover', + 'value_counts_normalised', + 'value_counts', + ] + else: + hover_template = ( + chunk_hover_label + + ' %{customdata[0]}: chunk index %{customdata[1]}, (%{customdata[2]}, %{customdata[3]})' + ) + custom_data = [chunk_column_name, 'chunk_index_unified', 'value_counts_normalised', 'value_counts'] layout = go.Layout( title=title, @@ -156,20 +184,17 @@ def _create_stacked_bar_plot( stacked_bar_table[feature_column_name] == category, ] - hover_data = data[ - [ - chunk_column_name, - 'start_date_label_hover', - 'end_date_label_hover', - 'value_counts_normalised', - 'value_counts', - ] - ].values + hover_data = data[custom_data].values + + if is_time_based_x_axis: + x = data[start_date_column_name] + else: + x = data['chunk_index_unified'] fig.add_trace( go.Bar( name=category, - x=data[start_date_column_name], + x=x, y=data['value_counts_normalised'], orientation='v', marker_line_color=data['hue'].apply(lambda hue: colors[hue] if hue == -1 else 'rgba(255,255,255,1)'), @@ -187,18 +212,22 @@ def _create_stacked_bar_plot( for i, chunk_type in enumerate(chunk_types): subset = stacked_bar_table.loc[stacked_bar_table[chunk_type_column_name] == chunk_type] if subset.shape[0] > 0: + x0 = subset[start_date_column_name].min() if is_time_based_x_axis else subset['chunk_index_unified'].min() + x1 = subset[end_date_column_name].max() if is_time_based_x_axis else subset['chunk_index_unified'].max() + 1 fig.add_shape( y0=0, y1=1.05, - x0=subset[start_date_column_name].min(), - x1=subset[end_date_column_name].max(), + x0=x0, + x1=x1, line_color=colors_transparant[i], layer='above', line_width=2, line=dict(dash='dash'), ), fig.add_annotation( - x=subset[start_date_column_name].mean(), + x=subset[start_date_column_name].mean() + if is_time_based_x_axis + else subset['chunk_index_unified'].mean(), y=1.025, text=chunk_type_labels[i], font=dict(color=colors[i]), @@ -209,7 +238,7 @@ def _create_stacked_bar_plot( # ____Add elements to legend___# x = [np.nan] * len(data) - y = data[start_date_column_name] + y = data[start_date_column_name] if is_time_based_x_axis else [np.nan] * len(data) # Add chunk types for i, hue_label in enumerate(chunk_types): @@ -264,8 +293,9 @@ def _stacked_bar_plot( feature_table, drift_table, feature_column_name, - start_date_column_name='start_date', - end_date_column_name='end_date', + start_date_column_name=None, + end_date_column_name=None, + chunk_index_column_name='chunk_index', chunk_type_column_name='period', chunk_column_name='chunk', drift_column_name='drift', @@ -319,6 +349,7 @@ def _stacked_bar_plot( value_counts_table, start_date_column_name, end_date_column_name, + chunk_index_column_name, chunk_type_column_name, chunk_column_name, drift_column_name, @@ -331,6 +362,7 @@ def _stacked_bar_plot( feature_column_name, start_date_column_name, end_date_column_name, + chunk_index_column_name, chunk_type_column_name, chunk_column_name, chunk_types, diff --git a/nannyml/plots/_step_plot.py b/nannyml/plots/_step_plot.py index 101a4a00..08161e7a 100644 --- a/nannyml/plots/_step_plot.py +++ b/nannyml/plots/_step_plot.py @@ -22,21 +22,35 @@ def _data_prep_step_plot( data: pd.DataFrame, metric_column_name: str, - start_date_column_name: str, - end_date_column_name: str, partial_target_column_name: str, sampling_error_column_name: str, drift_column_name: str, + start_date_column_name: str = None, + end_date_column_name: str = None, hover_metric_format='{0:.4f}', hover_date_label_format='%b-%d-%Y', ): data = data.copy() - data['mid_point_date'] = ( - data[start_date_column_name] + (data[end_date_column_name] - data[start_date_column_name]) / 2 + + if start_date_column_name and end_date_column_name: + # we have a time based X-axis + data['mid_point_date'] = ( + data[start_date_column_name] + (data[end_date_column_name] - data[start_date_column_name]) / 2 + ) + data['start_date_label'] = data[start_date_column_name].dt.strftime(hover_date_label_format) + data['end_date_label'] = data[end_date_column_name].dt.strftime(hover_date_label_format) + + offset = ( + data.loc[data['period'] == 'reference', 'chunk_index'].max() + 1 + if len(data.loc[data['period'] == 'reference']) > 0 + else 0 ) + data['chunk_index_unified'] = [ + idx + offset if period == 'analysis' else idx for idx, period in zip(data['chunk_index'], data['period']) + ] + data['metric_label'] = data[metric_column_name].apply(lambda x: hover_metric_format.format(x)) - data['start_date_label'] = data[start_date_column_name].dt.strftime(hover_date_label_format) - data['end_date_label'] = data[end_date_column_name].dt.strftime(hover_date_label_format) + if sampling_error_column_name is not None: data['plt_sampling_error'] = np.round(SAMPLING_ERROR_RANGE * data[sampling_error_column_name], 4) @@ -59,11 +73,15 @@ def _data_prep_step_plot( return data -def _add_artificial_end_point(data: pd.DataFrame, start_date_column_name: str, end_date_column_name: str): +def _add_artificial_end_point( + data: pd.DataFrame, start_date_column_name: str, end_date_column_name: str, chunk_index_column_name: str +): data_point_hack = data.tail(1).copy() - data_point_hack[start_date_column_name] = data_point_hack[end_date_column_name] - data_point_hack[end_date_column_name] = pd.NaT - data_point_hack['mid_point_date'] = pd.NaT + if start_date_column_name and end_date_column_name: + data_point_hack[start_date_column_name] = data_point_hack[end_date_column_name] + data_point_hack[end_date_column_name] = pd.NaT + data_point_hack['mid_point_date'] = pd.NaT + data_point_hack[chunk_index_column_name] = data_point_hack[chunk_index_column_name] + 1 data_point_hack.index = data_point_hack.index + 1 return pd.concat([data, data_point_hack], axis=0) @@ -81,8 +99,9 @@ def _step_plot( drift_column_name=None, partial_target_column_name=None, chunk_column_name='chunk', - start_date_column_name='start_date', - end_date_column_name='end_date', + start_date_column_name=None, + end_date_column_name=None, + chunk_index_column_name='chunk_index', chunk_type_column_name='period', chunk_types=None, confidence_legend_label='Confidence band', @@ -100,7 +119,7 @@ def _step_plot( v_line_separating_analysis_period=True, figure=None, title='Metric over time', - x_axis_title='Time', + x_axis_title=None, y_axis_title='Metric', y_axis_lim=None, alpha=0.2, @@ -121,14 +140,19 @@ def _step_plot( if colors is None: colors = [Colors.BLUE_SKY_CRAYOLA, Colors.INDIGO_PERSIAN, Colors.GRAY_DARK, Colors.RED_IMPERIAL] + is_time_based_x_axis = start_date_column_name and end_date_column_name + + if not x_axis_title: + x_axis_title = 'Time' if is_time_based_x_axis else 'Chunk' + data = _data_prep_step_plot( table, metric_column_name, - start_date_column_name, - end_date_column_name, partial_target_column_name, sampling_error_column_name, drift_column_name, + start_date_column_name, + end_date_column_name, hover_metric_format, hover_date_label_format, ) @@ -137,34 +161,48 @@ def _step_plot( 'rgba{}'.format(matplotlib.colors.to_rgba(matplotlib.colors.to_rgb(color), alpha)) for color in colors ] + custom_data_columns = [ + chunk_column_name, + 'metric_label', + 'hover_period', + 'hover_alert', + 'incomplete_target_percentage', + ] + # This has been updated to show the general shape, but the period label and other details are hard-coded. # I think this needs to be put together more conditionally when building each figure, but I couldn't figure out how # The border can also be changed, but I think that also means this needs restructuring? # https://plotly.com/python/hover-text-and-formatting/#customizing-hover-label-appearance hover_template = ( - '%{customdata[4]}     %{customdata[5]}
' # noqa: E501 + f'%{{customdata[{custom_data_columns.index("hover_period")}]}}     ' # noqa: E501 + + f'%{{customdata[{custom_data_columns.index("hover_alert")}]}}
' + hover_labels[0] - + ': %{customdata[0]}     ' - + 'From %{customdata[1]} to %{customdata[2]}    
' - + hover_labels[1] - + ': %{customdata[3]}     ' - + '%{customdata[6]}    ' + + f': %{{customdata[{custom_data_columns.index(chunk_column_name)}]}}     ' ) + if is_time_based_x_axis: + custom_data_columns += ['start_date_label', 'end_date_label'] + hover_template += ( + f'From %{{customdata[{custom_data_columns.index("start_date_label")}]}} to ' + f'%{{customdata[{custom_data_columns.index("end_date_label")}]}}    
' + ) + else: + custom_data_columns += ['chunk_index_unified'] + hover_template += ( + f'Chunk index: %{{customdata[{custom_data_columns.index("chunk_index_unified")}]}}
' + ) - custom_data_columns = [ - chunk_column_name, - 'start_date_label', - 'end_date_label', - 'metric_label', - 'hover_period', - 'hover_alert', - 'incomplete_target_percentage', - ] + hover_template += ( + hover_labels[1] + + f': %{{customdata[{custom_data_columns.index("metric_label")}]}}     ' + + f'%{{customdata[{custom_data_columns.index("incomplete_target_percentage")}]}}    ' + ) if sampling_error_column_name is not None: - hover_template += '
Sampling error range: +/-%{customdata[7]}' # noqa: E501 custom_data_columns += ['plt_sampling_error'] - + hover_template += ( + '
Sampling error range: +/-' + + f'%{{customdata[{custom_data_columns.index("plt_sampling_error")}]}}' # noqa: E501 + ) hover_template += '' layout = go.Layout( @@ -194,7 +232,15 @@ def _step_plot( # Plot line separating reference and analysis period _plot_reference_analysis_separator( - fig, data, colors, v_line_separating_analysis_period, chunk_type_column_name, chunk_types + fig, + data, + colors, + v_line_separating_analysis_period, + chunk_type_column_name, + chunk_types, + start_date_column_name, + end_date_column_name, + 'chunk_index_unified', ) # Plot confidence band, if the metric estimated @@ -209,6 +255,7 @@ def _step_plot( start_date_column_name, end_date_column_name, plot_for_reference=plot_confidence_for_reference, + chunk_index_column_name='chunk_index_unified', ) # Plot statistically significant band @@ -220,6 +267,7 @@ def _step_plot( metric_column_name, start_date_column_name, end_date_column_name, + chunk_index_column_name='chunk_index_unified', ) # Plot metric for reference and analysis period @@ -235,6 +283,7 @@ def _step_plot( start_date_column_name, end_date_column_name, partial_target_column_name, + 'chunk_index_unified', ) # Plot metric if partial target in analysis period @@ -249,6 +298,7 @@ def _step_plot( end_date_column_name, partial_target_column_name, partial_target_legend_label, + 'chunk_index_unified', ) # Plot reference and analysis markers that did not drift @@ -264,7 +314,9 @@ def _step_plot( chunk_column_name, chunk_type_column_name, chunk_types, + start_date_column_name, end_date_column_name, + 'chunk_index_unified', ) # Plot data drifted markers and areas @@ -281,11 +333,16 @@ def _step_plot( chunk_column_name, start_date_column_name, end_date_column_name, + 'chunk_index_unified', ) # ____Add elements to legend, order matters___# - x = [data['mid_point_date'].head(1).values, data['mid_point_date'].tail(1).values] + if is_time_based_x_axis: + x = [data['mid_point_date'].head(1).values, data['mid_point_date'].tail(1).values] + else: + x = data[chunk_index_column_name] + y = [np.nan, np.nan] # Add confidence band @@ -390,6 +447,7 @@ def _plot_metric( start_date_column_name, end_date_column_name, partial_target_column_name, + chunk_index_column_name, ): if partial_target_column_name and partial_target_column_name in data.columns: subset = data.loc[data[partial_target_column_name] == 0] @@ -398,7 +456,16 @@ def _plot_metric( for i, chunk_type in enumerate(chunk_types): data_subset = subset.loc[subset[chunk_type_column_name] == chunk_type] - data_subset = _add_artificial_end_point(data_subset, start_date_column_name, end_date_column_name) + + data_subset = _add_artificial_end_point( + data_subset, start_date_column_name, end_date_column_name, chunk_index_column_name + ) + + if start_date_column_name and end_date_column_name: + x = data_subset[start_date_column_name] + else: + x = data_subset[chunk_index_column_name] + dash = None if estimated_column_name and estimated_column_name in data.columns: if not data_subset.empty and data_subset[estimated_column_name].head(1).values[0]: @@ -407,7 +474,7 @@ def _plot_metric( go.Scatter( name=chunk_legend_labels[i], mode='lines', - x=data_subset[start_date_column_name], + x=x, y=data_subset[metric_column_name], line=dict(shape='hv', color=colors[i], width=2, dash=dash), hoverinfo='skip', @@ -427,15 +494,22 @@ def _plot_metric_partial_target( end_date_column_name, partial_target_column_name, partial_target_legend_label, + chunk_index_column_name, ): if partial_target_column_name and partial_target_column_name in data.columns: data_subset = data.loc[(data[chunk_type_column_name] == chunk_types[1])] - data_subset = _add_artificial_end_point(data_subset, start_date_column_name, end_date_column_name) + if start_date_column_name and end_date_column_name: + data_subset = _add_artificial_end_point( + data_subset, start_date_column_name, end_date_column_name, chunk_index_column_name + ) + x = data_subset[start_date_column_name] + else: + x = data_subset[chunk_index_column_name] fig.add_trace( go.Scatter( name=partial_target_legend_label, mode='lines', - x=data_subset[start_date_column_name], + x=x, y=data_subset[metric_column_name], line=dict(shape='hv', color=colors[1], width=2, dash='dot'), hoverinfo='skip', @@ -456,7 +530,9 @@ def _plot_non_drifted_markers( chunk_column_name, chunk_type_column_name, chunk_types, + start_date_column_name, end_date_column_name, + chunk_index_column_name, ): for i, chunk_type in enumerate(chunk_types): if drift_column_name and drift_column_name in data.columns: @@ -464,11 +540,16 @@ def _plot_non_drifted_markers( else: data_subset = data.loc[(data[chunk_type_column_name] == chunk_type)] + if start_date_column_name and end_date_column_name: + x = data_subset['mid_point_date'] + else: + x = data_subset[chunk_index_column_name] + 0.5 + fig.add_trace( go.Scatter( name=hover_marker_labels[i], mode='markers', - x=data_subset['mid_point_date'], + x=x, y=data_subset[metric_column_name], marker=dict(color=colors[i], size=6, symbol='square'), customdata=data_subset[custom_data_columns].values, @@ -491,12 +572,20 @@ def _plot_drifted_markers_and_areas( chunk_column_name, start_date_column_name, end_date_column_name, + chunk_index_column_name, ): if drift_column_name and drift_column_name in data.columns: for i, row in data.loc[data[drift_column_name], :].iterrows(): + if start_date_column_name and end_date_column_name: + x0 = row[start_date_column_name] + x1 = row[end_date_column_name] + else: + x0 = row[chunk_index_column_name] + x1 = x0 + 1 + fig.add_vrect( - x0=row[start_date_column_name], - x1=row[end_date_column_name], + x0=x0, + x1=x1, fillcolor=colors[-1], opacity=alpha, layer='below', @@ -504,11 +593,16 @@ def _plot_drifted_markers_and_areas( ) data_subset = data.loc[data[drift_column_name]] + if start_date_column_name and end_date_column_name: + x = data_subset['mid_point_date'] + else: + x = data_subset[chunk_index_column_name] + 0.5 + fig.add_trace( go.Scatter( name=hover_marker_labels[2], mode='markers', - x=data_subset['mid_point_date'], + x=x, y=data_subset[metric_column_name], marker=dict(color=colors[-1], size=6, symbol='diamond'), customdata=data_subset[custom_data_columns].values, @@ -556,13 +650,22 @@ def _plot_reference_analysis_separator( v_line_separating_analysis_period: bool, chunk_type_column_name: str, chunk_types: List[str], + start_date_column_name: str, + end_date_column_name: str, + chunk_index_column_name: str, ): if v_line_separating_analysis_period: data_subset = data.loc[ data[chunk_type_column_name] == chunk_types[1], ].head(1) + + if start_date_column_name and end_date_column_name: + x = pd.to_datetime(data_subset[start_date_column_name].values[0]) + else: + x = data_subset[chunk_index_column_name].values[0] + fig.add_vline( - x=pd.to_datetime(data_subset['start_date'].values[0]), + x=x, line=dict(color=colors[1], width=1, dash='dash'), layer='below', ) @@ -579,6 +682,7 @@ def _plot_confidence_band( start_date_column_name: str, end_date_column_name: str, plot_for_reference: bool, + chunk_index_column_name: str, ): if ( lower_confidence_column_name @@ -587,12 +691,19 @@ def _plot_confidence_band( ): def _plot(data_subset, fill_color): - data_subset = _add_artificial_end_point(data_subset, start_date_column_name, end_date_column_name) + data_subset = _add_artificial_end_point( + data_subset, start_date_column_name, end_date_column_name, chunk_index_column_name + ) + if start_date_column_name and end_date_column_name: + x = data_subset[start_date_column_name] + else: + x = data_subset[chunk_index_column_name] + fig.add_traces( [ go.Scatter( mode='lines', - x=data_subset[start_date_column_name], + x=x, y=data_subset[upper_confidence_column_name], line=dict(shape='hv', color='rgba(0,0,0,0)'), hoverinfo='skip', @@ -600,7 +711,7 @@ def _plot(data_subset, fill_color): ), go.Scatter( mode='lines', - x=data_subset[start_date_column_name], + x=x, y=data_subset[lower_confidence_column_name], line=dict(shape='hv', color='rgba(0,0,0,0)'), fill='tonexty', @@ -624,14 +735,19 @@ def _plot_statistical_significance_band( metric_column_name, start_date_column_name, end_date_column_name, + chunk_index_column_name, ): if statistically_significant_column_name is not None and statistically_significant_column_name in data.columns: data_subset = data.loc[data[statistically_significant_column_name]] for i, row in data_subset.iterrows(): + if start_date_column_name and end_date_column_name: + x = [row[start_date_column_name], row[end_date_column_name]] + else: + x = [row[chunk_index_column_name], row[chunk_index_column_name] + 1] fig.add_trace( go.Scatter( mode='lines', - x=[row[start_date_column_name], row[end_date_column_name]], + x=x, y=[row[metric_column_name], row[metric_column_name]], line=dict(color=colors_transparent[1], width=9), hoverinfo='skip', diff --git a/nannyml/runner.py b/nannyml/runner.py index 778290e1..534d30b2 100644 --- a/nannyml/runner.py +++ b/nannyml/runner.py @@ -141,7 +141,7 @@ def _run_statistical_univariate_feature_drift_calculator( console.log('fitting on reference data') calc = UnivariateStatisticalDriftCalculator( feature_column_names=column_mapping['features'], - timestamp_column_name=column_mapping['timestamp'], + timestamp_column_name=column_mapping.get('timestamp', None), chunker=chunker, ).fit(reference_data) @@ -191,7 +191,7 @@ def _run_data_reconstruction_multivariate_feature_drift_calculator( console.log('fitting on reference data') calc = DataReconstructionDriftCalculator( feature_column_names=column_mapping['features'], - timestamp_column_name=column_mapping['timestamp'], + timestamp_column_name=column_mapping.get('timestamp', None), chunker=chunker, ).fit(reference_data) @@ -235,8 +235,8 @@ def _run_statistical_model_output_drift_calculator( console.log('fitting on reference data') calc = StatisticalOutputDriftCalculator( y_pred=column_mapping['y_pred'], - y_pred_proba=column_mapping['y_pred_proba'], - timestamp_column_name=column_mapping['timestamp'], + y_pred_proba=column_mapping.get('y_pred_proba', None), + timestamp_column_name=column_mapping.get('timestamp', None), problem_type=problem_type, chunker=chunker, ).fit(reference_data) @@ -276,8 +276,8 @@ def _run_statistical_model_output_drift_calculator( } elif problem_type == ProblemType.REGRESSION: plots = { - 'prediction_drift_statistic': results.plot('prediction_drift', 'statistic'), - 'prediction_drift_metric': results.plot('prediction_drift', 'p_value'), + 'prediction_drift_ks_stat': results.plot('prediction_drift', 'statistic'), + 'prediction_drift_p_value': results.plot('prediction_drift', 'p_value'), 'prediction_distribution': results.plot('prediction_distribution'), } except Exception as exc: @@ -327,7 +327,7 @@ def _run_target_distribution_drift_calculator( console.log('fitting on reference data') calc = TargetDistributionCalculator( y_true=column_mapping['y_true'], - timestamp_column_name=column_mapping['timestamp'], + timestamp_column_name=column_mapping.get('timestamp', None), chunker=chunker, problem_type=problem_type, ).fit(reference_data) @@ -393,8 +393,8 @@ def _run_realized_performance_calculator( calc = PerformanceCalculator( y_true=column_mapping['y_true'], y_pred=column_mapping['y_pred'], - y_pred_proba=column_mapping['y_pred_proba'], - timestamp_column_name=column_mapping['timestamp'], + y_pred_proba=column_mapping.get('y_pred_proba', None), + timestamp_column_name=column_mapping.get('timestamp', None), chunker=chunker, metrics=metrics, problem_type=problem_type, @@ -456,7 +456,7 @@ def _run_cbpe_performance_estimation( y_true=column_mapping['y_true'], y_pred=column_mapping['y_pred'], y_pred_proba=column_mapping['y_pred_proba'], - timestamp_column_name=column_mapping['timestamp'], + timestamp_column_name=column_mapping.get('timestamp', None), problem_type=problem_type, chunker=chunker, metrics=metrics, @@ -516,7 +516,7 @@ def _run_dee_performance_estimation( feature_column_names=column_mapping['features'], y_true=column_mapping['y_true'], y_pred=column_mapping['y_pred'], - timestamp_column_name=column_mapping['timestamp'], + timestamp_column_name=column_mapping.get('timestamp', None), chunker=chunker, metrics=DEFAULT_METRICS, ).fit(reference_data) diff --git a/pyproject.toml b/pyproject.toml index 5c80e46d..348e9142 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [tool] [tool.poetry] name = "nannyml" -version = "0.6.1" +version = "0.6.2" homepage = "https://github.com/nannyml/nannyml" description = "NannyML, Your library for monitoring model performance." authors = ["Niels Nuyttens "] diff --git a/tests/drift/test_data_reconstruction_drift.py b/tests/drift/test_data_reconstruction_drift.py index a2831001..b4bc75cd 100644 --- a/tests/drift/test_data_reconstruction_drift.py +++ b/tests/drift/test_data_reconstruction_drift.py @@ -168,8 +168,9 @@ def test_data_reconstruction_drift_calculator_should_contain_chunk_details_and_s drift = calc.calculate(data=sample_drift_data) sut = drift.data.columns - assert len(sut) == 12 + assert len(sut) == 13 assert 'key' in sut + assert 'chunk_index' in sut assert 'start_index' in sut assert 'start_date' in sut assert 'end_index' in sut @@ -193,7 +194,7 @@ def test_data_reconstruction_drift_calculator_should_contain_a_row_for_each_chun ).fit(ref_data) drift = calc.calculate(data=sample_drift_data) - expected = len(PeriodBasedChunker(offset='W').split(sample_drift_data, timestamp_column_name='timestamp')) + expected = len(PeriodBasedChunker(offset='W', timestamp_column_name='timestamp').split(sample_drift_data)) sut = len(drift.data) assert sut == expected @@ -273,6 +274,82 @@ def test_data_reconstruction_drift_calculator_numeric_results(sample_drift_data) pd.testing.assert_frame_equal(expected_drift, drift.data[['key', 'reconstruction_error']]) +@pytest.mark.parametrize( + 'calculator_opts, expected', + [ + ( + {'chunk_size': 5000}, + [0.7998744001719177, 0.8020996183121666, 0.8043000024523013, 0.735524850766471, 0.7608678766056979], + ), + ( + {'chunk_size': 5000, 'timestamp_column_name': 'timestamp'}, + [0.7998744001719177, 0.8020996183121666, 0.8043000024523013, 0.735524850766471, 0.7608678766056979], + ), + ( + {'chunk_number': 5}, + [0.7975183099468669, 0.8101736730245841, 0.7942220878040264, 0.7855043522106143, 0.7388546967488279], + ), + ( + {'chunk_number': 5, 'timestamp_column_name': 'timestamp'}, + [0.7975183099468669, 0.8101736730245841, 0.7942220878040264, 0.7855043522106143, 0.7388546967488279], + ), + ( + {'chunk_period': 'M', 'timestamp_column_name': 'timestamp'}, + [0.7925562396242019, 0.81495562506899, 0.7914354678003803, 0.7766351972000973, 0.7442465240638783], + ), + ( + {}, + [ + 0.7899751792798048, + 0.805061440613929, + 0.828509894279626, + 0.7918374517695422, + 0.7904321700296298, + 0.798012005578423, + 0.8123277037588652, + 0.7586810006623634, + 0.721358457793149, + 0.7563509357045066, + ], + ), + ( + {'timestamp_column_name': 'timestamp'}, + [ + 0.7899751792798048, + 0.805061440613929, + 0.828509894279626, + 0.7918374517695422, + 0.7904321700296298, + 0.798012005578423, + 0.8123277037588652, + 0.7586810006623634, + 0.721358457793149, + 0.7563509357045066, + ], + ), + ], + ids=[ + 'size_based_without_timestamp', + 'size_based_with_timestamp', + 'count_based_without_timestamp', + 'count_based_with_timestamp', + 'period_based_with_timestamp', + 'default_without_timestamp', + 'default_with_timestamp', + ], +) +def test_data_reconstruction_drift_calculator_works_with_chunker( + sample_drift_data, calculator_opts, expected # noqa: D103 +): + ref_data = sample_drift_data.loc[sample_drift_data['period'] == 'reference'] + calc = DataReconstructionDriftCalculator(feature_column_names=['f1', 'f2', 'f3', 'f4'], **calculator_opts).fit( + ref_data + ) + sut = calc.calculate(data=sample_drift_data).data + + assert all(round(sut['reconstruction_error'], 5) == [round(n, 5) for n in expected]) + + def test_data_reconstruction_drift_calculator_with_only_numeric_should_not_fail(sample_drift_data): # noqa: D103 ref_data = sample_drift_data.loc[sample_drift_data['period'] == 'reference'] calc = DataReconstructionDriftCalculator(feature_column_names=['f1', 'f2'], timestamp_column_name='timestamp').fit( @@ -344,16 +421,6 @@ def test_data_reconstruction_drift_calculator_raises_type_error_when_missing_fea ) -def test_data_reconstruction_drift_calculator_raises_type_error_when_missing_timestamp_column_name( # noqa: D103 - sample_drift_data, -): - with pytest.raises(TypeError): - DataReconstructionDriftCalculator( - feature_column_names=['f1', 'f2', 'f3', 'f4'], - chunk_period='W', - ) - - def test_data_reconstruction_drift_chunked_by_size_has_fixed_sampling_error(sample_drift_data): # noqa: D103 ref_data = sample_drift_data.loc[sample_drift_data['period'] == 'reference'] @@ -384,3 +451,30 @@ def test_data_reconstruction_drift_chunked_by_period_has_variable_sampling_error assert np.array_equal( np.round(results.data['sampling_error'], 4), np.round([0.009511, 0.009005, 0.008710, 0.008854, 0.009899], 4) ) + + +@pytest.mark.parametrize( + 'calc_args, plot_args', + [ + ({'timestamp_column_name': 'timestamp'}, {'kind': 'drift', 'plot_reference': False}), + ({}, {'kind': 'drift', 'plot_reference': False}), + ({'timestamp_column_name': 'timestamp'}, {'kind': 'drift', 'plot_reference': True}), + ({}, {'kind': 'drift', 'plot_reference': True}), + ], + ids=[ + 'drift_with_timestamp_without_reference', + 'drift_without_timestamp_without_reference', + 'drift_with_timestamp_with_reference', + 'drift_without_timestamp_with_reference', + ], +) +def test_result_plots_raise_no_exceptions(sample_drift_data, calc_args, plot_args): # noqa: D103 + ref_data = sample_drift_data.loc[sample_drift_data['period'] == 'reference'] + + calc = DataReconstructionDriftCalculator(feature_column_names=['f1', 'f2', 'f3', 'f4'], **calc_args).fit(ref_data) + sut = calc.calculate(data=sample_drift_data) + + try: + _ = sut.plot(**plot_args) + except Exception as exc: + pytest.fail(f"an unexpected exception occurred: {exc}") diff --git a/tests/drift/test_drift.py b/tests/drift/test_drift.py index eb019184..b633a70a 100644 --- a/tests/drift/test_drift.py +++ b/tests/drift/test_drift.py @@ -168,8 +168,8 @@ def test_base_drift_calculator_uses_default_chunker_when_no_chunker_specified(sa @pytest.mark.parametrize( 'chunker', [ - (PeriodBasedChunker(offset='W')), - (PeriodBasedChunker(offset='M')), + (PeriodBasedChunker(offset='W', timestamp_column_name='timestamp')), + (PeriodBasedChunker(offset='M', timestamp_column_name='timestamp')), (SizeBasedChunker(chunk_size=1000)), CountBasedChunker(chunk_count=25), ], @@ -184,7 +184,7 @@ def test_univariate_statistical_drift_calculator_should_return_a_row_for_each_an ).fit(ref_data) sut = calc.calculate(data=sample_drift_data).data - chunks = chunker.split(sample_drift_data, timestamp_column_name='timestamp') + chunks = chunker.split(sample_drift_data) assert len(chunks) == sut.shape[0] chunk_keys = [c.key for c in chunks] assert 'key' in sut.columns @@ -240,6 +240,88 @@ def test_statistical_drift_calculator_deals_with_missing_class_labels(sample_dri assert not np.isnan(results.data.loc[0, 'f3_p_value']) +@pytest.mark.parametrize( + 'calculator_opts, expected', + [ + ( + {'chunk_size': 5000}, + [0.005067460317460304, 0.004932539682539705, 0.01185952380952382, 0.2435952380952381, 0.21061507936507934], + ), + ( + {'chunk_size': 5000, 'timestamp_column_name': 'timestamp'}, + [0.005067460317460304, 0.004932539682539705, 0.01185952380952382, 0.2435952380952381, 0.21061507936507934], + ), + ( + {'chunk_number': 5}, + [0.008829365079365048, 0.007886904761904734, 0.015178571428571375, 0.06502976190476184, 0.2535218253968254], + ), + ( + {'chunk_number': 5, 'timestamp_column_name': 'timestamp'}, + [0.008829365079365048, 0.007886904761904734, 0.015178571428571375, 0.06502976190476184, 0.2535218253968254], + ), + ( + {'chunk_period': 'M', 'timestamp_column_name': 'timestamp'}, + [ + 0.007646520146520119, + 0.008035714285714257, + 0.009456605222734282, + 0.09057539682539684, + 0.25612599206349207, + ], + ), + ( + {}, + [ + 0.011011904761904723, + 0.01736111111111116, + 0.015773809523809523, + 0.011011904761904723, + 0.016865079365079305, + 0.01468253968253963, + 0.018650793650793696, + 0.11398809523809528, + 0.25496031746031744, + 0.2530753968253968, + ], + ), + ( + {'timestamp_column_name': 'timestamp'}, + [ + 0.011011904761904723, + 0.01736111111111116, + 0.015773809523809523, + 0.011011904761904723, + 0.016865079365079305, + 0.01468253968253963, + 0.018650793650793696, + 0.11398809523809528, + 0.25496031746031744, + 0.2530753968253968, + ], + ), + ], + ids=[ + 'size_based_without_timestamp', + 'size_based_with_timestamp', + 'count_based_without_timestamp', + 'count_based_with_timestamp', + 'period_based_with_timestamp', + 'default_without_timestamp', + 'default_with_timestamp', + ], +) +def test_univariate_statistical_drift_calculator_works_with_chunker( + sample_drift_data, calculator_opts, expected # noqa: D103 +): + ref_data = sample_drift_data.loc[sample_drift_data['period'] == 'reference'] + calc = UnivariateStatisticalDriftCalculator(feature_column_names=['f1', 'f2', 'f3', 'f4'], **calculator_opts).fit( + ref_data + ) + sut = calc.calculate(data=sample_drift_data).data + + assert all(sut['f1_dstat'] == expected) + + def test_statistical_drift_calculator_raises_type_error_when_features_missing(): # noqa: D103 with pytest.raises(TypeError, match='feature_column_names'): @@ -281,3 +363,82 @@ def test_base_drift_calculator_given_non_empty_features_list_should_only_calcula assert len([col for col in list(sut.data.columns) if col.startswith('f2')]) == 0 assert len([col for col in list(sut.data.columns) if col.startswith('f4')]) == 0 + + +@pytest.mark.parametrize( + 'calc_args, plot_args', + [ + ( + {'timestamp_column_name': 'timestamp'}, + {'kind': 'feature_drift', 'plot_reference': False, 'feature_column_name': 'f1'}, + ), + ({}, {'kind': 'feature_drift', 'plot_reference': False, 'feature_column_name': 'f1'}), + ( + {'timestamp_column_name': 'timestamp'}, + {'kind': 'feature_drift', 'plot_reference': True, 'feature_column_name': 'f1'}, + ), + ({}, {'kind': 'feature_drift', 'plot_reference': True, 'feature_column_name': 'f1'}), + ( + {'timestamp_column_name': 'timestamp'}, + {'kind': 'feature_drift', 'plot_reference': False, 'feature_column_name': 'f3'}, + ), + ({}, {'kind': 'feature_drift', 'plot_reference': False, 'feature_column_name': 'f3'}), + ( + {'timestamp_column_name': 'timestamp'}, + {'kind': 'feature_drift', 'plot_reference': True, 'feature_column_name': 'f3'}, + ), + ({}, {'kind': 'feature_drift', 'plot_reference': True, 'feature_column_name': 'f3'}), + ( + {'timestamp_column_name': 'timestamp'}, + {'kind': 'feature_distribution', 'plot_reference': False, 'feature_column_name': 'f1'}, + ), + ({}, {'kind': 'feature_distribution', 'plot_reference': False, 'feature_column_name': 'f1'}), + ( + {'timestamp_column_name': 'timestamp'}, + {'kind': 'feature_distribution', 'plot_reference': True, 'feature_column_name': 'f1'}, + ), + ({}, {'kind': 'feature_distribution', 'plot_reference': True, 'feature_column_name': 'f1'}), + ( + {'timestamp_column_name': 'timestamp'}, + {'kind': 'feature_distribution', 'plot_reference': False, 'feature_column_name': 'f3'}, + ), + ({}, {'kind': 'feature_distribution', 'plot_reference': False, 'feature_column_name': 'f3'}), + ( + {'timestamp_column_name': 'timestamp'}, + {'kind': 'feature_distribution', 'plot_reference': True, 'feature_column_name': 'f3'}, + ), + ({}, {'kind': 'feature_distribution', 'plot_reference': True, 'feature_column_name': 'f3'}), + ], + ids=[ + 'continuous_feature_drift_with_timestamp_without_reference', + 'continuous_feature_drift_without_timestamp_without_reference', + 'continuous_feature_drift_with_timestamp_with_reference', + 'continuous_feature_drift_without_timestamp_with_reference', + 'categorical_feature_drift_with_timestamp_without_reference', + 'categorical_feature_drift_without_timestamp_without_reference', + 'categorical_feature_drift_with_timestamp_with_reference', + 'categorical_feature_drift_without_timestamp_with_reference', + 'continuous_feature_distribution_with_timestamp_without_reference', + 'continuous_feature_distribution_without_timestamp_without_reference', + 'continuous_feature_distribution_with_timestamp_with_reference', + 'continuous_feature_distribution_without_timestamp_with_reference', + 'categorical_feature_distribution_with_timestamp_without_reference', + 'categorical_feature_distribution_without_timestamp_without_reference', + 'categorical_feature_distribution_with_timestamp_with_reference', + 'categorical_feature_distribution_without_timestamp_with_reference', + ], +) +def test_result_plots_raise_no_exceptions(sample_drift_data, calc_args, plot_args): # noqa: D103 + ref_data = sample_drift_data.loc[sample_drift_data['period'] == 'reference'] + ana_data = sample_drift_data.loc[sample_drift_data['period'] == 'analysis'] + + calc = UnivariateStatisticalDriftCalculator( + feature_column_names=['f1', 'f3'], + **calc_args, + ).fit(ref_data) + sut = calc.calculate(data=ana_data) + + try: + _ = sut.plot(**plot_args) + except Exception as exc: + pytest.fail(f"an unexpected exception occurred: {exc}") diff --git a/tests/drift/test_output_drift.py b/tests/drift/test_output_drift.py index c244f09f..6b1daf7a 100644 --- a/tests/drift/test_output_drift.py +++ b/tests/drift/test_output_drift.py @@ -1,14 +1,17 @@ # Author: Niels Nuyttens # # License: Apache Software License 2.0 -from typing import Tuple import numpy as np import pandas as pd import pytest from nannyml._typing import ProblemType -from nannyml.datasets import load_synthetic_car_price_dataset +from nannyml.datasets import ( + load_synthetic_binary_classification_dataset, + load_synthetic_car_price_dataset, + load_synthetic_multiclass_classification_dataset, +) from nannyml.drift.model_outputs.univariate.statistical import StatisticalOutputDriftCalculator @@ -109,13 +112,6 @@ def sample_drift_data_with_nans(sample_drift_data) -> pd.DataFrame: # noqa: D10 return data -@pytest.fixture -def regression_data() -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: # noqa: D103 - ref_df, ana_df, tgt_df = load_synthetic_car_price_dataset() - - return ref_df, ana_df, tgt_df - - def test_output_drift_calculator_with_params_should_not_fail(sample_drift_data): # noqa: D103 ref_data = sample_drift_data.loc[sample_drift_data['period'] == 'reference'] calc = StatisticalOutputDriftCalculator( @@ -146,8 +142,8 @@ def test_output_drift_calculator_with_default_params_should_not_fail(sample_drif pytest.fail() -def test_output_drift_calculator_for_regression_problems(regression_data): # noqa: D103 - reference, analysis, _ = regression_data +def test_output_drift_calculator_for_regression_problems(): # noqa: D103 + reference, analysis, _ = load_synthetic_car_price_dataset() calc = StatisticalOutputDriftCalculator( y_pred='y_pred', timestamp_column_name='timestamp', @@ -164,3 +160,845 @@ def test_output_drift_calculator_for_regression_problems(regression_data): # no round(results.data['y_pred_p_value'], 5) == [0.588, 0.501, 0.999, 0.599, 0.289, 0.809, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] ).all() + + +@pytest.mark.parametrize( + 'calculator_opts, expected', + [ + ( + {'chunk_size': 10000}, + pd.DataFrame( + { + 'key': [ + '[0:9999]', + '[10000:19999]', + '[20000:29999]', + '[30000:39999]', + '[40000:49999]', + '[50000:59999]', + ], + 'y_pred_dstat': [ + 0.01046666666666668, + 0.007200000000000012, + 0.007183333333333319, + 0.2041, + 0.20484999999999998, + 0.21286666666666665, + ], + 'y_pred_p_value': [0.303, 0.763, 0.766, 0.0, 0.0, 0.0], + } + ), + ), + ( + {'chunk_size': 10000, 'timestamp_column_name': 'timestamp'}, + pd.DataFrame( + { + 'key': [ + '[0:9999]', + '[10000:19999]', + '[20000:29999]', + '[30000:39999]', + '[40000:49999]', + '[50000:59999]', + ], + 'y_pred_dstat': [ + 0.01046666666666668, + 0.007200000000000012, + 0.007183333333333319, + 0.2041, + 0.20484999999999998, + 0.21286666666666665, + ], + 'y_pred_p_value': [0.303, 0.763, 0.766, 0.0, 0.0, 0.0], + } + ), + ), + ( + {'chunk_number': 5}, + pd.DataFrame( + { + 'key': ['[0:11999]', '[12000:23999]', '[24000:35999]', '[36000:47999]', '[48000:59999]'], + 'y_pred_dstat': [ + 0.009250000000000008, + 0.007400000000000018, + 0.10435000000000001, + 0.20601666666666663, + 0.21116666666666667, + ], + 'y_pred_p_value': [0.357, 0.641, 0.0, 0.0, 0.0], + } + ), + ), + ( + {'chunk_number': 5, 'timestamp_column_name': 'timestamp'}, + pd.DataFrame( + { + 'key': ['[0:11999]', '[12000:23999]', '[24000:35999]', '[36000:47999]', '[48000:59999]'], + 'y_pred_dstat': [ + 0.009250000000000008, + 0.007400000000000018, + 0.10435000000000001, + 0.20601666666666663, + 0.21116666666666667, + ], + 'y_pred_p_value': [0.357, 0.641, 0.0, 0.0, 0.0], + } + ), + ), + ( + {'chunk_period': 'M', 'timestamp_column_name': 'timestamp'}, + pd.DataFrame( + { + 'key': ['2017-02', '2017-03'], + 'y_pred_dstat': [0.010885590414630303, 0.20699588120912144], + 'y_pred_p_value': [0.015, 0.0], + } + ), + ), + ( + {}, + pd.DataFrame( + { + 'key': [ + '[0:5999]', + '[6000:11999]', + '[12000:17999]', + '[18000:23999]', + '[24000:29999]', + '[30000:35999]', + '[36000:41999]', + '[42000:47999]', + '[48000:53999]', + '[54000:59999]', + ], + 'y_pred_dstat': [ + 0.009183333333333321, + 0.016349999999999976, + 0.01079999999999999, + 0.010183333333333336, + 0.01065000000000002, + 0.20288333333333336, + 0.20734999999999998, + 0.20468333333333333, + 0.20713333333333334, + 0.21588333333333334, + ], + 'y_pred_p_value': [0.743, 0.107, 0.544, 0.62, 0.562, 0.0, 0.0, 0.0, 0.0, 0.0], + } + ), + ), + ( + {'timestamp_column_name': 'timestamp'}, + pd.DataFrame( + { + 'key': [ + '[0:5999]', + '[6000:11999]', + '[12000:17999]', + '[18000:23999]', + '[24000:29999]', + '[30000:35999]', + '[36000:41999]', + '[42000:47999]', + '[48000:53999]', + '[54000:59999]', + ], + 'y_pred_dstat': [ + 0.009183333333333321, + 0.016349999999999976, + 0.01079999999999999, + 0.010183333333333336, + 0.01065000000000002, + 0.20288333333333336, + 0.20734999999999998, + 0.20468333333333333, + 0.20713333333333334, + 0.21588333333333334, + ], + 'y_pred_p_value': [0.743, 0.107, 0.544, 0.62, 0.562, 0.0, 0.0, 0.0, 0.0, 0.0], + } + ), + ), + ], + ids=[ + 'size_based_without_timestamp', + 'size_based_with_timestamp', + 'count_based_without_timestamp', + 'count_based_with_timestamp', + 'period_based_with_timestamp', + 'default_without_timestamp', + 'default_with_timestamp', + ], +) +def test_univariate_statistical_drift_calculator_for_regression_works_with_chunker( + calculator_opts, expected # noqa: D103 +): + reference, analysis, _ = load_synthetic_car_price_dataset() + calc = StatisticalOutputDriftCalculator( + y_pred='y_pred', + problem_type=ProblemType.REGRESSION, + **calculator_opts, + ).fit(reference) + results = calc.calculate(analysis) + + pd.testing.assert_frame_equal(expected, results.data[['key', 'y_pred_dstat', 'y_pred_p_value']]) + + +@pytest.mark.parametrize( + 'calculator_opts, expected', + [ + ( + {'chunk_size': 10000}, + pd.DataFrame( + { + 'key': ['[0:9999]', '[10000:19999]', '[20000:29999]', '[30000:39999]', '[40000:49999]'], + 'y_pred_chi2': [ + 0.860333803964031, + 3.0721462648836715, + 6.609667643816801, + 19.49553770190838, + 24.09326946563376, + ], + 'y_pred_p_value': [0.354, 0.08, 0.01, 0.0, 0.0], + 'y_pred_proba_dstat': [ + 0.009019999999999972, + 0.011160000000000003, + 0.07168000000000002, + 0.1286, + 0.12749999999999997, + ], + 'y_pred_proba_p_value': [0.504, 0.249, 0.0, 0.0, 0.0], + } + ), + ), + ( + {'chunk_size': 10000, 'timestamp_column_name': 'timestamp'}, + pd.DataFrame( + { + 'key': ['[0:9999]', '[10000:19999]', '[20000:29999]', '[30000:39999]', '[40000:49999]'], + 'y_pred_chi2': [ + 0.860333803964031, + 3.0721462648836715, + 6.609667643816801, + 19.49553770190838, + 24.09326946563376, + ], + 'y_pred_p_value': [0.354, 0.08, 0.01, 0.0, 0.0], + 'y_pred_proba_dstat': [ + 0.009019999999999972, + 0.011160000000000003, + 0.07168000000000002, + 0.1286, + 0.12749999999999997, + ], + 'y_pred_proba_p_value': [0.504, 0.249, 0.0, 0.0, 0.0], + } + ), + ), + ( + {'chunk_number': 5}, + pd.DataFrame( + { + 'key': ['[0:9999]', '[10000:19999]', '[20000:29999]', '[30000:39999]', '[40000:49999]'], + 'y_pred_chi2': [ + 0.860333803964031, + 3.0721462648836715, + 6.609667643816801, + 19.49553770190838, + 24.09326946563376, + ], + 'y_pred_p_value': [0.354, 0.08, 0.01, 0.0, 0.0], + 'y_pred_proba_dstat': [ + 0.009019999999999972, + 0.011160000000000003, + 0.07168000000000002, + 0.1286, + 0.12749999999999997, + ], + 'y_pred_proba_p_value': [0.504, 0.249, 0.0, 0.0, 0.0], + } + ), + ), + ( + {'chunk_number': 5, 'timestamp_column_name': 'timestamp'}, + pd.DataFrame( + { + 'key': ['[0:9999]', '[10000:19999]', '[20000:29999]', '[30000:39999]', '[40000:49999]'], + 'y_pred_chi2': [ + 0.860333803964031, + 3.0721462648836715, + 6.609667643816801, + 19.49553770190838, + 24.09326946563376, + ], + 'y_pred_p_value': [0.354, 0.08, 0.01, 0.0, 0.0], + 'y_pred_proba_dstat': [ + 0.009019999999999972, + 0.011160000000000003, + 0.07168000000000002, + 0.1286, + 0.12749999999999997, + ], + 'y_pred_proba_p_value': [0.504, 0.249, 0.0, 0.0, 0.0], + } + ), + ), + ( + {'chunk_period': 'Y', 'timestamp_column_name': 'timestamp'}, + pd.DataFrame( + { + 'key': ['2017', '2018', '2019', '2020', '2021'], + 'y_pred_chi2': [ + 7.70713490741521, + 4.264683512119149, + 14.259031460383845, + 30.73593452676024, + 0.12120817142905127, + ], + 'y_pred_p_value': [0.006, 0.039, 0.0, 0.0, 0.728], + 'y_pred_proba_dstat': [ + 0.0258059773828756, + 0.010519551545707828, + 0.09013549032688456, + 0.12807136369707492, + 0.46668, + ], + 'y_pred_proba_p_value': [0.005, 0.153, 0.0, 0.0, 0.016], + } + ), + ), + ( + {}, + pd.DataFrame( + { + 'key': [ + '[0:4999]', + '[5000:9999]', + '[10000:14999]', + '[15000:19999]', + '[20000:24999]', + '[25000:29999]', + '[30000:34999]', + '[35000:39999]', + '[40000:44999]', + '[45000:49999]', + ], + 'y_pred_chi2': [ + 7.442382882761337, + 1.800169196688272, + 1.7285289531065517, + 1.5896121630342237, + 0.060895769836341554, + 12.512106658022049, + 11.393384406782644, + 9.813531942242996, + 3.786524136939082, + 27.99003983833193, + ], + 'y_pred_p_value': [0.006, 0.18, 0.189, 0.207, 0.805, 0.0, 0.001, 0.002, 0.052, 0.0], + 'y_pred_proba_dstat': [ + 0.025300000000000045, + 0.012299999999999978, + 0.01641999999999999, + 0.010580000000000034, + 0.014080000000000037, + 0.13069999999999998, + 0.12729999999999997, + 0.1311, + 0.11969999999999997, + 0.13751999999999998, + ], + 'y_pred_proba_p_value': [0.006, 0.494, 0.17, 0.685, 0.325, 0.0, 0.0, 0.0, 0.0, 0.0], + } + ), + ), + ( + {'timestamp_column_name': 'timestamp'}, + pd.DataFrame( + { + 'key': [ + '[0:4999]', + '[5000:9999]', + '[10000:14999]', + '[15000:19999]', + '[20000:24999]', + '[25000:29999]', + '[30000:34999]', + '[35000:39999]', + '[40000:44999]', + '[45000:49999]', + ], + 'y_pred_chi2': [ + 7.442382882761337, + 1.800169196688272, + 1.7285289531065517, + 1.5896121630342237, + 0.060895769836341554, + 12.512106658022049, + 11.393384406782644, + 9.813531942242996, + 3.786524136939082, + 27.99003983833193, + ], + 'y_pred_p_value': [0.006, 0.18, 0.189, 0.207, 0.805, 0.0, 0.001, 0.002, 0.052, 0.0], + 'y_pred_proba_dstat': [ + 0.025300000000000045, + 0.012299999999999978, + 0.01641999999999999, + 0.010580000000000034, + 0.014080000000000037, + 0.13069999999999998, + 0.12729999999999997, + 0.1311, + 0.11969999999999997, + 0.13751999999999998, + ], + 'y_pred_proba_p_value': [0.006, 0.494, 0.17, 0.685, 0.325, 0.0, 0.0, 0.0, 0.0, 0.0], + } + ), + ), + ], + ids=[ + 'size_based_without_timestamp', + 'size_based_with_timestamp', + 'count_based_without_timestamp', + 'count_based_with_timestamp', + 'period_based_with_timestamp', + 'default_without_timestamp', + 'default_with_timestamp', + ], +) +def test_univariate_statistical_drift_calculator_for_binary_classification_works_with_chunker( + calculator_opts, expected # noqa: D103 +): + reference, analysis, _ = load_synthetic_binary_classification_dataset() + calc = StatisticalOutputDriftCalculator( + y_pred='y_pred', + y_pred_proba='y_pred_proba', + problem_type=ProblemType.CLASSIFICATION_BINARY, + **calculator_opts, + ).fit(reference) + results = calc.calculate(analysis) + + pd.testing.assert_frame_equal( + expected, results.data[['key', 'y_pred_chi2', 'y_pred_p_value', 'y_pred_proba_dstat', 'y_pred_proba_p_value']] + ) + + +@pytest.mark.parametrize( + 'calculator_opts, expected', + [ + ( + {'chunk_size': 10000}, + pd.DataFrame( + { + 'key': [ + '[0:9999]', + '[10000:19999]', + '[20000:29999]', + '[30000:39999]', + '[40000:49999]', + '[50000:59999]', + ], + 'y_pred_chi2': [ + 1.894844933794635, + 1.007925369679508, + 5.341968282280158, + 228.5670714965706, + 263.46933956608086, + 228.68832812811945, + ], + 'y_pred_p_value': [0.388, 0.604, 0.069, 0.0, 0.0, 0.0], + 'y_pred_proba_upmarket_card_dstat': [ + 0.009450000000000014, + 0.006950000000000012, + 0.014050000000000007, + 0.14831666666666665, + 0.13885000000000003, + 0.14631666666666668, + ], + 'y_pred_proba_upmarket_card_p_value': [0.426, 0.799, 0.067, 0.0, 0.0, 0.0], + } + ), + ), + ( + {'chunk_size': 10000, 'timestamp_column_name': 'timestamp'}, + pd.DataFrame( + { + 'key': [ + '[0:9999]', + '[10000:19999]', + '[20000:29999]', + '[30000:39999]', + '[40000:49999]', + '[50000:59999]', + ], + 'y_pred_chi2': [ + 1.894844933794635, + 1.007925369679508, + 5.341968282280158, + 228.5670714965706, + 263.46933956608086, + 228.68832812811945, + ], + 'y_pred_p_value': [0.388, 0.604, 0.069, 0.0, 0.0, 0.0], + 'y_pred_proba_upmarket_card_dstat': [ + 0.009450000000000014, + 0.006950000000000012, + 0.014050000000000007, + 0.14831666666666665, + 0.13885000000000003, + 0.14631666666666668, + ], + 'y_pred_proba_upmarket_card_p_value': [0.426, 0.799, 0.067, 0.0, 0.0, 0.0], + } + ), + ), + ( + {'chunk_number': 5}, + pd.DataFrame( + { + 'key': ['[0:11999]', '[12000:23999]', '[24000:35999]', '[36000:47999]', '[48000:59999]'], + 'y_pred_chi2': [ + 1.8853789747457756, + 0.9860257560785328, + 72.71926368401432, + 306.95758434731476, + 275.337354950812, + ], + 'y_pred_p_value': [0.39, 0.611, 0.0, 0.0, 0.0], + 'y_pred_proba_upmarket_card_dstat': [ + 0.0076166666666666605, + 0.00876666666666659, + 0.07696666666666666, + 0.14246666666666666, + 0.1448, + ], + 'y_pred_proba_upmarket_card_p_value': [0.605, 0.423, 0.0, 0.0, 0.0], + } + ), + ), + ( + {'chunk_number': 5, 'timestamp_column_name': 'timestamp'}, + pd.DataFrame( + { + 'key': ['[0:11999]', '[12000:23999]', '[24000:35999]', '[36000:47999]', '[48000:59999]'], + 'y_pred_chi2': [ + 1.8853789747457756, + 0.9860257560785328, + 72.71926368401432, + 306.95758434731476, + 275.337354950812, + ], + 'y_pred_p_value': [0.39, 0.611, 0.0, 0.0, 0.0], + 'y_pred_proba_upmarket_card_dstat': [ + 0.0076166666666666605, + 0.00876666666666659, + 0.07696666666666666, + 0.14246666666666666, + 0.1448, + ], + 'y_pred_proba_upmarket_card_p_value': [0.605, 0.423, 0.0, 0.0, 0.0], + } + ), + ), + ( + {'chunk_period': 'Y', 'timestamp_column_name': 'timestamp'}, + pd.DataFrame( + { + 'key': ['2020', '2021'], + 'y_pred_chi2': [207.2554347384955, 6.288483345011454], + 'y_pred_p_value': [0.0, 0.043], + 'y_pred_proba_upmarket_card_dstat': [0.07220877811346088, 0.16619285714285714], + 'y_pred_proba_upmarket_card_p_value': [0.0, 0.0], + } + ), + ), + ( + {}, + pd.DataFrame( + { + 'key': [ + '[0:5999]', + '[6000:11999]', + '[12000:17999]', + '[18000:23999]', + '[24000:29999]', + '[30000:35999]', + '[36000:41999]', + '[42000:47999]', + '[48000:53999]', + '[54000:59999]', + ], + 'y_pred_chi2': [ + 2.4199133518806706, + 1.2633881212231448, + 0.21170529003761418, + 1.0459424531991828, + 2.891011519973576, + 131.23790859647167, + 155.59305405725468, + 182.00063726142486, + 137.68526858822366, + 164.40667669928519, + ], + 'y_pred_p_value': [0.298, 0.532, 0.9, 0.593, 0.236, 0.0, 0.0, 0.0, 0.0, 0.0], + 'y_pred_proba_upmarket_card_dstat': [ + 0.012283333333333368, + 0.008450000000000013, + 0.007866666666666688, + 0.01261666666666661, + 0.01261666666666661, + 0.14679999999999999, + 0.14471666666666666, + 0.14096666666666668, + 0.14205, + 0.14755, + ], + 'y_pred_proba_upmarket_card_p_value': [0.38, 0.828, 0.886, 0.347, 0.347, 0.0, 0.0, 0.0, 0.0, 0.0], + } + ), + ), + ( + {'timestamp_column_name': 'timestamp'}, + pd.DataFrame( + { + 'key': [ + '[0:5999]', + '[6000:11999]', + '[12000:17999]', + '[18000:23999]', + '[24000:29999]', + '[30000:35999]', + '[36000:41999]', + '[42000:47999]', + '[48000:53999]', + '[54000:59999]', + ], + 'y_pred_chi2': [ + 2.4199133518806706, + 1.2633881212231448, + 0.21170529003761418, + 1.0459424531991828, + 2.891011519973576, + 131.23790859647167, + 155.59305405725468, + 182.00063726142486, + 137.68526858822366, + 164.40667669928519, + ], + 'y_pred_p_value': [0.298, 0.532, 0.9, 0.593, 0.236, 0.0, 0.0, 0.0, 0.0, 0.0], + 'y_pred_proba_upmarket_card_dstat': [ + 0.012283333333333368, + 0.008450000000000013, + 0.007866666666666688, + 0.01261666666666661, + 0.01261666666666661, + 0.14679999999999999, + 0.14471666666666666, + 0.14096666666666668, + 0.14205, + 0.14755, + ], + 'y_pred_proba_upmarket_card_p_value': [0.38, 0.828, 0.886, 0.347, 0.347, 0.0, 0.0, 0.0, 0.0, 0.0], + } + ), + ), + ], + ids=[ + 'size_based_without_timestamp', + 'size_based_with_timestamp', + 'count_based_without_timestamp', + 'count_based_with_timestamp', + 'period_based_with_timestamp', + 'default_without_timestamp', + 'default_with_timestamp', + ], +) +def test_univariate_statistical_drift_calculator_for_multiclass_classification_works_with_chunker( + calculator_opts, expected # noqa: D103 +): + reference, analysis, _ = load_synthetic_multiclass_classification_dataset() + calc = StatisticalOutputDriftCalculator( + y_pred='y_pred', + y_pred_proba={ + 'upmarket_card': 'y_pred_proba_upmarket_card', + 'highstreet_card': 'y_pred_proba_highstreet_card', + 'prepaid_card': 'y_pred_proba_prepaid_card', + }, + problem_type=ProblemType.CLASSIFICATION_MULTICLASS, + **calculator_opts, + ).fit(reference) + results = calc.calculate(analysis) + + pd.testing.assert_frame_equal( + expected, + results.data[ + [ + 'key', + 'y_pred_chi2', + 'y_pred_p_value', + 'y_pred_proba_upmarket_card_dstat', + 'y_pred_proba_upmarket_card_p_value', + ] + ], + ) + + +@pytest.mark.parametrize( + 'calc_args, plot_args', + [ + ( + {'timestamp_column_name': 'timestamp'}, + {'kind': 'score_drift', 'plot_reference': False, 'class_label': 'upmarket_card'}, + ), + ({}, {'kind': 'score_drift', 'plot_reference': False, 'class_label': 'upmarket_card'}), + ( + {'timestamp_column_name': 'timestamp'}, + {'kind': 'score_drift', 'plot_reference': True, 'class_label': 'upmarket_card'}, + ), + ({}, {'kind': 'score_drift', 'plot_reference': True, 'class_label': 'upmarket_card'}), + ( + {'timestamp_column_name': 'timestamp'}, + {'kind': 'score_distribution', 'plot_reference': False, 'class_label': 'upmarket_card'}, + ), + ({}, {'kind': 'score_distribution', 'plot_reference': False, 'class_label': 'upmarket_card'}), + ( + {'timestamp_column_name': 'timestamp'}, + {'kind': 'score_distribution', 'plot_reference': True, 'class_label': 'upmarket_card'}, + ), + ({}, {'kind': 'score_distribution', 'plot_reference': True, 'class_label': 'upmarket_card'}), + ({'timestamp_column_name': 'timestamp'}, {'kind': 'prediction_drift', 'plot_reference': False}), + ({}, {'kind': 'prediction_drift', 'plot_reference': False}), + ({'timestamp_column_name': 'timestamp'}, {'kind': 'prediction_drift', 'plot_reference': True}), + ({}, {'kind': 'prediction_drift', 'plot_reference': True}), + ({'timestamp_column_name': 'timestamp'}, {'kind': 'prediction_distribution', 'plot_reference': False}), + ({}, {'kind': 'prediction_distribution', 'plot_reference': False}), + ({'timestamp_column_name': 'timestamp'}, {'kind': 'prediction_distribution', 'plot_reference': True}), + ({}, {'kind': 'prediction_distribution', 'plot_reference': True}), + ], + ids=[ + 'score_drift_with_timestamp_without_reference', + 'score_drift_without_timestamp_without_reference', + 'score_drift_with_timestamp_with_reference', + 'score_drift_without_timestamp_with_reference', + 'score_distribution_with_timestamp_without_reference', + 'score_distribution_without_timestamp_without_reference', + 'score_distribution_with_timestamp_with_reference', + 'score_distribution_without_timestamp_with_reference', + 'prediction_drift_with_timestamp_without_reference', + 'prediction_drift_without_timestamp_without_reference', + 'prediction_drift_with_timestamp_with_reference', + 'prediction_drift_without_timestamp_with_reference', + 'prediction_distribution_with_timestamp_without_reference', + 'prediction_distribution_without_timestamp_without_reference', + 'prediction_distribution_with_timestamp_with_reference', + 'prediction_distribution_without_timestamp_with_reference', + ], +) +def test_multiclass_classification_result_plots_raise_no_exceptions(calc_args, plot_args): # noqa: D103 + reference, analysis, _ = load_synthetic_multiclass_classification_dataset() + calc = StatisticalOutputDriftCalculator( + y_pred='y_pred', + y_pred_proba={ + 'upmarket_card': 'y_pred_proba_upmarket_card', + 'highstreet_card': 'y_pred_proba_highstreet_card', + 'prepaid_card': 'y_pred_proba_prepaid_card', + }, + problem_type=ProblemType.CLASSIFICATION_MULTICLASS, + **calc_args, + ).fit(reference) + sut = calc.calculate(analysis) + + try: + _ = sut.plot(**plot_args) + except Exception as exc: + pytest.fail(f"an unexpected exception occurred: {exc}") + + +@pytest.mark.parametrize( + 'calc_args, plot_args', + [ + ({'timestamp_column_name': 'timestamp'}, {'kind': 'score_drift', 'plot_reference': False}), + ({}, {'kind': 'score_drift', 'plot_reference': False}), + ({'timestamp_column_name': 'timestamp'}, {'kind': 'score_drift', 'plot_reference': True}), + ({}, {'kind': 'score_drift', 'plot_reference': True}), + ({'timestamp_column_name': 'timestamp'}, {'kind': 'score_distribution', 'plot_reference': False}), + ({}, {'kind': 'score_distribution', 'plot_reference': False}), + ({'timestamp_column_name': 'timestamp'}, {'kind': 'score_distribution', 'plot_reference': True}), + ({}, {'kind': 'score_distribution', 'plot_reference': True}), + ({'timestamp_column_name': 'timestamp'}, {'kind': 'prediction_drift', 'plot_reference': False}), + ({}, {'kind': 'prediction_drift', 'plot_reference': False}), + ({'timestamp_column_name': 'timestamp'}, {'kind': 'prediction_drift', 'plot_reference': True}), + ({}, {'kind': 'prediction_drift', 'plot_reference': True}), + ({'timestamp_column_name': 'timestamp'}, {'kind': 'prediction_distribution', 'plot_reference': False}), + ({}, {'kind': 'prediction_distribution', 'plot_reference': False}), + ({'timestamp_column_name': 'timestamp'}, {'kind': 'prediction_distribution', 'plot_reference': True}), + ({}, {'kind': 'prediction_distribution', 'plot_reference': True}), + ], + ids=[ + 'score_drift_with_timestamp_without_reference', + 'score_drift_without_timestamp_without_reference', + 'score_drift_with_timestamp_with_reference', + 'score_drift_without_timestamp_with_reference', + 'score_distribution_with_timestamp_without_reference', + 'score_distribution_without_timestamp_without_reference', + 'score_distribution_with_timestamp_with_reference', + 'score_distribution_without_timestamp_with_reference', + 'prediction_drift_with_timestamp_without_reference', + 'prediction_drift_without_timestamp_without_reference', + 'prediction_drift_with_timestamp_with_reference', + 'prediction_drift_without_timestamp_with_reference', + 'prediction_distribution_with_timestamp_without_reference', + 'prediction_distribution_without_timestamp_without_reference', + 'prediction_distribution_with_timestamp_with_reference', + 'prediction_distribution_without_timestamp_with_reference', + ], +) +def test_binary_classification_result_plots_raise_no_exceptions(calc_args, plot_args): # noqa: D103 + reference, analysis, _ = load_synthetic_binary_classification_dataset() + calc = StatisticalOutputDriftCalculator( + y_pred='y_pred', y_pred_proba='y_pred_proba', problem_type=ProblemType.CLASSIFICATION_BINARY, **calc_args + ).fit(reference) + sut = calc.calculate(analysis) + + try: + _ = sut.plot(**plot_args) + except Exception as exc: + pytest.fail(f"an unexpected exception occurred: {exc}") + + +@pytest.mark.parametrize( + 'calc_args, plot_args', + [ + ({'timestamp_column_name': 'timestamp'}, {'kind': 'prediction_drift', 'plot_reference': False}), + ({}, {'kind': 'prediction_drift', 'plot_reference': False}), + ({'timestamp_column_name': 'timestamp'}, {'kind': 'prediction_drift', 'plot_reference': True}), + ({}, {'kind': 'prediction_drift', 'plot_reference': True}), + ({'timestamp_column_name': 'timestamp'}, {'kind': 'prediction_distribution', 'plot_reference': False}), + ({}, {'kind': 'prediction_distribution', 'plot_reference': False}), + ({'timestamp_column_name': 'timestamp'}, {'kind': 'prediction_distribution', 'plot_reference': True}), + ({}, {'kind': 'prediction_distribution', 'plot_reference': True}), + ], + ids=[ + 'prediction_drift_with_timestamp_without_reference', + 'prediction_drift_without_timestamp_without_reference', + 'prediction_drift_with_timestamp_with_reference', + 'prediction_drift_without_timestamp_with_reference', + 'prediction_distribution_with_timestamp_without_reference', + 'prediction_distribution_without_timestamp_without_reference', + 'prediction_distribution_with_timestamp_with_reference', + 'prediction_distribution_without_timestamp_with_reference', + ], +) +def test_regression_result_plots_raise_no_exceptions(calc_args, plot_args): # noqa: D103 + reference, analysis, _ = load_synthetic_car_price_dataset() + calc = StatisticalOutputDriftCalculator(y_pred='y_pred', problem_type=ProblemType.REGRESSION, **calc_args).fit( + reference + ) + sut = calc.calculate(analysis) + + try: + _ = sut.plot(**plot_args) + except Exception as exc: + pytest.fail(f"an unexpected exception occurred: {exc}") diff --git a/tests/drift/test_target_distribution.py b/tests/drift/test_target_distribution.py index fbf74834..80049df9 100644 --- a/tests/drift/test_target_distribution.py +++ b/tests/drift/test_target_distribution.py @@ -7,7 +7,12 @@ import pandas as pd import pytest -from nannyml.datasets import load_synthetic_car_price_dataset +from nannyml._typing import ProblemType +from nannyml.datasets import ( + load_synthetic_binary_classification_dataset, + load_synthetic_car_price_dataset, + load_synthetic_multiclass_classification_dataset, +) from nannyml.drift.target.target_distribution import TargetDistributionCalculator @@ -182,3 +187,693 @@ def test_target_distribution_calculator_for_regression_problems_mean_drift(regre 4787.09417, ] ).all() + + +@pytest.mark.parametrize( + 'calculator_opts, expected', + [ + ( + {'chunk_size': 10000}, + pd.DataFrame( + { + 'key': [ + '[0:9999]', + '[10000:19999]', + '[20000:29999]', + '[30000:39999]', + '[40000:49999]', + '[50000:59999]', + ], + 'metric_target_drift': [4834.7893, 4776.1287, 4839.639, 4868.6151, 4873.9113, 4818.2043], + 'statistical_target_drift': [ + 0.014583333333333337, + 0.006916666666666682, + 0.008950000000000014, + 0.17545, + 0.178, + 0.1867666666666667, + ], + } + ), + ), + ( + {'chunk_size': 10000, 'timestamp_column_name': 'timestamp'}, + pd.DataFrame( + { + 'key': [ + '[0:9999]', + '[10000:19999]', + '[20000:29999]', + '[30000:39999]', + '[40000:49999]', + '[50000:59999]', + ], + 'metric_target_drift': [4834.7893, 4776.1287, 4839.639, 4868.6151, 4873.9113, 4818.2043], + 'statistical_target_drift': [ + 0.014583333333333337, + 0.006916666666666682, + 0.008950000000000014, + 0.17545, + 0.178, + 0.1867666666666667, + ], + } + ), + ), + ( + {'chunk_number': 5}, + pd.DataFrame( + { + 'key': ['[0:11999]', '[12000:23999]', '[24000:35999]', '[36000:47999]', '[48000:59999]'], + 'metric_target_drift': [ + 4826.761333333333, + 4815.80275, + 4825.885083333334, + 4871.522833333333, + 4836.101083333333, + ], + 'statistical_target_drift': [ + 0.014516666666666678, + 0.00869999999999993, + 0.08675, + 0.1785, + 0.18508333333333332, + ], + } + ), + ), + ( + {'chunk_number': 5, 'timestamp_column_name': 'timestamp'}, + pd.DataFrame( + { + 'key': ['[0:11999]', '[12000:23999]', '[24000:35999]', '[36000:47999]', '[48000:59999]'], + 'metric_target_drift': [ + 4826.761333333333, + 4815.80275, + 4825.885083333334, + 4871.522833333333, + 4836.101083333333, + ], + 'statistical_target_drift': [ + 0.014516666666666678, + 0.00869999999999993, + 0.08675, + 0.1785, + 0.18508333333333332, + ], + } + ), + ), + ( + {'chunk_period': 'M', 'timestamp_column_name': 'timestamp'}, + pd.DataFrame( + { + 'key': ['2017-02', '2017-03'], + 'metric_target_drift': [4826.1297808607915, 4845.4011313417], + 'statistical_target_drift': [0.008650415155814883, 0.17979488539272875], + } + ), + ), + ( + {}, + pd.DataFrame( + { + 'key': [ + '[0:5999]', + '[6000:11999]', + '[12000:17999]', + '[18000:23999]', + '[24000:29999]', + '[30000:35999]', + '[36000:41999]', + '[42000:47999]', + '[48000:53999]', + '[54000:59999]', + ], + 'metric_target_drift': [ + 4862.9411666666665, + 4790.5815, + 4793.349333333334, + 4838.256166666667, + 4799.1335, + 4852.636666666666, + 4875.456666666667, + 4867.589, + 4885.108, + 4787.094166666667, + ], + 'statistical_target_drift': [ + 0.014249999999999985, + 0.016566666666666674, + 0.010066666666666668, + 0.011916666666666659, + 0.008666666666666656, + 0.1716833333333333, + 0.18011666666666665, + 0.17906666666666665, + 0.18323333333333333, + 0.1873833333333333, + ], + } + ), + ), + ( + {'timestamp_column_name': 'timestamp'}, + pd.DataFrame( + { + 'key': [ + '[0:5999]', + '[6000:11999]', + '[12000:17999]', + '[18000:23999]', + '[24000:29999]', + '[30000:35999]', + '[36000:41999]', + '[42000:47999]', + '[48000:53999]', + '[54000:59999]', + ], + 'metric_target_drift': [ + 4862.9411666666665, + 4790.5815, + 4793.349333333334, + 4838.256166666667, + 4799.1335, + 4852.636666666666, + 4875.456666666667, + 4867.589, + 4885.108, + 4787.094166666667, + ], + 'statistical_target_drift': [ + 0.014249999999999985, + 0.016566666666666674, + 0.010066666666666668, + 0.011916666666666659, + 0.008666666666666656, + 0.1716833333333333, + 0.18011666666666665, + 0.17906666666666665, + 0.18323333333333333, + 0.1873833333333333, + ], + } + ), + ), + ], + ids=[ + 'size_based_without_timestamp', + 'size_based_with_timestamp', + 'count_based_without_timestamp', + 'count_based_with_timestamp', + 'period_based_with_timestamp', + 'default_without_timestamp', + 'default_with_timestamp', + ], +) +def test_target_drift_for_regression_works_with_chunker(calculator_opts, expected): # noqa: D103 + reference, analysis, analysis_targets = load_synthetic_car_price_dataset() + calc = TargetDistributionCalculator( + y_true='y_true', + problem_type=ProblemType.REGRESSION, + **calculator_opts, + ).fit(reference) + results = calc.calculate(analysis.join(analysis_targets)) + + pd.testing.assert_frame_equal(expected, results.data[['key', 'metric_target_drift', 'statistical_target_drift']]) + + +@pytest.mark.parametrize( + 'calculator_opts, expected', + [ + ( + {'chunk_size': 10000}, + pd.DataFrame( + { + 'key': ['[0:9999]', '[10000:19999]', '[20000:29999]', '[30000:39999]', '[40000:49999]'], + 'metric_target_drift': [0.5044, 0.4911, 0.501, 0.5028, 0.5041], + 'statistical_target_drift': [ + 0.7552537772547374, + 2.3632451058508988, + 0.061653341622277646, + 0.3328533514553376, + 0.6630536280238508, + ], + } + ), + ), + ( + {'chunk_size': 10000, 'timestamp_column_name': 'timestamp'}, + pd.DataFrame( + { + 'key': ['[0:9999]', '[10000:19999]', '[20000:29999]', '[30000:39999]', '[40000:49999]'], + 'metric_target_drift': [0.5044, 0.4911, 0.501, 0.5028, 0.5041], + 'statistical_target_drift': [ + 0.7552537772547374, + 2.3632451058508988, + 0.061653341622277646, + 0.3328533514553376, + 0.6630536280238508, + ], + } + ), + ), + ( + {'chunk_number': 5}, + pd.DataFrame( + { + 'key': ['[0:9999]', '[10000:19999]', '[20000:29999]', '[30000:39999]', '[40000:49999]'], + 'metric_target_drift': [0.5044, 0.4911, 0.501, 0.5028, 0.5041], + 'statistical_target_drift': [ + 0.7552537772547374, + 2.3632451058508988, + 0.061653341622277646, + 0.3328533514553376, + 0.6630536280238508, + ], + } + ), + ), + ( + {'chunk_number': 5, 'timestamp_column_name': 'timestamp'}, + pd.DataFrame( + { + 'key': ['[0:9999]', '[10000:19999]', '[20000:29999]', '[30000:39999]', '[40000:49999]'], + 'metric_target_drift': [0.5044, 0.4911, 0.501, 0.5028, 0.5041], + 'statistical_target_drift': [ + 0.7552537772547374, + 2.3632451058508988, + 0.061653341622277646, + 0.3328533514553376, + 0.6630536280238508, + ], + } + ), + ), + ( + {'chunk_period': 'Y', 'timestamp_column_name': 'timestamp'}, + pd.DataFrame( + { + 'key': ['2017', '2018', '2019', '2020', '2021'], + 'metric_target_drift': [ + 0.5175686591276252, + 0.49144221838927954, + 0.5021347565043363, + 0.5030052090289836, + 0.4, + ], + 'statistical_target_drift': [ + 5.760397194889025, + 3.0356541961122385, + 0.2909492514342073, + 0.5271435212375012, + 0.09826781851967825, + ], + } + ), + ), + ( + {}, + pd.DataFrame( + { + 'key': [ + '[0:4999]', + '[5000:9999]', + '[10000:14999]', + '[15000:19999]', + '[20000:24999]', + '[25000:29999]', + '[30000:34999]', + '[35000:39999]', + '[40000:44999]', + '[45000:49999]', + ], + 'metric_target_drift': [ + 0.5172, + 0.4916, + 0.4858, + 0.4964, + 0.5026, + 0.4994, + 0.505, + 0.5006, + 0.4964, + 0.5118, + ], + 'statistical_target_drift': [ + 5.574578416652949, + 1.1261313647806923, + 3.3976543904089302, + 0.17136216283219585, + 0.15396546757525365, + 8.909097694730939e-05, + 0.5126563744826438, + 0.015056370086959939, + 0.17136216283219585, + 2.666406909476474, + ], + } + ), + ), + ( + {'timestamp_column_name': 'timestamp'}, + pd.DataFrame( + { + 'key': [ + '[0:4999]', + '[5000:9999]', + '[10000:14999]', + '[15000:19999]', + '[20000:24999]', + '[25000:29999]', + '[30000:34999]', + '[35000:39999]', + '[40000:44999]', + '[45000:49999]', + ], + 'metric_target_drift': [ + 0.5172, + 0.4916, + 0.4858, + 0.4964, + 0.5026, + 0.4994, + 0.505, + 0.5006, + 0.4964, + 0.5118, + ], + 'statistical_target_drift': [ + 5.574578416652949, + 1.1261313647806923, + 3.3976543904089302, + 0.17136216283219585, + 0.15396546757525365, + 8.909097694730939e-05, + 0.5126563744826438, + 0.015056370086959939, + 0.17136216283219585, + 2.666406909476474, + ], + } + ), + ), + ], + ids=[ + 'size_based_without_timestamp', + 'size_based_with_timestamp', + 'count_based_without_timestamp', + 'count_based_with_timestamp', + 'period_based_with_timestamp', + 'default_without_timestamp', + 'default_with_timestamp', + ], +) +def test_target_drift_for_binary_classification_works_with_chunker(calculator_opts, expected): # noqa: D103 + reference, analysis, analysis_targets = load_synthetic_binary_classification_dataset() + calc = TargetDistributionCalculator( + y_true='work_home_actual', + problem_type=ProblemType.CLASSIFICATION_BINARY, + **calculator_opts, + ).fit(reference) + results = calc.calculate(analysis.merge(analysis_targets, on='identifier')) + + pd.testing.assert_frame_equal(expected, results.data[['key', 'metric_target_drift', 'statistical_target_drift']]) + + +@pytest.mark.parametrize( + 'calculator_opts, expected', + [ + ( + {'chunk_size': 10000}, + pd.DataFrame( + { + 'key': [ + '[0:9999]', + '[10000:19999]', + '[20000:29999]', + '[30000:39999]', + '[40000:49999]', + '[50000:59999]', + ], + 'statistical_target_drift': [ + 0.12834899947890172, + 3.1960676394816177, + 1.295948474797905, + 19.16124656547084, + 18.025854609422936, + 24.018246053152254, + ], + } + ), + ), + ( + {'chunk_size': 10000, 'timestamp_column_name': 'timestamp'}, + pd.DataFrame( + { + 'key': [ + '[0:9999]', + '[10000:19999]', + '[20000:29999]', + '[30000:39999]', + '[40000:49999]', + '[50000:59999]', + ], + 'statistical_target_drift': [ + 0.12834899947890172, + 3.1960676394816177, + 1.295948474797905, + 19.16124656547084, + 18.025854609422936, + 24.018246053152254, + ], + } + ), + ), + ( + {'chunk_number': 5}, + pd.DataFrame( + { + 'key': ['[0:11999]', '[12000:23999]', '[24000:35999]', '[36000:47999]', '[48000:59999]'], + 'statistical_target_drift': [ + 0.26657487437501814, + 2.312782400721795, + 12.420850432581522, + 21.51752691617733, + 24.77164649048273, + ], + } + ), + ), + ( + {'chunk_number': 5, 'timestamp_column_name': 'timestamp'}, + pd.DataFrame( + { + 'key': ['[0:11999]', '[12000:23999]', '[24000:35999]', '[36000:47999]', '[48000:59999]'], + 'statistical_target_drift': [ + 0.26657487437501814, + 2.312782400721795, + 12.420850432581522, + 21.51752691617733, + 24.77164649048273, + ], + } + ), + ), + ( + {'chunk_period': 'Y', 'timestamp_column_name': 'timestamp'}, + pd.DataFrame( + {'key': ['2020', '2021'], 'statistical_target_drift': [19.18793594439608, 0.4207915478721409]} + ), + ), + ( + {}, + pd.DataFrame( + { + 'key': [ + '[0:5999]', + '[6000:11999]', + '[12000:17999]', + '[18000:23999]', + '[24000:29999]', + '[30000:35999]', + '[36000:41999]', + '[42000:47999]', + '[48000:53999]', + '[54000:59999]', + ], + 'statistical_target_drift': [ + 0.5215450181058865, + 2.1122555296182584, + 0.9401078333614571, + 2.130103897306355, + 2.2209947008941855, + 14.42105157991354, + 6.009302835706899, + 19.749168564900494, + 14.08710527642606, + 12.884655612509915, + ], + } + ), + ), + ( + {'timestamp_column_name': 'timestamp'}, + pd.DataFrame( + { + 'key': [ + '[0:5999]', + '[6000:11999]', + '[12000:17999]', + '[18000:23999]', + '[24000:29999]', + '[30000:35999]', + '[36000:41999]', + '[42000:47999]', + '[48000:53999]', + '[54000:59999]', + ], + 'statistical_target_drift': [ + 0.5215450181058865, + 2.1122555296182584, + 0.9401078333614571, + 2.130103897306355, + 2.2209947008941855, + 14.42105157991354, + 6.009302835706899, + 19.749168564900494, + 14.08710527642606, + 12.884655612509915, + ], + } + ), + ), + ], + ids=[ + 'size_based_without_timestamp', + 'size_based_with_timestamp', + 'count_based_without_timestamp', + 'count_based_with_timestamp', + 'period_based_with_timestamp', + 'default_without_timestamp', + 'default_with_timestamp', + ], +) +def test_target_drift_for_multiclass_classification_works_with_chunker(calculator_opts, expected): # noqa: D103 + reference, analysis, analysis_targets = load_synthetic_multiclass_classification_dataset() + calc = TargetDistributionCalculator( + y_true='y_true', + problem_type=ProblemType.CLASSIFICATION_MULTICLASS, + **calculator_opts, + ).fit(reference) + results = calc.calculate(analysis.merge(analysis_targets, on='identifier')) + + pd.testing.assert_frame_equal(expected, results.data[['key', 'statistical_target_drift']]) + + +@pytest.mark.parametrize( + 'calc_args, plot_args', + [ + ({'timestamp_column_name': 'timestamp'}, {'kind': 'target_drift', 'plot_reference': False}), + ({}, {'kind': 'target_drift', 'plot_reference': False}), + ({'timestamp_column_name': 'timestamp'}, {'kind': 'target_drift', 'plot_reference': True}), + ({}, {'kind': 'target_drift', 'plot_reference': True}), + ({'timestamp_column_name': 'timestamp'}, {'kind': 'target_distribution', 'plot_reference': False}), + ({}, {'kind': 'target_distribution', 'plot_reference': False}), + ({'timestamp_column_name': 'timestamp'}, {'kind': 'target_distribution', 'plot_reference': True}), + ({}, {'kind': 'target_distribution', 'plot_reference': True}), + ], + ids=[ + 'target_drift_with_timestamp_without_reference', + 'target_drift_without_timestamp_without_reference', + 'target_drift_with_timestamp_with_reference', + 'target_drift_without_timestamp_with_reference', + 'target_distribution_with_timestamp_without_reference', + 'target_distribution_without_timestamp_without_reference', + 'target_distribution_with_timestamp_with_reference', + 'target_distribution_without_timestamp_with_reference', + ], +) +def test_multiclass_classification_result_plots_raise_no_exceptions(calc_args, plot_args): # noqa: D103 + reference, analysis, analysis_targets = load_synthetic_multiclass_classification_dataset() + calc = TargetDistributionCalculator( + y_true='y_true', problem_type=ProblemType.CLASSIFICATION_MULTICLASS, **calc_args + ).fit(reference) + sut = calc.calculate(analysis.merge(analysis_targets, on='identifier')) + + try: + _ = sut.plot(**plot_args) + except Exception as exc: + pytest.fail(f"an unexpected exception occurred: {exc}") + + +@pytest.mark.parametrize( + 'calc_args, plot_args', + [ + ({'timestamp_column_name': 'timestamp'}, {'kind': 'target_drift', 'plot_reference': False}), + ({}, {'kind': 'target_drift', 'plot_reference': False}), + ({'timestamp_column_name': 'timestamp'}, {'kind': 'target_drift', 'plot_reference': True}), + ({}, {'kind': 'target_drift', 'plot_reference': True}), + ({'timestamp_column_name': 'timestamp'}, {'kind': 'target_distribution', 'plot_reference': False}), + ({}, {'kind': 'target_distribution', 'plot_reference': False}), + ({'timestamp_column_name': 'timestamp'}, {'kind': 'target_distribution', 'plot_reference': True}), + ({}, {'kind': 'target_distribution', 'plot_reference': True}), + ], + ids=[ + 'target_drift_with_timestamp_without_reference', + 'target_drift_without_timestamp_without_reference', + 'target_drift_with_timestamp_with_reference', + 'target_drift_without_timestamp_with_reference', + 'target_distribution_with_timestamp_without_reference', + 'target_distribution_without_timestamp_without_reference', + 'target_distribution_with_timestamp_with_reference', + 'target_distribution_without_timestamp_with_reference', + ], +) +def test_binary_classification_result_plots_raise_no_exceptions(calc_args, plot_args): # noqa: D103 + reference, analysis, analysis_targets = load_synthetic_binary_classification_dataset() + calc = TargetDistributionCalculator( + y_true='work_home_actual', problem_type=ProblemType.CLASSIFICATION_BINARY, **calc_args + ).fit(reference) + sut = calc.calculate(analysis.merge(analysis_targets, on='identifier')) + + try: + _ = sut.plot(**plot_args) + except Exception as exc: + pytest.fail(f"an unexpected exception occurred: {exc}") + + +@pytest.mark.parametrize( + 'calc_args, plot_args', + [ + ({'timestamp_column_name': 'timestamp'}, {'kind': 'target_drift', 'plot_reference': False}), + ({}, {'kind': 'target_drift', 'plot_reference': False}), + ({'timestamp_column_name': 'timestamp'}, {'kind': 'target_drift', 'plot_reference': True}), + ({}, {'kind': 'target_drift', 'plot_reference': True}), + ({'timestamp_column_name': 'timestamp'}, {'kind': 'target_distribution', 'plot_reference': False}), + ({}, {'kind': 'target_distribution', 'plot_reference': False}), + ({'timestamp_column_name': 'timestamp'}, {'kind': 'target_distribution', 'plot_reference': True}), + ({}, {'kind': 'target_distribution', 'plot_reference': True}), + ], + ids=[ + 'target_drift_with_timestamp_without_reference', + 'target_drift_without_timestamp_without_reference', + 'target_drift_with_timestamp_with_reference', + 'target_drift_without_timestamp_with_reference', + 'target_distribution_with_timestamp_without_reference', + 'target_distribution_without_timestamp_without_reference', + 'target_distribution_with_timestamp_with_reference', + 'target_distribution_without_timestamp_with_reference', + ], +) +def test_regression_result_plots_raise_no_exceptions(calc_args, plot_args): # noqa: D103 + reference, analysis, analysis_targets = load_synthetic_car_price_dataset() + calc = TargetDistributionCalculator(y_true='y_true', problem_type=ProblemType.REGRESSION, **calc_args).fit( + reference + ) + sut = calc.calculate(analysis.join(analysis_targets)) + + try: + _ = sut.plot(**plot_args) + except Exception as exc: + pytest.fail(f"an unexpected exception occurred: {exc}") diff --git a/tests/manual/reference implementation visualisation update.ipynb b/tests/manual/reference implementation visualisation update.ipynb index 292e147f..42bd4083 100644 --- a/tests/manual/reference implementation visualisation update.ipynb +++ b/tests/manual/reference implementation visualisation update.ipynb @@ -3,7 +3,11 @@ { "cell_type": "markdown", "id": "a58f2e02", - "metadata": {}, + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, "source": [ "Done\n", "- created step plot functionality\n", @@ -22,7 +26,11 @@ "cell_type": "code", "execution_count": null, "id": "379cab7b", - "metadata": {}, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, "outputs": [], "source": [ "# add incomplete target data functionality\n", @@ -33,7 +41,11 @@ "cell_type": "code", "execution_count": null, "id": "c7dd6f05", - "metadata": {}, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, "outputs": [], "source": [ "%load_ext autoreload\n", @@ -44,7 +56,11 @@ "cell_type": "code", "execution_count": null, "id": "fdf38b50", - "metadata": {}, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, "outputs": [], "source": [ "import pandas as pd\n", @@ -56,7 +72,10 @@ "execution_count": null, "id": "2e7b0b88", "metadata": { - "scrolled": true + "scrolled": true, + "pycharm": { + "name": "#%%\n" + } }, "outputs": [], "source": [ @@ -67,7 +86,11 @@ "cell_type": "code", "execution_count": null, "id": "52670c2f", - "metadata": {}, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, "outputs": [], "source": [ "reference, analysis, analysis_target = nml.load_synthetic_binary_classification_dataset()\n", @@ -79,7 +102,11 @@ "cell_type": "code", "execution_count": null, "id": "df13c412", - "metadata": {}, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, "outputs": [], "source": [ "metadata = nml.extract_metadata(data = reference, model_name='wfh_predictor')\n", @@ -90,7 +117,11 @@ "cell_type": "code", "execution_count": null, "id": "f0bcc524", - "metadata": {}, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, "outputs": [], "source": [ "CHUNK_KEY_COLUMN_NAME = 'key'" @@ -99,7 +130,11 @@ { "cell_type": "markdown", "id": "9cf8134d", - "metadata": {}, + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, "source": [ "# Confidence based performance estimation" ] @@ -108,7 +143,11 @@ "cell_type": "code", "execution_count": null, "id": "dc816d8e", - "metadata": {}, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, "outputs": [], "source": [ "estimator = nml.CBPE(model_metadata=metadata, chunk_size=chunk_size)\n", @@ -120,7 +159,11 @@ "cell_type": "code", "execution_count": null, "id": "613c8682", - "metadata": {}, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, "outputs": [], "source": [ "estimation_results = estimated_performance.data" @@ -129,7 +172,11 @@ { "cell_type": "markdown", "id": "55e25e7e", - "metadata": {}, + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, "source": [ "still a bug here, roc_auc in referene period should not be estimated\n", "\n", @@ -141,7 +188,10 @@ "execution_count": null, "id": "79254d59", "metadata": { - "scrolled": false + "scrolled": false, + "pycharm": { + "name": "#%%\n" + } }, "outputs": [], "source": [ @@ -180,7 +230,11 @@ "cell_type": "code", "execution_count": null, "id": "7db6af89", - "metadata": {}, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, "outputs": [], "source": [ "import plotly.graph_objects as go\n", @@ -241,7 +295,11 @@ { "cell_type": "markdown", "id": "e4a3b842", - "metadata": {}, + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, "source": [ "# Reconstruction error" ] @@ -250,7 +308,11 @@ "cell_type": "code", "execution_count": null, "id": "8b3def68", - "metadata": {}, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, "outputs": [], "source": [ "rcerror_calculator = nml.DataReconstructionDriftCalculator(model_metadata=metadata, chunk_size=chunk_size)\n", @@ -262,7 +324,11 @@ "cell_type": "code", "execution_count": null, "id": "1ddaae9e", - "metadata": {}, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, "outputs": [], "source": [ "data = rcerror_results.data" @@ -272,7 +338,11 @@ "cell_type": "code", "execution_count": null, "id": "99d94ee8", - "metadata": {}, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, "outputs": [], "source": [ "plot_partition_separator = len(data.value_counts()) > 1\n", @@ -296,7 +366,11 @@ { "cell_type": "markdown", "id": "c334cd0d", - "metadata": {}, + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, "source": [ "# Continuous/cartegoric univariate data drift" ] @@ -305,7 +379,11 @@ "cell_type": "code", "execution_count": null, "id": "5382410d", - "metadata": {}, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, "outputs": [], "source": [ "univariate_calculator = nml.UnivariateStatisticalDriftCalculator(model_metadata=metadata, chunk_size=chunk_size)\n", @@ -317,7 +395,11 @@ "cell_type": "code", "execution_count": null, "id": "9094d88b", - "metadata": {}, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, "outputs": [], "source": [ "data = univariate_results.data\n", @@ -329,7 +411,11 @@ "cell_type": "code", "execution_count": null, "id": "9462f047", - "metadata": {}, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, "outputs": [], "source": [ "plot_partition_separator = len(data.value_counts()) > 1\n", @@ -354,7 +440,11 @@ { "cell_type": "markdown", "id": "fdda3efc", - "metadata": {}, + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, "source": [ "# Realised performance monitoring" ] @@ -363,7 +453,11 @@ "cell_type": "code", "execution_count": null, "id": "e157a7d0", - "metadata": {}, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, "outputs": [], "source": [ "rperf_results = estimation_results.copy()\n", @@ -375,7 +469,11 @@ "cell_type": "code", "execution_count": null, "id": "9f05e8a2", - "metadata": {}, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, "outputs": [], "source": [ "rperf_results['realised_target_percentage'] = 1\n", @@ -385,7 +483,11 @@ { "cell_type": "markdown", "id": "5b374a9f", - "metadata": {}, + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, "source": [ "when realised target percange < 0.25, rows should be removed \n", "\n", @@ -398,7 +500,11 @@ "cell_type": "code", "execution_count": null, "id": "79da6c9c", - "metadata": {}, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, "outputs": [], "source": [ "plot_partition_separator = len(estimation_results.value_counts()) > 1\n", @@ -427,7 +533,11 @@ { "cell_type": "markdown", "id": "b57b796e", - "metadata": {}, + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, "source": [ "# Target distribution monitoring" ] @@ -436,7 +546,11 @@ "cell_type": "code", "execution_count": null, "id": "7180f8c7", - "metadata": {}, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, "outputs": [], "source": [ "tdist_results = estimation_results.copy()\n", @@ -453,7 +567,11 @@ "cell_type": "code", "execution_count": null, "id": "6c250bf4", - "metadata": {}, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, "outputs": [], "source": [ "tdist_results['realised_target_percentage'] = 1\n", @@ -463,7 +581,11 @@ { "cell_type": "markdown", "id": "479a6808", - "metadata": {}, + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, "source": [ "when realised target percange < 0.25, rows should be removed \n", "\n", @@ -474,7 +596,11 @@ "cell_type": "code", "execution_count": null, "id": "43d1581b", - "metadata": {}, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, "outputs": [], "source": [ "tdist_results = tdist_results.loc[tdist_results['realised_target_percentage'] > 0.25, ]" @@ -484,7 +610,11 @@ "cell_type": "code", "execution_count": null, "id": "c7b0bcde", - "metadata": {}, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, "outputs": [], "source": [ "plot_partition_separator = len(estimation_results.value_counts()) > 1\n", @@ -532,4 +662,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} +} \ No newline at end of file diff --git a/tests/performance_calculation/metrics/test_binary_classification.py b/tests/performance_calculation/metrics/test_binary_classification.py index f57b212c..6b08dc49 100644 --- a/tests/performance_calculation/metrics/test_binary_classification.py +++ b/tests/performance_calculation/metrics/test_binary_classification.py @@ -50,6 +50,18 @@ def realized_performance_metrics(performance_calculator, binary_data) -> pd.Data return results.data +@pytest.fixture(scope='module') +def no_timestamp_metrics(binary_data): + calc = PerformanceCalculator( + y_pred_proba='y_pred_proba', + y_pred='y_pred', + y_true='work_home_actual', + metrics=['roc_auc', 'f1', 'precision', 'recall', 'specificity', 'accuracy'], + problem_type='classification_binary', + ).fit(binary_data[0]) + return calc.calculate(binary_data[1].merge(binary_data[2], on='identifier')).data + + @pytest.mark.parametrize( 'key,problem_type,metric', [ @@ -88,3 +100,19 @@ def test_metric_factory_returns_correct_metric_given_key_and_problem_type(key, p def test_metric_values_are_calculated_correctly(realized_performance_metrics, metric, expected): metric_values = realized_performance_metrics[metric] assert (round(metric_values, 5) == expected).all() + + +@pytest.mark.parametrize( + 'metric, expected', + [ + ('roc_auc', [0.97096, 0.97025, 0.97628, 0.96772, 0.96989, 0.96005, 0.95853, 0.95904, 0.96309, 0.95756]), + ('f1', [0.92186, 0.92124, 0.92678, 0.91684, 0.92356, 0.87424, 0.87672, 0.86806, 0.883, 0.86775]), + ('precision', [0.96729, 0.96607, 0.96858, 0.96819, 0.9661, 0.94932, 0.95777, 0.95012, 0.95718, 0.94271]), + ('recall', [0.88051, 0.88039, 0.88843, 0.87067, 0.8846, 0.81017, 0.80832, 0.79904, 0.8195, 0.80383]), + ('specificity', [0.9681, 0.9701, 0.97277, 0.9718, 0.96864, 0.95685, 0.96364, 0.95795, 0.96386, 0.94879]), + ('accuracy', [0.9228, 0.926, 0.9318, 0.9216, 0.9264, 0.8836, 0.8852, 0.8784, 0.8922, 0.8746]), + ], +) +def test_metric_values_without_timestamp_are_calculated_correctly(no_timestamp_metrics, metric, expected): + metric_values = no_timestamp_metrics[metric] + assert (round(metric_values, 5) == expected).all() diff --git a/tests/performance_calculation/metrics/test_multiclass_classification.py b/tests/performance_calculation/metrics/test_multiclass_classification.py index a395a8df..335ccc99 100644 --- a/tests/performance_calculation/metrics/test_multiclass_classification.py +++ b/tests/performance_calculation/metrics/test_multiclass_classification.py @@ -49,7 +49,24 @@ def performance_calculator() -> PerformanceCalculator: @pytest.fixture(scope='module') -def realized_performance_metrics(performance_calculator, multiclass_data) -> pd.DataFrame: +def realized_performance_metrics(multiclass_data) -> pd.DataFrame: + performance_calculator = PerformanceCalculator( + y_pred_proba={ + 'prepaid_card': 'y_pred_proba_prepaid_card', + 'highstreet_card': 'y_pred_proba_highstreet_card', + 'upmarket_card': 'y_pred_proba_upmarket_card', + }, + y_pred='y_pred', + y_true='y_true', + metrics=['roc_auc', 'f1', 'precision', 'recall', 'specificity', 'accuracy'], + problem_type='classification_multiclass', + ).fit(multiclass_data[0]) + results = performance_calculator.calculate(multiclass_data[1].merge(multiclass_data[2], on='identifier')) + return results.data + + +@pytest.fixture(scope='module') +def no_timestamp_metrics(performance_calculator, multiclass_data) -> pd.DataFrame: performance_calculator.fit(multiclass_data[0]) results = performance_calculator.calculate(multiclass_data[1].merge(multiclass_data[2], on='identifier')) return results.data @@ -93,3 +110,19 @@ def test_metric_factory_returns_correct_metric_given_key_and_problem_type(key, p def test_metric_values_are_calculated_correctly(realized_performance_metrics, metric, expected): metric_values = realized_performance_metrics[metric] assert (round(metric_values, 5) == expected).all() + + +@pytest.mark.parametrize( + 'metric, expected', + [ + ('roc_auc', [0.90759, 0.91053, 0.90941, 0.91158, 0.90753, 0.74859, 0.75114, 0.7564, 0.75856, 0.75394]), + ('f1', [0.7511, 0.76305, 0.75849, 0.75894, 0.75796, 0.55711, 0.55915, 0.56506, 0.5639, 0.56164]), + ('precision', [0.75127, 0.76313, 0.7585, 0.75897, 0.75795, 0.5597, 0.56291, 0.56907, 0.56667, 0.56513]), + ('recall', [0.75103, 0.76315, 0.75848, 0.75899, 0.75798, 0.55783, 0.56017, 0.56594, 0.56472, 0.56277]), + ('specificity', [0.87555, 0.88151, 0.87937, 0.87963, 0.87899, 0.77991, 0.78068, 0.78422, 0.78342, 0.78243]), + ('accuracy', [0.75117, 0.763, 0.75867, 0.75917, 0.758, 0.56083, 0.56233, 0.56983, 0.56783, 0.566]), + ], +) +def test_metric_values_without_timestamps_are_calculated_correctly(no_timestamp_metrics, metric, expected): + metric_values = no_timestamp_metrics[metric] + assert (round(metric_values, 5) == expected).all() diff --git a/tests/performance_calculation/metrics/test_regression.py b/tests/performance_calculation/metrics/test_regression.py index cb4518bc..dd131a84 100644 --- a/tests/performance_calculation/metrics/test_regression.py +++ b/tests/performance_calculation/metrics/test_regression.py @@ -44,6 +44,23 @@ def realized_performance_metrics(performance_calculator, regression_data) -> pd. return results.data +@pytest.fixture(scope='module') +def no_timestamp_metrics(regression_data) -> pd.DataFrame: + # Get rid of negative values for log based metrics + reference = regression_data[0][~(regression_data[0]['y_pred'] < 0)] + analysis = regression_data[1][~(regression_data[1]['y_pred'] < 0)] + + performance_calculator = PerformanceCalculator( + timestamp_column_name='timestamp', + y_pred='y_pred', + y_true='y_true', + metrics=['mae', 'mape', 'mse', 'msle', 'rmse', 'rmsle'], + problem_type='regression', + ).fit(reference) + results = performance_calculator.calculate(analysis.join(regression_data[2])) + return results.data + + @pytest.mark.parametrize( 'key,problem_type,metric', [ @@ -113,3 +130,50 @@ def test_metric_factory_returns_correct_metric_given_key_and_problem_type(key, p def test_metric_values_are_calculated_correctly(realized_performance_metrics, metric, expected): metric_values = realized_performance_metrics[metric] assert (round(metric_values, 5) == expected).all() + + +@pytest.mark.parametrize( + 'metric, expected', + [ + ( + 'mae', + [853.39967, 853.13667, 846.304, 855.4945, 849.3295, 702.51767, 700.73583, 684.70167, 705.814, 698.34383], + ), + ('mape', [0.22871, 0.23082, 0.22904, 0.23362, 0.23389, 0.26286, 0.26346, 0.26095, 0.26537, 0.26576]), + ( + 'mse', + [ + 1143129.298, + 1139867.667, + 1128720.807, + 1158285.6715, + 1124285.66517, + 829589.49233, + 829693.3775, + 792286.80933, + 835916.964, + 825935.67917, + ], + ), + ('msle', [0.07049, 0.06999, 0.06969, 0.07193, 0.07249, 0.10495, 0.10481, 0.10435, 0.10471, 0.10588]), + ( + 'rmse', + [ + 1069.17225, + 1067.64585, + 1062.41273, + 1076.23681, + 1060.32338, + 910.81803, + 910.87506, + 890.10494, + 914.28495, + 908.81003, + ], + ), + ('rmsle', [0.2655, 0.26456, 0.26399, 0.2682, 0.26924, 0.32396, 0.32375, 0.32303, 0.3236, 0.32539]), + ], +) +def test_metric_values_without_timestamps_are_calculated_correctly(no_timestamp_metrics, metric, expected): + metric_values = no_timestamp_metrics[metric] + assert (round(metric_values, 5) == expected).all() diff --git a/tests/performance_calculation/test_performance_calculator.py b/tests/performance_calculation/test_performance_calculator.py index fe477111..c4fb2bda 100644 --- a/tests/performance_calculation/test_performance_calculator.py +++ b/tests/performance_calculation/test_performance_calculator.py @@ -9,7 +9,12 @@ import pandas as pd import pytest -from nannyml.datasets import load_synthetic_binary_classification_dataset +from nannyml._typing import ProblemType +from nannyml.datasets import ( + load_synthetic_binary_classification_dataset, + load_synthetic_car_price_dataset, + load_synthetic_multiclass_classification_dataset, +) from nannyml.exceptions import InvalidArgumentsException from nannyml.performance_calculation import PerformanceCalculator from nannyml.performance_calculation.metrics.binary_classification import ( @@ -139,3 +144,101 @@ def test_calculator_calculate_should_include_target_completeness_rate(data): # assert 'targets_missing_rate' in sut.data.columns assert sut.data.loc[0, 'targets_missing_rate'] == 0.1 assert sut.data.loc[1, 'targets_missing_rate'] == 0.9 + + +@pytest.mark.parametrize( + 'calc_args, plot_args', + [ + ({'timestamp_column_name': 'timestamp'}, {'kind': 'performance', 'plot_reference': False, 'metric': 'mae'}), + ({}, {'kind': 'performance', 'plot_reference': False, 'metric': 'mae'}), + ({'timestamp_column_name': 'timestamp'}, {'kind': 'performance', 'plot_reference': True, 'metric': 'mae'}), + ({}, {'kind': 'performance', 'plot_reference': True, 'metric': 'mae'}), + ], + ids=[ + 'performance_with_timestamp_without_reference', + 'performance_without_timestamp_without_reference', + 'performance_with_timestamp_with_reference', + 'performance_without_timestamp_with_reference', + ], +) +def test_regression_result_plots_raise_no_exceptions(calc_args, plot_args): # noqa: D103 + reference, analysis, analysis_targets = load_synthetic_car_price_dataset() + calc = PerformanceCalculator( + y_true='y_true', y_pred='y_pred', problem_type=ProblemType.REGRESSION, metrics=['mae', 'mape'], **calc_args + ).fit(reference) + sut = calc.calculate(analysis.join(analysis_targets)) + + try: + _ = sut.plot(**plot_args) + except Exception as exc: + pytest.fail(f"an unexpected exception occurred: {exc}") + + +@pytest.mark.parametrize( + 'calc_args, plot_args', + [ + ({'timestamp_column_name': 'timestamp'}, {'kind': 'performance', 'plot_reference': False, 'metric': 'f1'}), + ({}, {'kind': 'performance', 'plot_reference': False, 'metric': 'f1'}), + ({'timestamp_column_name': 'timestamp'}, {'kind': 'performance', 'plot_reference': True, 'metric': 'f1'}), + ({}, {'kind': 'performance', 'plot_reference': True, 'metric': 'f1'}), + ], + ids=[ + 'performance_with_timestamp_without_reference', + 'performance_without_timestamp_without_reference', + 'performance_with_timestamp_with_reference', + 'performance_without_timestamp_with_reference', + ], +) +def test_multiclass_classification_result_plots_raise_no_exceptions(calc_args, plot_args): # noqa: D103 + reference, analysis, analysis_targets = load_synthetic_multiclass_classification_dataset() + calc = PerformanceCalculator( + y_true='y_true', + y_pred='y_pred', + y_pred_proba={ + 'upmarket_card': 'y_pred_proba_upmarket_card', + 'highstreet_card': 'y_pred_proba_highstreet_card', + 'prepaid_card': 'y_pred_proba_prepaid_card', + }, + problem_type=ProblemType.CLASSIFICATION_MULTICLASS, + metrics=['roc_auc', 'f1'], + **calc_args, + ).fit(reference) + sut = calc.calculate(analysis.merge(analysis_targets, on='identifier')) + + try: + _ = sut.plot(**plot_args) + except Exception as exc: + pytest.fail(f"an unexpected exception occurred: {exc}") + + +@pytest.mark.parametrize( + 'calc_args, plot_args', + [ + ({'timestamp_column_name': 'timestamp'}, {'kind': 'performance', 'plot_reference': False, 'metric': 'f1'}), + ({}, {'kind': 'performance', 'plot_reference': False, 'metric': 'f1'}), + ({'timestamp_column_name': 'timestamp'}, {'kind': 'performance', 'plot_reference': True, 'metric': 'f1'}), + ({}, {'kind': 'performance', 'plot_reference': True, 'metric': 'f1'}), + ], + ids=[ + 'performance_with_timestamp_without_reference', + 'performance_without_timestamp_without_reference', + 'performance_with_timestamp_with_reference', + 'performance_without_timestamp_with_reference', + ], +) +def test_binary_classification_result_plots_raise_no_exceptions(calc_args, plot_args): # noqa: D103 + reference, analysis, analysis_targets = load_synthetic_binary_classification_dataset() + calc = PerformanceCalculator( + y_true='work_home_actual', + y_pred='y_pred', + y_pred_proba='y_pred_proba', + problem_type=ProblemType.CLASSIFICATION_BINARY, + metrics=['roc_auc', 'f1'], + **calc_args, + ).fit(reference) + sut = calc.calculate(analysis.merge(analysis_targets, on='identifier')) + + try: + _ = sut.plot(**plot_args) + except Exception as exc: + pytest.fail(f"an unexpected exception occurred: {exc}") diff --git a/tests/performance_estimation/CBPE/test_cbpe.py b/tests/performance_estimation/CBPE/test_cbpe.py index 5149a96b..8c456a04 100644 --- a/tests/performance_estimation/CBPE/test_cbpe.py +++ b/tests/performance_estimation/CBPE/test_cbpe.py @@ -16,6 +16,7 @@ import pytest from pytest_mock import MockerFixture +from nannyml._typing import ProblemType from nannyml.calibration import Calibrator, IsotonicCalibrator from nannyml.datasets import ( load_synthetic_binary_classification_dataset, @@ -500,3 +501,72 @@ def test_cbpe_for_multiclass_classification_chunked_by_period_should_include_var assert f'sampling_error_{metric}' in results.data.columns assert np.array_equal(np.round(results.data[f'sampling_error_{metric}'], 4), np.round(sampling_error, 4)) + + +@pytest.mark.parametrize( + 'estimator_args, plot_args', + [ + ({'timestamp_column_name': 'timestamp'}, {'kind': 'performance', 'plot_reference': False, 'metric': 'f1'}), + ({}, {'kind': 'performance', 'plot_reference': False, 'metric': 'f1'}), + ({'timestamp_column_name': 'timestamp'}, {'kind': 'performance', 'plot_reference': True, 'metric': 'f1'}), + ({}, {'kind': 'performance', 'plot_reference': True, 'metric': 'f1'}), + ], + ids=[ + 'performance_with_timestamp_without_reference', + 'performance_without_timestamp_without_reference', + 'performance_with_timestamp_with_reference', + 'performance_without_timestamp_with_reference', + ], +) +def test_multiclass_classification_result_plots_raise_no_exceptions(estimator_args, plot_args): # noqa: D103 + reference, analysis, analysis_targets = load_synthetic_multiclass_classification_dataset() + est = CBPE( + y_true='y_true', + y_pred='y_pred', + y_pred_proba={ + 'upmarket_card': 'y_pred_proba_upmarket_card', + 'highstreet_card': 'y_pred_proba_highstreet_card', + 'prepaid_card': 'y_pred_proba_prepaid_card', + }, + problem_type=ProblemType.CLASSIFICATION_MULTICLASS, + metrics=['roc_auc', 'f1'], + ).fit(reference) + sut = est.estimate(analysis) + + try: + _ = sut.plot(**plot_args) + except Exception as exc: + pytest.fail(f"an unexpected exception occurred: {exc}") + + +@pytest.mark.parametrize( + 'estimator_args, plot_args', + [ + ({'timestamp_column_name': 'timestamp'}, {'kind': 'performance', 'plot_reference': False, 'metric': 'f1'}), + ({}, {'kind': 'performance', 'plot_reference': False, 'metric': 'f1'}), + ({'timestamp_column_name': 'timestamp'}, {'kind': 'performance', 'plot_reference': True, 'metric': 'f1'}), + ({}, {'kind': 'performance', 'plot_reference': True, 'metric': 'f1'}), + ], + ids=[ + 'performance_with_timestamp_without_reference', + 'performance_without_timestamp_without_reference', + 'performance_with_timestamp_with_reference', + 'performance_without_timestamp_with_reference', + ], +) +def test_binary_classification_result_plots_raise_no_exceptions(estimator_args, plot_args): # noqa: D103 + reference, analysis, analysis_targets = load_synthetic_binary_classification_dataset() + est = CBPE( + y_true='work_home_actual', + y_pred='y_pred', + y_pred_proba='y_pred_proba', + problem_type=ProblemType.CLASSIFICATION_BINARY, + metrics=['roc_auc', 'f1'], + **estimator_args, + ).fit(reference) + sut = est.estimate(analysis) + + try: + _ = sut.plot(**plot_args) + except Exception as exc: + pytest.fail(f"an unexpected exception occurred: {exc}") diff --git a/tests/performance_estimation/CBPE/test_cbpe_metrics.py b/tests/performance_estimation/CBPE/test_cbpe_metrics.py index e69de29b..fadde0fd 100644 --- a/tests/performance_estimation/CBPE/test_cbpe_metrics.py +++ b/tests/performance_estimation/CBPE/test_cbpe_metrics.py @@ -0,0 +1,726 @@ +import pandas as pd +import pytest + +from nannyml.datasets import ( + load_synthetic_binary_classification_dataset, + load_synthetic_multiclass_classification_dataset, +) +from nannyml.performance_estimation.confidence_based import CBPE + + +@pytest.mark.parametrize( + 'calculator_opts, expected', + [ + ( + {'chunk_size': 20000}, + pd.DataFrame( + { + 'key': ['[0:19999]', '[20000:39999]', '[40000:49999]'], + 'estimated_roc_auc': [0.9690827692140925, 0.9634851917907716, 0.9612426494258154], + 'estimated_f1': [0.9479079222515973, 0.9294599255593571, 0.9245516739588603], + 'estimated_precision': [0.9436121782324026, 0.9217384775672881, 0.9159308766434395], + 'estimated_recall': [0.9522429574319092, 0.9373118323003661, 0.933336292106975], + 'estimated_specificity': [0.9434949869571513, 0.9153107130639775, 0.906418072349441], + 'estimated_accuracy': [0.9478536003143163, 0.9266531619733247, 0.9204715081119524], + } + ), + ), + ( + {'chunk_size': 20000, 'timestamp_column_name': 'timestamp'}, + pd.DataFrame( + { + 'key': ['[0:19999]', '[20000:39999]', '[40000:49999]'], + 'estimated_roc_auc': [0.9690827692140925, 0.9634851917907716, 0.9612426494258154], + 'estimated_f1': [0.9479079222515973, 0.9294599255593571, 0.9245516739588603], + 'estimated_precision': [0.9436121782324026, 0.9217384775672881, 0.9159308766434395], + 'estimated_recall': [0.9522429574319092, 0.9373118323003661, 0.933336292106975], + 'estimated_specificity': [0.9434949869571513, 0.9153107130639775, 0.906418072349441], + 'estimated_accuracy': [0.9478536003143163, 0.9266531619733247, 0.9204715081119524], + } + ), + ), + ( + {'chunk_number': 4}, + pd.DataFrame( + { + 'key': ['[0:12499]', '[12500:24999]', '[25000:37499]', '[37500:49999]'], + 'estimated_roc_auc': [ + 0.9690496133543502, + 0.9690315182120458, + 0.9607404611197395, + 0.9611281635643135, + ], + 'estimated_f1': [0.9478220802546018, 0.9483683062849226, 0.9235839411719146, 0.9239534478264141], + 'estimated_precision': [ + 0.9432817252782693, + 0.9446550044048795, + 0.9145613073300847, + 0.915231419744167, + ], + 'estimated_recall': [0.9524063553647716, 0.9521109162736728, 0.9327863750020324, 0.932843314818147], + 'estimated_specificity': [ + 0.9426505776503413, + 0.9446298328855539, + 0.9050174893945843, + 0.9065202406788675, + ], + 'estimated_accuracy': [ + 0.9475319773956268, + 0.9483565214666658, + 0.9194997180834459, + 0.920199809204049, + ], + } + ), + ), + ( + {'chunk_number': 4, 'timestamp_column_name': 'timestamp'}, + pd.DataFrame( + { + 'key': ['[0:12499]', '[12500:24999]', '[25000:37499]', '[37500:49999]'], + 'estimated_roc_auc': [ + 0.9690496133543502, + 0.9690315182120458, + 0.9607404611197395, + 0.9611281635643135, + ], + 'estimated_f1': [0.9478220802546018, 0.9483683062849226, 0.9235839411719146, 0.9239534478264141], + 'estimated_precision': [ + 0.9432817252782693, + 0.9446550044048795, + 0.9145613073300847, + 0.915231419744167, + ], + 'estimated_recall': [0.9524063553647716, 0.9521109162736728, 0.9327863750020324, 0.932843314818147], + 'estimated_specificity': [ + 0.9426505776503413, + 0.9446298328855539, + 0.9050174893945843, + 0.9065202406788675, + ], + 'estimated_accuracy': [ + 0.9475319773956268, + 0.9483565214666658, + 0.9194997180834459, + 0.920199809204049, + ], + } + ), + ), + ( + {'chunk_period': 'Y', 'timestamp_column_name': 'timestamp'}, + pd.DataFrame( + { + 'key': ['2017', '2018', '2019', '2020', '2021'], + 'estimated_roc_auc': [ + 0.9686251180452822, + 0.9691860189228905, + 0.9643157928165687, + 0.9610299125179039, + 0.8947719335229752, + ], + 'estimated_f1': [ + 0.9486231864324036, + 0.9477092734637503, + 0.9315900332119164, + 0.9240488520044474, + 0.7926388251111386, + ], + 'estimated_precision': [ + 0.9425338528308135, + 0.9440235151260787, + 0.9244858823267199, + 0.9151737753536913, + 0.7100248791165871, + ], + 'estimated_recall': [ + 0.9547917131585199, + 0.9514239252802291, + 0.9388042123614001, + 0.9330977499034733, + 0.8970090514396291, + ], + 'estimated_specificity': [ + 0.937163050852636, + 0.9454556899788352, + 0.918667486315966, + 0.9060337870230502, + 0.8302703490347314, + ], + 'estimated_accuracy': [ + 0.9463140395517069, + 0.9483894898124492, + 0.9290320609580357, + 0.9201265813431901, + 0.8514010785547443, + ], + } + ), + ), + ( + {}, + pd.DataFrame( + { + 'key': [ + '[0:4999]', + '[5000:9999]', + '[10000:14999]', + '[15000:19999]', + '[20000:24999]', + '[25000:29999]', + '[30000:34999]', + '[35000:39999]', + '[40000:44999]', + '[45000:49999]', + ], + 'estimated_roc_auc': [ + 0.9686306166809373, + 0.9690438435474089, + 0.9694438261961335, + 0.9690472634498744, + 0.9688726952274674, + 0.960478016244988, + 0.9611336210051199, + 0.9605358219511105, + 0.9618691204348024, + 0.9605366452565602, + ], + 'estimated_f1': [ + 0.948555321454138, + 0.9465779465209089, + 0.9488069354531159, + 0.9476546805658866, + 0.9488338110572719, + 0.9232325655782327, + 0.923595783273164, + 0.9229020961475524, + 0.9241722368115827, + 0.9249152193013231, + ], + 'estimated_precision': [ + 0.9425440716307568, + 0.9431073537123396, + 0.9442277523749669, + 0.9446336452028188, + 0.9453707449594683, + 0.9148732952385749, + 0.9143489884377147, + 0.9133821290777778, + 0.9173539177641505, + 0.9145726617897999, + ], + 'estimated_recall': [ + 0.9546437391537488, + 0.95007417692678, + 0.9534307499357034, + 0.9506951010870748, + 0.9523223421280594, + 0.9317460031975845, + 0.9330315141543621, + 0.9326226021419973, + 0.9310926704269641, + 0.9354943725540698, + ], + 'estimated_specificity': [ + 0.9372742491954283, + 0.9446319110387503, + 0.9458015988378999, + 0.9459382667254527, + 0.9442380077839293, + 0.9054074470674046, + 0.9054733303403276, + 0.905142438988179, + 0.9120143191041326, + 0.9005790635739856, + ], + 'estimated_accuracy': [ + 0.9462845122805443, + 0.947306078326665, + 0.949543087716159, + 0.9482807229338969, + 0.9483068458984663, + 0.9191503008524827, + 0.9197921095331023, + 0.9193633916092473, + 0.9217811573797979, + 0.9191618588441066, + ], + } + ), + ), + ( + {'timestamp_column_name': 'timestamp'}, + pd.DataFrame( + { + 'key': [ + '[0:4999]', + '[5000:9999]', + '[10000:14999]', + '[15000:19999]', + '[20000:24999]', + '[25000:29999]', + '[30000:34999]', + '[35000:39999]', + '[40000:44999]', + '[45000:49999]', + ], + 'estimated_roc_auc': [ + 0.9686306166809373, + 0.9690438435474089, + 0.9694438261961335, + 0.9690472634498744, + 0.9688726952274674, + 0.960478016244988, + 0.9611336210051199, + 0.9605358219511105, + 0.9618691204348024, + 0.9605366452565602, + ], + 'estimated_f1': [ + 0.948555321454138, + 0.9465779465209089, + 0.9488069354531159, + 0.9476546805658866, + 0.9488338110572719, + 0.9232325655782327, + 0.923595783273164, + 0.9229020961475524, + 0.9241722368115827, + 0.9249152193013231, + ], + 'estimated_precision': [ + 0.9425440716307568, + 0.9431073537123396, + 0.9442277523749669, + 0.9446336452028188, + 0.9453707449594683, + 0.9148732952385749, + 0.9143489884377147, + 0.9133821290777778, + 0.9173539177641505, + 0.9145726617897999, + ], + 'estimated_recall': [ + 0.9546437391537488, + 0.95007417692678, + 0.9534307499357034, + 0.9506951010870748, + 0.9523223421280594, + 0.9317460031975845, + 0.9330315141543621, + 0.9326226021419973, + 0.9310926704269641, + 0.9354943725540698, + ], + 'estimated_specificity': [ + 0.9372742491954283, + 0.9446319110387503, + 0.9458015988378999, + 0.9459382667254527, + 0.9442380077839293, + 0.9054074470674046, + 0.9054733303403276, + 0.905142438988179, + 0.9120143191041326, + 0.9005790635739856, + ], + 'estimated_accuracy': [ + 0.9462845122805443, + 0.947306078326665, + 0.949543087716159, + 0.9482807229338969, + 0.9483068458984663, + 0.9191503008524827, + 0.9197921095331023, + 0.9193633916092473, + 0.9217811573797979, + 0.9191618588441066, + ], + } + ), + ), + ], + ids=[ + 'size_based_without_timestamp', + 'size_based_with_timestamp', + 'count_based_without_timestamp', + 'count_based_with_timestamp', + 'period_based_with_timestamp', + 'default_without_timestamp', + 'default_with_timestamp', + ], +) +def test_cbpe_for_binary_classification_with_timestamps(calculator_opts, expected): + ref_df, ana_df, _ = load_synthetic_binary_classification_dataset() + cbpe = CBPE( + y_pred_proba='y_pred_proba', + y_pred='y_pred', + y_true='work_home_actual', + problem_type='classification_binary', + metrics=['roc_auc', 'f1', 'precision', 'recall', 'specificity', 'accuracy'], + **calculator_opts, + ).fit(ref_df) + sut = cbpe.estimate(ana_df).data + + pd.testing.assert_frame_equal( + expected, + sut[ + [ + 'key', + 'estimated_roc_auc', + 'estimated_f1', + 'estimated_precision', + 'estimated_recall', + 'estimated_specificity', + 'estimated_accuracy', + ] + ], + ) + + +@pytest.mark.parametrize( + 'calculator_opts, expected', + [ + ( + {'chunk_size': 20000}, + pd.DataFrame( + { + 'key': ['[0:19999]', '[20000:39999]', '[40000:59999]'], + 'estimated_roc_auc': [0.9092377595865466, 0.8683877226653395, 0.8204766170638091], + 'estimated_f1': [0.756401608336434, 0.6937135623882767, 0.632386421613214], + 'estimated_precision': [0.7564437378390059, 0.694174192229447, 0.6336288859123612], + 'estimated_recall': [0.7564129287764665, 0.6934788458355289, 0.6319310599943714], + 'estimated_specificity': [0.8782068281303994, 0.8469556750949159, 0.8172644220189141], + 'estimated_accuracy': [0.7564451493123628, 0.6946947603445697, 0.6378557309960986], + } + ), + ), + ( + {'chunk_size': 20000, 'timestamp_column_name': 'timestamp'}, + pd.DataFrame( + { + 'key': ['[0:19999]', '[20000:39999]', '[40000:59999]'], + 'estimated_roc_auc': [0.9092377595865466, 0.8683877226653395, 0.8204766170638091], + 'estimated_f1': [0.756401608336434, 0.6937135623882767, 0.632386421613214], + 'estimated_precision': [0.7564437378390059, 0.694174192229447, 0.6336288859123612], + 'estimated_recall': [0.7564129287764665, 0.6934788458355289, 0.6319310599943714], + 'estimated_specificity': [0.8782068281303994, 0.8469556750949159, 0.8172644220189141], + 'estimated_accuracy': [0.7564451493123628, 0.6946947603445697, 0.6378557309960986], + } + ), + ), + ( + {'chunk_number': 4}, + pd.DataFrame( + { + 'key': ['[0:14999]', '[15000:29999]', '[30000:44999]', '[45000:59999]'], + 'estimated_roc_auc': [ + 0.9085083182636969, + 0.9088360564807361, + 0.8196861857675541, + 0.8203213219880933, + ], + 'estimated_f1': [0.7550059244451006, 0.7562711250144366, 0.63091155676697, 0.6324244687112559], + 'estimated_precision': [ + 0.755038246904623, + 0.7562647262876293, + 0.6323547131368327, + 0.6335323150520741, + ], + 'estimated_recall': [ + 0.7550277340784216, + 0.7562926950204228, + 0.6304009454574501, + 0.6320155112489632, + ], + 'estimated_specificity': [ + 0.8775094795233379, + 0.8781429133214084, + 0.8165537125162895, + 0.8172408983542975, + ], + 'estimated_accuracy': [ + 0.7550428613792668, + 0.7562888217426292, + 0.6364205304514962, + 0.6375753072973162, + ], + } + ), + ), + ( + {'chunk_number': 4, 'timestamp_column_name': 'timestamp'}, + pd.DataFrame( + { + 'key': ['[0:14999]', '[15000:29999]', '[30000:44999]', '[45000:59999]'], + 'estimated_roc_auc': [ + 0.9085083182636969, + 0.9088360564807361, + 0.8196861857675541, + 0.8203213219880933, + ], + 'estimated_f1': [0.7550059244451006, 0.7562711250144366, 0.63091155676697, 0.6324244687112559], + 'estimated_precision': [ + 0.755038246904623, + 0.7562647262876293, + 0.6323547131368327, + 0.6335323150520741, + ], + 'estimated_recall': [ + 0.7550277340784216, + 0.7562926950204228, + 0.6304009454574501, + 0.6320155112489632, + ], + 'estimated_specificity': [ + 0.8775094795233379, + 0.8781429133214084, + 0.8165537125162895, + 0.8172408983542975, + ], + 'estimated_accuracy': [ + 0.7550428613792668, + 0.7562888217426292, + 0.6364205304514962, + 0.6375753072973162, + ], + } + ), + ), + ( + {'chunk_period': 'Y', 'timestamp_column_name': 'timestamp'}, + pd.DataFrame( + { + 'key': ['2020', '2021'], + 'estimated_roc_auc': [0.8698976466237168, 0.8161798350988554], + 'estimated_f1': [0.6959459683194374, 0.6271637037915178], + 'estimated_precision': [0.696279612597813, 0.6275707355339551], + 'estimated_recall': [0.6957620347508907, 0.6272720458900231], + 'estimated_specificity': [0.8480220572478717, 0.8145095377877009], + 'estimated_accuracy': [0.6967957612985849, 0.6305270354546132], + } + ), + ), + ( + {}, + pd.DataFrame( + { + 'key': [ + '[0:5999]', + '[6000:11999]', + '[12000:17999]', + '[18000:23999]', + '[24000:29999]', + '[30000:35999]', + '[36000:41999]', + '[42000:47999]', + '[48000:53999]', + '[54000:59999]', + ], + 'estimated_roc_auc': [ + 0.9070370295881952, + 0.9099482873688786, + 0.909958298321231, + 0.9091054631936624, + 0.9071886074481345, + 0.8195150939735344, + 0.8202573402661897, + 0.819127271546289, + 0.81940620684774, + 0.8215839856335078, + ], + 'estimated_f1': [ + 0.7533014808638749, + 0.7564216285126095, + 0.7581655498090148, + 0.7565574256686872, + 0.7536178781761181, + 0.6309845876312757, + 0.6314817817562323, + 0.6305520915509012, + 0.6317356567785939, + 0.6334434546737131, + ], + 'estimated_precision': [ + 0.7533705392643878, + 0.7564117249294325, + 0.758189598004742, + 0.7565615847700302, + 0.7536066338250612, + 0.6320478731185537, + 0.632902872215884, + 0.6320206664159113, + 0.6327414423931249, + 0.6349486020523588, + ], + 'estimated_recall': [ + 0.7532927299604006, + 0.756484874870973, + 0.7581945446990431, + 0.7565540393358208, + 0.7536741378197004, + 0.6306782567019867, + 0.6308507454911164, + 0.6299485723087591, + 0.6315005643336115, + 0.6328910840233342, + ], + 'estimated_specificity': [ + 0.876648985916452, + 0.8781935469456502, + 0.87910164279675, + 0.8783467614315459, + 0.8767972368373717, + 0.816629841587921, + 0.8166384221108975, + 0.8163342736347389, + 0.8167795549189886, + 0.8180293638991758, + ], + 'estimated_accuracy': [ + 0.7533903547412437, + 0.7564140260171383, + 0.7582062911442542, + 0.7566888307842633, + 0.7536297051178407, + 0.6358257165904244, + 0.6366025073789935, + 0.6367168031955528, + 0.6365172577468735, + 0.6393273094601863, + ], + } + ), + ), + ( + {'timestamp_column_name': 'timestamp'}, + pd.DataFrame( + { + 'key': [ + '[0:5999]', + '[6000:11999]', + '[12000:17999]', + '[18000:23999]', + '[24000:29999]', + '[30000:35999]', + '[36000:41999]', + '[42000:47999]', + '[48000:53999]', + '[54000:59999]', + ], + 'estimated_roc_auc': [ + 0.9070370295881952, + 0.9099482873688786, + 0.909958298321231, + 0.9091054631936624, + 0.9071886074481345, + 0.8195150939735344, + 0.8202573402661897, + 0.819127271546289, + 0.81940620684774, + 0.8215839856335078, + ], + 'estimated_f1': [ + 0.7533014808638749, + 0.7564216285126095, + 0.7581655498090148, + 0.7565574256686872, + 0.7536178781761181, + 0.6309845876312757, + 0.6314817817562323, + 0.6305520915509012, + 0.6317356567785939, + 0.6334434546737131, + ], + 'estimated_precision': [ + 0.7533705392643878, + 0.7564117249294325, + 0.758189598004742, + 0.7565615847700302, + 0.7536066338250612, + 0.6320478731185537, + 0.632902872215884, + 0.6320206664159113, + 0.6327414423931249, + 0.6349486020523588, + ], + 'estimated_recall': [ + 0.7532927299604006, + 0.756484874870973, + 0.7581945446990431, + 0.7565540393358208, + 0.7536741378197004, + 0.6306782567019867, + 0.6308507454911164, + 0.6299485723087591, + 0.6315005643336115, + 0.6328910840233342, + ], + 'estimated_specificity': [ + 0.876648985916452, + 0.8781935469456502, + 0.87910164279675, + 0.8783467614315459, + 0.8767972368373717, + 0.816629841587921, + 0.8166384221108975, + 0.8163342736347389, + 0.8167795549189886, + 0.8180293638991758, + ], + 'estimated_accuracy': [ + 0.7533903547412437, + 0.7564140260171383, + 0.7582062911442542, + 0.7566888307842633, + 0.7536297051178407, + 0.6358257165904244, + 0.6366025073789935, + 0.6367168031955528, + 0.6365172577468735, + 0.6393273094601863, + ], + } + ), + ), + ], + ids=[ + 'size_based_without_timestamp', + 'size_based_with_timestamp', + 'count_based_without_timestamp', + 'count_based_with_timestamp', + 'period_based_with_timestamp', + 'default_without_timestamp', + 'default_with_timestamp', + ], +) +def test_cbpe_for_multiclass_classification_with_timestamps(calculator_opts, expected): + ref_df, ana_df, _ = load_synthetic_multiclass_classification_dataset() + cbpe = CBPE( + y_pred_proba={ + 'upmarket_card': 'y_pred_proba_upmarket_card', + 'highstreet_card': 'y_pred_proba_highstreet_card', + 'prepaid_card': 'y_pred_proba_prepaid_card', + }, + y_pred='y_pred', + y_true='y_true', + problem_type='classification_multiclass', + metrics=['roc_auc', 'f1', 'precision', 'recall', 'specificity', 'accuracy'], + **calculator_opts, + ).fit(ref_df) + sut = cbpe.estimate(ana_df).data + + pd.testing.assert_frame_equal( + expected, + sut[ + [ + 'key', + 'estimated_roc_auc', + 'estimated_f1', + 'estimated_precision', + 'estimated_recall', + 'estimated_specificity', + 'estimated_accuracy', + ] + ], + ) diff --git a/tests/performance_estimation/DLE/test_dle.py b/tests/performance_estimation/DLE/test_dle.py index ecc79ca8..75f1bc73 100644 --- a/tests/performance_estimation/DLE/test_dle.py +++ b/tests/performance_estimation/DLE/test_dle.py @@ -255,3 +255,35 @@ def test_result_plot_contains_reference_data_when_plot_reference_set_to_true(est sut = estimates.plot(metric=metric, plot_reference=True) assert len(sut.to_dict()['data'][2]['x']) > 0 assert len(sut.to_dict()['data'][2]['y']) > 0 + + +@pytest.mark.parametrize( + 'estimator_args, plot_args', + [ + ({'timestamp_column_name': 'timestamp'}, {'kind': 'performance', 'plot_reference': False, 'metric': 'mae'}), + ({}, {'kind': 'performance', 'plot_reference': False, 'metric': 'mae'}), + ({'timestamp_column_name': 'timestamp'}, {'kind': 'performance', 'plot_reference': True, 'metric': 'mae'}), + ({}, {'kind': 'performance', 'plot_reference': True, 'metric': 'mae'}), + ], + ids=[ + 'performance_with_timestamp_without_reference', + 'performance_without_timestamp_without_reference', + 'performance_with_timestamp_with_reference', + 'performance_without_timestamp_with_reference', + ], +) +def test_binary_classification_result_plots_raise_no_exceptions(estimator_args, plot_args): # noqa: D103 + reference, analysis, analysis_targets = load_synthetic_car_price_dataset() + est = DLE( + feature_column_names=[col for col in reference.columns if col not in ['y_true', 'y_pred', 'timestamp']], + y_true='y_true', + y_pred='y_pred', + metrics=['mae', 'mape'], + **estimator_args, + ).fit(reference) + sut = est.estimate(analysis) + + try: + _ = sut.plot(**plot_args) + except Exception as exc: + pytest.fail(f"an unexpected exception occurred: {exc}") diff --git a/tests/performance_estimation/DLE/test_dle_metrics.py b/tests/performance_estimation/DLE/test_dle_metrics.py index 9a81ddf0..12823aa3 100644 --- a/tests/performance_estimation/DLE/test_dle_metrics.py +++ b/tests/performance_estimation/DLE/test_dle_metrics.py @@ -6,6 +6,8 @@ from nannyml._typing import ProblemType from nannyml.base import AbstractEstimator, AbstractEstimatorResult +from nannyml.datasets import load_synthetic_car_price_dataset +from nannyml.performance_estimation.direct_loss_estimation import DLE from nannyml.performance_estimation.direct_loss_estimation.metrics import MetricFactory @@ -21,3 +23,326 @@ def _estimate(self, data: pd.DataFrame, *args, **kwargs) -> AbstractEstimatorRes def test_metric_creation_with_non_dle_estimator_raises_runtime_exc(metric): with pytest.raises(RuntimeError, match='not an instance of type DLE'): MetricFactory.create(key=metric, problem_type=ProblemType.REGRESSION, kwargs={'estimator': FakeEstimator()}) + + +@pytest.mark.parametrize( + 'calculator_opts, expected', + [ + ( + {'chunk_size': 20000}, + pd.DataFrame( + { + 'key': ['[0:19999]', '[20000:39999]', '[40000:59999]'], + 'estimated_mae': [845.9611134332384, 781.4926674835554, 711.0517135412116], + 'estimated_mape': [0.23267452687056098, 0.24304465695708963, 0.25434522765463374], + 'estimated_mse': [1122878.6557536805, 996772.391655789, 860754.7329283138], + 'estimated_rmse': [1059.659688651824, 998.384891540226, 927.7686850332435], + 'estimated_msle': [0.07129931251850186, 0.08237610554658686, 0.09424921080964306], + 'estimated_rmsle': [0.2670193111340486, 0.2870123787340659, 0.30700034333798887], + } + ), + ), + ( + {'chunk_size': 20000, 'timestamp_column_name': 'timestamp'}, + pd.DataFrame( + { + 'key': ['[0:19999]', '[20000:39999]', '[40000:59999]'], + 'estimated_mae': [845.9611134332384, 781.4926674835554, 711.0517135412116], + 'estimated_mape': [0.23267452687056098, 0.24304465695708963, 0.25434522765463374], + 'estimated_mse': [1122878.6557536805, 996772.391655789, 860754.7329283138], + 'estimated_rmse': [1059.659688651824, 998.384891540226, 927.7686850332435], + 'estimated_msle': [0.07129931251850186, 0.08237610554658686, 0.09424921080964306], + 'estimated_rmsle': [0.2670193111340486, 0.2870123787340659, 0.30700034333798887], + } + ), + ), + ( + {'chunk_number': 4}, + pd.DataFrame( + { + 'key': ['[0:14999]', '[15000:29999]', '[30000:44999]', '[45000:59999]'], + 'estimated_mae': [847.3424959557017, 845.7373344640564, 711.5439127244119, 713.3835827998372], + 'estimated_mape': [ + 0.23214149556800917, + 0.23284470158214202, + 0.25471199464536814, + 0.2537210235141931, + ], + 'estimated_mse': [1128043.558035664, 1120912.7518854835, 859451.4065392805, 865466.657323283], + 'estimated_rmse': [1062.093949721805, 1058.7316713339048, 927.0660205936148, 930.3046045910355], + 'estimated_msle': [ + 0.07099665700945837, + 0.07144067806672186, + 0.09442423585275446, + 0.09370460090404097, + ], + 'estimated_rmsle': [ + 0.2664519788056722, + 0.26728389039880773, + 0.3072852678745834, + 0.30611207245719824, + ], + } + ), + ), + ( + {'chunk_number': 4, 'timestamp_column_name': 'timestamp'}, + pd.DataFrame( + { + 'key': ['[0:14999]', '[15000:29999]', '[30000:44999]', '[45000:59999]'], + 'estimated_mae': [847.3424959557017, 845.7373344640564, 711.5439127244119, 713.3835827998372], + 'estimated_mape': [ + 0.23214149556800917, + 0.23284470158214202, + 0.25471199464536814, + 0.2537210235141931, + ], + 'estimated_mse': [1128043.558035664, 1120912.7518854835, 859451.4065392805, 865466.657323283], + 'estimated_rmse': [1062.093949721805, 1058.7316713339048, 927.0660205936148, 930.3046045910355], + 'estimated_msle': [ + 0.07099665700945837, + 0.07144067806672186, + 0.09442423585275446, + 0.09370460090404097, + ], + 'estimated_rmsle': [ + 0.2664519788056722, + 0.26728389039880773, + 0.3072852678745834, + 0.30611207245719824, + ], + } + ), + ), + ( + {'chunk_period': 'M', 'timestamp_column_name': 'timestamp'}, + pd.DataFrame( + { + 'key': ['2017-02', '2017-03'], + 'estimated_mae': [839.9892842078503, 711.6793261625646], + 'estimated_mape': [0.2334885072120842, 0.25441754369504815], + 'estimated_mse': [1111996.0629481985, 860567.8087450431], + 'estimated_rmse': [1054.5122393543845, 927.6679409923807], + 'estimated_msle': [0.07234480460729872, 0.09418692237490388], + 'estimated_rmsle': [0.2689698953550354, 0.3068988797224647], + } + ), + ), + ( + {}, + pd.DataFrame( + { + 'key': [ + '[0:5999]', + '[6000:11999]', + '[12000:17999]', + '[18000:23999]', + '[24000:29999]', + '[30000:35999]', + '[36000:41999]', + '[42000:47999]', + '[48000:53999]', + '[54000:59999]', + ], + 'estimated_mae': [ + 849.5639685964373, + 848.7184466434562, + 842.7092353630159, + 849.0077257199042, + 842.7001997265812, + 715.7915405144998, + 714.4308784285377, + 711.8195599307788, + 714.5584926607276, + 705.7182672760792, + ], + 'estimated_mape': [ + 0.23107162545446974, + 0.23273121873466127, + 0.23346395434814515, + 0.23156279696192553, + 0.2336358973761762, + 0.25298911958498194, + 0.2547138215609582, + 0.253774827282499, + 0.2534991075167855, + 0.25610566945367863, + ], + 'estimated_mse': [ + 1139377.7053649914, + 1129419.217327195, + 1112040.535613815, + 1128981.1043172355, + 1112572.2121796315, + 865825.3950750288, + 865532.7479281372, + 862284.1323365847, + 869065.8827812723, + 849587.0015353857, + ], + 'estimated_rmse': [ + 1067.4163692603704, + 1062.741368973277, + 1054.533325985393, + 1062.5352249771465, + 1054.7853867871092, + 930.4973912241929, + 930.3401248619438, + 928.5925545343257, + 932.2370314363576, + 921.7304386507943, + ], + 'estimated_msle': [ + 0.0706372455894371, + 0.07116430399895682, + 0.07171993662227429, + 0.07056145753021863, + 0.07201039394956373, + 0.09311525247406541, + 0.09410921276805023, + 0.09402118738701887, + 0.09370361672238206, + 0.09537282254047202, + ], + 'estimated_rmsle': [ + 0.2657766836828187, + 0.26676638468697067, + 0.26780578153257684, + 0.26563406696095776, + 0.26834752458251543, + 0.30514791900661126, + 0.3067722490187961, + 0.3066287452066731, + 0.30611046490177696, + 0.30882490595881673, + ], + } + ), + ), + ( + {'timestamp_column_name': 'timestamp'}, + pd.DataFrame( + { + 'key': [ + '[0:5999]', + '[6000:11999]', + '[12000:17999]', + '[18000:23999]', + '[24000:29999]', + '[30000:35999]', + '[36000:41999]', + '[42000:47999]', + '[48000:53999]', + '[54000:59999]', + ], + 'estimated_mae': [ + 849.5639685964373, + 848.7184466434562, + 842.7092353630159, + 849.0077257199042, + 842.7001997265812, + 715.7915405144998, + 714.4308784285377, + 711.8195599307788, + 714.5584926607276, + 705.7182672760792, + ], + 'estimated_mape': [ + 0.23107162545446974, + 0.23273121873466127, + 0.23346395434814515, + 0.23156279696192553, + 0.2336358973761762, + 0.25298911958498194, + 0.2547138215609582, + 0.253774827282499, + 0.2534991075167855, + 0.25610566945367863, + ], + 'estimated_mse': [ + 1139377.7053649914, + 1129419.217327195, + 1112040.535613815, + 1128981.1043172355, + 1112572.2121796315, + 865825.3950750288, + 865532.7479281372, + 862284.1323365847, + 869065.8827812723, + 849587.0015353857, + ], + 'estimated_rmse': [ + 1067.4163692603704, + 1062.741368973277, + 1054.533325985393, + 1062.5352249771465, + 1054.7853867871092, + 930.4973912241929, + 930.3401248619438, + 928.5925545343257, + 932.2370314363576, + 921.7304386507943, + ], + 'estimated_msle': [ + 0.0706372455894371, + 0.07116430399895682, + 0.07171993662227429, + 0.07056145753021863, + 0.07201039394956373, + 0.09311525247406541, + 0.09410921276805023, + 0.09402118738701887, + 0.09370361672238206, + 0.09537282254047202, + ], + 'estimated_rmsle': [ + 0.2657766836828187, + 0.26676638468697067, + 0.26780578153257684, + 0.26563406696095776, + 0.26834752458251543, + 0.30514791900661126, + 0.3067722490187961, + 0.3066287452066731, + 0.30611046490177696, + 0.30882490595881673, + ], + } + ), + ), + ], + ids=[ + 'size_based_without_timestamp', + 'size_based_with_timestamp', + 'count_based_without_timestamp', + 'count_based_with_timestamp', + 'period_based_with_timestamp', + 'default_without_timestamp', + 'default_with_timestamp', + ], +) +def test_cbpe_for_binary_classification_with_timestamps(calculator_opts, expected): + ref_df, ana_df, _ = load_synthetic_car_price_dataset() + dle = DLE( + feature_column_names=[col for col in ref_df.columns if col not in ['timestamp', 'y_true', 'y_pred']], + y_pred='y_pred', + y_true='y_true', + metrics=['mae', 'mape', 'mse', 'msle', 'rmse', 'rmsle'], + **calculator_opts, + ).fit(ref_df) + sut = dle.estimate(ana_df).data + + pd.testing.assert_frame_equal( + expected, + sut[ + [ + 'key', + 'estimated_mae', + 'estimated_mape', + 'estimated_mse', + 'estimated_rmse', + 'estimated_msle', + 'estimated_rmsle', + ] + ], + ) diff --git a/tests/performance_estimation/test_base.py b/tests/performance_estimation/test_base.py index 2eaf5913..86045d5b 100644 --- a/tests/performance_estimation/test_base.py +++ b/tests/performance_estimation/test_base.py @@ -44,7 +44,7 @@ def _fit(self, reference_data: pd.DataFrame, *args, **kwargs): # noqa: D102 return self def _estimate(self, data: pd.DataFrame, *args, **kwargs) -> SimpleEstimatorResult: # noqa: D102 - chunks = self.chunker.split(data, timestamp_column_name='timestamp') + chunks = self.chunker.split(data) return SimpleEstimatorResult( results_data=pd.DataFrame(columns=data.columns).assign(key=[chunk.key for chunk in chunks]), calculator=self, diff --git a/tests/test_chunk.py b/tests/test_chunk.py index 09ce49ac..55951ee1 100644 --- a/tests/test_chunk.py +++ b/tests/test_chunk.py @@ -12,7 +12,7 @@ import pytest from pandas import Timestamp -from nannyml.chunk import Chunk, Chunker, CountBasedChunker, DefaultChunker, SizeBasedChunker +from nannyml.chunk import Chunk, Chunker, CountBasedChunker, DefaultChunker, PeriodBasedChunker, SizeBasedChunker from nannyml.exceptions import ChunkerException, InvalidArgumentsException rng = np.random.default_rng() @@ -142,17 +142,17 @@ def test_chunk_len_should_return_0_for_empty_chunk(): # noqa: D103 def test_chunker_should_log_warning_when_less_than_6_chunks(sample_chunk_data, caplog): # noqa: D103 class SimpleChunker(Chunker): - def _split(self, data: pd.DataFrame, timestamp_column_name: str, minimum_chunk_size: int = None) -> List[Chunk]: + def _split(self, data: pd.DataFrame) -> List[Chunk]: return [Chunk(key='row0', data=data)] c = SimpleChunker() with pytest.warns(UserWarning, match="The resulting number of chunks is too low."): - _ = c.split(sample_chunk_data, timestamp_column_name='timestamp') + _ = c.split(sample_chunk_data) def test_chunker_should_set_index_boundaries(sample_chunk_data): # noqa: D103 class SimpleChunker(Chunker): - def _split(self, data: pd.DataFrame, timestamp_column_name: str, minimum_chunk_size: int = None) -> List[Chunk]: + def _split(self, data: pd.DataFrame) -> List[Chunk]: return [ Chunk(key='[0:6665]', data=data.iloc[0:6666, :]), Chunk(key='[6666:13331]', data=data.iloc[6666:13332, :]), @@ -160,7 +160,7 @@ def _split(self, data: pd.DataFrame, timestamp_column_name: str, minimum_chunk_s ] chunker = SimpleChunker() - sut = chunker.split(data=sample_chunk_data, timestamp_column_name='timestamp') + sut = chunker.split(data=sample_chunk_data) assert sut[0].start_index == 0 assert sut[0].end_index == 6665 assert sut[1].start_index == 6666 @@ -171,22 +171,22 @@ def _split(self, data: pd.DataFrame, timestamp_column_name: str, minimum_chunk_s def test_chunker_should_include_all_data_columns_by_default(sample_chunk_data): # noqa: D103 class SimpleChunker(Chunker): - def _split(self, data: pd.DataFrame, timestamp_column_name: str, minimum_chunk_size: int = None) -> List[Chunk]: + def _split(self, data: pd.DataFrame) -> List[Chunk]: return [Chunk(key='row0', data=data)] c = SimpleChunker() - sut = c.split(sample_chunk_data, timestamp_column_name='timestamp')[0].data.columns + sut = c.split(sample_chunk_data)[0].data.columns assert sorted(sut) == sorted(sample_chunk_data.columns) def test_chunker_should_only_include_listed_columns_when_given_columns_param(sample_chunk_data): # noqa: D103 class SimpleChunker(Chunker): - def _split(self, data: pd.DataFrame, timestamp_column_name: str, minimum_chunk_size: int = None) -> List[Chunk]: + def _split(self, data: pd.DataFrame) -> List[Chunk]: return [Chunk(key='row0', data=data)] columns = ['f1', 'f3', 'period'] c = SimpleChunker() - sut = c.split(sample_chunk_data, columns=columns, timestamp_column_name='timestamp')[0].data.columns + sut = c.split(sample_chunk_data, columns=columns)[0].data.columns assert sorted(sut) == sorted(columns) @@ -194,32 +194,25 @@ def test_chunker_should_raise_chunker_exception_upon_exception_during_inherited_ sample_chunk_data, ): class SimpleChunker(Chunker): - def _split(self, data: pd.DataFrame, timestamp_column_name: str, minimum_chunk_size: int = None) -> List[Chunk]: + def _split(self, data: pd.DataFrame) -> List[Chunk]: raise RuntimeError("oops, I broke it again") c = SimpleChunker() with pytest.raises(ChunkerException): - _ = c.split(sample_chunk_data, timestamp_column_name='timestamp') - - -def test_chunker_should_fail_when_timestamp_column_not_provided(sample_chunk_data): # noqa: D103 - class SimpleChunker(Chunker): - def _split(self, data: pd.DataFrame, timestamp_column_name: str, minimum_chunk_size: int = None) -> List[Chunk]: - return [Chunk(key='row0', data=data)] - - c = SimpleChunker() - with pytest.raises(TypeError, match="'timestamp_column_name'"): - c.split(sample_chunk_data) + _ = c.split(sample_chunk_data) def test_chunker_should_fail_when_timestamp_column_is_not_present(sample_chunk_data): # noqa: D103 class SimpleChunker(Chunker): - def _split(self, data: pd.DataFrame, timestamp_column_name: str, minimum_chunk_size: int = None) -> List[Chunk]: + def __init__(self): + super().__init__(timestamp_column_name='foo') + + def _split(self, data: pd.DataFrame) -> List[Chunk]: return [Chunk(key='row0', data=data)] c = SimpleChunker() with pytest.raises(InvalidArgumentsException, match="timestamp column 'foo' not in columns"): - c.split(sample_chunk_data, timestamp_column_name='foo') + c.split(sample_chunk_data) def test_size_based_chunker_raises_exception_when_passed_nan_size(sample_chunk_data): # noqa: D103 @@ -239,16 +232,14 @@ def test_size_based_chunker_raises_exception_when_passed_zero_size(sample_chunk_ def test_size_based_chunker_works_with_empty_dataset(): # noqa: D103 chunker = SizeBasedChunker(chunk_size=100) - sut = chunker.split( - pd.DataFrame(columns=['date', 'timestamp', 'f1', 'f2', 'f3', 'f4']), timestamp_column_name='timestamp' - ) + sut = chunker.split(pd.DataFrame(columns=['date', 'timestamp', 'f1', 'f2', 'f3', 'f4'])) assert len(sut) == 0 def test_size_based_chunker_returns_chunks_of_required_size(sample_chunk_data): # noqa: D103 chunk_size = 1500 chunker = SizeBasedChunker(chunk_size=chunk_size) - sut = chunker.split(sample_chunk_data, timestamp_column_name='timestamp') + sut = chunker.split(sample_chunk_data) assert len(sut[0]) == chunk_size assert len(sut) == math.ceil(sample_chunk_data.shape[0] / chunk_size) @@ -257,7 +248,7 @@ def test_size_based_chunker_returns_last_chunk_that_is_partially_filled(sample_c chunk_size = 3333 expected_last_chunk_size = sample_chunk_data.shape[0] % chunk_size chunker = SizeBasedChunker(chunk_size) - sut = chunker.split(sample_chunk_data, timestamp_column_name='timestamp') + sut = chunker.split(sample_chunk_data) assert len(sut[-1]) == expected_last_chunk_size @@ -267,7 +258,7 @@ def test_size_based_chunker_works_when_data_set_is_multiple_of_chunk_size(sample chunker = SizeBasedChunker(chunk_size) sut = [] try: - sut = chunker.split(data, timestamp_column_name='timestamp') + sut = chunker.split(data) except Exception as exc: pytest.fail(f'an unexpected exception occurred: {exc}') @@ -279,13 +270,13 @@ def test_size_based_chunker_drops_last_incomplete_chunk_when_set_drop_incomplete ): chunk_size = 3333 chunker = SizeBasedChunker(chunk_size, drop_incomplete=True) - sut = chunker.split(sample_chunk_data, timestamp_column_name='timestamp') + sut = chunker.split(sample_chunk_data) assert len(sut[-1]) == chunk_size def test_size_based_chunker_uses_observations_to_set_chunk_date_boundaries(sample_chunk_data): # noqa: D103 - chunker = SizeBasedChunker(chunk_size=5000) - sut = chunker.split(sample_chunk_data, timestamp_column_name='timestamp') + chunker = SizeBasedChunker(chunk_size=5000, timestamp_column_name='timestamp') + sut = chunker.split(sample_chunk_data) assert sut[0].start_datetime == Timestamp(year=2020, month=1, day=6, hour=0, minute=0, second=0) assert sut[-1].end_datetime == Timestamp(year=2020, month=5, day=24, hour=23, minute=50, second=0) @@ -296,7 +287,7 @@ def test_size_based_chunker_assigns_observation_range_to_chunk_keys(sample_chunk last_chunk_end = sample_chunk_data.shape[0] - 1 chunker = SizeBasedChunker(chunk_size=chunk_size) - sut = chunker.split(sample_chunk_data, timestamp_column_name='timestamp') + sut = chunker.split(sample_chunk_data) assert sut[0].key == '[0:1499]' assert sut[1].key == '[1500:2999]' assert sut[-1].key == f'[{last_chunk_start}:{last_chunk_end}]' @@ -319,23 +310,21 @@ def test_count_based_chunker_raises_exception_when_passed_zero_size(sample_chunk def test_count_based_chunker_works_with_empty_dataset(): # noqa: D103 chunker = CountBasedChunker(chunk_count=5) - sut = chunker.split( - pd.DataFrame(columns=['date', 'timestamp', 'f1', 'f2', 'f3', 'f4']), timestamp_column_name='timestamp' - ) + sut = chunker.split(pd.DataFrame(columns=['date', 'timestamp', 'f1', 'f2', 'f3', 'f4'])) assert len(sut) == 0 def test_count_based_chunker_returns_chunks_of_required_size(sample_chunk_data): # noqa: D103 chunk_count = 5 chunker = CountBasedChunker(chunk_count=chunk_count) - sut = chunker.split(sample_chunk_data, timestamp_column_name='timestamp') + sut = chunker.split(sample_chunk_data) assert len(sut[0]) == sample_chunk_data.shape[0] // chunk_count assert len(sut) == chunk_count def test_count_based_chunker_uses_observations_to_set_chunk_date_boundaries(sample_chunk_data): # noqa: D103 - chunker = CountBasedChunker(chunk_count=20) - sut = chunker.split(sample_chunk_data, timestamp_column_name='timestamp') + chunker = CountBasedChunker(chunk_count=20, timestamp_column_name='timestamp') + sut = chunker.split(sample_chunk_data) assert sut[0].start_datetime == Timestamp(year=2020, month=1, day=6, hour=0, minute=0, second=0) assert sut[-1].end_datetime == Timestamp(year=2020, month=5, day=24, hour=23, minute=50, second=0) @@ -344,7 +333,7 @@ def test_count_based_chunker_assigns_observation_range_to_chunk_keys(sample_chun chunk_count = 5 chunker = CountBasedChunker(chunk_count=chunk_count) - sut = chunker.split(sample_chunk_data, timestamp_column_name='timestamp') + sut = chunker.split(sample_chunk_data) assert sut[0].key == '[0:4031]' assert sut[1].key == '[4032:8063]' assert sut[-1].key == '[16128:20159]' @@ -352,8 +341,36 @@ def test_count_based_chunker_assigns_observation_range_to_chunk_keys(sample_chun def test_default_chunker_splits_into_ten_chunks(sample_chunk_data): # noqa: D103 expected_size = sample_chunk_data.shape[0] / 10 - sut = DefaultChunker().split(sample_chunk_data, timestamp_column_name='timestamp') + sut = DefaultChunker().split(sample_chunk_data) assert len(sut) == 10 assert len(sut[0]) == expected_size assert len(sut[1]) == expected_size assert len(sut[-1]) == expected_size + + +@pytest.mark.parametrize( + 'chunker', [SizeBasedChunker(chunk_size=5000), CountBasedChunker(chunk_count=10), DefaultChunker()] +) +def test_size_based_chunker_without_timestamp_column_sets_date_boundaries_to_none( # noqa: D103 + sample_chunk_data, chunker +): + sut = chunker.split(sample_chunk_data) + assert all([chunk.start_datetime is None for chunk in sut]) + assert all([chunk.end_datetime is None for chunk in sut]) + + +@pytest.mark.parametrize( + 'chunker', + [ + SizeBasedChunker(chunk_size=5000), + CountBasedChunker(chunk_count=10), + DefaultChunker(), + SizeBasedChunker(chunk_size=5000, timestamp_column_name='timestamp'), + CountBasedChunker(chunk_count=10, timestamp_column_name='timestamp'), + DefaultChunker(timestamp_column_name='timestamp'), + PeriodBasedChunker(offset='W', timestamp_column_name='timestamp'), + ], +) +def test_size_based_chunker_sets_chunk_index(sample_chunk_data, chunker): # noqa: D103 + sut = chunker.split(sample_chunk_data) + assert all([chunk.chunk_index == chunk_index for chunk_index, chunk in enumerate(sut)]) diff --git a/tests/test_runner.py b/tests/test_runner.py new file mode 100644 index 00000000..9ac94279 --- /dev/null +++ b/tests/test_runner.py @@ -0,0 +1,124 @@ +# Author: Niels Nuyttens +# +# License: Apache Software License 2.0 +import tempfile + +import pytest + +from nannyml._typing import ProblemType +from nannyml.chunk import DefaultChunker +from nannyml.datasets import ( + load_synthetic_binary_classification_dataset, + load_synthetic_car_price_dataset, + load_synthetic_multiclass_classification_dataset, +) +from nannyml.io.file_writer import FileWriter +from nannyml.runner import run + + +@pytest.mark.parametrize('timestamp_column_name', [None, 'timestamp'], ids=['without_timestamp', 'with_timestamp']) +def test_runner_executes_for_binary_classification_without_exceptions(timestamp_column_name): + reference, analysis, analysis_targets = load_synthetic_binary_classification_dataset() + analysis_with_targets = analysis.merge(analysis_targets, on='identifier') + + try: + with tempfile.TemporaryDirectory() as tmpdir: + run( + reference_data=reference, + analysis_data=analysis_with_targets, + column_mapping={ + 'features': [ + 'distance_from_office', + 'salary_range', + 'gas_price_per_litre', + 'public_transportation_cost', + 'wfh_prev_workday', + 'workday', + 'tenure', + ], + 'y_pred': 'y_pred', + 'y_pred_proba': 'y_pred_proba', + 'y_true': 'work_home_actual', + 'timestamp': timestamp_column_name, + }, + problem_type=ProblemType.CLASSIFICATION_BINARY, + chunker=DefaultChunker(timestamp_column_name=timestamp_column_name), + writer=FileWriter(filepath=tmpdir, data_format='parquet'), + run_in_console=False, + ignore_errors=False, + ) + except Exception as exc: + pytest.fail(f"an unexpected exception occurred: {exc}") + + +@pytest.mark.parametrize('timestamp_column_name', [None, 'timestamp'], ids=['without_timestamp', 'with_timestamp']) +def test_runner_executes_for_multiclass_classification_without_exceptions(timestamp_column_name): + reference, analysis, analysis_targets = load_synthetic_multiclass_classification_dataset() + analysis_with_targets = analysis.merge(analysis_targets, on='identifier') + + try: + with tempfile.TemporaryDirectory() as tmpdir: + run( + reference_data=reference, + analysis_data=analysis_with_targets, + column_mapping={ + 'features': [ + 'acq_channel', + 'app_behavioral_score', + 'requested_credit_limit', + 'app_channel', + 'credit_bureau_score', + 'stated_income', + 'is_customer', + ], + 'y_pred': 'y_pred', + 'y_pred_proba': { + 'prepaid_card': 'y_pred_proba_prepaid_card', + 'highstreet_card': 'y_pred_proba_highstreet_card', + 'upmarket_card': 'y_pred_proba_upmarket_card', + }, + 'y_true': 'y_true', + 'timestamp': timestamp_column_name, + }, + problem_type=ProblemType.CLASSIFICATION_MULTICLASS, + chunker=DefaultChunker(timestamp_column_name=timestamp_column_name), + writer=FileWriter(filepath=tmpdir, data_format='parquet'), + run_in_console=False, + ignore_errors=False, + ) + except Exception as exc: + pytest.fail(f"an unexpected exception occurred: {exc}") + + +@pytest.mark.parametrize('timestamp_column_name', [None, 'timestamp'], ids=['without_timestamp', 'with_timestamp']) +def test_runner_executes_for_regression_without_exceptions(timestamp_column_name): + reference, analysis, analysis_targets = load_synthetic_car_price_dataset() + analysis_with_targets = analysis.join(analysis_targets) + + try: + with tempfile.TemporaryDirectory() as tmpdir: + run( + reference_data=reference, + analysis_data=analysis_with_targets, + column_mapping={ + 'features': [ + 'car_age', + 'km_driven', + 'price_new', + 'accident_count', + 'door_count', + 'transmission', + 'fuel', + ], + 'y_pred': 'y_pred', + 'y_true': 'y_true', + 'timestamp': timestamp_column_name, + }, + problem_type=ProblemType.REGRESSION, + chunker=DefaultChunker(timestamp_column_name=timestamp_column_name), + writer=FileWriter(filepath=tmpdir, data_format='parquet'), + run_in_console=False, + ignore_errors=False, + ) + except Exception as exc: + pytest.fail(f"an unexpected exception occurred: {exc}")