|
1 | 1 | """Functions for computing Receiver Operating Characteristic (ROC) curves."""
|
2 | 2 | import warnings
|
3 |
| -from typing import List, Literal, Optional, Tuple, Union |
| 3 | +from typing import List, Literal, NamedTuple, Optional, Tuple, Union |
4 | 4 |
|
5 | 5 | import array_api_compat as apc
|
6 | 6 |
|
|
28 | 28 | from cyclops.evaluate.metrics.experimental.utils.types import Array
|
29 | 29 |
|
30 | 30 |
|
| 31 | +class ROCCurve(NamedTuple): |
| 32 | + """Named tuple to store ROC curve (FPR, TPR and thresholds).""" |
| 33 | + |
| 34 | + fpr: Union[Array, List[Array]] |
| 35 | + tpr: Union[Array, List[Array]] |
| 36 | + thresholds: Union[Array, List[Array]] |
| 37 | + |
| 38 | + |
31 | 39 | def _binary_roc_compute(
|
32 | 40 | state: Union[Array, Tuple[Array, Array]],
|
33 | 41 | thresholds: Optional[Array],
|
@@ -91,7 +99,7 @@ def binary_roc(
|
91 | 99 | preds: Array,
|
92 | 100 | thresholds: Optional[Union[int, List[float], Array]] = None,
|
93 | 101 | ignore_index: Optional[int] = None,
|
94 |
| -) -> Tuple[Array, Array, Array]: |
| 102 | +) -> ROCCurve: |
95 | 103 | """Compute the receiver operating characteristic (ROC) curve for binary tasks.
|
96 | 104 |
|
97 | 105 | Parameters
|
@@ -120,15 +128,11 @@ def binary_roc(
|
120 | 128 |
|
121 | 129 | Returns
|
122 | 130 | -------
|
123 |
| - fpr : Array |
124 |
| - The false positive rates for all unique thresholds. The shape of the array is |
125 |
| - `(num_thresholds + 1,)`. |
126 |
| - tpr : Array |
127 |
| - The true positive rates for all unique thresholds. The shape of the array is |
128 |
| - `(num_thresholds + 1,)`. |
129 |
| - thresholds : Array |
130 |
| - The thresholds used for computing the ROC curve values, in descending order. |
131 |
| - The shape of the array is `(num_thresholds,)`. |
| 131 | + ROCCurve |
| 132 | + A named tuple containing the false positive rate (FPR), true positive rate |
| 133 | + (TPR) and thresholds. The FPR and TPR are arrays of of shape |
| 134 | + `(num_thresholds + 1,)` and the thresholds are an array of shape |
| 135 | + `(num_thresholds,)`. |
132 | 136 |
|
133 | 137 | Raises
|
134 | 138 | ------
|
@@ -209,7 +213,8 @@ def binary_roc(
|
209 | 213 | xp=xp,
|
210 | 214 | )
|
211 | 215 | state = _binary_precision_recall_curve_update(target, preds, thresholds, xp=xp)
|
212 |
| - return _binary_roc_compute(state, thresholds) |
| 216 | + fpr, tpr, thresh = _binary_roc_compute(state, thresholds) |
| 217 | + return ROCCurve(fpr, tpr, thresh) |
213 | 218 |
|
214 | 219 |
|
215 | 220 | def _multiclass_roc_compute(
|
@@ -277,7 +282,7 @@ def multiclass_roc(
|
277 | 282 | thresholds: Optional[Union[int, List[float], Array]] = None,
|
278 | 283 | average: Optional[Literal["macro", "micro", "none"]] = None,
|
279 | 284 | ignore_index: Optional[Union[int, Tuple[int]]] = None,
|
280 |
| -) -> Union[Tuple[Array, Array, Array], Tuple[List[Array], List[Array], List[Array]]]: |
| 285 | +) -> ROCCurve: |
281 | 286 | """Compute the receiver operating characteristic (ROC) curve for multiclass tasks.
|
282 | 287 |
|
283 | 288 | Parameters
|
@@ -318,19 +323,13 @@ def multiclass_roc(
|
318 | 323 |
|
319 | 324 | Returns
|
320 | 325 | -------
|
321 |
| - fpr : Array or List[Array] |
322 |
| - The false positive rates for all unique thresholds. If `thresholds` is `"none"` |
323 |
| - or `None`, a list for each class is returned with 1-D Arrays of shape |
324 |
| - `(num_thresholds + 1,)`. Otherwise, a 2-D Array of shape |
| 326 | + ROCCurve |
| 327 | + A named tuple that contains the false positive rate, true positive rate, |
| 328 | + and the thresholds used for computing the ROC curve. If `thresholds` is `"none"` |
| 329 | + or `None`, a list of TPRs and FPRs for each class is returned with 1-D Arrays |
| 330 | + of shape `(num_thresholds + 1,)`. Otherwise, a 2-D Array of shape |
325 | 331 | `(num_thresholds + 1, num_classes)` is returned.
|
326 |
| - tpr : Array or List[Array] |
327 |
| - The true positive rates for all unique thresholds. If `thresholds` is `"none"` |
328 |
| - or `None`, a list for each class is returned with 1-D Arrays of shape |
329 |
| - `(num_thresholds + 1,)`. Otherwise, a 2-D Array of shape |
330 |
| - `(num_thresholds + 1, num_classes)` is returned. |
331 |
| - thresholds : Array or List[Array] |
332 |
| - The thresholds used for computing the ROC curve values, in descending order. |
333 |
| - If `thresholds` is `"none"` or `None`, a list for each class is returned |
| 332 | + Similarly, a list of thresholds for each class is returned |
334 | 333 | with 1-D Arrays of shape `(num_thresholds,)`. Otherwise, a 1-D Array of
|
335 | 334 | shape `(num_thresholds,)` is returned.
|
336 | 335 |
|
@@ -455,12 +454,13 @@ def multiclass_roc(
|
455 | 454 | average,
|
456 | 455 | xp=xp,
|
457 | 456 | )
|
458 |
| - return _multiclass_roc_compute( |
| 457 | + fpr_, tpr_, thresholds_ = _multiclass_roc_compute( |
459 | 458 | state,
|
460 | 459 | num_classes,
|
461 | 460 | thresholds=thresholds,
|
462 | 461 | average=average,
|
463 | 462 | )
|
| 463 | + return ROCCurve(fpr_, tpr_, thresholds_) |
464 | 464 |
|
465 | 465 |
|
466 | 466 | def _multilabel_roc_compute(
|
@@ -504,7 +504,7 @@ def multilabel_roc(
|
504 | 504 | num_labels: int,
|
505 | 505 | thresholds: Optional[Union[int, List[float], Array]] = None,
|
506 | 506 | ignore_index: Optional[int] = None,
|
507 |
| -) -> Union[Tuple[Array, Array, Array], Tuple[List[Array], List[Array], List[Array]]]: |
| 507 | +) -> ROCCurve: |
508 | 508 | """Compute the receiver operating characteristic (ROC) curve for multilabel tasks.
|
509 | 509 |
|
510 | 510 | Parameters
|
@@ -535,21 +535,15 @@ def multilabel_roc(
|
535 | 535 |
|
536 | 536 | Returns
|
537 | 537 | -------
|
538 |
| - fpr : Array or List[Array] |
539 |
| - The false positive rates for all unique thresholds. If `thresholds` is `None`, |
540 |
| - a list for each label is returned with 1-D Arrays of shape |
541 |
| - `(num_thresholds + 1,)`. Otherwise, a 2-D Array of shape |
542 |
| - `(num_thresholds + 1, num_labels)` is returned. |
543 |
| - tpr : Array or List[Array] |
544 |
| - The true positive rates for all unique thresholds. If `thresholds` is `None`, |
545 |
| - a list for each label is returned with 1-D Arrays of shape |
546 |
| - `(num_thresholds + 1,)`. Otherwise, a 2-D Array of shape |
547 |
| - `(num_thresholds + 1, num_labels)` is returned. |
548 |
| - thresholds : Array or List[Array] |
549 |
| - The thresholds used for computing the ROC curve values, in |
550 |
| - descending order. If `thresholds` is `None`, a list for each label is |
551 |
| - returned with 1-D Arrays of shape `(num_thresholds,)`. Otherwise, a 1-D |
552 |
| - Array of shape `(num_thresholds,)` is returned. |
| 538 | + ROCCurve |
| 539 | + A named tuple that contains the false positive rate, true positive rate, |
| 540 | + and the thresholds used for computing the ROC curve. If `thresholds` is `"none"` |
| 541 | + or `None`, a list of TPRs and FPRs for each class is returned with 1-D Arrays |
| 542 | + of shape `(num_thresholds + 1,)`. Otherwise, a 2-D Array of shape |
| 543 | + `(num_thresholds + 1, num_classes)` is returned. |
| 544 | + Similarly, a list of thresholds for each class is returned |
| 545 | + with 1-D Arrays of shape `(num_thresholds,)`. Otherwise, a 1-D Array of |
| 546 | + shape `(num_thresholds,)` is returned. |
553 | 547 |
|
554 | 548 | Raises
|
555 | 549 | ------
|
@@ -660,9 +654,10 @@ def multilabel_roc(
|
660 | 654 | thresholds,
|
661 | 655 | xp=xp,
|
662 | 656 | )
|
663 |
| - return _multilabel_roc_compute( |
| 657 | + fpr_, tpr_, thresholds_ = _multilabel_roc_compute( |
664 | 658 | state,
|
665 | 659 | num_labels,
|
666 | 660 | thresholds,
|
667 | 661 | ignore_index,
|
668 | 662 | )
|
| 663 | + return ROCCurve(fpr_, tpr_, thresholds_) |
0 commit comments