Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use namedtuple to store curve results (ROC, PR) #572

Merged
merged 5 commits into from
Mar 5, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions cyclops/evaluate/metrics/experimental/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,13 @@
multilabel_tpr,
)
from cyclops.evaluate.metrics.experimental.functional.precision_recall_curve import (
PRCurve,
binary_precision_recall_curve,
multiclass_precision_recall_curve,
multilabel_precision_recall_curve,
)
from cyclops.evaluate.metrics.experimental.functional.roc import (
ROCCurve,
binary_roc,
multiclass_roc,
multilabel_roc,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Functions for computing the precision and recall for different unique thresholds."""
from types import ModuleType
from typing import Any, List, Literal, Optional, Sequence, Tuple, Union
from typing import Any, List, Literal, NamedTuple, Optional, Sequence, Tuple, Union

import array_api_compat as apc
import numpy as np
Expand Down Expand Up @@ -28,6 +28,14 @@
)


class PRCurve(NamedTuple):
"""Named tuple with Precision-Recall curve (Precision, Recall and thresholds)."""

precision: Union[Array, List[Array]]
recall: Union[Array, List[Array]]
thresholds: Union[Array, List[Array]]


def _validate_thresholds(thresholds: Optional[Union[int, List[float], Array]]) -> None:
"""Validate the `thresholds` argument."""
if thresholds is not None and not (
Expand Down Expand Up @@ -352,14 +360,13 @@ def binary_precision_recall_curve(

Returns
-------
precision : Array
The precision values for all unique thresholds. The shape of the array is
PRCurve
A named tuple that contains the following elements:
- `precision` values for all unique thresholds. The shape of the array is
`(num_thresholds + 1,)`.
recall : Array
The recall values for all unique thresholds. The shape of the array is
- `recall` values for all unique thresholds. The shape of the array is
`(num_thresholds + 1,)`.
thresholds : Array
The thresholds used for computing the precision and recall values, in
- `thresholds` used for computing the precision and recall values, in
ascending order. The shape of the array is `(num_thresholds,)`.

Raises
Expand Down Expand Up @@ -688,7 +695,7 @@ def multiclass_precision_recall_curve(
thresholds: Optional[Union[int, List[float], Array]] = None,
average: Optional[Literal["macro", "micro", "none"]] = None,
ignore_index: Optional[Union[int, Tuple[int]]] = None,
) -> Union[Tuple[Array, Array, Array], Tuple[List[Array], List[Array], List[Array]]]:
) -> PRCurve:
"""Compute the precision and recall for all unique thresholds.

Parameters
Expand Down Expand Up @@ -730,18 +737,17 @@ def multiclass_precision_recall_curve(

Returns
-------
precision : Array or List[Array]
The precision values for all unique thresholds. If `thresholds` is `"none"`
PRCurve
A named tuple that contains the following elements:
- `precision` values for all unique thresholds. If `thresholds` is `"none"`
or `None`, a list for each class is returned with 1-D Arrays of shape
`(num_thresholds + 1,)`. Otherwise, a 2-D Array of shape
`(num_thresholds + 1, num_classes)` is returned.
recall : Array or List[Array]
The recall values for all unique thresholds. If `thresholds` is `"none"`
- `recall` values for all unique thresholds. If `thresholds` is `"none"`
or `None`, a list for each class is returned with 1-D Arrays of shape
`(num_thresholds + 1,)`. Otherwise, a 2-D Array of shape
`(num_thresholds + 1, num_classes)` is returned.
thresholds : Array or List[Array]
The thresholds used for computing the precision and recall values, in
- `thresholds` used for computing the precision and recall values, in
ascending order. If `thresholds` is `"none"` or `None`, a list for each
class is returned with 1-D Arrays of shape `(num_thresholds,)`. Otherwise,
a 1-D Array of shape `(num_thresholds,)` is returned.
Expand Down Expand Up @@ -868,12 +874,13 @@ class is returned with 1-D Arrays of shape `(num_thresholds,)`. Otherwise,
average,
xp=xp,
)
return _multiclass_precision_recall_curve_compute(
precision, recall, thresholds_ = _multiclass_precision_recall_curve_compute(
state,
num_classes,
thresholds=thresholds,
average=average,
)
return PRCurve(precision, recall, thresholds_)


def _multilabel_precision_recall_curve_validate_args(
Expand Down Expand Up @@ -1035,7 +1042,7 @@ def multilabel_precision_recall_curve(
num_labels: int,
thresholds: Optional[Union[int, List[float], Array]] = None,
ignore_index: Optional[int] = None,
) -> Union[Tuple[Array, Array, Array], Tuple[List[Array], List[Array], List[Array]]]:
) -> PRCurve:
"""Compute the precision and recall for all unique thresholds.

Parameters
Expand Down Expand Up @@ -1067,21 +1074,20 @@ def multilabel_precision_recall_curve(

Returns
-------
precision : Array or List[Array]
The precision values for all unique thresholds. If `thresholds` is `None`,
a list for each label is returned with 1-D Arrays of shape
PRCurve
A named tuple that contains the following elements:
- `precision` values for all unique thresholds. If `thresholds` is `"none"`
or `None`, a list for each class is returned with 1-D Arrays of shape
`(num_thresholds + 1,)`. Otherwise, a 2-D Array of shape
`(num_thresholds + 1, num_labels)` is returned.
recall : Array or List[Array]
The recall values for all unique thresholds. If `thresholds` is `None`,
a list for each label is returned with 1-D Arrays of shape
`(num_thresholds + 1, num_classes)` is returned.
- `recall` values for all unique thresholds. If `thresholds` is `"none"`
or `None`, a list for each class is returned with 1-D Arrays of shape
`(num_thresholds + 1,)`. Otherwise, a 2-D Array of shape
`(num_thresholds + 1, num_labels)` is returned.
thresholds : Array or List[Array]
The thresholds used for computing the precision and recall values, in
ascending order. If `thresholds` is `None`, a list for each label is
returned with 1-D Arrays of shape `(num_thresholds,)`. Otherwise, a 1-D
Array of shape `(num_thresholds,)` is returned.
`(num_thresholds + 1, num_classes)` is returned.
- `thresholds` used for computing the precision and recall values, in
ascending order. If `thresholds` is `"none"` or `None`, a list for each
class is returned with 1-D Arrays of shape `(num_thresholds,)`. Otherwise,
a 1-D Array of shape `(num_thresholds,)` is returned.

Raises
------
Expand Down Expand Up @@ -1193,9 +1199,10 @@ def multilabel_precision_recall_curve(
thresholds,
xp=xp,
)
return _multilabel_precision_recall_curve_compute(
precision, recall, thresholds_ = _multilabel_precision_recall_curve_compute(
state,
num_labels,
thresholds,
ignore_index,
)
return PRCurve(precision, recall, thresholds_)
81 changes: 38 additions & 43 deletions cyclops/evaluate/metrics/experimental/functional/roc.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Functions for computing Receiver Operating Characteristic (ROC) curves."""
import warnings
from typing import List, Literal, Optional, Tuple, Union
from typing import List, Literal, NamedTuple, Optional, Tuple, Union

import array_api_compat as apc

Expand Down Expand Up @@ -28,6 +28,14 @@
from cyclops.evaluate.metrics.experimental.utils.types import Array


class ROCCurve(NamedTuple):
"""Named tuple to store ROC curve (FPR, TPR and thresholds)."""

fpr: Union[Array, List[Array]]
tpr: Union[Array, List[Array]]
thresholds: Union[Array, List[Array]]


def _binary_roc_compute(
state: Union[Array, Tuple[Array, Array]],
thresholds: Optional[Array],
Expand Down Expand Up @@ -91,7 +99,7 @@ def binary_roc(
preds: Array,
thresholds: Optional[Union[int, List[float], Array]] = None,
ignore_index: Optional[int] = None,
) -> Tuple[Array, Array, Array]:
) -> ROCCurve:
"""Compute the receiver operating characteristic (ROC) curve for binary tasks.

Parameters
Expand Down Expand Up @@ -120,15 +128,11 @@ def binary_roc(

Returns
-------
fpr : Array
The false positive rates for all unique thresholds. The shape of the array is
`(num_thresholds + 1,)`.
tpr : Array
The true positive rates for all unique thresholds. The shape of the array is
`(num_thresholds + 1,)`.
thresholds : Array
The thresholds used for computing the ROC curve values, in descending order.
The shape of the array is `(num_thresholds,)`.
ROCCurve
A named tuple containing the false positive rate (FPR), true positive rate
(TPR) and thresholds. The FPR and TPR are arrays of of shape
amrit110 marked this conversation as resolved.
Show resolved Hide resolved
`(num_thresholds + 1,)` and the thresholds are an array of shape
`(num_thresholds,)`.

Raises
------
Expand Down Expand Up @@ -209,7 +213,8 @@ def binary_roc(
xp=xp,
)
state = _binary_precision_recall_curve_update(target, preds, thresholds, xp=xp)
return _binary_roc_compute(state, thresholds)
fpr, tpr, thresh = _binary_roc_compute(state, thresholds)
return ROCCurve(fpr, tpr, thresh)


def _multiclass_roc_compute(
Expand Down Expand Up @@ -277,7 +282,7 @@ def multiclass_roc(
thresholds: Optional[Union[int, List[float], Array]] = None,
average: Optional[Literal["macro", "micro", "none"]] = None,
ignore_index: Optional[Union[int, Tuple[int]]] = None,
) -> Union[Tuple[Array, Array, Array], Tuple[List[Array], List[Array], List[Array]]]:
) -> ROCCurve:
"""Compute the receiver operating characteristic (ROC) curve for multiclass tasks.

Parameters
Expand Down Expand Up @@ -318,19 +323,13 @@ def multiclass_roc(

Returns
-------
fpr : Array or List[Array]
The false positive rates for all unique thresholds. If `thresholds` is `"none"`
or `None`, a list for each class is returned with 1-D Arrays of shape
`(num_thresholds + 1,)`. Otherwise, a 2-D Array of shape
ROCCurve
A named tuple that contains the false positive rate, true positive rate,
and the thresholds used for computing the ROC curve. If `thresholds` is `"none"`
or `None`, a list of TPRs and FPRs for each class is returned with 1-D Arrays
of shape `(num_thresholds + 1,)`. Otherwise, a 2-D Array of shape
`(num_thresholds + 1, num_classes)` is returned.
tpr : Array or List[Array]
The true positive rates for all unique thresholds. If `thresholds` is `"none"`
or `None`, a list for each class is returned with 1-D Arrays of shape
`(num_thresholds + 1,)`. Otherwise, a 2-D Array of shape
`(num_thresholds + 1, num_classes)` is returned.
thresholds : Array or List[Array]
The thresholds used for computing the ROC curve values, in descending order.
If `thresholds` is `"none"` or `None`, a list for each class is returned
Similarly, a list of thresholds for each class is returned
with 1-D Arrays of shape `(num_thresholds,)`. Otherwise, a 1-D Array of
shape `(num_thresholds,)` is returned.

Expand Down Expand Up @@ -455,12 +454,13 @@ def multiclass_roc(
average,
xp=xp,
)
return _multiclass_roc_compute(
fpr_, tpr_, thresholds_ = _multiclass_roc_compute(
state,
num_classes,
thresholds=thresholds,
average=average,
)
return ROCCurve(fpr_, tpr_, thresholds_)


def _multilabel_roc_compute(
Expand Down Expand Up @@ -504,7 +504,7 @@ def multilabel_roc(
num_labels: int,
thresholds: Optional[Union[int, List[float], Array]] = None,
ignore_index: Optional[int] = None,
) -> Union[Tuple[Array, Array, Array], Tuple[List[Array], List[Array], List[Array]]]:
) -> ROCCurve:
"""Compute the receiver operating characteristic (ROC) curve for multilabel tasks.

Parameters
Expand Down Expand Up @@ -535,21 +535,15 @@ def multilabel_roc(

Returns
-------
fpr : Array or List[Array]
The false positive rates for all unique thresholds. If `thresholds` is `None`,
a list for each label is returned with 1-D Arrays of shape
`(num_thresholds + 1,)`. Otherwise, a 2-D Array of shape
`(num_thresholds + 1, num_labels)` is returned.
tpr : Array or List[Array]
The true positive rates for all unique thresholds. If `thresholds` is `None`,
a list for each label is returned with 1-D Arrays of shape
`(num_thresholds + 1,)`. Otherwise, a 2-D Array of shape
`(num_thresholds + 1, num_labels)` is returned.
thresholds : Array or List[Array]
The thresholds used for computing the ROC curve values, in
descending order. If `thresholds` is `None`, a list for each label is
returned with 1-D Arrays of shape `(num_thresholds,)`. Otherwise, a 1-D
Array of shape `(num_thresholds,)` is returned.
ROCCurve
A named tuple that contains the false positive rate, true positive rate,
and the thresholds used for computing the ROC curve. If `thresholds` is `"none"`
or `None`, a list of TPRs and FPRs for each class is returned with 1-D Arrays
of shape `(num_thresholds + 1,)`. Otherwise, a 2-D Array of shape
`(num_thresholds + 1, num_classes)` is returned.
Similarly, a list of thresholds for each class is returned
with 1-D Arrays of shape `(num_thresholds,)`. Otherwise, a 1-D Array of
shape `(num_thresholds,)` is returned.

Raises
------
Expand Down Expand Up @@ -660,9 +654,10 @@ def multilabel_roc(
thresholds,
xp=xp,
)
return _multilabel_roc_compute(
fpr_, tpr_, thresholds_ = _multilabel_roc_compute(
state,
num_labels,
thresholds,
ignore_index,
)
return ROCCurve(fpr_, tpr_, thresholds_)
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import array_api_compat as apc

from cyclops.evaluate.metrics.experimental.functional.precision_recall_curve import (
PRCurve,
_binary_precision_recall_curve_compute,
_binary_precision_recall_curve_format_arrays,
_binary_precision_recall_curve_update,
Expand Down Expand Up @@ -140,14 +141,18 @@ def _update_state(self, target: Array, preds: Array) -> None:
self.target.append(state[0]) # type: ignore[attr-defined]
self.preds.append(state[1]) # type: ignore[attr-defined]

def _compute_metric(self) -> Tuple[Array, Array, Array]:
def _compute_metric(self) -> PRCurve:
"""Compute the metric."""
state = (
(dim_zero_cat(self.target), dim_zero_cat(self.preds)) # type: ignore[attr-defined]
if self.thresholds is None
else self.confmat # type: ignore[attr-defined]
)
return _binary_precision_recall_curve_compute(state, self.thresholds) # type: ignore[arg-type]
precision, recall, thresholds = _binary_precision_recall_curve_compute(
state,
self.thresholds, # type: ignore
)
return PRCurve(precision, recall, thresholds)


class MulticlassPrecisionRecallCurve(
Expand Down
6 changes: 4 additions & 2 deletions cyclops/evaluate/metrics/experimental/roc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import List, Tuple, Union

from cyclops.evaluate.metrics.experimental.functional.roc import (
ROCCurve,
_binary_roc_compute,
_multiclass_roc_compute,
_multilabel_roc_compute,
Expand Down Expand Up @@ -55,13 +56,14 @@ class BinaryROC(BinaryPrecisionRecallCurve, registry_key="binary_roc_curve"):

name: str = "ROC Curve"

def _compute_metric(self) -> Tuple[Array, Array, Array]:
def _compute_metric(self) -> ROCCurve: # type: ignore
state = (
(dim_zero_cat(self.target), dim_zero_cat(self.preds)) # type: ignore[attr-defined]
if self.thresholds is None
else self.confmat # type: ignore[attr-defined]
)
return _binary_roc_compute(state, self.thresholds) # type: ignore[arg-type]
fpr, tpr, thresholds = _binary_roc_compute(state, self.thresholds) # type: ignore[arg-type]
return ROCCurve(fpr, tpr, thresholds)


class MulticlassROC(
Expand Down
Loading
Loading