Skip to content

Commit

Permalink
Fix accuracy estimation using CBPE for multiclass cases (#346) (#347)
Browse files Browse the repository at this point in the history
  • Loading branch information
nnansters authored Nov 28, 2023
1 parent b31f877 commit 44b47ca
Showing 1 changed file with 5 additions and 7 deletions.
12 changes: 5 additions & 7 deletions nannyml/performance_estimation/confidence_based/cbpe.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,16 +509,14 @@ def _get_class_splits(
data: pd.DataFrame, y_true: str, y_pred_proba: Dict[str, str], include_targets: bool = True
) -> List[Tuple]:
classes = sorted(y_pred_proba.keys())
y_trues: List[np.ndarray] = []
y_trues: Dict[str, np.ndarray] = {}

if include_targets:
y_trues = list(label_binarize(data[y_true], classes=classes).T)
y_trues = {classes[idx]: (label_binarize(data[y_true], classes=classes).T[idx]) for idx in range(len(classes))}

y_pred_probas = [data[y_pred_proba[clazz]] for clazz in classes]
y_pred_probas = {clazz: data[y_pred_proba[clazz]] for clazz in classes}

return [
(classes[idx], y_trues[idx] if include_targets else None, y_pred_probas[idx]) for idx in range(len(classes))
]
return [(cls, y_trues[cls] if include_targets else None, y_pred_probas[cls]) for cls in classes]


def _fit_calibrators(
Expand Down Expand Up @@ -556,7 +554,7 @@ def _calibrate_predicted_probabilities(
calibrated_probas = np.divide(calibrated_probas, denominator, out=uniform_proba, where=denominator != 0)

calibrated_data = data.copy(deep=True)
predicted_class_proba_column_names = sorted([v for k, v in y_pred_proba.items()])
predicted_class_proba_column_names = [y_pred_proba[cls] for cls in sorted(y_pred_proba.keys())]
for idx in range(number_of_classes):
calibrated_data[predicted_class_proba_column_names[idx]] = calibrated_probas[:, idx]

Expand Down

0 comments on commit 44b47ca

Please sign in to comment.