Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate experimental metrics with other modules #549

Merged
merged 13 commits into from
Jan 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 17 additions & 16 deletions cyclops/evaluate/evaluator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""Evaluate one or more models on a dataset."""

import logging
import warnings
from dataclasses import asdict
Expand All @@ -16,7 +15,9 @@
)
from cyclops.evaluate.fairness.config import FairnessConfig
from cyclops.evaluate.fairness.evaluator import evaluate_fairness
from cyclops.evaluate.metrics.metric import Metric, MetricCollection
from cyclops.evaluate.metrics.experimental.metric import Metric
from cyclops.evaluate.metrics.experimental.metric_dict import MetricDict
from cyclops.evaluate.metrics.experimental.utils.types import Array
from cyclops.evaluate.utils import _format_column_names, choose_split
from cyclops.utils.log import setup_logging

Expand All @@ -27,7 +28,7 @@

def evaluate(
dataset: Union[str, Dataset, DatasetDict],
metrics: Union[Metric, Sequence[Metric], Dict[str, Metric], MetricCollection],
metrics: Union[Metric, Sequence[Metric], Dict[str, Metric], MetricDict],
target_columns: Union[str, List[str]],
prediction_columns: Union[str, List[str]],
ignore_columns: Optional[Union[str, List[str]]] = None,
Expand All @@ -47,7 +48,7 @@
The dataset to evaluate on. If a string, the dataset will be loaded
using `datasets.load_dataset`. If `DatasetDict`, the `split` argument
must be specified.
metrics : Union[Metric, Sequence[Metric], Dict[str, Metric], MetricCollection]
metrics : Union[Metric, Sequence[Metric], Dict[str, Metric], MetricDict]
The metrics to compute.
target_columns : Union[str, List[str]]
The name of the column(s) containing the target values. A string value
Expand Down Expand Up @@ -202,28 +203,28 @@


def _prepare_metrics(
metrics: Union[Metric, Sequence[Metric], Dict[str, Metric], MetricCollection],
) -> MetricCollection:
metrics: Union[Metric, Sequence[Metric], Dict[str, Metric], MetricDict],
) -> MetricDict:
"""Prepare metrics for evaluation."""
# TODO: wrap in BootstrappedMetric if computing confidence intervals
# TODO [fcogidi]: wrap in BootstrappedMetric if computing confidence intervals
if isinstance(metrics, (Metric, Sequence, Dict)) and not isinstance(
metrics,
MetricCollection,
MetricDict,
):
return MetricCollection(metrics)
if isinstance(metrics, MetricCollection):
return MetricDict(metrics) # type: ignore[arg-type]
if isinstance(metrics, MetricDict):

Check warning on line 215 in cyclops/evaluate/evaluator.py

View check run for this annotation

Codecov / codecov/patch

cyclops/evaluate/evaluator.py#L214-L215

Added lines #L214 - L215 were not covered by tests
return metrics

raise TypeError(
f"Invalid type for `metrics`: {type(metrics)}. "
"Expected one of: Metric, Sequence[Metric], Dict[str, Metric], "
"MetricCollection.",
"MetricDict.",
)


def _compute_metrics(
dataset: Dataset,
metrics: MetricCollection,
metrics: MetricDict,
slice_spec: SliceSpec,
target_columns: Union[str, List[str]],
prediction_columns: Union[str, List[str]],
Expand Down Expand Up @@ -266,8 +267,8 @@
RuntimeWarning,
stacklevel=1,
)
metric_output = {
metric_name: float("NaN") for metric_name in metrics
metric_output: Dict[str, Array] = {

Check warning on line 270 in cyclops/evaluate/evaluator.py

View check run for this annotation

Codecov / codecov/patch

cyclops/evaluate/evaluator.py#L270

Added line #L270 was not covered by tests
metric_name: float("NaN") for metric_name in metrics # type: ignore[attr-defined,misc]
}
elif (
batch_size is None or batch_size < 0
Expand All @@ -293,10 +294,10 @@
)

# update the metric state
metrics.update_state(targets, predictions)
metrics.update(targets, predictions)

Check warning on line 297 in cyclops/evaluate/evaluator.py

View check run for this annotation

Codecov / codecov/patch

cyclops/evaluate/evaluator.py#L297

Added line #L297 was not covered by tests

metric_output = metrics.compute()
metrics.reset_state()
metrics.reset()

Check warning on line 300 in cyclops/evaluate/evaluator.py

View check run for this annotation

Codecov / codecov/patch

cyclops/evaluate/evaluator.py#L300

Added line #L300 was not covered by tests

model_name: str = "model_for_%s" % prediction_column
results.setdefault(model_name, {})
Expand Down
5 changes: 3 additions & 2 deletions cyclops/evaluate/fairness/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@

from datasets import Dataset, config

from cyclops.evaluate.metrics.metric import Metric, MetricCollection
from cyclops.evaluate.metrics.experimental.metric import Metric
from cyclops.evaluate.metrics.experimental.metric_dict import MetricDict


@dataclass
class FairnessConfig:
"""Configuration for fairness metrics."""

metrics: Union[str, Callable[..., Any], Metric, MetricCollection]
metrics: Union[str, Callable[..., Any], Metric, MetricDict]
dataset: Dataset
groups: Union[str, List[str]]
target_columns: Union[str, List[str]]
Expand Down
89 changes: 43 additions & 46 deletions cyclops/evaluate/fairness/evaluator.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
"""Fairness evaluator."""

import inspect
import itertools
import logging
import warnings
from datetime import datetime
from typing import Any, Callable, Dict, List, Optional, Union

import array_api_compat.numpy
import numpy as np
import numpy.typing as npt
import pandas as pd
from datasets import Dataset, config
from datasets.features import Features
Expand All @@ -21,15 +20,14 @@
get_columns_as_numpy_array,
set_decode,
)
from cyclops.evaluate.metrics.factory import create_metric
from cyclops.evaluate.metrics.functional.precision_recall_curve import (
from cyclops.evaluate.metrics.experimental.functional.precision_recall_curve import (
_format_thresholds,
_validate_thresholds,
)
from cyclops.evaluate.metrics.metric import Metric, MetricCollection, OperatorMetric
from cyclops.evaluate.metrics.utils import (
_check_thresholds,
_get_value_if_singleton_array,
)
from cyclops.evaluate.metrics.experimental.metric import Metric, OperatorMetric
from cyclops.evaluate.metrics.experimental.metric_dict import MetricDict
from cyclops.evaluate.metrics.experimental.utils.types import Array
from cyclops.evaluate.metrics.factory import create_metric
from cyclops.evaluate.utils import _format_column_names
from cyclops.utils.log import setup_logging

Expand All @@ -39,7 +37,7 @@


def evaluate_fairness(
metrics: Union[str, Callable[..., Any], Metric, MetricCollection],
metrics: Union[str, Callable[..., Any], Metric, MetricDict],
dataset: Dataset,
groups: Union[str, List[str]],
target_columns: Union[str, List[str]],
Expand All @@ -62,7 +60,7 @@

Parameters
----------
metrics : Union[str, Callable[..., Any], Metric, MetricCollection]
metrics : Union[str, Callable[..., Any], Metric, MetricDict]
The metric or metrics to compute. If a string, it should be the name of a
metric provided by CyclOps. If a callable, it should be a function that
takes target, prediction, and optionally threshold/thresholds as arguments
Expand Down Expand Up @@ -147,18 +145,14 @@
raise TypeError(
"Expected `dataset` to be of type `Dataset`, but got " f"{type(dataset)}.",
)
_validate_thresholds(thresholds)

Check warning on line 148 in cyclops/evaluate/fairness/evaluator.py

View check run for this annotation

Codecov / codecov/patch

cyclops/evaluate/fairness/evaluator.py#L148

Added line #L148 was not covered by tests

_check_thresholds(thresholds)
fmt_thresholds: npt.NDArray[np.float_] = _format_thresholds( # type: ignore
thresholds,
)

metrics_: Union[Callable[..., Any], MetricCollection] = _format_metrics(
metrics_: Union[Callable[..., Any], MetricDict] = _format_metrics(

Check warning on line 150 in cyclops/evaluate/fairness/evaluator.py

View check run for this annotation

Codecov / codecov/patch

cyclops/evaluate/fairness/evaluator.py#L150

Added line #L150 was not covered by tests
metrics,
metric_name,
**(metric_kwargs or {}),
)

fmt_thresholds = _format_thresholds(thresholds, xp=array_api_compat.numpy)

Check warning on line 155 in cyclops/evaluate/fairness/evaluator.py

View check run for this annotation

Codecov / codecov/patch

cyclops/evaluate/fairness/evaluator.py#L155

Added line #L155 was not covered by tests
fmt_groups: List[str] = _format_column_names(groups)
fmt_target_columns: List[str] = _format_column_names(target_columns)
fmt_prediction_columns: List[str] = _format_column_names(prediction_columns)
Expand Down Expand Up @@ -361,15 +355,15 @@


def _format_metrics(
metrics: Union[str, Callable[..., Any], Metric, MetricCollection],
metrics: Union[str, Callable[..., Any], Metric, MetricDict],
metric_name: Optional[str] = None,
**metric_kwargs: Any,
) -> Union[Callable[..., Any], Metric, MetricCollection]:
) -> Union[Callable[..., Any], Metric, MetricDict]:
"""Format the metrics argument.

