From ff877e6fcad497da12c9eac9655b70c2e57d5d78 Mon Sep 17 00:00:00 2001 From: Son Hoang Date: Tue, 2 Sep 2025 17:27:11 +0900 Subject: [PATCH] Fixed --- src/anomalib/metrics/aupro.py | 28 ++++++++++++++++++++-------- 1 file changed, 20 insertions(+), 8 deletions(-) diff --git a/src/anomalib/metrics/aupro.py b/src/anomalib/metrics/aupro.py index 926e0d25dd..450b071ef9 100644 --- a/src/anomalib/metrics/aupro.py +++ b/src/anomalib/metrics/aupro.py @@ -135,7 +135,9 @@ def __init__( self.add_state("preds", default=[], dist_reduce_fx="cat") self.add_state("target", default=[], dist_reduce_fx="cat") - self.register_buffer("fpr_limit", torch.tensor(fpr_limit)) + # Determine the device dynamically + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.register_buffer("fpr_limit", torch.tensor(fpr_limit).to(device)) self.num_thresholds = num_thresholds def update(self, preds: torch.Tensor, target: torch.Tensor) -> None: @@ -145,8 +147,9 @@ def update(self, preds: torch.Tensor, target: torch.Tensor) -> None: preds (torch.Tensor): predictions of the model target (torch.Tensor): ground truth targets """ - self.target.append(target) - self.preds.append(preds) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.target.append(target.to(device)) + self.preds.append(preds.to(device)) def perform_cca(self) -> torch.Tensor: """Perform the Connected Component Analysis on the self.target tensor. @@ -210,12 +213,15 @@ def compute_pro( else: thresholds = None + # Determine the device dynamically + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + # compute the global fpr-size fpr: torch.Tensor = binary_roc( preds=preds, target=target, thresholds=thresholds, - )[0] # only need fpr + )[0].to(device) # only need fpr output_size = torch.where(fpr <= self.fpr_limit)[0].size(0) # compute the PRO curve by aggregating per-region tpr/fpr curves/values. @@ -244,7 +250,9 @@ def compute_pro( preds=preds[background | mask], target=mask[background | mask], thresholds=thresholds, - )[:-1] + )[:-1] # only need fpr and tpr + fpr_ = fpr_.to(device) + tpr_ = tpr_.to(device) # catch edge-case where ROC only has fpr vals > self.fpr_limit if fpr_[fpr_ <= self.fpr_limit].max() == 0: @@ -255,7 +263,7 @@ def compute_pro( fpr_idx_ = torch.where(fpr_ <= fpr_limit_)[0] # if computed roc curve is not specified sufficiently close to # self.fpr_limit, we include the closest higher tpr/fpr pair and - # linearly interpolate the tpr/fpr point at self.fpr_limit + # linearly interpolate the tpr/fpr point at self. if not torch.allclose(fpr_[fpr_idx_].max(), self.fpr_limit): tmp_idx_ = torch.searchsorted(fpr_, self.fpr_limit) fpr_idx_ = torch.cat([fpr_idx_, tmp_idx_.unsqueeze_(0)]) @@ -263,7 +271,7 @@ def compute_pro( interp = True fpr_ = fpr_[fpr_idx_] - tpr_ = tpr_[fpr_idx_] + tpr_ = tpr_[fpr_idx_.detach().cpu()] fpr_idx_ = fpr_idx_.float() fpr_idx_ /= fpr_idx_.max() @@ -342,6 +350,10 @@ def interp1d(old_x: torch.Tensor, old_y: torch.Tensor, new_x: torch.Tensor) -> t Returns: Tensor: y-values at corresponding new_x values. """ + # Ensure all tensors are on the same device + old_x = old_x.to(new_x.device) + old_y = old_y.to(new_x.device) + # Compute slope eps = torch.finfo(old_y.dtype).eps slope = (old_y[1:] - old_y[:-1]) / (eps + (old_x[1:] - old_x[:-1])) @@ -353,7 +365,7 @@ def interp1d(old_x: torch.Tensor, old_y: torch.Tensor, new_x: torch.Tensor) -> t # to preserve order, but we actually want the preceeding index. idx -= 1 # we clamp the index, because the number of intervals = old_x.size(0) -1, - # and the left neighbour should hence be at most number of intervals -1, + # and the left neighbour should be at most number of intervals -1, idx = torch.clamp(idx, 0, old_x.size(0) - 2) # perform actual linear interpolation