Skip to content

Commit

Permalink
Fix type annotations for metric values in ClassificationPlotter
Browse files Browse the repository at this point in the history
  • Loading branch information
fcogidi committed Jan 29, 2024
1 parent 7fa71d7 commit 46ef980
Showing 1 changed file with 12 additions and 6 deletions.
18 changes: 12 additions & 6 deletions cyclops/report/plot/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def roc_curve(
if auroc is not None:
assert isinstance(
auroc,
float,
(float, np.floating),
), "AUROCs must be a float for binary tasks"
name = f"Model (AUC = {auroc:.2f})"
else:
Expand Down Expand Up @@ -401,7 +401,7 @@ def precision_recall_curve_comparison(
if auprcs and slice_name in auprcs:
assert isinstance(
auprcs[slice_name],
float,
(float, np.floating),
), "AUPRCs must be a float for binary tasks"
name = f"{slice_name} (AUC = {auprcs[slice_name]:.2f})"
else:
Expand Down Expand Up @@ -707,7 +707,8 @@ def metrics_comparison_radar(
for slice_name, metrics in slice_metrics.items():
metric_names = list(metrics.keys())
assert all(
not isinstance(value, (list, np.ndarray))
not isinstance(value, list)
and not (isinstance(value, np.ndarray) and value.ndim > 0)
for value in metrics.values()
), (
"Generic metrics must not be of type list or np.ndarray for"
Expand All @@ -727,7 +728,9 @@ def metrics_comparison_radar(
radial_data: List[float] = []
theta_data: List[float] = []
for metric_name, metric_values in metrics.items():
if isinstance(metric_values, (list, np.ndarray)):
if isinstance(metric_values, list) or (

Check warning on line 731 in cyclops/report/plot/classification.py

View check run for this annotation

Codecov / codecov/patch

cyclops/report/plot/classification.py#L731

Added line #L731 was not covered by tests
isinstance(metric_values, np.ndarray) and metric_values.ndim > 0
):
assert (
len(metric_values) == self.class_num
), "Metric values must be of length class_num for \
Expand All @@ -738,7 +741,7 @@ def metrics_comparison_radar(
for i in range(self.class_num)
]
theta_data.extend(theta) # type: ignore[arg-type]
elif isinstance(metric_values, float):
elif isinstance(metric_values, (float, np.floating)):

Check warning on line 744 in cyclops/report/plot/classification.py

View check run for this annotation

Codecov / codecov/patch

cyclops/report/plot/classification.py#L744

Added line #L744 was not covered by tests
radial_data.append(metric_values)
theta_data.append(metric_name) # type: ignore[arg-type]
else:
Expand Down Expand Up @@ -859,7 +862,10 @@ def metrics_comparison_bar(
metric_names = list(metrics.keys())
for num in range(self.class_num):
for metric_name in metric_names:
if isinstance(metrics[metric_name], (list, np.ndarray)):
if isinstance(metrics[metric_name], list) or (

Check warning on line 865 in cyclops/report/plot/classification.py

View check run for this annotation

Codecov / codecov/patch

cyclops/report/plot/classification.py#L865

Added line #L865 was not covered by tests
isinstance(metrics[metric_name], np.ndarray)
and metrics[metric_name].ndim > 0
):
metric_values = metrics[metric_name][num] # type: ignore
else:
metric_values = metrics[metric_name] # type: ignore
Expand Down

0 comments on commit 46ef980

Please sign in to comment.