Skip to content

Commit f5809fa

Browse files
committed
Fix type annotations and formatting issues
1 parent 8c968e1 commit f5809fa

File tree

7 files changed

+17
-15
lines changed

7 files changed

+17
-15
lines changed

cyclops/evaluate/fairness/evaluator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -744,7 +744,7 @@ def _compute_metrics( # noqa: C901, PLR0912
744744
"threshold",
745745
):
746746
metric.metric_a.threshold = threshold
747-
metric.metric_b.threshold = threshold
747+
metric.metric_b.threshold = threshold # type: ignore[union-attr]
748748
else:
749749
LOGGER.warning(
750750
"Metric %s does not have a threshold attribute. "

cyclops/evaluate/metrics/experimental/auroc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def __init__(
7373
self.max_fpr = max_fpr
7474

7575
def _compute_metric(self) -> Array: # type: ignore[override]
76-
"""Compute the AUROC.""" ""
76+
"""Compute the AUROC."""
7777
state = (
7878
(dim_zero_cat(self.target), dim_zero_cat(self.preds)) # type: ignore[attr-defined]
7979
if self.thresholds is None

cyclops/evaluate/metrics/experimental/average_precision.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,18 +43,18 @@ class BinaryAveragePrecision(
4343
>>> metric(target, preds)
4444
Array(0.75, dtype=float32)
4545
>>> metric.reset()
46-
>>> target = anp.asarray([[0, 1, 0, 1], [1, 1, 0, 0]])
47-
>>> preds = anp.asarray([[0.1, 0.4, 0.35, 0.8], [0.6, 0.3, 0.1, 0.7]])
46+
>>> target = [[0, 1, 0, 1], [1, 1, 0, 0]]
47+
>>> preds = [[0.1, 0.4, 0.35, 0.8], [0.6, 0.3, 0.1, 0.7]]
4848
>>> for t, p in zip(target, preds):
49-
... metric.update(t, p)
49+
... metric.update(anp.asarray(t), anp.asarray(p))
5050
>>> metric.compute()
51-
Array(0.5833333333333333, dtype=float32)
51+
Array(0.5833334, dtype=float32)
5252
5353
"""
5454

5555
name: str = "Average Precision"
5656

57-
def _compute_metric(self) -> Array:
57+
def _compute_metric(self) -> Array: # type: ignore[override]
5858
"""Compute the metric."""
5959
state = (
6060
(dim_zero_cat(self.target), dim_zero_cat(self.preds)) # type: ignore[attr-defined]

cyclops/evaluate/metrics/experimental/functional/average_precision.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def _binary_average_precision_compute(
5050
pos_label,
5151
)
5252
xp = apc.array_namespace(precision, recall)
53-
return -xp.sum(_diff(recall) * precision[:-1]) # type: ignore
53+
return -xp.sum(_diff(recall) * precision[:-1], dtype=xp.float32) # type: ignore
5454

5555

5656
def binary_average_precision(
@@ -128,13 +128,13 @@ def binary_average_precision(
128128
Examples
129129
--------
130130
>>> import numpy.array_api as anp
131-
>>> from cyclops.evaluate.metrics..experimental.functional import (
131+
>>> from cyclops.evaluate.metrics.experimental.functional import (
132132
... binary_average_precision
133133
... )
134134
>>> target = anp.asarray([0, 1, 1, 0])
135135
>>> preds = anp.asarray([0, 0.5, 0.7, 0.8])
136136
>>> binary_average_precision(target, preds, thresholds=None)
137-
Array(0.5833333333333333, dtype=float32)
137+
Array(0.5833334, dtype=float32)
138138
139139
"""
140140
_binary_precision_recall_curve_validate_args(thresholds, ignore_index)

cyclops/tasks/classification.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -348,9 +348,11 @@ def evaluate(
348348

349349
# select the probability scores of the positive class since metrics
350350
# expect a single column of probabilities
351-
dataset = dataset.map(
351+
dataset = dataset.map( # type: ignore[union-attr]
352352
lambda examples: {
353-
f"{prediction_column_prefix}.{model_name}": np.array(examples)[
353+
f"{prediction_column_prefix}.{model_name}": np.array( # noqa: B023
354+
examples,
355+
)[
354356
:,
355357
1,
356358
].tolist(),

docs/source/tutorials/kaggle/heart_failure_prediction.ipynb

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -858,8 +858,9 @@
858858
" metrics=MetricDict(\n",
859859
" {\n",
860860
" \"BinaryAccuracy\": create_metric(\n",
861-
" metric_name=\"binary_accuracy\", experimental=True\n",
862-
" )\n",
861+
" metric_name=\"binary_accuracy\",\n",
862+
" experimental=True,\n",
863+
" ),\n",
863864
" },\n",
864865
" ),\n",
865866
" model_names=model_name,\n",

docs/source/tutorials/nihcxr/cxr_classification.ipynb

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
"\n",
3030
"import shutil\n",
3131
"from functools import partial\n",
32-
"from typing import Optional\n",
3332
"\n",
3433
"import numpy as np\n",
3534
"import plotly.express as px\n",

0 commit comments

Comments
 (0)