Skip to content

Commit bfb7c5b

Browse files
committed
integrate experimental metrics with other modules
1 parent 5c4ebb2 commit bfb7c5b

File tree

6 files changed

+107
-142
lines changed

6 files changed

+107
-142
lines changed

cyclops/evaluate/evaluator.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
"""Evaluate one or more models on a dataset."""
2-
32
import logging
43
import warnings
54
from dataclasses import asdict
@@ -16,7 +15,9 @@
1615
)
1716
from cyclops.evaluate.fairness.config import FairnessConfig
1817
from cyclops.evaluate.fairness.evaluator import evaluate_fairness
19-
from cyclops.evaluate.metrics.metric import Metric, MetricCollection
18+
from cyclops.evaluate.metrics.experimental.metric import Metric
19+
from cyclops.evaluate.metrics.experimental.metric_dict import MetricDict
20+
from cyclops.evaluate.metrics.experimental.utils.types import Array
2021
from cyclops.evaluate.utils import _format_column_names, choose_split
2122
from cyclops.utils.log import setup_logging
2223

@@ -27,7 +28,7 @@
2728

2829
def evaluate(
2930
dataset: Union[str, Dataset, DatasetDict],
30-
metrics: Union[Metric, Sequence[Metric], Dict[str, Metric], MetricCollection],
31+
metrics: Union[Metric, Sequence[Metric], Dict[str, Metric], MetricDict],
3132
target_columns: Union[str, List[str]],
3233
prediction_columns: Union[str, List[str]],
3334
ignore_columns: Optional[Union[str, List[str]]] = None,
@@ -47,7 +48,7 @@ def evaluate(
4748
The dataset to evaluate on. If a string, the dataset will be loaded
4849
using `datasets.load_dataset`. If `DatasetDict`, the `split` argument
4950
must be specified.
50-
metrics : Union[Metric, Sequence[Metric], Dict[str, Metric], MetricCollection]
51+
metrics : Union[Metric, Sequence[Metric], Dict[str, Metric], MetricDict]
5152
The metrics to compute.
5253
target_columns : Union[str, List[str]]
5354
The name of the column(s) containing the target values. A string value
@@ -202,28 +203,28 @@ def _load_data(
202203

203204

204205
def _prepare_metrics(
205-
metrics: Union[Metric, Sequence[Metric], Dict[str, Metric], MetricCollection],
206-
) -> MetricCollection:
206+
metrics: Union[Metric, Sequence[Metric], Dict[str, Metric], MetricDict],
207+
) -> MetricDict:
207208
"""Prepare metrics for evaluation."""
208-
# TODO: wrap in BootstrappedMetric if computing confidence intervals
209+
# TODO [fcogidi]: wrap in BootstrappedMetric if computing confidence intervals
209210
if isinstance(metrics, (Metric, Sequence, Dict)) and not isinstance(
210211
metrics,
211-
MetricCollection,
212+
MetricDict,
212213
):
213-
return MetricCollection(metrics)
214-
if isinstance(metrics, MetricCollection):
214+
return MetricDict(metrics) # type: ignore[arg-type]
215+
if isinstance(metrics, MetricDict):
215216
return metrics
216217

217218
raise TypeError(
218219
f"Invalid type for `metrics`: {type(metrics)}. "
219220
"Expected one of: Metric, Sequence[Metric], Dict[str, Metric], "
220-
"MetricCollection.",
221+
"MetricDict.",
221222
)
222223

223224

224225
def _compute_metrics(
225226
dataset: Dataset,
226-
metrics: MetricCollection,
227+
metrics: MetricDict,
227228
slice_spec: SliceSpec,
228229
target_columns: Union[str, List[str]],
229230
prediction_columns: Union[str, List[str]],
@@ -266,8 +267,8 @@ def _compute_metrics(
266267
RuntimeWarning,
267268
stacklevel=1,
268269
)
269-
metric_output = {
270-
metric_name: float("NaN") for metric_name in metrics
270+
metric_output: Dict[str, Array] = {
271+
metric_name: float("NaN") for metric_name in metrics # type: ignore[attr-defined,misc]
271272
}
272273
elif (
273274
batch_size is None or batch_size < 0
@@ -293,10 +294,10 @@ def _compute_metrics(
293294
)
294295

295296
# update the metric state
296-
metrics.update_state(targets, predictions)
297+
metrics.update(targets, predictions)
297298

298299
metric_output = metrics.compute()
299-
metrics.reset_state()
300+
metrics.reset()
300301

301302
model_name: str = "model_for_%s" % prediction_column
302303
results.setdefault(model_name, {})

cyclops/evaluate/fairness/config.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,15 @@
55

66
from datasets import Dataset, config
77

8-
from cyclops.evaluate.metrics.metric import Metric, MetricCollection
8+
from cyclops.evaluate.metrics.experimental.metric import Metric
9+
from cyclops.evaluate.metrics.experimental.metric_dict import MetricDict
910

1011

1112
@dataclass
1213
class FairnessConfig:
1314
"""Configuration for fairness metrics."""
1415

15-
metrics: Union[str, Callable[..., Any], Metric, MetricCollection]
16+
metrics: Union[str, Callable[..., Any], Metric, MetricDict]
1617
dataset: Dataset
1718
groups: Union[str, List[str]]
1819
target_columns: Union[str, List[str]]

cyclops/evaluate/fairness/evaluator.py

Lines changed: 35 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
11
"""Fairness evaluator."""
2-
32
import inspect
43
import itertools
54
import logging
65
import warnings
76
from datetime import datetime
87
from typing import Any, Callable, Dict, List, Optional, Union
98

9+
import array_api_compat.numpy
1010
import numpy as np
11-
import numpy.typing as npt
1211
import pandas as pd
1312
from datasets import Dataset, config
1413
from datasets.features import Features
@@ -21,15 +20,14 @@
2120
get_columns_as_numpy_array,
2221
set_decode,
2322
)
24-
from cyclops.evaluate.metrics.factory import create_metric
25-
from cyclops.evaluate.metrics.functional.precision_recall_curve import (
23+
from cyclops.evaluate.metrics.experimental.functional.precision_recall_curve import (
2624
_format_thresholds,
25+
_validate_thresholds,
2726
)
28-
from cyclops.evaluate.metrics.metric import Metric, MetricCollection, OperatorMetric
29-
from cyclops.evaluate.metrics.utils import (
30-
_check_thresholds,
31-
_get_value_if_singleton_array,
32-
)
27+
from cyclops.evaluate.metrics.experimental.metric import Metric, OperatorMetric
28+
from cyclops.evaluate.metrics.experimental.metric_dict import MetricDict
29+
from cyclops.evaluate.metrics.experimental.utils.types import Array
30+
from cyclops.evaluate.metrics.factory import create_metric
3331
from cyclops.evaluate.utils import _format_column_names
3432
from cyclops.utils.log import setup_logging
3533

@@ -39,7 +37,7 @@
3937

4038

4139
def evaluate_fairness(
42-
metrics: Union[str, Callable[..., Any], Metric, MetricCollection],
40+
metrics: Union[str, Callable[..., Any], Metric, MetricDict],
4341
dataset: Dataset,
4442
groups: Union[str, List[str]],
4543
target_columns: Union[str, List[str]],
@@ -62,7 +60,7 @@ def evaluate_fairness(
6260
6361
Parameters
6462
----------
65-
metrics : Union[str, Callable[..., Any], Metric, MetricCollection]
63+
metrics : Union[str, Callable[..., Any], Metric, MetricDict]
6664
The metric or metrics to compute. If a string, it should be the name of a
6765
metric provided by CyclOps. If a callable, it should be a function that
6866
takes target, prediction, and optionally threshold/thresholds as arguments
@@ -147,18 +145,14 @@ def evaluate_fairness(
147145
raise TypeError(
148146
"Expected `dataset` to be of type `Dataset`, but got " f"{type(dataset)}.",
149147
)
148+
_validate_thresholds(thresholds)
150149

151-
_check_thresholds(thresholds)
152-
fmt_thresholds: npt.NDArray[np.float_] = _format_thresholds( # type: ignore
153-
thresholds,
154-
)
155-
156-
metrics_: Union[Callable[..., Any], MetricCollection] = _format_metrics(
150+
metrics_: Union[Callable[..., Any], MetricDict] = _format_metrics(
157151
metrics,
158152
metric_name,
159153
**(metric_kwargs or {}),
160154
)
161-
155+
fmt_thresholds = _format_thresholds(thresholds, xp=array_api_compat.numpy)
162156
fmt_groups: List[str] = _format_column_names(groups)
163157
fmt_target_columns: List[str] = _format_column_names(target_columns)
164158
fmt_prediction_columns: List[str] = _format_column_names(prediction_columns)
@@ -361,15 +355,15 @@ def warn_too_many_unique_values(
361355

362356

363357
def _format_metrics(
364-
metrics: Union[str, Callable[..., Any], Metric, MetricCollection],
358+
metrics: Union[str, Callable[..., Any], Metric, MetricDict],
365359
metric_name: Optional[str] = None,
366360
**metric_kwargs: Any,
367-
) -> Union[Callable[..., Any], Metric, MetricCollection]:
361+
) -> Union[Callable[..., Any], Metric, MetricDict]:
368362
"""Format the metrics argument.
369363
370364
Parameters
371365
----------
372-
metrics : Union[str, Callable[..., Any], Metric, MetricCollection]
366+
metrics : Union[str, Callable[..., Any], Metric, MetricDict]
373367
The metrics to use for computing the metric results.
374368
metric_name : str, optional, default=None
375369
The name of the metric. This is only used if `metrics` is a callable.
@@ -379,23 +373,23 @@ def _format_metrics(
379373
380374
Returns
381375
-------
382-
Union[Callable[..., Any], Metric, MetricCollection]
376+
Union[Callable[..., Any], Metric, MetricDict]
383377
The formatted metrics.
384378
385379
Raises
386380
------
387381
TypeError
388-
If `metrics` is not of type `str`, `Callable`, `Metric`, or `MetricCollection`.
382+
If `metrics` is not of type `str`, `Callable`, `Metric`, or `MetricDict`.
389383
390384
"""
391385
if isinstance(metrics, str):
392-
metrics = create_metric(metric_name=metrics, **metric_kwargs)
386+
metrics = create_metric(metric_name=metrics, experimental=True, **metric_kwargs)
393387
if isinstance(metrics, Metric):
394388
if metric_name is not None and isinstance(metrics, OperatorMetric):
395389
# single metric created from arithmetic operation, with given name
396-
return MetricCollection({metric_name: metrics})
397-
return MetricCollection(metrics)
398-
if isinstance(metrics, MetricCollection):
390+
return MetricDict({metric_name: metrics})
391+
return MetricDict(metrics)
392+
if isinstance(metrics, MetricDict):
399393
return metrics
400394
if callable(metrics):
401395
if metric_name is None:
@@ -407,7 +401,7 @@ def _format_metrics(
407401
return metrics
408402

409403
raise TypeError(
410-
f"Expected `metrics` to be of type `str`, `Metric`, `MetricCollection`, or "
404+
f"Expected `metrics` to be of type `str`, `Metric`, `MetricDict`, or "
411405
f"`Callable`, but got {type(metrics)}.",
412406
)
413407

@@ -701,7 +695,7 @@ def _get_slice_spec(
701695

702696

703697
def _compute_metrics( # noqa: C901, PLR0912
704-
metrics: Union[Callable[..., Any], MetricCollection],
698+
metrics: Union[Callable[..., Any], MetricDict],
705699
dataset: Dataset,
706700
target_columns: List[str],
707701
prediction_column: str,
@@ -713,7 +707,7 @@ def _compute_metrics( # noqa: C901, PLR0912
713707
714708
Parameters
715709
----------
716-
metrics : Union[Callable, MetricCollection]
710+
metrics : Union[Callable, MetricDict]
717711
The metrics to compute.
718712
dataset : Dataset
719713
The dataset to compute the metrics on.
@@ -738,7 +732,7 @@ def _compute_metrics( # noqa: C901, PLR0912
738732
"Encountered empty dataset while computing metrics. "
739733
"The metric values will be set to `None`."
740734
)
741-
if isinstance(metrics, MetricCollection):
735+
if isinstance(metrics, MetricDict):
742736
if threshold is not None:
743737
# set the threshold for each metric in the collection
744738
for name, metric in metrics.items():
@@ -754,7 +748,7 @@ def _compute_metrics( # noqa: C901, PLR0912
754748
if len(dataset) == 0:
755749
warnings.warn(empty_dataset_msg, RuntimeWarning, stacklevel=1)
756750
results: Dict[str, Any] = {
757-
metric_name: float("NaN") for metric_name in metrics
751+
metric_name: float("NaN") for metric_name in metrics # type: ignore[attr-defined]
758752
}
759753
elif (
760754
batch_size is None or batch_size <= 0
@@ -779,11 +773,11 @@ def _compute_metrics( # noqa: C901, PLR0912
779773
columns=prediction_column,
780774
)
781775

782-
metrics.update_state(targets, predictions)
776+
metrics.update(targets, predictions)
783777

784778
results = metrics.compute()
785779

786-
metrics.reset_state()
780+
metrics.reset()
787781

788782
return results
789783
if callable(metrics):
@@ -817,26 +811,26 @@ def _compute_metrics( # noqa: C901, PLR0912
817811
return {metric_name.title(): output}
818812

819813
raise TypeError(
820-
"The `metrics` argument must be a string, a Metric, a MetricCollection, "
814+
"The `metrics` argument must be a string, a Metric, a MetricDict, "
821815
f"or a callable. Got {type(metrics)}.",
822816
)
823817

824818

825819
def _get_metric_results_for_prediction_and_slice(
826-
metrics: Union[Callable[..., Any], MetricCollection],
820+
metrics: Union[Callable[..., Any], MetricDict],
827821
dataset: Dataset,
828822
target_columns: List[str],
829823
prediction_column: str,
830824
slice_name: str,
831825
batch_size: Optional[int] = config.DEFAULT_MAX_BATCH_SIZE,
832826
metric_name: Optional[str] = None,
833-
thresholds: Optional[npt.NDArray[np.float_]] = None,
827+
thresholds: Optional[Array] = None,
834828
) -> Dict[str, Dict[str, Any]]:
835829
"""Compute metrics for a slice of a dataset.
836830
837831
Parameters
838832
----------
839-
metrics : Union[Callable, MetricCollection]
833+
metrics : Union[Callable, MetricDict]
840834
The metrics to compute.
841835
dataset : Dataset
842836
The dataset to compute the metrics on.
@@ -850,7 +844,7 @@ def _get_metric_results_for_prediction_and_slice(
850844
The batch size to use for the computation.
851845
metric_name : Optional[str]
852846
The name of the metric to compute.
853-
thresholds : Optional[List[float]]
847+
thresholds : Optional[Array]
854848
The thresholds to use for the metrics.
855849
856850
Returns
@@ -873,7 +867,7 @@ def _get_metric_results_for_prediction_and_slice(
873867
return {slice_name: metric_output}
874868

875869
results: Dict[str, Dict[str, Any]] = {}
876-
for threshold in thresholds:
870+
for threshold in thresholds: # type: ignore[attr-defined]
877871
metric_output = _compute_metrics(
878872
metrics=metrics,
879873
dataset=dataset,
@@ -969,11 +963,7 @@ def _compute_parity_metrics(
969963
)
970964

971965
parity_results[key].setdefault(slice_name, {}).update(
972-
{
973-
parity_metric_name: _get_value_if_singleton_array(
974-
parity_metric_value,
975-
),
976-
},
966+
{parity_metric_name: parity_metric_value},
977967
)
978968

979969
return parity_results

0 commit comments

Comments
 (0)