From 03c580cca04b8a92148d7e34cd067d098b81daaf Mon Sep 17 00:00:00 2001 From: Nikolaos Perrakis Date: Tue, 9 Jul 2024 21:06:22 +0300 Subject: [PATCH 1/4] realized perf MC AUROC class handling --- .../metrics/binary_classification.py | 2 +- .../metrics/multiclass_classification.py | 78 +++++++++++-------- .../metrics/test_multiclass_classification.py | 56 +++++++++++-- 3 files changed, 97 insertions(+), 39 deletions(-) diff --git a/nannyml/performance_calculation/metrics/binary_classification.py b/nannyml/performance_calculation/metrics/binary_classification.py index c70033ad..b28d08ff 100644 --- a/nannyml/performance_calculation/metrics/binary_classification.py +++ b/nannyml/performance_calculation/metrics/binary_classification.py @@ -721,7 +721,6 @@ def __init__( Name(s) of the column(s) containing your model output. For binary classification, pass a single string refering to the model output column. """ - if normalize_business_value not in [None, "per_prediction"]: raise InvalidArgumentsException( f"normalize_business_value must be None or 'per_prediction', but got {normalize_business_value}" @@ -863,6 +862,7 @@ def __init__( self._sampling_error_components: Tuple = () def __str__(self): + """Get string representation of metric.""" return "confusion_matrix" def fit(self, reference_data: pd.DataFrame, chunker: Chunker): diff --git a/nannyml/performance_calculation/metrics/multiclass_classification.py b/nannyml/performance_calculation/metrics/multiclass_classification.py index 75bceec6..ad5fbabc 100644 --- a/nannyml/performance_calculation/metrics/multiclass_classification.py +++ b/nannyml/performance_calculation/metrics/multiclass_classification.py @@ -19,7 +19,7 @@ ) from sklearn.preprocessing import LabelBinarizer, label_binarize -from nannyml._typing import ProblemType, class_labels, model_output_column_names +from nannyml._typing import ProblemType, class_labels from nannyml.base import _list_missing, common_nan_removal from nannyml.chunk import Chunker from nannyml.exceptions import InvalidArgumentsException @@ -84,8 +84,16 @@ def __init__( upper_threshold_limit=1, components=[("ROC AUC", "roc_auc")], ) - # FIXME: Should we check the y_pred_proba argument here to ensure it's a dict? self.y_pred_proba: Dict[str, str] + # Move check here, since we have all the info we need for checking. + if not isinstance(self.y_pred_proba, Dict): + raise InvalidArgumentsException( + f"'y_pred_proba' is of type {type(self.y_pred_proba)}\n" + "multiclass use cases require 'y_pred_proba' to be a dictionary mapping classes to columns." + ) + # classes and class probability columns + self.classes: List[str] = [""] + self.class_probability_columns: List[str] # sampling error self._sampling_error_components: List[Tuple] = [] @@ -95,61 +103,65 @@ def __str__(self): return "roc_auc" def _fit(self, reference_data: pd.DataFrame): - classes = class_labels(self.y_pred_proba) - class_y_pred_proba_columns = model_output_column_names(self.y_pred_proba) - _list_missing([self.y_true] + class_y_pred_proba_columns, list(reference_data.columns)) + # set up sorted classes and prob_column_names to use across metric class + self.classes = class_labels(self.y_pred_proba) + self.class_probability_columns = [self.y_pred_proba[clazz] for clazz in self.classes] + + _list_missing([self.y_true] + self.class_probability_columns, list(reference_data.columns)) reference_data, empty = common_nan_removal( - reference_data[[self.y_true] + class_y_pred_proba_columns], [self.y_true] + class_y_pred_proba_columns + reference_data[[self.y_true] + self.class_probability_columns], + [self.y_true] + self.class_probability_columns ) if empty: - self._sampling_error_components = [(np.NaN, 0) for class_col in class_y_pred_proba_columns] + self._sampling_error_components = [(np.NaN, 0) for clasz in self.classes] + # TODO: Ideally we would also raise an error here! else: + # test if reference data are represented correctly + observed_classes = set(reference_data[self.y_true].unique()) + if not observed_classes == set(self.classes): + self._logger.error( + "The specified classification classes are not the same as the classes observed in the reference" + "targets." + ) + raise InvalidArgumentsException( + "y_pred_proba class and class probabilities dictionary does not match reference data.") + # sampling error - binarized_y_true = list(label_binarize(reference_data[self.y_true], classes=classes).T) - y_pred_proba = [reference_data[self.y_pred_proba[clazz]].T for clazz in classes] + binarized_y_true = list(label_binarize(reference_data[self.y_true], classes=self.classes).T) + y_pred_proba = [reference_data[self.y_pred_proba[clazz]].T for clazz in self.classes] self._sampling_error_components = auroc_sampling_error_components( y_true_reference=binarized_y_true, y_pred_proba_reference=y_pred_proba ) def _calculate(self, data: pd.DataFrame): - if not isinstance(self.y_pred_proba, Dict): - raise InvalidArgumentsException( - f"'y_pred_proba' is of type {type(self.y_pred_proba)}\n" - f"multiclass use cases require 'y_pred_proba' to " - "be a dictionary mapping classes to columns." - ) - - class_y_pred_proba_columns = model_output_column_names(self.y_pred_proba) - _list_missing([self.y_true] + class_y_pred_proba_columns, data) + _list_missing([self.y_true] + self.class_probability_columns, data) data, empty = common_nan_removal( - data[[self.y_true] + class_y_pred_proba_columns], [self.y_true] + class_y_pred_proba_columns + data[[self.y_true] + self.class_probability_columns], [self.y_true] + self.class_probability_columns ) if empty: - warnings.warn(f"Too many missing values, cannot calculate {self.display_name}. " f"Returning NaN.") + _message = f"Too many missing values, cannot calculate {self.display_name}. " f"Returning NaN." + self._logger.warning(_message) + warnings.warn(_message) return np.NaN - labels, class_probability_columns = [], [] - for label in sorted(list(self.y_pred_proba.keys())): - labels.append(label) - class_probability_columns.append(self.y_pred_proba[label]) - y_true = data[self.y_true] - y_pred_proba = data[class_probability_columns] + y_pred_proba = data[self.class_probability_columns] - if y_true.nunique() <= 1: - warnings.warn( - f"'{self.y_true}' only contains a single class for chunk, cannot calculate {self.display_name}. " + if set(y_true.unique()) != set(self.classes): + _message = ( + f"'{self.y_true}' does not contain all reported classes, cannot calculate {self.display_name}. " "Returning NaN." ) + warnings.warn(_message) + self._logger.warning(_message) return np.NaN else: - return roc_auc_score(y_true, y_pred_proba, multi_class='ovr', average='macro', labels=labels) + return roc_auc_score(y_true, y_pred_proba, multi_class='ovr', average='macro', labels=self.classes) def _sampling_error(self, data: pd.DataFrame) -> float: - class_y_pred_proba_columns = model_output_column_names(self.y_pred_proba) - _list_missing([self.y_true] + class_y_pred_proba_columns, data) + _list_missing([self.y_true] + self.class_probability_columns, data) data, empty = common_nan_removal( - data[[self.y_true] + class_y_pred_proba_columns], [self.y_true] + class_y_pred_proba_columns + data[[self.y_true] + self.class_probability_columns], [self.y_true] + self.class_probability_columns ) if empty: warnings.warn( diff --git a/tests/performance_calculation/metrics/test_multiclass_classification.py b/tests/performance_calculation/metrics/test_multiclass_classification.py index 6a4d1379..1b3a9649 100644 --- a/tests/performance_calculation/metrics/test_multiclass_classification.py +++ b/tests/performance_calculation/metrics/test_multiclass_classification.py @@ -1,15 +1,13 @@ -# Author: Niels Nuyttens -# # # License: Apache Software License 2.0 # Author: Niels Nuyttens -# -# License: Apache Software License 2.0 + """Unit tests for performance metrics.""" from typing import Tuple import pandas as pd import pytest +from logging import getLogger from nannyml import PerformanceCalculator from nannyml._typing import ProblemType @@ -27,6 +25,8 @@ ) from nannyml.thresholds import ConstantThreshold, StandardDeviationThreshold +LOGGER = getLogger(__name__) + @pytest.fixture(scope='module') def multiclass_data() -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: # noqa: D103 @@ -94,7 +94,7 @@ def no_timestamp_metrics(performance_calculator, multiclass_data) -> pd.DataFram def test_metric_factory_returns_correct_metric_given_key_and_problem_type(key, problem_type, metric): # noqa: D103 calc = PerformanceCalculator( timestamp_column_name='timestamp', - y_pred_proba='y_pred_proba', + y_pred_proba={'class1': 'y_pred_proba1', 'class2': 'y_pred_proba2', 'class3': 'y_pred_proba3'}, y_pred='y_pred', y_true='y_true', metrics=['roc_auc', 'f1'], @@ -229,3 +229,49 @@ def test_metric_logs_warning_when_upper_threshold_is_overridden_by_metric_limits f'{metric.display_name} upper threshold value 2 overridden by ' f'upper threshold value limit {metric.upper_threshold_value_limit}' in caplog.messages ) + + +def test_auroc_errors_out_when_not_all_classes_are_represented_reference(multiclass_data, caplog): + LOGGER.info("testing test_auroc_errors_out_when_not_all_classes_are_represented_reference") + reference, _, _ = multiclass_data + reference['y_pred_proba_clazz'] = reference['y_pred_proba_upmarket_card'] + performance_calculator = PerformanceCalculator( + y_pred_proba={ + 'prepaid_card': 'y_pred_proba_prepaid_card', + 'highstreet_card': 'y_pred_proba_highstreet_card', + 'upmarket_card': 'y_pred_proba_upmarket_card', + 'clazz': 'y_pred_proba_clazz' + }, + y_pred='y_pred', + y_true='y_true', + metrics=['roc_auc'], + problem_type='classification_multiclass', + ) + performance_calculator.fit(reference) + expected_exc_test = "y_pred_proba class and class probabilities dictionary does not match reference data." + assert expected_exc_test in caplog.text + + +def test_auroc_errors_out_when_not_all_classes_are_represented_chunk(multiclass_data, caplog): + LOGGER.info("testing test_auroc_errors_out_when_not_all_classes_are_represented_chunk") + reference, monitored, targets = multiclass_data + monitored = monitored.merge(targets) + reference['y_pred_proba_clazz'] = reference['y_pred_proba_upmarket_card'] + monitored['y_pred_proba_clazz'] = monitored['y_pred_proba_upmarket_card'] + reference['y_true'].iloc[-1000:] = 'clazz' + performance_calculator = PerformanceCalculator( + y_pred_proba={ + 'prepaid_card': 'y_pred_proba_prepaid_card', + 'highstreet_card': 'y_pred_proba_highstreet_card', + 'upmarket_card': 'y_pred_proba_upmarket_card', + 'clazz': 'y_pred_proba_clazz' + }, + y_pred='y_pred', + y_true='y_true', + metrics=['roc_auc'], + problem_type='classification_multiclass', + ) + performance_calculator.fit(reference) + _ = performance_calculator.calculate(monitored) + expected_exc_test = "does not contain all reported classes, cannot calculate" + assert expected_exc_test in caplog.text From b441a07e618eb585c78c8af3aed7ec835a6b70d6 Mon Sep 17 00:00:00 2001 From: Nikolaos Perrakis Date: Tue, 9 Jul 2024 23:39:51 +0300 Subject: [PATCH 2/4] add CBPE MC AUROC class checks --- .../metrics/multiclass_classification.py | 16 +++-- .../confidence_based/cbpe.py | 7 ++- .../confidence_based/metrics.py | 58 +++++++++++-------- .../CBPE/test_cbpe_metrics.py | 56 ++++++++++++++++++ 4 files changed, 100 insertions(+), 37 deletions(-) diff --git a/nannyml/performance_calculation/metrics/multiclass_classification.py b/nannyml/performance_calculation/metrics/multiclass_classification.py index ad5fbabc..76d21ebf 100644 --- a/nannyml/performance_calculation/metrics/multiclass_classification.py +++ b/nannyml/performance_calculation/metrics/multiclass_classification.py @@ -85,17 +85,8 @@ def __init__( components=[("ROC AUC", "roc_auc")], ) self.y_pred_proba: Dict[str, str] - # Move check here, since we have all the info we need for checking. - if not isinstance(self.y_pred_proba, Dict): - raise InvalidArgumentsException( - f"'y_pred_proba' is of type {type(self.y_pred_proba)}\n" - "multiclass use cases require 'y_pred_proba' to be a dictionary mapping classes to columns." - ) - # classes and class probability columns self.classes: List[str] = [""] self.class_probability_columns: List[str] - - # sampling error self._sampling_error_components: List[Tuple] = [] def __str__(self): @@ -134,6 +125,13 @@ def _fit(self, reference_data: pd.DataFrame): ) def _calculate(self, data: pd.DataFrame): + if not isinstance(self.y_pred_proba, Dict): + raise InvalidArgumentsException( + f"'y_pred_proba' is of type {type(self.y_pred_proba)}\n" + f"multiclass use cases require 'y_pred_proba' to " + "be a dictionary mapping classes to columns." + ) + _list_missing([self.y_true] + self.class_probability_columns, data) data, empty = common_nan_removal( data[[self.y_true] + self.class_probability_columns], [self.y_true] + self.class_probability_columns diff --git a/nannyml/performance_estimation/confidence_based/cbpe.py b/nannyml/performance_estimation/confidence_based/cbpe.py index 545f19a5..9739841c 100644 --- a/nannyml/performance_estimation/confidence_based/cbpe.py +++ b/nannyml/performance_estimation/confidence_based/cbpe.py @@ -541,11 +541,12 @@ def _fit_calibrators( noop_calibrator = NoopCalibrator() for clazz, y_true, y_pred_proba in _get_class_splits(reference_data, y_true_col, y_pred_proba_col): + _calibrator = copy.deepcopy(calibrator) if not needs_calibration(np.asarray(y_true), np.asarray(y_pred_proba), calibrator): - calibrator = noop_calibrator + _calibrator = noop_calibrator - calibrator.fit(y_pred_proba, y_true) - fitted_calibrators[clazz] = copy.deepcopy(calibrator) + _calibrator.fit(y_pred_proba, y_true) + fitted_calibrators[clazz] = copy.deepcopy(_calibrator) return fitted_calibrators diff --git a/nannyml/performance_estimation/confidence_based/metrics.py b/nannyml/performance_estimation/confidence_based/metrics.py index 4596ec33..e7e0fde6 100644 --- a/nannyml/performance_estimation/confidence_based/metrics.py +++ b/nannyml/performance_estimation/confidence_based/metrics.py @@ -2327,36 +2327,43 @@ def __init__( threshold=threshold, components=[('ROC AUC', 'roc_auc')], ) - # FIXME: Should we check the y_pred_proba argument here to ensure it's a dict? self.y_pred_proba: Dict[str, str] - - # sampling error + self.classes: List[str] = [""] + self.class_probability_columns: List[str] + self.class_uncalibrated_y_pred_proba_columns: List[str] self._sampling_error_components: List[Tuple] = [] def _fit(self, reference_data: pd.DataFrame): - classes = class_labels(self.y_pred_proba) - class_y_pred_proba_columns = model_output_column_names(self.y_pred_proba) - class_uncalibrated_y_pred_proba_columns = ['uncalibrated_' + el for el in class_y_pred_proba_columns] - _list_missing([self.y_true] + class_uncalibrated_y_pred_proba_columns, list(reference_data.columns)) + self.classes = class_labels(self.y_pred_proba) + self.class_probability_columns = [self.y_pred_proba[clazz] for clazz in self.classes] + self.class_uncalibrated_y_pred_proba_columns = ['uncalibrated_' + el for el in self.class_probability_columns] + _list_missing([self.y_true] + self.class_uncalibrated_y_pred_proba_columns, list(reference_data.columns)) # filter nans here reference_data, empty = common_nan_removal( - reference_data[[self.y_true] + class_uncalibrated_y_pred_proba_columns], - [self.y_true] + class_uncalibrated_y_pred_proba_columns, + reference_data[[self.y_true] + self.class_uncalibrated_y_pred_proba_columns], + [self.y_true] + self.class_uncalibrated_y_pred_proba_columns, ) if empty: - self._sampling_error_components = [(np.NaN, 0) for class_col in class_y_pred_proba_columns] + self._sampling_error_components = [(np.NaN, 0) for clasz in self.classes] else: + # test if reference data are represented correctly + observed_classes = set(reference_data[self.y_true].unique()) + if not observed_classes == set(self.classes): + self._logger.error( + "The specified classification classes are not the same as the classes observed in the reference" + "targets." + ) + raise InvalidArgumentsException( + "y_pred_proba class and class probabilities dictionary does not match reference data.") # sampling error - binarized_y_true = list(label_binarize(reference_data[self.y_true], classes=classes).T) - y_pred_proba = [reference_data['uncalibrated_' + self.y_pred_proba[clazz]].T for clazz in classes] + binarized_y_true = list(label_binarize(reference_data[self.y_true], classes=self.classes).T) + y_pred_proba = [reference_data['uncalibrated_' + self.y_pred_proba[clazz]].T for clazz in self.classes] self._sampling_error_components = mse.auroc_sampling_error_components( y_true_reference=binarized_y_true, y_pred_proba_reference=y_pred_proba ) def _estimate(self, data: pd.DataFrame): - class_y_pred_proba_columns = model_output_column_names(self.y_pred_proba) - class_uncalibrated_y_pred_proba_columns = ['uncalibrated_' + el for el in class_y_pred_proba_columns] - needed_columns = class_y_pred_proba_columns + class_uncalibrated_y_pred_proba_columns + needed_columns = self.class_probability_columns + self.class_uncalibrated_y_pred_proba_columns try: _list_missing(needed_columns, list(data.columns)) except InvalidArgumentsException as ex: @@ -2390,9 +2397,7 @@ def _estimate(self, data: pd.DataFrame): return multiclass_roc_auc def _sampling_error(self, data: pd.DataFrame) -> float: - class_y_pred_proba_columns = model_output_column_names(self.y_pred_proba) - class_uncalibrated_y_pred_proba_columns = ['uncalibrated_' + el for el in class_y_pred_proba_columns] - needed_columns = class_y_pred_proba_columns + class_uncalibrated_y_pred_proba_columns + needed_columns = self.class_probability_columns + self.class_uncalibrated_y_pred_proba_columns _list_missing(needed_columns, data) data, empty = common_nan_removal(data[needed_columns], needed_columns) if empty: @@ -2404,10 +2409,8 @@ def _sampling_error(self, data: pd.DataFrame) -> float: return mse.auroc_sampling_error(self._sampling_error_components, data) def _realized_performance(self, data: pd.DataFrame) -> float: - class_y_pred_proba_columns = model_output_column_names(self.y_pred_proba) - class_uncalibrated_y_pred_proba_columns = ['uncalibrated_' + el for el in class_y_pred_proba_columns] try: - _list_missing([self.y_true] + class_uncalibrated_y_pred_proba_columns, data) + _list_missing([self.y_true] + self.class_uncalibrated_y_pred_proba_columns, data) except InvalidArgumentsException as ex: if "missing required columns" in str(ex): self._logger.debug(str(ex)) @@ -2415,14 +2418,19 @@ def _realized_performance(self, data: pd.DataFrame) -> float: else: raise ex - data, empty = common_nan_removal(data, [self.y_true] + class_uncalibrated_y_pred_proba_columns) + data, empty = common_nan_removal(data, [self.y_true] + self.class_uncalibrated_y_pred_proba_columns) if empty: warnings.warn(f"Too many missing values, cannot calculate {self.display_name}. " f"Returning NaN.") return np.NaN y_true = data[self.y_true] - if y_true.nunique() <= 1: - warnings.warn("Too few unique values present in 'y_true', returning NaN as realized ROC-AUC.") + if set(y_true.unique()) != set(self.classes): + _message = ( + f"'{self.y_true}' does not contain all reported classes, cannot calculate {self.display_name}. " + "Returning NaN." + ) + warnings.warn(_message) + self._logger.warning(_message) return np.NaN _, y_pred_probas, labels = _get_multiclass_uncalibrated_predictions(data, self.y_pred, self.y_pred_proba) @@ -3158,7 +3166,7 @@ def _multi_class_confusion_matrix_realized_performance(self, data: pd.DataFrame) warnings.warn( f"Too few unique values present in 'y_pred', returning NaN as realized {self.display_name} score." ) - return nan_array + return nan_array cm = confusion_matrix( data[self.y_true], data[self.y_pred], labels=self.classes, normalize=self.normalize_confusion_matrix diff --git a/tests/performance_estimation/CBPE/test_cbpe_metrics.py b/tests/performance_estimation/CBPE/test_cbpe_metrics.py index c2ae06cb..5335c101 100644 --- a/tests/performance_estimation/CBPE/test_cbpe_metrics.py +++ b/tests/performance_estimation/CBPE/test_cbpe_metrics.py @@ -3,6 +3,7 @@ import pandas as pd import numpy as np import pytest +from logging import getLogger from nannyml.chunk import DefaultChunker, SizeBasedChunker from nannyml.datasets import ( @@ -21,6 +22,9 @@ BinaryClassificationSpecificity, ) from nannyml.thresholds import ConstantThreshold +from nannyml.exceptions import InvalidArgumentsException + +LOGGER = getLogger(__name__) @pytest.mark.parametrize( @@ -3580,3 +3584,55 @@ def test_cbpe_for_multiclass_classification_cm_with_nans(calculator_opts, realiz 'realized_true_upmarket_card_pred_upmarket_card', ] pd.testing.assert_frame_equal(realized, sut) + + +def test_auroc_errors_out_when_not_all_classes_are_represented_reference(): + reference, _, _ = load_synthetic_multiclass_classification_dataset() + reference['y_pred_proba_clazz'] = reference['y_pred_proba_upmarket_card'] + calc = CBPE( + y_pred_proba={ + 'prepaid_card': 'y_pred_proba_prepaid_card', + 'highstreet_card': 'y_pred_proba_highstreet_card', + 'upmarket_card': 'y_pred_proba_upmarket_card', + 'clazz': 'y_pred_proba_clazz' + }, + y_pred='y_pred', + y_true='y_true', + metrics=['roc_auc'], + problem_type='classification_multiclass', + ) + expected_exc_test = "y_pred_proba class and class probabilities dictionary does not match reference data." + with pytest.raises(InvalidArgumentsException, match=expected_exc_test): + calc.fit(reference) + + +def test_auroc_errors_out_when_not_all_classes_are_represented_chunk(caplog): + LOGGER.info("testing test_auroc_errors_out_when_not_all_classes_are_represented_chunk") + reference, monitored, targets = load_synthetic_multiclass_classification_dataset() + monitored = monitored.merge(targets) + # Uncalibrated probabilities need to sum up to 1 per row. + reference['y_pred_proba_clazz'] = 0.1 + reference['y_pred_proba_prepaid_card'] = 0.9 * reference['y_pred_proba_prepaid_card'] + reference['y_pred_proba_highstreet_card'] = 0.9 * reference['y_pred_proba_highstreet_card'] + reference['y_pred_proba_upmarket_card'] = 0.9 * reference['y_pred_proba_upmarket_card'] + monitored['y_pred_proba_clazz'] = 0.1 + monitored['y_pred_proba_prepaid_card'] = 0.9 * monitored['y_pred_proba_prepaid_card'] + monitored['y_pred_proba_highstreet_card'] = 0.9 * monitored['y_pred_proba_highstreet_card'] + monitored['y_pred_proba_upmarket_card'] = 0.9 * monitored['y_pred_proba_upmarket_card'] + reference['y_true'].iloc[-1000:] = 'clazz' + calc = CBPE( + y_pred_proba={ + 'prepaid_card': 'y_pred_proba_prepaid_card', + 'highstreet_card': 'y_pred_proba_highstreet_card', + 'upmarket_card': 'y_pred_proba_upmarket_card', + 'clazz': 'y_pred_proba_clazz' + }, + y_pred='y_pred', + y_true='y_true', + metrics=['roc_auc'], + problem_type='classification_multiclass', + ) + calc.fit(reference) + _ = calc.estimate(monitored) + expected_exc_test = "does not contain all reported classes, cannot calculate" + assert expected_exc_test in caplog.text From deeacacaef5e81f195a960d2dc4dd08247847f4b Mon Sep 17 00:00:00 2001 From: Nikolaos Perrakis Date: Tue, 9 Jul 2024 21:06:22 +0300 Subject: [PATCH 3/4] realized perf MC AUROC class handling --- .../metrics/binary_classification.py | 2 +- .../metrics/multiclass_classification.py | 78 +++++++++++-------- .../metrics/test_multiclass_classification.py | 56 +++++++++++-- 3 files changed, 97 insertions(+), 39 deletions(-) diff --git a/nannyml/performance_calculation/metrics/binary_classification.py b/nannyml/performance_calculation/metrics/binary_classification.py index c70033ad..b28d08ff 100644 --- a/nannyml/performance_calculation/metrics/binary_classification.py +++ b/nannyml/performance_calculation/metrics/binary_classification.py @@ -721,7 +721,6 @@ def __init__( Name(s) of the column(s) containing your model output. For binary classification, pass a single string refering to the model output column. """ - if normalize_business_value not in [None, "per_prediction"]: raise InvalidArgumentsException( f"normalize_business_value must be None or 'per_prediction', but got {normalize_business_value}" @@ -863,6 +862,7 @@ def __init__( self._sampling_error_components: Tuple = () def __str__(self): + """Get string representation of metric.""" return "confusion_matrix" def fit(self, reference_data: pd.DataFrame, chunker: Chunker): diff --git a/nannyml/performance_calculation/metrics/multiclass_classification.py b/nannyml/performance_calculation/metrics/multiclass_classification.py index 75bceec6..ad5fbabc 100644 --- a/nannyml/performance_calculation/metrics/multiclass_classification.py +++ b/nannyml/performance_calculation/metrics/multiclass_classification.py @@ -19,7 +19,7 @@ ) from sklearn.preprocessing import LabelBinarizer, label_binarize -from nannyml._typing import ProblemType, class_labels, model_output_column_names +from nannyml._typing import ProblemType, class_labels from nannyml.base import _list_missing, common_nan_removal from nannyml.chunk import Chunker from nannyml.exceptions import InvalidArgumentsException @@ -84,8 +84,16 @@ def __init__( upper_threshold_limit=1, components=[("ROC AUC", "roc_auc")], ) - # FIXME: Should we check the y_pred_proba argument here to ensure it's a dict? self.y_pred_proba: Dict[str, str] + # Move check here, since we have all the info we need for checking. + if not isinstance(self.y_pred_proba, Dict): + raise InvalidArgumentsException( + f"'y_pred_proba' is of type {type(self.y_pred_proba)}\n" + "multiclass use cases require 'y_pred_proba' to be a dictionary mapping classes to columns." + ) + # classes and class probability columns + self.classes: List[str] = [""] + self.class_probability_columns: List[str] # sampling error self._sampling_error_components: List[Tuple] = [] @@ -95,61 +103,65 @@ def __str__(self): return "roc_auc" def _fit(self, reference_data: pd.DataFrame): - classes = class_labels(self.y_pred_proba) - class_y_pred_proba_columns = model_output_column_names(self.y_pred_proba) - _list_missing([self.y_true] + class_y_pred_proba_columns, list(reference_data.columns)) + # set up sorted classes and prob_column_names to use across metric class + self.classes = class_labels(self.y_pred_proba) + self.class_probability_columns = [self.y_pred_proba[clazz] for clazz in self.classes] + + _list_missing([self.y_true] + self.class_probability_columns, list(reference_data.columns)) reference_data, empty = common_nan_removal( - reference_data[[self.y_true] + class_y_pred_proba_columns], [self.y_true] + class_y_pred_proba_columns + reference_data[[self.y_true] + self.class_probability_columns], + [self.y_true] + self.class_probability_columns ) if empty: - self._sampling_error_components = [(np.NaN, 0) for class_col in class_y_pred_proba_columns] + self._sampling_error_components = [(np.NaN, 0) for clasz in self.classes] + # TODO: Ideally we would also raise an error here! else: + # test if reference data are represented correctly + observed_classes = set(reference_data[self.y_true].unique()) + if not observed_classes == set(self.classes): + self._logger.error( + "The specified classification classes are not the same as the classes observed in the reference" + "targets." + ) + raise InvalidArgumentsException( + "y_pred_proba class and class probabilities dictionary does not match reference data.") + # sampling error - binarized_y_true = list(label_binarize(reference_data[self.y_true], classes=classes).T) - y_pred_proba = [reference_data[self.y_pred_proba[clazz]].T for clazz in classes] + binarized_y_true = list(label_binarize(reference_data[self.y_true], classes=self.classes).T) + y_pred_proba = [reference_data[self.y_pred_proba[clazz]].T for clazz in self.classes] self._sampling_error_components = auroc_sampling_error_components( y_true_reference=binarized_y_true, y_pred_proba_reference=y_pred_proba ) def _calculate(self, data: pd.DataFrame): - if not isinstance(self.y_pred_proba, Dict): - raise InvalidArgumentsException( - f"'y_pred_proba' is of type {type(self.y_pred_proba)}\n" - f"multiclass use cases require 'y_pred_proba' to " - "be a dictionary mapping classes to columns." - ) - - class_y_pred_proba_columns = model_output_column_names(self.y_pred_proba) - _list_missing([self.y_true] + class_y_pred_proba_columns, data) + _list_missing([self.y_true] + self.class_probability_columns, data) data, empty = common_nan_removal( - data[[self.y_true] + class_y_pred_proba_columns], [self.y_true] + class_y_pred_proba_columns + data[[self.y_true] + self.class_probability_columns], [self.y_true] + self.class_probability_columns ) if empty: - warnings.warn(f"Too many missing values, cannot calculate {self.display_name}. " f"Returning NaN.") + _message = f"Too many missing values, cannot calculate {self.display_name}. " f"Returning NaN." + self._logger.warning(_message) + warnings.warn(_message) return np.NaN - labels, class_probability_columns = [], [] - for label in sorted(list(self.y_pred_proba.keys())): - labels.append(label) - class_probability_columns.append(self.y_pred_proba[label]) - y_true = data[self.y_true] - y_pred_proba = data[class_probability_columns] + y_pred_proba = data[self.class_probability_columns] - if y_true.nunique() <= 1: - warnings.warn( - f"'{self.y_true}' only contains a single class for chunk, cannot calculate {self.display_name}. " + if set(y_true.unique()) != set(self.classes): + _message = ( + f"'{self.y_true}' does not contain all reported classes, cannot calculate {self.display_name}. " "Returning NaN." ) + warnings.warn(_message) + self._logger.warning(_message) return np.NaN else: - return roc_auc_score(y_true, y_pred_proba, multi_class='ovr', average='macro', labels=labels) + return roc_auc_score(y_true, y_pred_proba, multi_class='ovr', average='macro', labels=self.classes) def _sampling_error(self, data: pd.DataFrame) -> float: - class_y_pred_proba_columns = model_output_column_names(self.y_pred_proba) - _list_missing([self.y_true] + class_y_pred_proba_columns, data) + _list_missing([self.y_true] + self.class_probability_columns, data) data, empty = common_nan_removal( - data[[self.y_true] + class_y_pred_proba_columns], [self.y_true] + class_y_pred_proba_columns + data[[self.y_true] + self.class_probability_columns], [self.y_true] + self.class_probability_columns ) if empty: warnings.warn( diff --git a/tests/performance_calculation/metrics/test_multiclass_classification.py b/tests/performance_calculation/metrics/test_multiclass_classification.py index 6a4d1379..1b3a9649 100644 --- a/tests/performance_calculation/metrics/test_multiclass_classification.py +++ b/tests/performance_calculation/metrics/test_multiclass_classification.py @@ -1,15 +1,13 @@ -# Author: Niels Nuyttens -# # # License: Apache Software License 2.0 # Author: Niels Nuyttens -# -# License: Apache Software License 2.0 + """Unit tests for performance metrics.""" from typing import Tuple import pandas as pd import pytest +from logging import getLogger from nannyml import PerformanceCalculator from nannyml._typing import ProblemType @@ -27,6 +25,8 @@ ) from nannyml.thresholds import ConstantThreshold, StandardDeviationThreshold +LOGGER = getLogger(__name__) + @pytest.fixture(scope='module') def multiclass_data() -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: # noqa: D103 @@ -94,7 +94,7 @@ def no_timestamp_metrics(performance_calculator, multiclass_data) -> pd.DataFram def test_metric_factory_returns_correct_metric_given_key_and_problem_type(key, problem_type, metric): # noqa: D103 calc = PerformanceCalculator( timestamp_column_name='timestamp', - y_pred_proba='y_pred_proba', + y_pred_proba={'class1': 'y_pred_proba1', 'class2': 'y_pred_proba2', 'class3': 'y_pred_proba3'}, y_pred='y_pred', y_true='y_true', metrics=['roc_auc', 'f1'], @@ -229,3 +229,49 @@ def test_metric_logs_warning_when_upper_threshold_is_overridden_by_metric_limits f'{metric.display_name} upper threshold value 2 overridden by ' f'upper threshold value limit {metric.upper_threshold_value_limit}' in caplog.messages ) + + +def test_auroc_errors_out_when_not_all_classes_are_represented_reference(multiclass_data, caplog): + LOGGER.info("testing test_auroc_errors_out_when_not_all_classes_are_represented_reference") + reference, _, _ = multiclass_data + reference['y_pred_proba_clazz'] = reference['y_pred_proba_upmarket_card'] + performance_calculator = PerformanceCalculator( + y_pred_proba={ + 'prepaid_card': 'y_pred_proba_prepaid_card', + 'highstreet_card': 'y_pred_proba_highstreet_card', + 'upmarket_card': 'y_pred_proba_upmarket_card', + 'clazz': 'y_pred_proba_clazz' + }, + y_pred='y_pred', + y_true='y_true', + metrics=['roc_auc'], + problem_type='classification_multiclass', + ) + performance_calculator.fit(reference) + expected_exc_test = "y_pred_proba class and class probabilities dictionary does not match reference data." + assert expected_exc_test in caplog.text + + +def test_auroc_errors_out_when_not_all_classes_are_represented_chunk(multiclass_data, caplog): + LOGGER.info("testing test_auroc_errors_out_when_not_all_classes_are_represented_chunk") + reference, monitored, targets = multiclass_data + monitored = monitored.merge(targets) + reference['y_pred_proba_clazz'] = reference['y_pred_proba_upmarket_card'] + monitored['y_pred_proba_clazz'] = monitored['y_pred_proba_upmarket_card'] + reference['y_true'].iloc[-1000:] = 'clazz' + performance_calculator = PerformanceCalculator( + y_pred_proba={ + 'prepaid_card': 'y_pred_proba_prepaid_card', + 'highstreet_card': 'y_pred_proba_highstreet_card', + 'upmarket_card': 'y_pred_proba_upmarket_card', + 'clazz': 'y_pred_proba_clazz' + }, + y_pred='y_pred', + y_true='y_true', + metrics=['roc_auc'], + problem_type='classification_multiclass', + ) + performance_calculator.fit(reference) + _ = performance_calculator.calculate(monitored) + expected_exc_test = "does not contain all reported classes, cannot calculate" + assert expected_exc_test in caplog.text From 4a02e0218a094ac95336d01bc937f12b9cb47695 Mon Sep 17 00:00:00 2001 From: Nikolaos Perrakis Date: Tue, 9 Jul 2024 23:39:51 +0300 Subject: [PATCH 4/4] add CBPE MC AUROC class checks --- .../metrics/multiclass_classification.py | 16 +++-- .../confidence_based/cbpe.py | 7 ++- .../confidence_based/metrics.py | 58 +++++++++++-------- .../CBPE/test_cbpe_metrics.py | 56 ++++++++++++++++++ 4 files changed, 100 insertions(+), 37 deletions(-) diff --git a/nannyml/performance_calculation/metrics/multiclass_classification.py b/nannyml/performance_calculation/metrics/multiclass_classification.py index ad5fbabc..76d21ebf 100644 --- a/nannyml/performance_calculation/metrics/multiclass_classification.py +++ b/nannyml/performance_calculation/metrics/multiclass_classification.py @@ -85,17 +85,8 @@ def __init__( components=[("ROC AUC", "roc_auc")], ) self.y_pred_proba: Dict[str, str] - # Move check here, since we have all the info we need for checking. - if not isinstance(self.y_pred_proba, Dict): - raise InvalidArgumentsException( - f"'y_pred_proba' is of type {type(self.y_pred_proba)}\n" - "multiclass use cases require 'y_pred_proba' to be a dictionary mapping classes to columns." - ) - # classes and class probability columns self.classes: List[str] = [""] self.class_probability_columns: List[str] - - # sampling error self._sampling_error_components: List[Tuple] = [] def __str__(self): @@ -134,6 +125,13 @@ def _fit(self, reference_data: pd.DataFrame): ) def _calculate(self, data: pd.DataFrame): + if not isinstance(self.y_pred_proba, Dict): + raise InvalidArgumentsException( + f"'y_pred_proba' is of type {type(self.y_pred_proba)}\n" + f"multiclass use cases require 'y_pred_proba' to " + "be a dictionary mapping classes to columns." + ) + _list_missing([self.y_true] + self.class_probability_columns, data) data, empty = common_nan_removal( data[[self.y_true] + self.class_probability_columns], [self.y_true] + self.class_probability_columns diff --git a/nannyml/performance_estimation/confidence_based/cbpe.py b/nannyml/performance_estimation/confidence_based/cbpe.py index 545f19a5..9739841c 100644 --- a/nannyml/performance_estimation/confidence_based/cbpe.py +++ b/nannyml/performance_estimation/confidence_based/cbpe.py @@ -541,11 +541,12 @@ def _fit_calibrators( noop_calibrator = NoopCalibrator() for clazz, y_true, y_pred_proba in _get_class_splits(reference_data, y_true_col, y_pred_proba_col): + _calibrator = copy.deepcopy(calibrator) if not needs_calibration(np.asarray(y_true), np.asarray(y_pred_proba), calibrator): - calibrator = noop_calibrator + _calibrator = noop_calibrator - calibrator.fit(y_pred_proba, y_true) - fitted_calibrators[clazz] = copy.deepcopy(calibrator) + _calibrator.fit(y_pred_proba, y_true) + fitted_calibrators[clazz] = copy.deepcopy(_calibrator) return fitted_calibrators diff --git a/nannyml/performance_estimation/confidence_based/metrics.py b/nannyml/performance_estimation/confidence_based/metrics.py index 4596ec33..e7e0fde6 100644 --- a/nannyml/performance_estimation/confidence_based/metrics.py +++ b/nannyml/performance_estimation/confidence_based/metrics.py @@ -2327,36 +2327,43 @@ def __init__( threshold=threshold, components=[('ROC AUC', 'roc_auc')], ) - # FIXME: Should we check the y_pred_proba argument here to ensure it's a dict? self.y_pred_proba: Dict[str, str] - - # sampling error + self.classes: List[str] = [""] + self.class_probability_columns: List[str] + self.class_uncalibrated_y_pred_proba_columns: List[str] self._sampling_error_components: List[Tuple] = [] def _fit(self, reference_data: pd.DataFrame): - classes = class_labels(self.y_pred_proba) - class_y_pred_proba_columns = model_output_column_names(self.y_pred_proba) - class_uncalibrated_y_pred_proba_columns = ['uncalibrated_' + el for el in class_y_pred_proba_columns] - _list_missing([self.y_true] + class_uncalibrated_y_pred_proba_columns, list(reference_data.columns)) + self.classes = class_labels(self.y_pred_proba) + self.class_probability_columns = [self.y_pred_proba[clazz] for clazz in self.classes] + self.class_uncalibrated_y_pred_proba_columns = ['uncalibrated_' + el for el in self.class_probability_columns] + _list_missing([self.y_true] + self.class_uncalibrated_y_pred_proba_columns, list(reference_data.columns)) # filter nans here reference_data, empty = common_nan_removal( - reference_data[[self.y_true] + class_uncalibrated_y_pred_proba_columns], - [self.y_true] + class_uncalibrated_y_pred_proba_columns, + reference_data[[self.y_true] + self.class_uncalibrated_y_pred_proba_columns], + [self.y_true] + self.class_uncalibrated_y_pred_proba_columns, ) if empty: - self._sampling_error_components = [(np.NaN, 0) for class_col in class_y_pred_proba_columns] + self._sampling_error_components = [(np.NaN, 0) for clasz in self.classes] else: + # test if reference data are represented correctly + observed_classes = set(reference_data[self.y_true].unique()) + if not observed_classes == set(self.classes): + self._logger.error( + "The specified classification classes are not the same as the classes observed in the reference" + "targets." + ) + raise InvalidArgumentsException( + "y_pred_proba class and class probabilities dictionary does not match reference data.") # sampling error - binarized_y_true = list(label_binarize(reference_data[self.y_true], classes=classes).T) - y_pred_proba = [reference_data['uncalibrated_' + self.y_pred_proba[clazz]].T for clazz in classes] + binarized_y_true = list(label_binarize(reference_data[self.y_true], classes=self.classes).T) + y_pred_proba = [reference_data['uncalibrated_' + self.y_pred_proba[clazz]].T for clazz in self.classes] self._sampling_error_components = mse.auroc_sampling_error_components( y_true_reference=binarized_y_true, y_pred_proba_reference=y_pred_proba ) def _estimate(self, data: pd.DataFrame): - class_y_pred_proba_columns = model_output_column_names(self.y_pred_proba) - class_uncalibrated_y_pred_proba_columns = ['uncalibrated_' + el for el in class_y_pred_proba_columns] - needed_columns = class_y_pred_proba_columns + class_uncalibrated_y_pred_proba_columns + needed_columns = self.class_probability_columns + self.class_uncalibrated_y_pred_proba_columns try: _list_missing(needed_columns, list(data.columns)) except InvalidArgumentsException as ex: @@ -2390,9 +2397,7 @@ def _estimate(self, data: pd.DataFrame): return multiclass_roc_auc def _sampling_error(self, data: pd.DataFrame) -> float: - class_y_pred_proba_columns = model_output_column_names(self.y_pred_proba) - class_uncalibrated_y_pred_proba_columns = ['uncalibrated_' + el for el in class_y_pred_proba_columns] - needed_columns = class_y_pred_proba_columns + class_uncalibrated_y_pred_proba_columns + needed_columns = self.class_probability_columns + self.class_uncalibrated_y_pred_proba_columns _list_missing(needed_columns, data) data, empty = common_nan_removal(data[needed_columns], needed_columns) if empty: @@ -2404,10 +2409,8 @@ def _sampling_error(self, data: pd.DataFrame) -> float: return mse.auroc_sampling_error(self._sampling_error_components, data) def _realized_performance(self, data: pd.DataFrame) -> float: - class_y_pred_proba_columns = model_output_column_names(self.y_pred_proba) - class_uncalibrated_y_pred_proba_columns = ['uncalibrated_' + el for el in class_y_pred_proba_columns] try: - _list_missing([self.y_true] + class_uncalibrated_y_pred_proba_columns, data) + _list_missing([self.y_true] + self.class_uncalibrated_y_pred_proba_columns, data) except InvalidArgumentsException as ex: if "missing required columns" in str(ex): self._logger.debug(str(ex)) @@ -2415,14 +2418,19 @@ def _realized_performance(self, data: pd.DataFrame) -> float: else: raise ex - data, empty = common_nan_removal(data, [self.y_true] + class_uncalibrated_y_pred_proba_columns) + data, empty = common_nan_removal(data, [self.y_true] + self.class_uncalibrated_y_pred_proba_columns) if empty: warnings.warn(f"Too many missing values, cannot calculate {self.display_name}. " f"Returning NaN.") return np.NaN y_true = data[self.y_true] - if y_true.nunique() <= 1: - warnings.warn("Too few unique values present in 'y_true', returning NaN as realized ROC-AUC.") + if set(y_true.unique()) != set(self.classes): + _message = ( + f"'{self.y_true}' does not contain all reported classes, cannot calculate {self.display_name}. " + "Returning NaN." + ) + warnings.warn(_message) + self._logger.warning(_message) return np.NaN _, y_pred_probas, labels = _get_multiclass_uncalibrated_predictions(data, self.y_pred, self.y_pred_proba) @@ -3158,7 +3166,7 @@ def _multi_class_confusion_matrix_realized_performance(self, data: pd.DataFrame) warnings.warn( f"Too few unique values present in 'y_pred', returning NaN as realized {self.display_name} score." ) - return nan_array + return nan_array cm = confusion_matrix( data[self.y_true], data[self.y_pred], labels=self.classes, normalize=self.normalize_confusion_matrix diff --git a/tests/performance_estimation/CBPE/test_cbpe_metrics.py b/tests/performance_estimation/CBPE/test_cbpe_metrics.py index c2ae06cb..5335c101 100644 --- a/tests/performance_estimation/CBPE/test_cbpe_metrics.py +++ b/tests/performance_estimation/CBPE/test_cbpe_metrics.py @@ -3,6 +3,7 @@ import pandas as pd import numpy as np import pytest +from logging import getLogger from nannyml.chunk import DefaultChunker, SizeBasedChunker from nannyml.datasets import ( @@ -21,6 +22,9 @@ BinaryClassificationSpecificity, ) from nannyml.thresholds import ConstantThreshold +from nannyml.exceptions import InvalidArgumentsException + +LOGGER = getLogger(__name__) @pytest.mark.parametrize( @@ -3580,3 +3584,55 @@ def test_cbpe_for_multiclass_classification_cm_with_nans(calculator_opts, realiz 'realized_true_upmarket_card_pred_upmarket_card', ] pd.testing.assert_frame_equal(realized, sut) + + +def test_auroc_errors_out_when_not_all_classes_are_represented_reference(): + reference, _, _ = load_synthetic_multiclass_classification_dataset() + reference['y_pred_proba_clazz'] = reference['y_pred_proba_upmarket_card'] + calc = CBPE( + y_pred_proba={ + 'prepaid_card': 'y_pred_proba_prepaid_card', + 'highstreet_card': 'y_pred_proba_highstreet_card', + 'upmarket_card': 'y_pred_proba_upmarket_card', + 'clazz': 'y_pred_proba_clazz' + }, + y_pred='y_pred', + y_true='y_true', + metrics=['roc_auc'], + problem_type='classification_multiclass', + ) + expected_exc_test = "y_pred_proba class and class probabilities dictionary does not match reference data." + with pytest.raises(InvalidArgumentsException, match=expected_exc_test): + calc.fit(reference) + + +def test_auroc_errors_out_when_not_all_classes_are_represented_chunk(caplog): + LOGGER.info("testing test_auroc_errors_out_when_not_all_classes_are_represented_chunk") + reference, monitored, targets = load_synthetic_multiclass_classification_dataset() + monitored = monitored.merge(targets) + # Uncalibrated probabilities need to sum up to 1 per row. + reference['y_pred_proba_clazz'] = 0.1 + reference['y_pred_proba_prepaid_card'] = 0.9 * reference['y_pred_proba_prepaid_card'] + reference['y_pred_proba_highstreet_card'] = 0.9 * reference['y_pred_proba_highstreet_card'] + reference['y_pred_proba_upmarket_card'] = 0.9 * reference['y_pred_proba_upmarket_card'] + monitored['y_pred_proba_clazz'] = 0.1 + monitored['y_pred_proba_prepaid_card'] = 0.9 * monitored['y_pred_proba_prepaid_card'] + monitored['y_pred_proba_highstreet_card'] = 0.9 * monitored['y_pred_proba_highstreet_card'] + monitored['y_pred_proba_upmarket_card'] = 0.9 * monitored['y_pred_proba_upmarket_card'] + reference['y_true'].iloc[-1000:] = 'clazz' + calc = CBPE( + y_pred_proba={ + 'prepaid_card': 'y_pred_proba_prepaid_card', + 'highstreet_card': 'y_pred_proba_highstreet_card', + 'upmarket_card': 'y_pred_proba_upmarket_card', + 'clazz': 'y_pred_proba_clazz' + }, + y_pred='y_pred', + y_true='y_true', + metrics=['roc_auc'], + problem_type='classification_multiclass', + ) + calc.fit(reference) + _ = calc.estimate(monitored) + expected_exc_test = "does not contain all reported classes, cannot calculate" + assert expected_exc_test in caplog.text