Parameters
----------
metrics : Union[str, Callable[..., Any], Metric, MetricCollection]
metrics : Union[str, Callable[..., Any], Metric, MetricDict]
The metrics to use for computing the metric results.
metric_name : str, optional, default=None
The name of the metric. This is only used if `metrics` is a callable.
Expand All @@ -379,23 +373,23 @@

Returns
-------
Union[Callable[..., Any], Metric, MetricCollection]
Union[Callable[..., Any], Metric, MetricDict]
The formatted metrics.

Raises
------
TypeError
If `metrics` is not of type `str`, `Callable`, `Metric`, or `MetricCollection`.
If `metrics` is not of type `str`, `Callable`, `Metric`, or `MetricDict`.

"""
if isinstance(metrics, str):
metrics = create_metric(metric_name=metrics, **metric_kwargs)
metrics = create_metric(metric_name=metrics, experimental=True, **metric_kwargs)

Check warning on line 386 in cyclops/evaluate/fairness/evaluator.py

View check run for this annotation

Codecov / codecov/patch

cyclops/evaluate/fairness/evaluator.py#L386

Added line #L386 was not covered by tests
if isinstance(metrics, Metric):
if metric_name is not None and isinstance(metrics, OperatorMetric):
# single metric created from arithmetic operation, with given name
return MetricCollection({metric_name: metrics})
return MetricCollection(metrics)
if isinstance(metrics, MetricCollection):
return MetricDict({metric_name: metrics})
return MetricDict(metrics)
if isinstance(metrics, MetricDict):

Check warning on line 392 in cyclops/evaluate/fairness/evaluator.py

View check run for this annotation

Codecov / codecov/patch

cyclops/evaluate/fairness/evaluator.py#L390-L392

Added lines #L390 - L392 were not covered by tests
return metrics
if callable(metrics):
if metric_name is None:
Expand All @@ -407,7 +401,7 @@
return metrics

raise TypeError(
f"Expected `metrics` to be of type `str`, `Metric`, `MetricCollection`, or "
f"Expected `metrics` to be of type `str`, `Metric`, `MetricDict`, or "
f"`Callable`, but got {type(metrics)}.",
)

Expand Down Expand Up @@ -701,7 +695,7 @@


def _compute_metrics( # noqa: C901, PLR0912
metrics: Union[Callable[..., Any], MetricCollection],
metrics: Union[Callable[..., Any], MetricDict],
dataset: Dataset,
target_columns: List[str],
prediction_column: str,
Expand All @@ -713,7 +707,7 @@

Parameters
----------
metrics : Union[Callable, MetricCollection]
metrics : Union[Callable, MetricDict]
The metrics to compute.
dataset : Dataset
The dataset to compute the metrics on.
Expand All @@ -738,12 +732,19 @@
"Encountered empty dataset while computing metrics. "
"The metric values will be set to `None`."
)
if isinstance(metrics, MetricCollection):
if isinstance(metrics, MetricDict):

Check warning on line 735 in cyclops/evaluate/fairness/evaluator.py

View check run for this annotation

Codecov / codecov/patch

cyclops/evaluate/fairness/evaluator.py#L735

Added line #L735 was not covered by tests
if threshold is not None:
# set the threshold for each metric in the collection
for name, metric in metrics.items():
if hasattr(metric, "threshold"):
if isinstance(metric, Metric) and hasattr(metric, "threshold"):

Check warning on line 739 in cyclops/evaluate/fairness/evaluator.py

View check run for this annotation

Codecov / codecov/patch

cyclops/evaluate/fairness/evaluator.py#L739

Added line #L739 was not covered by tests
metric.threshold = threshold
elif isinstance(metric, OperatorMetric):
if hasattr(metric.metric_a, "threshold") and hasattr(

Check warning on line 742 in cyclops/evaluate/fairness/evaluator.py

View check run for this annotation

Codecov / codecov/patch

cyclops/evaluate/fairness/evaluator.py#L741-L742

Added lines #L741 - L742 were not covered by tests
metric.metric_b,
"threshold",
):
metric.metric_a.threshold = threshold
metric.metric_b.threshold = threshold # type: ignore[union-attr]

Check warning on line 747 in cyclops/evaluate/fairness/evaluator.py

View check run for this annotation

Codecov / codecov/patch

cyclops/evaluate/fairness/evaluator.py#L746-L747

Added lines #L746 - L747 were not covered by tests
else:
LOGGER.warning(
"Metric %s does not have a threshold attribute. "
Expand All @@ -754,7 +755,7 @@
if len(dataset) == 0:
warnings.warn(empty_dataset_msg, RuntimeWarning, stacklevel=1)
results: Dict[str, Any] = {
metric_name: float("NaN") for metric_name in metrics
metric_name: float("NaN") for metric_name in metrics # type: ignore[attr-defined]
}
elif (
batch_size is None or batch_size <= 0
Expand All @@ -779,11 +780,11 @@
columns=prediction_column,
)

metrics.update_state(targets, predictions)
metrics.update(targets, predictions)

Check warning on line 783 in cyclops/evaluate/fairness/evaluator.py

View check run for this annotation

Codecov / codecov/patch

cyclops/evaluate/fairness/evaluator.py#L783

Added line #L783 was not covered by tests

results = metrics.compute()

metrics.reset_state()
metrics.reset()

Check warning on line 787 in cyclops/evaluate/fairness/evaluator.py

View check run for this annotation

Codecov / codecov/patch

cyclops/evaluate/fairness/evaluator.py#L787

Added line #L787 was not covered by tests

return results
if callable(metrics):
Expand Down Expand Up @@ -817,26 +818,26 @@
return {metric_name.title(): output}

raise TypeError(
"The `metrics` argument must be a string, a Metric, a MetricCollection, "
"The `metrics` argument must be a string, a Metric, a MetricDict, "
f"or a callable. Got {type(metrics)}.",
)


def _get_metric_results_for_prediction_and_slice(
metrics: Union[Callable[..., Any], MetricCollection],
metrics: Union[Callable[..., Any], MetricDict],
dataset: Dataset,
target_columns: List[str],
prediction_column: str,
slice_name: str,
batch_size: Optional[int] = config.DEFAULT_MAX_BATCH_SIZE,
metric_name: Optional[str] = None,
thresholds: Optional[npt.NDArray[np.float_]] = None,
thresholds: Optional[Array] = None,
) -> Dict[str, Dict[str, Any]]:
"""Compute metrics for a slice of a dataset.

