Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 20 additions & 8 deletions src/anomalib/metrics/aupro.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,9 @@ def __init__(

self.add_state("preds", default=[], dist_reduce_fx="cat")
self.add_state("target", default=[], dist_reduce_fx="cat")
self.register_buffer("fpr_limit", torch.tensor(fpr_limit))
# Determine the device dynamically
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will still work for devices that are neither cuda or cpu. For example, when we have intel GPUs, the device is called xpu. It would be nice to solve this such that it works independent of the device used.

self.register_buffer("fpr_limit", torch.tensor(fpr_limit).to(device))
self.num_thresholds = num_thresholds

def update(self, preds: torch.Tensor, target: torch.Tensor) -> None:
Expand All @@ -145,8 +147,9 @@ def update(self, preds: torch.Tensor, target: torch.Tensor) -> None:
preds (torch.Tensor): predictions of the model
target (torch.Tensor): ground truth targets
"""
self.target.append(target)
self.preds.append(preds)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.target.append(target.to(device))
self.preds.append(preds.to(device))

def perform_cca(self) -> torch.Tensor:
"""Perform the Connected Component Analysis on the self.target tensor.
Expand Down Expand Up @@ -210,12 +213,15 @@ def compute_pro(
else:
thresholds = None

# Determine the device dynamically
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# compute the global fpr-size
fpr: torch.Tensor = binary_roc(
preds=preds,
target=target,
thresholds=thresholds,
)[0] # only need fpr
)[0].to(device) # only need fpr
output_size = torch.where(fpr <= self.fpr_limit)[0].size(0)

# compute the PRO curve by aggregating per-region tpr/fpr curves/values.
Expand Down Expand Up @@ -244,7 +250,9 @@ def compute_pro(
preds=preds[background | mask],
target=mask[background | mask],
thresholds=thresholds,
)[:-1]
)[:-1] # only need fpr and tpr
fpr_ = fpr_.to(device)
tpr_ = tpr_.to(device)

# catch edge-case where ROC only has fpr vals > self.fpr_limit
if fpr_[fpr_ <= self.fpr_limit].max() == 0:
Expand All @@ -255,15 +263,15 @@ def compute_pro(
fpr_idx_ = torch.where(fpr_ <= fpr_limit_)[0]
# if computed roc curve is not specified sufficiently close to
# self.fpr_limit, we include the closest higher tpr/fpr pair and
# linearly interpolate the tpr/fpr point at self.fpr_limit
# linearly interpolate the tpr/fpr point at self.
if not torch.allclose(fpr_[fpr_idx_].max(), self.fpr_limit):
tmp_idx_ = torch.searchsorted(fpr_, self.fpr_limit)
fpr_idx_ = torch.cat([fpr_idx_, tmp_idx_.unsqueeze_(0)])
slope_ = 1 - ((fpr_[tmp_idx_] - self.fpr_limit) / (fpr_[tmp_idx_] - fpr_[tmp_idx_ - 1]))
interp = True

fpr_ = fpr_[fpr_idx_]
tpr_ = tpr_[fpr_idx_]
tpr_ = tpr_[fpr_idx_.detach().cpu()]

fpr_idx_ = fpr_idx_.float()
fpr_idx_ /= fpr_idx_.max()
Expand Down Expand Up @@ -342,6 +350,10 @@ def interp1d(old_x: torch.Tensor, old_y: torch.Tensor, new_x: torch.Tensor) -> t
Returns:
Tensor: y-values at corresponding new_x values.
"""
# Ensure all tensors are on the same device
old_x = old_x.to(new_x.device)
old_y = old_y.to(new_x.device)

# Compute slope
eps = torch.finfo(old_y.dtype).eps
slope = (old_y[1:] - old_y[:-1]) / (eps + (old_x[1:] - old_x[:-1]))
Expand All @@ -353,7 +365,7 @@ def interp1d(old_x: torch.Tensor, old_y: torch.Tensor, new_x: torch.Tensor) -> t
# to preserve order, but we actually want the preceeding index.
idx -= 1
# we clamp the index, because the number of intervals = old_x.size(0) -1,
# and the left neighbour should hence be at most number of intervals -1,
# and the left neighbour should be at most number of intervals -1,
idx = torch.clamp(idx, 0, old_x.size(0) - 2)

# perform actual linear interpolation
Expand Down