Skip to content

Commit ca6a26b

Browse files
authored
Fix SliceSpec and evaluator (#619)
* fixes for slicer and evaluators * fix state reference creation in `MetricDict`
1 parent d781e3d commit ca6a26b

File tree

7 files changed

+25
-15
lines changed

7 files changed

+25
-15
lines changed

cyclops/data/slicer.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import copy
44
import datetime
55
import itertools
6+
import json
67
from dataclasses import dataclass, field
78
from functools import partial
89
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
@@ -248,6 +249,19 @@ def _create_intersections(self) -> None:
248249
)
249250
self.spec_list.extend(intersect_list)
250251

252+
# remove duplicates
253+
seen = set()
254+
result = []
255+
256+
for spec in self.spec_list:
257+
spec_str = json.dumps(spec, sort_keys=True)
258+
if spec_str not in seen:
259+
seen.add(spec_str)
260+
result.append(spec)
261+
262+
seen.clear()
263+
self.spec_list = result
264+
251265
def _parse_and_register_slice_specs(
252266
self,
253267
slice_spec: Dict[str, Dict[str, Any]],

cyclops/evaluate/evaluator.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,9 @@ def evaluate(
152152
fairness_config.batch_size = batch_size
153153
fairness_config.remove_columns = ignore_columns
154154

155-
fairness_results = evaluate_fairness(**asdict(fairness_config))
155+
fairness_results = evaluate_fairness(
156+
**asdict(fairness_config), array_lib=array_lib
157+
)
156158
results["fairness"] = fairness_results
157159

158160
return results
@@ -304,7 +306,7 @@ def _compute_metrics(
304306
metrics.update(targets, predictions)
305307

306308
metric_output = metrics.compute()
307-
metrics.reset()
309+
metrics.reset()
308310

309311
model_name: str = "model_for_%s" % prediction_column
310312
results.setdefault(model_name, {})

cyclops/evaluate/fairness/evaluator.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -728,6 +728,9 @@ def _compute_metrics( # noqa: C901, PLR0912
728728
The batch size to use for the computation.
729729
metric_name : Optional[str]
730730
The name of the metric to compute.
731+
array_lib : {"torch", "numpy, "cupy"}, default="numpy"
732+
The array library to use for the metric computation. The metric results
733+
will be returned in the format of `array_lib`.
731734
732735
Returns
733736
-------

cyclops/evaluate/metrics/experimental/metric.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -431,6 +431,7 @@ def reset(self) -> None:
431431
"object or a list of array API objects. But got "
432432
f"`{type(default_value)} instead.",
433433
)
434+
self._defaults = {}
434435

435436
self._update_count = 0
436437
self._computed = None

cyclops/evaluate/metrics/experimental/metric_dict.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -361,7 +361,7 @@ def deepcopy_state(obj: Any) -> Any:
361361
for metric_names in self._metric_groups.values():
362362
base_metric = self.data[metric_names[0]]
363363
for metric_name in metric_names[1:]:
364-
for state in self.data[metric_name]._defaults:
364+
for state in base_metric._defaults:
365365
base_metric_state = getattr(base_metric, state)
366366
setattr(
367367
self.data[metric_name],

docs/source/tutorials/synthea/los_prediction.ipynb

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1068,16 +1068,6 @@
10681068
")"
10691069
]
10701070
},
1071-
{
1072-
"cell_type": "code",
1073-
"execution_count": null,
1074-
"id": "172a1654",
1075-
"metadata": {},
1076-
"outputs": [],
1077-
"source": [
1078-
"results"
1079-
]
1080-
},
10811071
{
10821072
"cell_type": "markdown",
10831073
"id": "7d2d1d75-f7d8-44d3-a782-2aba9a4fbac0",

tests/cyclops/evaluate/metrics/experimental/test_metric.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,7 @@ def test_reset_compute():
331331
anp.asarray(42, dtype=anp.float32),
332332
)
333333
metric.reset()
334-
assert metric.state_vars == {"x": anp.asarray(0, dtype=anp.float32)}
334+
assert metric.state_vars == {}
335335

336336

337337
def test_error_on_compute_before_update():
@@ -397,7 +397,7 @@ def test_call():
397397
assert metric._computed is None
398398

399399
metric.reset()
400-
assert metric.state_vars == {"x": anp.asarray(0, dtype=anp.float32)}
400+
assert metric.state_vars == {}
401401
assert metric._computed is None
402402

403403

0 commit comments

Comments
 (0)