From 44b47ca5270570613ae77895d2f2f4c9e8c030f6 Mon Sep 17 00:00:00 2001 From: Niels <94110348+nnansters@users.noreply.github.com> Date: Tue, 28 Nov 2023 14:12:04 +0100 Subject: [PATCH] Fix accuracy estimation using CBPE for multiclass cases (#346) (#347) --- .../performance_estimation/confidence_based/cbpe.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/nannyml/performance_estimation/confidence_based/cbpe.py b/nannyml/performance_estimation/confidence_based/cbpe.py index 75429b7b..c9167307 100644 --- a/nannyml/performance_estimation/confidence_based/cbpe.py +++ b/nannyml/performance_estimation/confidence_based/cbpe.py @@ -509,16 +509,14 @@ def _get_class_splits( data: pd.DataFrame, y_true: str, y_pred_proba: Dict[str, str], include_targets: bool = True ) -> List[Tuple]: classes = sorted(y_pred_proba.keys()) - y_trues: List[np.ndarray] = [] + y_trues: Dict[str, np.ndarray] = {} if include_targets: - y_trues = list(label_binarize(data[y_true], classes=classes).T) + y_trues = {classes[idx]: (label_binarize(data[y_true], classes=classes).T[idx]) for idx in range(len(classes))} - y_pred_probas = [data[y_pred_proba[clazz]] for clazz in classes] + y_pred_probas = {clazz: data[y_pred_proba[clazz]] for clazz in classes} - return [ - (classes[idx], y_trues[idx] if include_targets else None, y_pred_probas[idx]) for idx in range(len(classes)) - ] + return [(cls, y_trues[cls] if include_targets else None, y_pred_probas[cls]) for cls in classes] def _fit_calibrators( @@ -556,7 +554,7 @@ def _calibrate_predicted_probabilities( calibrated_probas = np.divide(calibrated_probas, denominator, out=uniform_proba, where=denominator != 0) calibrated_data = data.copy(deep=True) - predicted_class_proba_column_names = sorted([v for k, v in y_pred_proba.items()]) + predicted_class_proba_column_names = [y_pred_proba[cls] for cls in sorted(y_pred_proba.keys())] for idx in range(number_of_classes): calibrated_data[predicted_class_proba_column_names[idx]] = calibrated_probas[:, idx]