Skip to content

Commit

Permalink
Integrate experimental metrics with other modules (#549)
Browse files Browse the repository at this point in the history
* integrate experimental metrics with other modules

* add average precision metric to experimental metrics package

* fix tutorials

* Add type hints and keyword arguments to metrics classes

* Update nbsphinx version to 0.9.3

* Update nbconvert version to 7.14.2

* Fix type annotations and formatting issues

* Update kernel display name in mortality_prediction.ipynb

* Add guard clause to prevent module execution on import

* Update `torch_distributed.py` with type hints

* Add multiclass and multilabel average precision metrics

* Change jupyter kernel

* Fix type annotations for metric values in ClassificationPlotter

---------

Co-authored-by: Amrit K <amritk@vectorinstitute.ai>
  • Loading branch information
fcogidi and amrit110 authored Jan 30, 2024
1 parent 5c4ebb2 commit 8fb3cf1
Show file tree
Hide file tree
Showing 33 changed files with 1,900 additions and 394 deletions.
33 changes: 17 additions & 16 deletions cyclops/evaluate/evaluator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""Evaluate one or more models on a dataset."""

import logging
import warnings
from dataclasses import asdict
Expand All @@ -16,7 +15,9 @@
)
from cyclops.evaluate.fairness.config import FairnessConfig
from cyclops.evaluate.fairness.evaluator import evaluate_fairness
from cyclops.evaluate.metrics.metric import Metric, MetricCollection
from cyclops.evaluate.metrics.experimental.metric import Metric
from cyclops.evaluate.metrics.experimental.metric_dict import MetricDict
from cyclops.evaluate.metrics.experimental.utils.types import Array
from cyclops.evaluate.utils import _format_column_names, choose_split
from cyclops.utils.log import setup_logging

Expand All @@ -27,7 +28,7 @@

def evaluate(
dataset: Union[str, Dataset, DatasetDict],
metrics: Union[Metric, Sequence[Metric], Dict[str, Metric], MetricCollection],
metrics: Union[Metric, Sequence[Metric], Dict[str, Metric], MetricDict],
target_columns: Union[str, List[str]],
prediction_columns: Union[str, List[str]],
ignore_columns: Optional[Union[str, List[str]]] = None,
Expand All @@ -47,7 +48,7 @@ def evaluate(
The dataset to evaluate on. If a string, the dataset will be loaded
using `datasets.load_dataset`. If `DatasetDict`, the `split` argument
must be specified.
metrics : Union[Metric, Sequence[Metric], Dict[str, Metric], MetricCollection]
metrics : Union[Metric, Sequence[Metric], Dict[str, Metric], MetricDict]
The metrics to compute.
target_columns : Union[str, List[str]]
The name of the column(s) containing the target values. A string value
Expand Down Expand Up @@ -202,28 +203,28 @@ def _load_data(


def _prepare_metrics(
metrics: Union[Metric, Sequence[Metric], Dict[str, Metric], MetricCollection],
) -> MetricCollection:
metrics: Union[Metric, Sequence[Metric], Dict[str, Metric], MetricDict],
) -> MetricDict:
"""Prepare metrics for evaluation."""
# TODO: wrap in BootstrappedMetric if computing confidence intervals
# TODO [fcogidi]: wrap in BootstrappedMetric if computing confidence intervals
if isinstance(metrics, (Metric, Sequence, Dict)) and not isinstance(
metrics,
MetricCollection,
MetricDict,
):
return MetricCollection(metrics)
if isinstance(metrics, MetricCollection):
return MetricDict(metrics) # type: ignore[arg-type]
if isinstance(metrics, MetricDict):
return metrics

raise TypeError(
f"Invalid type for `metrics`: {type(metrics)}. "
"Expected one of: Metric, Sequence[Metric], Dict[str, Metric], "
"MetricCollection.",
"MetricDict.",
)


def _compute_metrics(
dataset: Dataset,
metrics: MetricCollection,
metrics: MetricDict,
slice_spec: SliceSpec,
target_columns: Union[str, List[str]],
prediction_columns: Union[str, List[str]],
Expand Down Expand Up @@ -266,8 +267,8 @@ def _compute_metrics(
RuntimeWarning,
stacklevel=1,
)
metric_output = {
metric_name: float("NaN") for metric_name in metrics
metric_output: Dict[str, Array] = {
metric_name: float("NaN") for metric_name in metrics # type: ignore[attr-defined,misc]
}
elif (
batch_size is None or batch_size < 0
Expand All @@ -293,10 +294,10 @@ def _compute_metrics(
)

# update the metric state
metrics.update_state(targets, predictions)
metrics.update(targets, predictions)

metric_output = metrics.compute()
metrics.reset_state()
metrics.reset()

model_name: str = "model_for_%s" % prediction_column
results.setdefault(model_name, {})
Expand Down
5 changes: 3 additions & 2 deletions cyclops/evaluate/fairness/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@

from datasets import Dataset, config

from cyclops.evaluate.metrics.metric import Metric, MetricCollection
from cyclops.evaluate.metrics.experimental.metric import Metric
from cyclops.evaluate.metrics.experimental.metric_dict import MetricDict


@dataclass
class FairnessConfig:
"""Configuration for fairness metrics."""

metrics: Union[str, Callable[..., Any], Metric, MetricCollection]
metrics: Union[str, Callable[..., Any], Metric, MetricDict]
dataset: Dataset
groups: Union[str, List[str]]
target_columns: Union[str, List[str]]
Expand Down
89 changes: 43 additions & 46 deletions cyclops/evaluate/fairness/evaluator.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
"""Fairness evaluator."""

import inspect
import itertools
import logging
import warnings
from datetime import datetime
from typing import Any, Callable, Dict, List, Optional, Union

import array_api_compat.numpy
import numpy as np
import numpy.typing as npt
import pandas as pd
from datasets import Dataset, config
from datasets.features import Features
Expand All @@ -21,15 +20,14 @@
get_columns_as_numpy_array,
set_decode,
)
from cyclops.evaluate.metrics.factory import create_metric
from cyclops.evaluate.metrics.functional.precision_recall_curve import (
from cyclops.evaluate.metrics.experimental.functional.precision_recall_curve import (
_format_thresholds,
_validate_thresholds,
)
from cyclops.evaluate.metrics.metric import Metric, MetricCollection, OperatorMetric
from cyclops.evaluate.metrics.utils import (
_check_thresholds,
_get_value_if_singleton_array,
)
from cyclops.evaluate.metrics.experimental.metric import Metric, OperatorMetric
from cyclops.evaluate.metrics.experimental.metric_dict import MetricDict
from cyclops.evaluate.metrics.experimental.utils.types import Array
from cyclops.evaluate.metrics.factory import create_metric
from cyclops.evaluate.utils import _format_column_names
from cyclops.utils.log import setup_logging

Expand All @@ -39,7 +37,7 @@


def evaluate_fairness(
metrics: Union[str, Callable[..., Any], Metric, MetricCollection],
metrics: Union[str, Callable[..., Any], Metric, MetricDict],
dataset: Dataset,
groups: Union[str, List[str]],
target_columns: Union[str, List[str]],
Expand All @@ -62,7 +60,7 @@ def evaluate_fairness(
Parameters
----------
metrics : Union[str, Callable[..., Any], Metric, MetricCollection]
metrics : Union[str, Callable[..., Any], Metric, MetricDict]
The metric or metrics to compute. If a string, it should be the name of a
metric provided by CyclOps. If a callable, it should be a function that
takes target, prediction, and optionally threshold/thresholds as arguments
Expand Down Expand Up @@ -147,18 +145,14 @@ def evaluate_fairness(
raise TypeError(
"Expected `dataset` to be of type `Dataset`, but got " f"{type(dataset)}.",
)
_validate_thresholds(thresholds)

_check_thresholds(thresholds)
fmt_thresholds: npt.NDArray[np.float_] = _format_thresholds( # type: ignore
thresholds,
)

metrics_: Union[Callable[..., Any], MetricCollection] = _format_metrics(
metrics_: Union[Callable[..., Any], MetricDict] = _format_metrics(
metrics,
metric_name,
**(metric_kwargs or {}),
)

fmt_thresholds = _format_thresholds(thresholds, xp=array_api_compat.numpy)
fmt_groups: List[str] = _format_column_names(groups)
fmt_target_columns: List[str] = _format_column_names(target_columns)
fmt_prediction_columns: List[str] = _format_column_names(prediction_columns)
Expand Down Expand Up @@ -361,15 +355,15 @@ def warn_too_many_unique_values(


def _format_metrics(
metrics: Union[str, Callable[..., Any], Metric, MetricCollection],
metrics: Union[str, Callable[..., Any], Metric, MetricDict],
metric_name: Optional[str] = None,
**metric_kwargs: Any,
) -> Union[Callable[..., Any], Metric, MetricCollection]:
) -> Union[Callable[..., Any], Metric, MetricDict]:
"""Format the metrics argument.
Parameters
----------
metrics : Union[str, Callable[..., Any], Metric, MetricCollection]
metrics : Union[str, Callable[..., Any], Metric, MetricDict]
The metrics to use for computing the metric results.
metric_name : str, optional, default=None
The name of the metric. This is only used if `metrics` is a callable.
Expand All @@ -379,23 +373,23 @@ def _format_metrics(
Returns
-------
Union[Callable[..., Any], Metric, MetricCollection]
Union[Callable[..., Any], Metric, MetricDict]
The formatted metrics.
Raises
------
TypeError
If `metrics` is not of type `str`, `Callable`, `Metric`, or `MetricCollection`.
If `metrics` is not of type `str`, `Callable`, `Metric`, or `MetricDict`.
"""
if isinstance(metrics, str):
metrics = create_metric(metric_name=metrics, **metric_kwargs)
metrics = create_metric(metric_name=metrics, experimental=True, **metric_kwargs)
if isinstance(metrics, Metric):
if metric_name is not None and isinstance(metrics, OperatorMetric):
# single metric created from arithmetic operation, with given name
return MetricCollection({metric_name: metrics})
return MetricCollection(metrics)
if isinstance(metrics, MetricCollection):
return MetricDict({metric_name: metrics})
return MetricDict(metrics)
if isinstance(metrics, MetricDict):
return metrics
if callable(metrics):
if metric_name is None:
Expand All @@ -407,7 +401,7 @@ def _format_metrics(
return metrics

raise TypeError(
f"Expected `metrics` to be of type `str`, `Metric`, `MetricCollection`, or "
f"Expected `metrics` to be of type `str`, `Metric`, `MetricDict`, or "
f"`Callable`, but got {type(metrics)}.",
)

Expand Down Expand Up @@ -701,7 +695,7 @@ def _get_slice_spec(


def _compute_metrics( # noqa: C901, PLR0912
metrics: Union[Callable[..., Any], MetricCollection],
metrics: Union[Callable[..., Any], MetricDict],
dataset: Dataset,
target_columns: List[str],
prediction_column: str,
Expand All @@ -713,7 +707,7 @@ def _compute_metrics( # noqa: C901, PLR0912
Parameters
----------
metrics : Union[Callable, MetricCollection]
metrics : Union[Callable, MetricDict]
The metrics to compute.
dataset : Dataset
The dataset to compute the metrics on.
Expand All @@ -738,12 +732,19 @@ def _compute_metrics( # noqa: C901, PLR0912
"Encountered empty dataset while computing metrics. "
"The metric values will be set to `None`."
)
if isinstance(metrics, MetricCollection):
if isinstance(metrics, MetricDict):
if threshold is not None:
# set the threshold for each metric in the collection
for name, metric in metrics.items():
if hasattr(metric, "threshold"):
if isinstance(metric, Metric) and hasattr(metric, "threshold"):
metric.threshold = threshold
elif isinstance(metric, OperatorMetric):
if hasattr(metric.metric_a, "threshold") and hasattr(
metric.metric_b,
"threshold",
):
metric.metric_a.threshold = threshold
metric.metric_b.threshold = threshold # type: ignore[union-attr]
else:
LOGGER.warning(
"Metric %s does not have a threshold attribute. "
Expand All @@ -754,7 +755,7 @@ def _compute_metrics( # noqa: C901, PLR0912
if len(dataset) == 0:
warnings.warn(empty_dataset_msg, RuntimeWarning, stacklevel=1)
results: Dict[str, Any] = {
metric_name: float("NaN") for metric_name in metrics
metric_name: float("NaN") for metric_name in metrics # type: ignore[attr-defined]
}
elif (
batch_size is None or batch_size <= 0
Expand All @@ -779,11 +780,11 @@ def _compute_metrics( # noqa: C901, PLR0912
columns=prediction_column,
)

metrics.update_state(targets, predictions)
metrics.update(targets, predictions)

results = metrics.compute()

metrics.reset_state()
metrics.reset()

return results
if callable(metrics):
Expand Down Expand Up @@ -817,26 +818,26 @@ def _compute_metrics( # noqa: C901, PLR0912
return {metric_name.title(): output}

raise TypeError(
"The `metrics` argument must be a string, a Metric, a MetricCollection, "
"The `metrics` argument must be a string, a Metric, a MetricDict, "
f"or a callable. Got {type(metrics)}.",
)


def _get_metric_results_for_prediction_and_slice(
metrics: Union[Callable[..., Any], MetricCollection],
metrics: Union[Callable[..., Any], MetricDict],
dataset: Dataset,
target_columns: List[str],
prediction_column: str,
slice_name: str,
batch_size: Optional[int] = config.DEFAULT_MAX_BATCH_SIZE,
metric_name: Optional[str] = None,
thresholds: Optional[npt.NDArray[np.float_]] = None,
thresholds: Optional[Array] = None,
) -> Dict[str, Dict[str, Any]]:
"""Compute metrics for a slice of a dataset.
Parameters
----------
metrics : Union[Callable, MetricCollection]
metrics : Union[Callable, MetricDict]
The metrics to compute.
dataset : Dataset
The dataset to compute the metrics on.
Expand All @@ -850,7 +851,7 @@ def _get_metric_results_for_prediction_and_slice(
The batch size to use for the computation.
metric_name : Optional[str]
The name of the metric to compute.
thresholds : Optional[List[float]]
thresholds : Optional[Array]
The thresholds to use for the metrics.
Returns
Expand All @@ -873,7 +874,7 @@ def _get_metric_results_for_prediction_and_slice(
return {slice_name: metric_output}

results: Dict[str, Dict[str, Any]] = {}
for threshold in thresholds:
for threshold in thresholds: # type: ignore[attr-defined]
metric_output = _compute_metrics(
metrics=metrics,
dataset=dataset,
Expand Down Expand Up @@ -969,11 +970,7 @@ def _compute_parity_metrics(
)

parity_results[key].setdefault(slice_name, {}).update(
{
parity_metric_name: _get_value_if_singleton_array(
parity_metric_value,
),
},
{parity_metric_name: parity_metric_value},
)

return parity_results
5 changes: 5 additions & 0 deletions cyclops/evaluate/metrics/experimental/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@
MulticlassAUROC,
MultilabelAUROC,
)
from cyclops.evaluate.metrics.experimental.average_precision import (
BinaryAveragePrecision,
MulticlassAveragePrecision,
MultilabelAveragePrecision,
)
from cyclops.evaluate.metrics.experimental.confusion_matrix import (
BinaryConfusionMatrix,
MulticlassConfusionMatrix,
Expand Down
Loading

0 comments on commit 8fb3cf1

Please sign in to comment.