From 46ef98024333d6a323f011e8e6079e17f688c11a Mon Sep 17 00:00:00 2001 From: Franklin <41602287+fcogidi@users.noreply.github.com> Date: Mon, 29 Jan 2024 18:13:01 -0500 Subject: [PATCH] Fix type annotations for metric values in ClassificationPlotter --- cyclops/report/plot/classification.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/cyclops/report/plot/classification.py b/cyclops/report/plot/classification.py index f58efd3c1..04e0130f6 100644 --- a/cyclops/report/plot/classification.py +++ b/cyclops/report/plot/classification.py @@ -131,7 +131,7 @@ def roc_curve( if auroc is not None: assert isinstance( auroc, - float, + (float, np.floating), ), "AUROCs must be a float for binary tasks" name = f"Model (AUC = {auroc:.2f})" else: @@ -401,7 +401,7 @@ def precision_recall_curve_comparison( if auprcs and slice_name in auprcs: assert isinstance( auprcs[slice_name], - float, + (float, np.floating), ), "AUPRCs must be a float for binary tasks" name = f"{slice_name} (AUC = {auprcs[slice_name]:.2f})" else: @@ -707,7 +707,8 @@ def metrics_comparison_radar( for slice_name, metrics in slice_metrics.items(): metric_names = list(metrics.keys()) assert all( - not isinstance(value, (list, np.ndarray)) + not isinstance(value, list) + and not (isinstance(value, np.ndarray) and value.ndim > 0) for value in metrics.values() ), ( "Generic metrics must not be of type list or np.ndarray for" @@ -727,7 +728,9 @@ def metrics_comparison_radar( radial_data: List[float] = [] theta_data: List[float] = [] for metric_name, metric_values in metrics.items(): - if isinstance(metric_values, (list, np.ndarray)): + if isinstance(metric_values, list) or ( + isinstance(metric_values, np.ndarray) and metric_values.ndim > 0 + ): assert ( len(metric_values) == self.class_num ), "Metric values must be of length class_num for \ @@ -738,7 +741,7 @@ def metrics_comparison_radar( for i in range(self.class_num) ] theta_data.extend(theta) # type: ignore[arg-type] - elif isinstance(metric_values, float): + elif isinstance(metric_values, (float, np.floating)): radial_data.append(metric_values) theta_data.append(metric_name) # type: ignore[arg-type] else: @@ -859,7 +862,10 @@ def metrics_comparison_bar( metric_names = list(metrics.keys()) for num in range(self.class_num): for metric_name in metric_names: - if isinstance(metrics[metric_name], (list, np.ndarray)): + if isinstance(metrics[metric_name], list) or ( + isinstance(metrics[metric_name], np.ndarray) + and metrics[metric_name].ndim > 0 + ): metric_values = metrics[metric_name][num] # type: ignore else: metric_values = metrics[metric_name] # type: ignore