Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions sklearn/metrics/_ranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -1019,7 +1019,8 @@ def roc_curve(
thresholds : ndarray of shape = (n_thresholds,)
Decreasing thresholds on the decision function used to compute
fpr and tpr. `thresholds[0]` represents no instances being predicted
and is arbitrarily set to `max(y_score) + 1`.
and is set to the next representable value above ``max(y_score)`` for
floating scores (or ``max(y_score) + 1`` otherwise).

See Also
--------
Expand Down Expand Up @@ -1083,7 +1084,11 @@ def roc_curve(
# to make sure that the curve starts at (0, 0)
tps = np.r_[0, tps]
fps = np.r_[0, fps]
thresholds = np.r_[thresholds[0] + 1, thresholds]
if np.issubdtype(thresholds.dtype, np.floating):
prepend_thresh = np.nextafter(thresholds[0], np.inf, dtype=thresholds.dtype)
else:
prepend_thresh = thresholds[0] + 1
thresholds = np.r_[prepend_thresh, thresholds]

if fps[-1] <= 0:
warnings.warn(
Expand Down
44 changes: 44 additions & 0 deletions sklearn/metrics/tests/test_ranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ def test_roc_curve_end_points():
assert fpr.shape == thr.shape



def test_roc_returns_consistency():
# Test whether the returned threshold matches up with tpr
# make small toy dataset
Expand Down Expand Up @@ -2199,3 +2200,46 @@ def test_ranking_metric_pos_label_types(metric, classes):
assert not np.isnan(metric_1).any()
assert not np.isnan(metric_2).any()
assert not np.isnan(thresholds).any()


def test_roc_curve_thresholds_probabilities_below_one():
# Probabilistic scores below 1.0 should prepend the smallest greater float
y_true = np.array([0, 0, 1, 1])
y_score = np.array([0.75, 0.2, 0.33, 0.9], dtype=np.float64)

fpr, tpr, thresholds = roc_curve(y_true, y_score, drop_intermediate=False)

max_score = y_score.max()
expected_prepend = np.nextafter(max_score, np.inf)

assert thresholds[0] == pytest.approx(expected_prepend)
assert thresholds[0] > max_score
assert np.all(np.diff(thresholds) <= 0)
assert thresholds.shape == fpr.shape == tpr.shape


def test_roc_curve_thresholds_probability_one():
# The prepend threshold for a max score of 1.0 should be nextafter(1.0)
y_true = np.array([0, 1, 0, 1])
y_score = np.array([1.0, 0.6, 0.4, 0.8], dtype=np.float64)

fpr, tpr, thresholds = roc_curve(y_true, y_score, drop_intermediate=False)

expected_prepend = np.nextafter(1.0, np.inf, dtype=thresholds.dtype)

assert thresholds[0] == pytest.approx(expected_prepend)
assert thresholds[1] == pytest.approx(1.0)
assert np.all(np.diff(thresholds) <= 0)
assert thresholds.shape == fpr.shape == tpr.shape


def test_roc_curve_thresholds_integer_scores():
# Non-probability integer scores should continue to prepend max(score) + 1
y_true = np.array([0, 1, 0, 1])
y_score = np.array([2, 3, 1, 0], dtype=np.int64)

fpr, tpr, thresholds = roc_curve(y_true, y_score, drop_intermediate=False)

assert thresholds[0] == y_score.max() + 1
assert np.all(np.diff(thresholds) <= 0)
assert thresholds.shape == fpr.shape == tpr.shape