1
1
"""Fairness evaluator."""
2
-
3
2
import inspect
4
3
import itertools
5
4
import logging
6
5
import warnings
7
6
from datetime import datetime
8
7
from typing import Any , Callable , Dict , List , Optional , Union
9
8
9
+ import array_api_compat .numpy
10
10
import numpy as np
11
- import numpy .typing as npt
12
11
import pandas as pd
13
12
from datasets import Dataset , config
14
13
from datasets .features import Features
21
20
get_columns_as_numpy_array ,
22
21
set_decode ,
23
22
)
24
- from cyclops .evaluate .metrics .factory import create_metric
25
- from cyclops .evaluate .metrics .functional .precision_recall_curve import (
23
+ from cyclops .evaluate .metrics .experimental .functional .precision_recall_curve import (
26
24
_format_thresholds ,
25
+ _validate_thresholds ,
27
26
)
28
- from cyclops .evaluate .metrics .metric import Metric , MetricCollection , OperatorMetric
29
- from cyclops .evaluate .metrics .utils import (
30
- _check_thresholds ,
31
- _get_value_if_singleton_array ,
32
- )
27
+ from cyclops .evaluate .metrics .experimental .metric import Metric , OperatorMetric
28
+ from cyclops .evaluate .metrics .experimental .metric_dict import MetricDict
29
+ from cyclops .evaluate .metrics .experimental .utils .types import Array
30
+ from cyclops .evaluate .metrics .factory import create_metric
33
31
from cyclops .evaluate .utils import _format_column_names
34
32
from cyclops .utils .log import setup_logging
35
33
39
37
40
38
41
39
def evaluate_fairness (
42
- metrics : Union [str , Callable [..., Any ], Metric , MetricCollection ],
40
+ metrics : Union [str , Callable [..., Any ], Metric , MetricDict ],
43
41
dataset : Dataset ,
44
42
groups : Union [str , List [str ]],
45
43
target_columns : Union [str , List [str ]],
@@ -62,7 +60,7 @@ def evaluate_fairness(
62
60
63
61
Parameters
64
62
----------
65
- metrics : Union[str, Callable[..., Any], Metric, MetricCollection ]
63
+ metrics : Union[str, Callable[..., Any], Metric, MetricDict ]
66
64
The metric or metrics to compute. If a string, it should be the name of a
67
65
metric provided by CyclOps. If a callable, it should be a function that
68
66
takes target, prediction, and optionally threshold/thresholds as arguments
@@ -147,18 +145,14 @@ def evaluate_fairness(
147
145
raise TypeError (
148
146
"Expected `dataset` to be of type `Dataset`, but got " f"{ type (dataset )} ." ,
149
147
)
148
+ _validate_thresholds (thresholds )
150
149
151
- _check_thresholds (thresholds )
152
- fmt_thresholds : npt .NDArray [np .float_ ] = _format_thresholds ( # type: ignore
153
- thresholds ,
154
- )
155
-
156
- metrics_ : Union [Callable [..., Any ], MetricCollection ] = _format_metrics (
150
+ metrics_ : Union [Callable [..., Any ], MetricDict ] = _format_metrics (
157
151
metrics ,
158
152
metric_name ,
159
153
** (metric_kwargs or {}),
160
154
)
161
-
155
+ fmt_thresholds = _format_thresholds ( thresholds , xp = array_api_compat . numpy )
162
156
fmt_groups : List [str ] = _format_column_names (groups )
163
157
fmt_target_columns : List [str ] = _format_column_names (target_columns )
164
158
fmt_prediction_columns : List [str ] = _format_column_names (prediction_columns )
@@ -361,15 +355,15 @@ def warn_too_many_unique_values(
361
355
362
356
363
357
def _format_metrics (
364
- metrics : Union [str , Callable [..., Any ], Metric , MetricCollection ],
358
+ metrics : Union [str , Callable [..., Any ], Metric , MetricDict ],
365
359
metric_name : Optional [str ] = None ,
366
360
** metric_kwargs : Any ,
367
- ) -> Union [Callable [..., Any ], Metric , MetricCollection ]:
361
+ ) -> Union [Callable [..., Any ], Metric , MetricDict ]:
368
362
"""Format the metrics argument.
369
363
370
364
Parameters
371
365
----------
372
- metrics : Union[str, Callable[..., Any], Metric, MetricCollection ]
366
+ metrics : Union[str, Callable[..., Any], Metric, MetricDict ]
373
367
The metrics to use for computing the metric results.
374
368
metric_name : str, optional, default=None
375
369
The name of the metric. This is only used if `metrics` is a callable.
@@ -379,23 +373,23 @@ def _format_metrics(
379
373
380
374
Returns
381
375
-------
382
- Union[Callable[..., Any], Metric, MetricCollection ]
376
+ Union[Callable[..., Any], Metric, MetricDict ]
383
377
The formatted metrics.
384
378
385
379
Raises
386
380
------
387
381
TypeError
388
- If `metrics` is not of type `str`, `Callable`, `Metric`, or `MetricCollection `.
382
+ If `metrics` is not of type `str`, `Callable`, `Metric`, or `MetricDict `.
389
383
390
384
"""
391
385
if isinstance (metrics , str ):
392
- metrics = create_metric (metric_name = metrics , ** metric_kwargs )
386
+ metrics = create_metric (metric_name = metrics , experimental = True , ** metric_kwargs )
393
387
if isinstance (metrics , Metric ):
394
388
if metric_name is not None and isinstance (metrics , OperatorMetric ):
395
389
# single metric created from arithmetic operation, with given name
396
- return MetricCollection ({metric_name : metrics })
397
- return MetricCollection (metrics )
398
- if isinstance (metrics , MetricCollection ):
390
+ return MetricDict ({metric_name : metrics })
391
+ return MetricDict (metrics )
392
+ if isinstance (metrics , MetricDict ):
399
393
return metrics
400
394
if callable (metrics ):
401
395
if metric_name is None :
@@ -407,7 +401,7 @@ def _format_metrics(
407
401
return metrics
408
402
409
403
raise TypeError (
410
- f"Expected `metrics` to be of type `str`, `Metric`, `MetricCollection `, or "
404
+ f"Expected `metrics` to be of type `str`, `Metric`, `MetricDict `, or "
411
405
f"`Callable`, but got { type (metrics )} ." ,
412
406
)
413
407
@@ -701,7 +695,7 @@ def _get_slice_spec(
701
695
702
696
703
697
def _compute_metrics ( # noqa: C901, PLR0912
704
- metrics : Union [Callable [..., Any ], MetricCollection ],
698
+ metrics : Union [Callable [..., Any ], MetricDict ],
705
699
dataset : Dataset ,
706
700
target_columns : List [str ],
707
701
prediction_column : str ,
@@ -713,7 +707,7 @@ def _compute_metrics( # noqa: C901, PLR0912
713
707
714
708
Parameters
715
709
----------
716
- metrics : Union[Callable, MetricCollection ]
710
+ metrics : Union[Callable, MetricDict ]
717
711
The metrics to compute.
718
712
dataset : Dataset
719
713
The dataset to compute the metrics on.
@@ -738,7 +732,7 @@ def _compute_metrics( # noqa: C901, PLR0912
738
732
"Encountered empty dataset while computing metrics. "
739
733
"The metric values will be set to `None`."
740
734
)
741
- if isinstance (metrics , MetricCollection ):
735
+ if isinstance (metrics , MetricDict ):
742
736
if threshold is not None :
743
737
# set the threshold for each metric in the collection
744
738
for name , metric in metrics .items ():
@@ -754,7 +748,7 @@ def _compute_metrics( # noqa: C901, PLR0912
754
748
if len (dataset ) == 0 :
755
749
warnings .warn (empty_dataset_msg , RuntimeWarning , stacklevel = 1 )
756
750
results : Dict [str , Any ] = {
757
- metric_name : float ("NaN" ) for metric_name in metrics
751
+ metric_name : float ("NaN" ) for metric_name in metrics # type: ignore[attr-defined]
758
752
}
759
753
elif (
760
754
batch_size is None or batch_size <= 0
@@ -779,11 +773,11 @@ def _compute_metrics( # noqa: C901, PLR0912
779
773
columns = prediction_column ,
780
774
)
781
775
782
- metrics .update_state (targets , predictions )
776
+ metrics .update (targets , predictions )
783
777
784
778
results = metrics .compute ()
785
779
786
- metrics .reset_state ()
780
+ metrics .reset ()
787
781
788
782
return results
789
783
if callable (metrics ):
@@ -817,26 +811,26 @@ def _compute_metrics( # noqa: C901, PLR0912
817
811
return {metric_name .title (): output }
818
812
819
813
raise TypeError (
820
- "The `metrics` argument must be a string, a Metric, a MetricCollection , "
814
+ "The `metrics` argument must be a string, a Metric, a MetricDict , "
821
815
f"or a callable. Got { type (metrics )} ." ,
822
816
)
823
817
824
818
825
819
def _get_metric_results_for_prediction_and_slice (
826
- metrics : Union [Callable [..., Any ], MetricCollection ],
820
+ metrics : Union [Callable [..., Any ], MetricDict ],
827
821
dataset : Dataset ,
828
822
target_columns : List [str ],
829
823
prediction_column : str ,
830
824
slice_name : str ,
831
825
batch_size : Optional [int ] = config .DEFAULT_MAX_BATCH_SIZE ,
832
826
metric_name : Optional [str ] = None ,
833
- thresholds : Optional [npt . NDArray [ np . float_ ] ] = None ,
827
+ thresholds : Optional [Array ] = None ,
834
828
) -> Dict [str , Dict [str , Any ]]:
835
829
"""Compute metrics for a slice of a dataset.
836
830
837
831
Parameters
838
832
----------
839
- metrics : Union[Callable, MetricCollection ]
833
+ metrics : Union[Callable, MetricDict ]
840
834
The metrics to compute.
841
835
dataset : Dataset
842
836
The dataset to compute the metrics on.
@@ -850,7 +844,7 @@ def _get_metric_results_for_prediction_and_slice(
850
844
The batch size to use for the computation.
851
845
metric_name : Optional[str]
852
846
The name of the metric to compute.
853
- thresholds : Optional[List[float] ]
847
+ thresholds : Optional[Array ]
854
848
The thresholds to use for the metrics.
855
849
856
850
Returns
@@ -873,7 +867,7 @@ def _get_metric_results_for_prediction_and_slice(
873
867
return {slice_name : metric_output }
874
868
875
869
results : Dict [str , Dict [str , Any ]] = {}
876
- for threshold in thresholds :
870
+ for threshold in thresholds : # type: ignore[attr-defined]
877
871
metric_output = _compute_metrics (
878
872
metrics = metrics ,
879
873
dataset = dataset ,
@@ -969,11 +963,7 @@ def _compute_parity_metrics(
969
963
)
970
964
971
965
parity_results [key ].setdefault (slice_name , {}).update (
972
- {
973
- parity_metric_name : _get_value_if_singleton_array (
974
- parity_metric_value ,
975
- ),
976
- },
966
+ {parity_metric_name : parity_metric_value },
977
967
)
978
968
979
969
return parity_results
0 commit comments