Skip to content

Commit

Permalink
Deal with single class in y_true for needs_calibration check
Browse files Browse the repository at this point in the history
  • Loading branch information
nnansters committed Feb 7, 2024
1 parent 224b2a0 commit 57a6d38
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 0 deletions.
5 changes: 5 additions & 0 deletions nannyml/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,11 @@ def needs_calibration(
# y_pred_proba = y_pred_proba.reset_index(drop=True)
# y_true = y_true.reset_index(drop=True)

# Check if we have a single class in y_true. This would crash the AUROC check below.
# If we do only have a single class in y_true, no calibration will be required.
if len(np.unique(y_true)) == 1:
return False

if roc_auc_score(y_true, y_pred_proba, multi_class='ovr') > 0.999:
return False

Expand Down
9 changes: 9 additions & 0 deletions tests/test_calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,15 @@ def test_needs_calibration_returns_true_when_calibration_always_improves_ece():
assert sut


def test_needs_calibration_returns_false_when_only_single_class_in_y_true(): # noqa: D103
y_true = pd.Series([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
y_pred_proba = abs(1 - y_true)
shuffled_indexes = np.random.permutation(len(y_true))
y_true, y_pred_proba = y_true[shuffled_indexes], y_pred_proba[shuffled_indexes]
sut = needs_calibration(y_true, y_pred_proba, IsotonicCalibrator())
assert sut is False


def test_needs_calibration_raises_invalid_args_exception_when_y_true_contains_nan(): # noqa: D103
y_true = pd.Series([0, 0, 0, 0, 0, np.NaN, 1, 1, 1, 1, 1, 1])
y_pred_proba = np.asarray([0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1])
Expand Down

0 comments on commit 57a6d38

Please sign in to comment.