Parameters
----------
metrics : Union[Callable, MetricCollection]
metrics : Union[Callable, MetricDict]
The metrics to compute.
dataset : Dataset
The dataset to compute the metrics on.
Expand All @@ -850,7 +851,7 @@
The batch size to use for the computation.
metric_name : Optional[str]
The name of the metric to compute.
thresholds : Optional[List[float]]
thresholds : Optional[Array]
The thresholds to use for the metrics.

Returns
Expand All @@ -873,7 +874,7 @@
return {slice_name: metric_output}

results: Dict[str, Dict[str, Any]] = {}
for threshold in thresholds:
for threshold in thresholds: # type: ignore[attr-defined]

Check warning on line 877 in cyclops/evaluate/fairness/evaluator.py

View check run for this annotation

Codecov / codecov/patch

cyclops/evaluate/fairness/evaluator.py#L877

Added line #L877 was not covered by tests
metric_output = _compute_metrics(
metrics=metrics,
dataset=dataset,
Expand Down Expand Up @@ -969,11 +970,7 @@
)

parity_results[key].setdefault(slice_name, {}).update(
{
parity_metric_name: _get_value_if_singleton_array(
parity_metric_value,
),
},
{parity_metric_name: parity_metric_value},
)

return parity_results
5 changes: 5 additions & 0 deletions cyclops/evaluate/metrics/experimental/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@
MulticlassAUROC,
MultilabelAUROC,
)
from cyclops.evaluate.metrics.experimental.average_precision import (
BinaryAveragePrecision,
MulticlassAveragePrecision,
MultilabelAveragePrecision,
)
from cyclops.evaluate.metrics.experimental.confusion_matrix import (
BinaryConfusionMatrix,
MulticlassConfusionMatrix,
Expand Down
Loading
Loading