From e028449b0c8c9a383531f86b473321acbc73a1a1 Mon Sep 17 00:00:00 2001 From: alfieroddan <51797647+alfieroddan@users.noreply.github.com> Date: Thu, 28 Aug 2025 12:42:19 +0100 Subject: [PATCH 1/6] =?UTF-8?q?=F0=9F=90=9B=20fix(model):=20Reduce=20memor?= =?UTF-8?q?y=20reserved=20for=20memory=20bank=20based=20models=20PatchCore?= =?UTF-8?q?,=20Padim,=20Dfkde=20(#2913)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * refactor patchcore, use list instead of cat * typo in patchcore Signed-off-by: Alfie Roddan <51797647+alfieroddan@users.noreply.github.com> * padim model to now use cat instead of list Signed-off-by: Alfie Roddan <51797647+alfieroddan@users.noreply.github.com> * dfm and padim update cat to list Signed-off-by: Alfie Roddan <51797647+alfieroddan@users.noreply.github.com> * add dfkde cat to list changes Signed-off-by: Alfie Roddan <51797647+alfieroddan@users.noreply.github.com> * change assert from memory bank to list size, fix pre-commit issues Signed-off-by: Alfie Roddan <51797647+alfieroddan@users.noreply.github.com> --------- Signed-off-by: Alfie Roddan <51797647+alfieroddan@users.noreply.github.com> Signed-off-by: StarPlatinum7 <2732981250@qq.com> --- src/anomalib/models/image/dfkde/torch_model.py | 13 +++++-------- src/anomalib/models/image/dfm/torch_model.py | 11 ++++------- src/anomalib/models/image/padim/torch_model.py | 13 +++++-------- .../models/image/patchcore/torch_model.py | 18 +++++++++--------- 4 files changed, 23 insertions(+), 32 deletions(-) diff --git a/src/anomalib/models/image/dfkde/torch_model.py b/src/anomalib/models/image/dfkde/torch_model.py index 9af1ef4b65..43ac31bda6 100644 --- a/src/anomalib/models/image/dfkde/torch_model.py +++ b/src/anomalib/models/image/dfkde/torch_model.py @@ -89,7 +89,7 @@ def __init__( feature_scaling_method=feature_scaling_method, max_training_points=max_training_points, ) - self.memory_bank = torch.empty(0) + self.memory_bank: list[torch.tensor] = [] def get_features(self, batch: torch.Tensor) -> torch.Tensor: """Extract features from the pre-trained backbone network. @@ -142,11 +142,7 @@ def forward(self, batch: torch.Tensor) -> torch.Tensor | InferenceBatch: # 1. apply feature extraction features = self.get_features(batch) if self.training: - if self.memory_bank.size(0) == 0: - self.memory_bank = features - else: - new_bank = torch.cat((self.memory_bank, features), dim=0).to(self.memory_bank) - self.memory_bank = new_bank + self.memory_bank.append(features) return features # 2. apply density estimation @@ -164,12 +160,13 @@ def fit(self) -> None: Raises: ValueError: If the memory bank is empty. """ - if self.memory_bank.size(0) == 0: + if len(self.memory_bank) == 0: msg = "Memory bank is empty. Cannot perform coreset selection." raise ValueError(msg) + self.memory_bank = torch.vstack(self.memory_bank) # fit gaussian self.classifier.fit(self.memory_bank) # clear memory bank, redcues gpu size - self.memory_bank = torch.empty(0).to(self.memory_bank) + self.memory_bank = [] diff --git a/src/anomalib/models/image/dfm/torch_model.py b/src/anomalib/models/image/dfm/torch_model.py index 3d5a366780..87d8740bb9 100644 --- a/src/anomalib/models/image/dfm/torch_model.py +++ b/src/anomalib/models/image/dfm/torch_model.py @@ -153,17 +153,18 @@ def __init__( layers=[layer], ).eval() - self.memory_bank = torch.empty(0) + self.memory_bank: list[torch.tensor] = [] def fit(self) -> None: """Fit PCA and Gaussian model to dataset.""" + self.memory_bank = torch.vstack(self.memory_bank) self.pca_model.fit(self.memory_bank) if self.score_type == "nll": features_reduced = self.pca_model.transform(self.memory_bank) self.gaussian_model.fit(features_reduced.T) # clear memory bank, reduces GPU size - self.memory_bank = torch.empty(0).to(self.memory_bank) + self.memory_bank = [] def score(self, features: torch.Tensor, feature_shapes: tuple) -> torch.Tensor: """Compute anomaly scores. @@ -228,11 +229,7 @@ def forward(self, batch: torch.Tensor) -> torch.Tensor | InferenceBatch: feature_vector, feature_shapes = self.get_features(batch) if self.training: - if self.memory_bank.size(0) == 0: - self.memory_bank = feature_vector - else: - new_bank = torch.cat((self.memory_bank, feature_vector), dim=0).to(self.memory_bank) - self.memory_bank = new_bank + self.memory_bank.append(feature_vector) return feature_vector pred_score, anomaly_map = self.score(feature_vector.view(feature_vector.shape[:2]), feature_shapes) diff --git a/src/anomalib/models/image/padim/torch_model.py b/src/anomalib/models/image/padim/torch_model.py index d7367149a0..c0f017d9eb 100644 --- a/src/anomalib/models/image/padim/torch_model.py +++ b/src/anomalib/models/image/padim/torch_model.py @@ -147,7 +147,7 @@ def __init__( self.anomaly_map_generator = AnomalyMapGenerator() self.gaussian = MultiVariateGaussian() - self.memory_bank = torch.empty(0) + self.memory_bank: list[torch.tensor] = [] def forward(self, input_tensor: torch.Tensor) -> torch.Tensor | InferenceBatch: """Forward-pass image-batch (N, C, H, W) into model to extract features. @@ -183,11 +183,7 @@ def forward(self, input_tensor: torch.Tensor) -> torch.Tensor | InferenceBatch: embeddings = self.tiler.untile(embeddings) if self.training: - if self.memory_bank.size(0) == 0: - self.memory_bank = embeddings - else: - new_bank = torch.cat((self.memory_bank, embeddings), dim=0).to(self.memory_bank) - self.memory_bank = new_bank + self.memory_bank.append(embeddings) return embeddings anomaly_map = self.anomaly_map_generator( @@ -234,12 +230,13 @@ def fit(self) -> None: Raises: ValueError: If the memory bank is empty. """ - if self.memory_bank.size(0) == 0: + if len(self.memory_bank) == 0: msg = "Memory bank is empty. Cannot perform coreset selection." raise ValueError(msg) + self.memory_bank = torch.vstack(self.memory_bank) # fit gaussian self.gaussian.fit(self.memory_bank) # clear memory bank, redcues gpu usage - self.memory_bank = torch.empty(0).to(self.memory_bank) + self.memory_bank = [] diff --git a/src/anomalib/models/image/patchcore/torch_model.py b/src/anomalib/models/image/patchcore/torch_model.py index 1970e6dcc2..20b09de613 100644 --- a/src/anomalib/models/image/patchcore/torch_model.py +++ b/src/anomalib/models/image/patchcore/torch_model.py @@ -126,6 +126,7 @@ def __init__( self.anomaly_map_generator = AnomalyMapGenerator() self.memory_bank: torch.Tensor self.register_buffer("memory_bank", torch.empty(0)) + self.embedding_store: list[torch.tensor] = [] def forward(self, input_tensor: torch.Tensor) -> torch.Tensor | InferenceBatch: """Process input tensor through the model. @@ -169,11 +170,7 @@ def forward(self, input_tensor: torch.Tensor) -> torch.Tensor | InferenceBatch: embedding = self.reshape_embedding(embedding) if self.training: - if self.memory_bank.size(0) == 0: - self.memory_bank = embedding - else: - new_bank = torch.cat((self.memory_bank, embedding), dim=0).to(self.memory_bank) - self.memory_bank = new_bank + self.embedding_store.append(embedding) return embedding # Ensure memory bank is not empty @@ -272,13 +269,16 @@ def subsample_embedding(self, sampling_ratio: float, embeddings: torch.Tensor = if embeddings is not None: del embeddings - if self.memory_bank.size(0) == 0: - msg = "Memory bank is empty. Cannot perform coreset selection." + if len(self.embedding_store) == 0: + msg = "Embedding store is empty. Cannot perform coreset selection." raise ValueError(msg) + # Coreset Subsampling + self.memory_bank = torch.vstack(self.embedding_store) + self.embedding_store.clear() + sampler = KCenterGreedy(embedding=self.memory_bank, sampling_ratio=sampling_ratio) - coreset = sampler.sample_coreset().to(self.memory_bank) - self.memory_bank = coreset + self.memory_bank = sampler.sample_coreset() @staticmethod def euclidean_dist(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: From ce4097b3241b3302bec29edf73612e4da75c7932 Mon Sep 17 00:00:00 2001 From: StarPlatinum7 <2732981250@qq.com> Date: Wed, 3 Sep 2025 16:45:34 +0800 Subject: [PATCH 2/6] add mebin postprocessor Signed-off-by: StarPlatinum7 <2732981250@qq.com> --- CHANGELOG.md | 2 + src/anomalib/metrics/__init__.py | 3 +- src/anomalib/metrics/threshold/__init__.py | 3 +- src/anomalib/metrics/threshold/mebin.py | 250 ++++++++++++++++++ src/anomalib/post_processing/__init__.py | 3 +- .../post_processing/mebin_post_processor.py | 127 +++++++++ 6 files changed, 385 insertions(+), 3 deletions(-) create mode 100644 src/anomalib/metrics/threshold/mebin.py create mode 100644 src/anomalib/post_processing/mebin_post_processor.py diff --git a/CHANGELOG.md b/CHANGELOG.md index c4933a7d86..74cf439169 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). ## [Unreleased] ### Added +- 🚀 Add MEBin post-processing method + ### Removed diff --git a/src/anomalib/metrics/__init__.py b/src/anomalib/metrics/__init__.py index 6d4d0b64e0..ab57b7fca7 100644 --- a/src/anomalib/metrics/__init__.py +++ b/src/anomalib/metrics/__init__.py @@ -59,7 +59,7 @@ from .pimo import AUPIMO, PIMO from .precision_recall_curve import BinaryPrecisionRecallCurve from .pro import PRO -from .threshold import F1AdaptiveThreshold, ManualThreshold +from .threshold import F1AdaptiveThreshold, ManualThreshold ,MEBin __all__ = [ "AUROC", @@ -78,4 +78,5 @@ "PRO", "PIMO", "AUPIMO", + "MEBin", ] diff --git a/src/anomalib/metrics/threshold/__init__.py b/src/anomalib/metrics/threshold/__init__.py index 1bf10854ec..0cebd87b82 100644 --- a/src/anomalib/metrics/threshold/__init__.py +++ b/src/anomalib/metrics/threshold/__init__.py @@ -24,5 +24,6 @@ from .base import BaseThreshold, Threshold from .f1_adaptive_threshold import F1AdaptiveThreshold from .manual_threshold import ManualThreshold +from .mebin import MEBin -__all__ = ["BaseThreshold", "Threshold", "F1AdaptiveThreshold", "ManualThreshold"] +__all__ = ["BaseThreshold", "Threshold", "F1AdaptiveThreshold", "ManualThreshold", "MEBin"] diff --git a/src/anomalib/metrics/threshold/mebin.py b/src/anomalib/metrics/threshold/mebin.py new file mode 100644 index 0000000000..c823ea3b5f --- /dev/null +++ b/src/anomalib/metrics/threshold/mebin.py @@ -0,0 +1,250 @@ +"""MEBin adaptive thresholding algorithm for anomaly detection. + +This module provides the ``MEBin`` class which automatically finds the optimal +threshold value by analyzing the stability of connected components across +multiple threshold levels. + +The threshold is computed by: +1. Sampling anomaly maps at configurable rates across threshold range +2. Counting connected components at each threshold level +3. Finding stable intervals where component count remains constant +4. Selecting threshold from the longest stable interval + +Example: + >>> from anomalib.metrics.threshold import MEBin + >>> import numpy as np + >>> # Create sample anomaly maps + >>> anomaly_maps = [np.random.rand(256, 256) * 255 for _ in range(10)] + >>> # Initialize and compute thresholds + >>> mebin = MEBin(anomaly_map_list=anomaly_maps, sample_rate=4) + >>> binarized_maps, thresholds = mebin.binarize_anomaly_maps() + >>> print(f"Computed {len(thresholds)} thresholds") + +Note: + The algorithm works best when anomaly maps contain clear separation between + normal and anomalous regions. The min_interval_len parameter should be tuned + based on the expected stability of anomaly score distributions. +""" + +import cv2 +import numpy as np +from tqdm import tqdm + + + +class MEBin: + """MEBin adaptive thresholding algorithm for anomaly detection. + + This class implements the MEBin (Minimum Entropy Binarization) algorithm + which automatically determines optimal thresholds for converting continuous + anomaly maps to binary masks by analyzing the stability of connected + component counts across different threshold levels. + + The algorithm works by: + - Sampling anomaly maps at configurable rates across threshold range + - Counting connected components at each threshold level + - Identifying stable intervals where component count remains constant + - Selecting the optimal threshold from the longest stable interval + - Optionally applying morphological erosion to reduce noise + + Args: + anomaly_map_path_list (list, optional): List of file paths to anomaly maps. + If provided, maps will be loaded as grayscale images. + Defaults to None. + anomaly_map_list (list, optional): List of anomaly map arrays. If provided, + maps should be numpy arrays. + Defaults to None. + sample_rate (int, optional): Sampling rate for threshold search. Higher + values reduce processing time but may affect accuracy. + Defaults to 4. + min_interval_len (int, optional): Minimum length of stable intervals. + Should be tuned based on the expected stability of anomaly score + distributions. + Defaults to 4. + erode (bool, optional): Whether to apply morphological erosion to + binarized results to reduce noise. + Defaults to True. + + Example: + >>> from anomalib.metrics.threshold import MEBin + >>> import numpy as np + >>> # Create sample anomaly maps + >>> anomaly_maps = [np.random.rand(256, 256) * 255 for _ in range(10)] + >>> # Initialize MEBin + >>> mebin = MEBin(anomaly_map_list=anomaly_maps, sample_rate=4) + >>> # Compute binary masks and thresholds + >>> binarized_maps, thresholds = mebin.binarize_anomaly_maps() + """ + + def __init__(self, anomaly_map_path_list=None, sample_rate=4, min_interval_len=4, erode=True): + + self.anomaly_map_path_list = anomaly_map_path_list + # Load anomaly maps as grayscale images if paths are provided + self.anomaly_map_list = [cv2.imread(x, cv2.IMREAD_GRAYSCALE) for x in self.anomaly_map_path_list] + + self.sample_rate = sample_rate + self.min_interval_len = min_interval_len + self.erode = erode + + # Adaptively determine the threshold search range + self.max_th, self.min_th = self.get_search_range() + + + def get_search_range(self): + """Determine the threshold search range adaptively. + + This method analyzes all anomaly maps to determine the minimum and maximum + threshold values for the binarization process. The search range is based + on the actual anomaly score distributions in the input maps. + + Returns: + max_th (int): Maximum threshold for binarization. + min_th (int): Minimum threshold for binarization. + """ + # Get the anomaly scores of all anomaly maps + anomaly_score_list = [np.max(x) for x in self.anomaly_map_list] + + # Select the maximum and minimum anomaly scores from images + max_score, min_score = max(anomaly_score_list), min(anomaly_score_list) + max_th, min_th = max_score, min_score + + print(f"Value range: {min_score} - {max_score}") + + return max_th, min_th + + + + def get_threshold(self, anomaly_num_sequence, min_interval_len): + """ + Find the 'stable interval' in the anomaly region number sequence. + Stable Interval: A continuous threshold range in which the number of connected components remains constant, + and the length of the threshold range is greater than or equal to the given length threshold (min_interval_len). + + Args: + anomaly_num_sequence (list): Sequence of connected component counts + at each threshold level, ordered from high to low threshold. + min_interval_len (int): Minimum length requirement for stable intervals. + Longer intervals indicate more robust threshold selection. + + Returns: + threshold (int): The final threshold for binarization. + est_anomaly_num (int): The estimated number of anomalies. + """ + interval_result = {} + current_index = 0 + while current_index < len(anomaly_num_sequence): + end = current_index + + start = end + + # Find the interval where the connected component count remains constant. + if len(set(anomaly_num_sequence[start:end+1])) == 1 and anomaly_num_sequence[start] != 0: + # Move the 'end' pointer forward until a different connected component number is encountered. + while end < len(anomaly_num_sequence)-1 and anomaly_num_sequence[end] == anomaly_num_sequence[end+1]: + end += 1 + current_index += 1 + # If the length of the current stable interval is greater than or equal to the given threshold (min_interval_len), record this interval. + if end - start + 1 >= min_interval_len: + if anomaly_num_sequence[start] not in interval_result: + interval_result[anomaly_num_sequence[start]] = [(start, end)] + else: + interval_result[anomaly_num_sequence[start]].append((start, end)) + current_index += 1 + + """ + If a 'stable interval' exists, calculate the final threshold based on the longest stable interval. + If no stable interval is found, it indicates that no anomaly regions exist, and 255 is returned. + """ + + if interval_result: + # Iterate through the stable intervals, calculating their lengths and corresponding number of connected component. + count_result = {} + for anomaly_num in interval_result: + count_result[anomaly_num] = max([x[1] - x[0] for x in interval_result[anomaly_num]]) + est_anomaly_num = max(count_result, key=count_result.get) + est_anomaly_num_interval_result = interval_result[est_anomaly_num] + + # Find the longest stable interval. + longest_interval = sorted(est_anomaly_num_interval_result, key=lambda x: x[1] - x[0])[-1] + + # Use the endpoint threshold of the longest stable interval as the final threshold. + index = longest_interval[1] + threshold = 255 - index * self.sample_rate + threshold = int(threshold*(self.max_th - self.min_th)/255 + self.min_th) + return threshold, est_anomaly_num + else: + return 255, 0 + + + def bin_and_erode(self, anomaly_map, threshold): + """Binarize anomaly map and optionally apply erosion. + + This method converts a continuous anomaly map to a binary mask using + the specified threshold, and optionally applies morphological erosion + to reduce noise and smooth the boundaries of anomaly regions. + + The binarization process: + 1. Pixels above threshold become 255 (anomalous) + 2. Pixels below threshold become 0 (normal) + 3. Optional erosion with 6x6 kernel to reduce noise + + Args: + anomaly_map (numpy.ndarray): Input anomaly map with continuous + anomaly scores to be binarized. + threshold (int): Threshold value for binarization. Pixels with + values above this threshold are considered anomalous. + + Returns: + numpy.ndarray: Binary mask where 255 indicates anomalous regions + and 0 indicates normal regions. The result is of type uint8. + + Note: + Erosion is applied with a 6x6 kernel and 1 iteration to balance + noise reduction with preservation of anomaly boundaries. + """ + bin_result = np.where(anomaly_map > threshold, 255, 0).astype(np.uint8) + + # Apply erosion operation to the binarized result + if self.erode: + kernel_size = 6 + iter_num = 1 + kernel = np.ones((kernel_size, kernel_size), np.uint8) + bin_result = cv2.erode(bin_result, kernel, iterations=iter_num) + return bin_result + + + def binarize_anomaly_maps(self): + """ + Perform binarization within the given threshold search range, + count the number of connected components in the binarized results. + Adaptively determine the threshold according to the count, + and perform binarization on the anomaly maps. + + Returns: + binarized_maps (list): List of binarized images. + thresholds (list): List of thresholds for each image. + """ + self.binarized_maps = [] + self.thresholds = [] + + for i, anomaly_map in enumerate(tqdm(self.anomaly_map_list)): + # Normalize the anomaly map within the given threshold search range. + anomaly_map_norm = np.where(anomaly_map < self.min_th, 0, ((anomaly_map - self.min_th) / (self.max_th - self.min_th)) * 255) + anomaly_num_sequence = [] + + # Search for the threshold from high to low within the given range using the specified sampling rate. + for score in range(255, 0, -self.sample_rate): + bin_result = self.bin_and_erode(anomaly_map_norm, score) + num_labels, *rest = cv2.connectedComponentsWithStats(bin_result, connectivity=8) + anomaly_num = num_labels - 1 + anomaly_num_sequence.append(anomaly_num) + + # Adaptively determine the threshold based on the anomaly connected component count sequence. + threshold, est_anomaly_num = self.get_threshold(anomaly_num_sequence, self.min_interval_len) + + # Binarize the anomaly image based on the determined threshold. + bin_result = self.bin_and_erode(anomaly_map, threshold) + self.binarized_maps.append(bin_result) + self.thresholds.append(threshold) + + return self.binarized_maps, self.thresholds \ No newline at end of file diff --git a/src/anomalib/post_processing/__init__.py b/src/anomalib/post_processing/__init__.py index 5ca7bf598b..7e57c6d82d 100644 --- a/src/anomalib/post_processing/__init__.py +++ b/src/anomalib/post_processing/__init__.py @@ -20,5 +20,6 @@ """ from .post_processor import PostProcessor +from .mebin_post_processor import MEBinPostProcessor -__all__ = ["PostProcessor"] +__all__ = ["PostProcessor", "MEBinPostProcessor"] diff --git a/src/anomalib/post_processing/mebin_post_processor.py b/src/anomalib/post_processing/mebin_post_processor.py new file mode 100644 index 0000000000..8b71fd68c4 --- /dev/null +++ b/src/anomalib/post_processing/mebin_post_processor.py @@ -0,0 +1,127 @@ +"""Post-processing module for MEBin-based anomaly detection results. + +This module provides post-processing functionality for anomaly detection +outputs through the :class:`MEBinPostProcessor` class. + +The MEBin post-processor handles: + - Converting anomaly maps to binary masks using MEBin algorithm + - Sampling anomaly maps at configurable rates for efficient processing + - Applying morphological operations (erosion) to refine binary masks + - Maintaining minimum interval lengths for consistent mask generation + - Formatting results for downstream use + +Example: + >>> from anomalib.post_processing import MEBinPostProcessor + >>> post_processor = MEBinPostProcessor(sample_rate=4, min_interval_len=4) + >>> predictions = post_processor(anomaly_maps=anomaly_maps) +""" + +import argparse +import yaml +import os +import json +import shutil +import cv2 +from tqdm import tqdm +import csv +import sys +sys.path.append(os.getcwd()) + +from anomalib.post_processing import PostProcessor +from anomalib.data import InferenceBatch +import torch +import numpy as np + +from anomalib.metrics import MEBin + + +class MEBinPostProcessor(PostProcessor): + """Post-processor for MEBin-based anomaly detection. + + This class handles post-processing of anomaly detection results by: + - Converting continuous anomaly maps to binary masks using MEBin algorithm + - Sampling anomaly maps at configurable rates for efficient processing + - Applying morphological operations (erosion) to refine binary masks + - Maintaining minimum interval lengths for consistent mask generation + - Formatting results for downstream use + + Args: + sample_rate (int, optional): Threshold sampling step size, + Default to 4 + min_interval_len (int, optional): Minimum length of the stable interval, + can be fine-tuned according to the interval between normal and abnormal score distributions in the anomaly score maps, + decrease if there are many false negatives, increase if there are many false positives. + Default to 4 + erode (bool, optional): Whether to perform erosion after binarization to eliminate noise, + this operation can smooth the change process of the number of abnormal + connected components. + Default to True + **kwargs: Additional keyword arguments passed to parent class. + + Example: + >>> from anomalib.post_processing import MEBinPostProcessor + >>> post_processor = MEBinPostProcessor(sample_rate=4, min_interval_len=4) + >>> predictions = post_processor(anomaly_maps=anomaly_maps) + """ + + def __init__( + self, + sample_rate: int = 4, + min_interval_len: int = 4, + erode: bool = True, + **kwargs, + ) -> None: + super().__init__(**kwargs) + + self.sample_rate = sample_rate + self.min_interval_len = min_interval_len + self.erode = erode + + """Custom post-processor using MEBin algorithm""" + def forward(self, predictions: InferenceBatch) -> InferenceBatch: + """Post-process model predictions using MEBin algorithm. + + This method converts continuous anomaly maps to binary masks using the MEBin + algorithm, which provides efficient and accurate binarization of anomaly + detection results. + + Args: + predictions (InferenceBatch): Batch containing model predictions with + anomaly maps to be processed. + + Returns: + InferenceBatch: Post-processed batch with binary masks generated from + anomaly maps using MEBin algorithm. + + Note: + The method automatically handles tensor-to-numpy conversion and back, + ensuring compatibility with the original tensor device and dtype. + """ + anomaly_maps = predictions.anomaly_map + if isinstance(anomaly_maps, torch.Tensor): + anomaly_maps = anomaly_maps.detach().cpu().numpy() + if anomaly_maps.ndim == 4: + anomaly_maps = anomaly_maps[:, 0, :, :] # Remove channel dimension + + # Normalize to 0-255 and convert to uint8 + norm_maps = [] + for amap in anomaly_maps: + amap_norm = (amap - amap.min()) / (amap.max() - amap.min() + 1e-8) * 255 + norm_maps.append(amap_norm.astype(np.uint8)) + + + mebin = MEBin(anomaly_map_list=norm_maps, sample_rate=self.sample_rate, min_interval_len=self.min_interval_len, erode=self.erode) + binarized_maps, thresholds = mebin.binarize_anomaly_maps() + + # Convert back to torch.Tensor and normalize to 0/1 + pred_masks = torch.stack([torch.from_numpy(bm).to(predictions.anomaly_map.device) for bm in binarized_maps]) + pred_masks = (pred_masks > 0).to(predictions.anomaly_map.dtype) + + return InferenceBatch( + pred_label=predictions.pred_label, + pred_score=predictions.pred_score, + pred_mask=pred_masks, + anomaly_map=predictions.anomaly_map, + ) + + \ No newline at end of file From 1e68949b1faca6913b27a9aad6efc0eb057e2fd0 Mon Sep 17 00:00:00 2001 From: StarPlatinum7 <2732981250@qq.com> Date: Mon, 8 Sep 2025 17:15:40 +0800 Subject: [PATCH 3/6] test mebin postprocessor Signed-off-by: StarPlatinum7 <2732981250@qq.com> --- tests/unit/post_processing/test_mebin_post_processor.py | 4 ++++ 1 file changed, 4 insertions(+) create mode 100644 tests/unit/post_processing/test_mebin_post_processor.py diff --git a/tests/unit/post_processing/test_mebin_post_processor.py b/tests/unit/post_processing/test_mebin_post_processor.py new file mode 100644 index 0000000000..88dc0d5a1b --- /dev/null +++ b/tests/unit/post_processing/test_mebin_post_processor.py @@ -0,0 +1,4 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""Test the MeBinPostProcessor class.""" \ No newline at end of file From 3603157878206f111a35ab9120d0affb4684fc79 Mon Sep 17 00:00:00 2001 From: StarPlatinum7 <2732981250@qq.com> Date: Mon, 8 Sep 2025 18:27:51 +0800 Subject: [PATCH 4/6] test mebin postprocessor Signed-off-by: StarPlatinum7 <2732981250@qq.com> --- .../test_mebin_post_processor.py | 213 +++++++++++++++++- 1 file changed, 212 insertions(+), 1 deletion(-) diff --git a/tests/unit/post_processing/test_mebin_post_processor.py b/tests/unit/post_processing/test_mebin_post_processor.py index 88dc0d5a1b..25925d40cd 100644 --- a/tests/unit/post_processing/test_mebin_post_processor.py +++ b/tests/unit/post_processing/test_mebin_post_processor.py @@ -1,4 +1,215 @@ # Copyright (C) 2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -"""Test the MeBinPostProcessor class.""" \ No newline at end of file +"""Test the MEBinPostProcessor class.""" + +import pytest +import torch +import numpy as np +from unittest.mock import Mock, patch + +from anomalib.data import InferenceBatch +from anomalib.post_processing import MEBinPostProcessor + + +class TestMEBinPostProcessor: + """Test the MEBinPostProcessor class.""" + + @staticmethod + def test_initialization_default_params() -> None: + """Test MEBinPostProcessor initialization with default parameters.""" + processor = MEBinPostProcessor() + + assert processor.sample_rate == 4 + assert processor.min_interval_len == 4 + assert processor.erode is True + + @staticmethod + @pytest.mark.parametrize( + ("sample_rate", "min_interval_len", "erode"), + [ + (2, 3, True), + (8, 6, False), + (1, 1, True), + ], + ) + def test_initialization_custom_params( + sample_rate: int, + min_interval_len: int, + erode: bool + ) -> None: + """Test MEBinPostProcessor initialization with custom parameters.""" + processor = MEBinPostProcessor( + sample_rate=sample_rate, + min_interval_len=min_interval_len, + erode=erode + ) + + assert processor.sample_rate == sample_rate + assert processor.min_interval_len == min_interval_len + assert processor.erode == erode + + @staticmethod + @patch('anomalib.post_processing.mebin_post_processor.MEBin') + def test_forward_single_anomaly_map(mock_mebin) -> None: + """Test forward method with single anomaly map.""" + # Setup mock + mock_mebin_instance = Mock() + mock_mebin_instance.binarize_anomaly_maps.return_value = ( + [np.array([[0, 0], [1, 1]], dtype=np.uint8)], + [0.5] + ) + mock_mebin.return_value = mock_mebin_instance + + # Create test data + anomaly_map = torch.rand(1, 1, 4, 4) + predictions = InferenceBatch( + pred_score=torch.tensor([0.8]), + pred_label=torch.tensor([1]), + anomaly_map=anomaly_map, + pred_mask=None + ) + + # Test forward pass + processor = MEBinPostProcessor() + result = processor.forward(predictions) + + # Verify results + assert isinstance(result, InferenceBatch) + assert result.pred_mask is not None + assert result.pred_mask.shape == (1, 2, 2) + assert result.pred_mask.dtype == anomaly_map.dtype + + @staticmethod + @patch('anomalib.post_processing.mebin_post_processor.MEBin') + def test_forward_batch_anomaly_maps(mock_mebin) -> None: + """Test forward method with batch of anomaly maps.""" + # Setup mock + mock_mebin_instance = Mock() + mock_mebin_instance.binarize_anomaly_maps.return_value = ( + [ + np.array([[0, 0], [1, 1]], dtype=np.uint8), + np.array([[1, 0], [0, 1]], dtype=np.uint8) + ], + [0.5, 0.6] + ) + mock_mebin.return_value = mock_mebin_instance + + # Create test data + anomaly_maps = torch.rand(2, 1, 4, 4) + predictions = InferenceBatch( + pred_score=torch.tensor([0.8, 0.9]), + pred_label=torch.tensor([1, 1]), + anomaly_map=anomaly_maps, + pred_mask=None + ) + + # Test forward pass + processor = MEBinPostProcessor() + result = processor.forward(predictions) + + # Verify results + assert isinstance(result, InferenceBatch) + assert result.pred_mask.shape == (2, 2, 2) + + @staticmethod + @patch('anomalib.post_processing.mebin_post_processor.MEBin') + def test_forward_normalization(mock_mebin) -> None: + """Test that anomaly maps are properly normalized to 0-255 range.""" + # Setup mock + mock_mebin_instance = Mock() + mock_mebin_instance.binarize_anomaly_maps.return_value = ( + [np.array([[0, 0], [1, 1]], dtype=np.uint8)], + [0.5] + ) + mock_mebin.return_value = mock_mebin_instance + + # Create test data with specific range + anomaly_maps = torch.tensor([[[[0.0, 0.5], [1.0, 0.2]]]]) + predictions = InferenceBatch( + pred_score=torch.tensor([0.8]), + pred_label=torch.tensor([1]), + anomaly_map=anomaly_maps, + pred_mask=None + ) + + # Test forward pass + processor = MEBinPostProcessor() + result = processor.forward(predictions) + + # Verify MEBin was called with normalized data + mock_mebin.assert_called_once() + call_args = mock_mebin.call_args + anomaly_map_list = call_args[1]['anomaly_map_list'] + + # Check that the data is normalized to 0-255 range + assert len(anomaly_map_list) == 1 + assert anomaly_map_list[0].dtype == np.uint8 + assert anomaly_map_list[0].min() >= 0 + assert anomaly_map_list[0].max() <= 255 + + @staticmethod + @patch('anomalib.post_processing.mebin_post_processor.MEBin') + def test_forward_mebin_parameters(mock_mebin) -> None: + """Test that MEBin is called with correct parameters.""" + # Setup mock + mock_mebin_instance = Mock() + mock_mebin_instance.binarize_anomaly_maps.return_value = ( + [np.array([[0, 0], [1, 1]], dtype=np.uint8)], + [0.5] + ) + mock_mebin.return_value = mock_mebin_instance + + # Create test data + anomaly_maps = torch.rand(1, 1, 4, 4) + predictions = InferenceBatch( + pred_score=torch.tensor([0.8]), + pred_label=torch.tensor([1]), + anomaly_map=anomaly_maps, + pred_mask=None + ) + + # Test with custom parameters + processor = MEBinPostProcessor( + sample_rate=8, + min_interval_len=6, + erode=False + ) + result = processor.forward(predictions) + + # Verify MEBin was called with correct parameters + mock_mebin.assert_called_once_with( + anomaly_map_list=mock_mebin.call_args[1]['anomaly_map_list'], + sample_rate=8, + min_interval_len=6, + erode=False + ) + + @staticmethod + @patch('anomalib.post_processing.mebin_post_processor.MEBin') + def test_forward_binary_mask_conversion(mock_mebin) -> None: + """Test that binary masks are properly converted to 0/1 values.""" + # Setup mock to return masks with values > 0 + mock_mebin_instance = Mock() + mock_mebin_instance.binarize_anomaly_maps.return_value = ( + [np.array([[0, 128], [255, 64]], dtype=np.uint8)], + [0.5] + ) + mock_mebin.return_value = mock_mebin_instance + + # Create test data + anomaly_maps = torch.rand(1, 1, 2, 2) + predictions = InferenceBatch( + pred_score=torch.tensor([0.8]), + pred_label=torch.tensor([1]), + anomaly_map=anomaly_maps, + pred_mask=None + ) + + # Test forward pass + processor = MEBinPostProcessor() + result = processor.forward(predictions) + + # Verify that all values are either 0 or 1 + unique_values = torch.unique(result.pred_mask) + assert torch.all((unique_values == 0) | (unique_values == 1)) From 075352025d1dac8e1f34cc8bfe8472f75bdff9ad Mon Sep 17 00:00:00 2001 From: Aimira Baitieva <63813435+abc-125@users.noreply.github.com> Date: Fri, 5 Sep 2025 14:11:03 +0200 Subject: [PATCH 5/6] =?UTF-8?q?=F0=9F=9A=80=20feat(metric):=20added=20hist?= =?UTF-8?q?ogram=20of=20anomaly=20scores=20(#2920)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * added histogram of anomaly scores * Update src/anomalib/metrics/anomaly_score_distribution.py Co-authored-by: Samet Akcay Signed-off-by: Aimira Baitieva <63813435+abc-125@users.noreply.github.com> * Update anomaly_score_distribution.py * Update anomaly_score_distribution.py * Fixed pre-commit checks --------- Signed-off-by: Aimira Baitieva <63813435+abc-125@users.noreply.github.com> Co-authored-by: Samet Akcay Signed-off-by: StarPlatinum7 <2732981250@qq.com> --- .../metrics/anomaly_score_distribution.py | 78 +++++++++++++++++-- src/anomalib/metrics/utils.py | 60 +++++++++++++- 2 files changed, 129 insertions(+), 9 deletions(-) diff --git a/src/anomalib/metrics/anomaly_score_distribution.py b/src/anomalib/metrics/anomaly_score_distribution.py index 05a271d063..c63ab8811f 100644 --- a/src/anomalib/metrics/anomaly_score_distribution.py +++ b/src/anomalib/metrics/anomaly_score_distribution.py @@ -3,9 +3,11 @@ """Compute statistics of anomaly score distributions. -This module provides the ``AnomalyScoreDistribution`` class which computes mean -and standard deviation statistics of anomaly scores from normal training data. +This module provides the ``AnomalyScoreDistribution`` class, which computes the mean +and standard deviation statistics of anomaly scores. Statistics are computed for both image-level and pixel-level scores. +The ``plot`` method generates a histogram of anomaly scores, +separated by label, to visualize score distributions for normal and abnormal samples. The class tracks: - Image-level statistics: Mean and std of image anomaly scores @@ -17,29 +19,34 @@ >>> # Create sample data >>> scores = torch.tensor([0.1, 0.2, 0.15]) # Image anomaly scores >>> maps = torch.tensor([[0.1, 0.2], [0.15, 0.25]]) # Pixel anomaly maps + >>> labels = torch.tensor([0, 1, 0]) # Binary labels >>> # Initialize and compute stats >>> dist = AnomalyScoreDistribution() - >>> dist.update(anomaly_scores=scores, anomaly_maps=maps) + >>> dist.update(anomaly_scores=scores, anomaly_maps=maps, labels=labels) >>> image_mean, image_std, pixel_mean, pixel_std = dist.compute() + >>> fig, title = dist.plot() Note: The input scores and maps are log-transformed before computing statistics. - Both image-level scores and pixel-level maps are optional inputs. + Image-level scores, pixel-level maps, and labels are optional inputs. """ import torch +from matplotlib.figure import Figure from torchmetrics import Metric +from .utils import plot_score_histogram + class AnomalyScoreDistribution(Metric): """Compute distribution statistics of anomaly scores. This class tracks and computes the mean and standard deviation of anomaly - scores from the normal samples in the training set. Statistics are computed - for both image-level scores and pixel-level anomaly maps. + scores. Statistics are computed for both image-level scores and pixel-level + anomaly maps. - The metric maintains internal state to accumulate scores and maps across - batches before computing final statistics. + The metric maintains internal state to accumulate scores, anomaly maps, + and labels across batches before computing final statistics. Example: >>> dist = AnomalyScoreDistribution() @@ -59,6 +66,7 @@ def __init__(self, **kwargs) -> None: super().__init__(**kwargs) self.anomaly_maps: list[torch.Tensor] = [] self.anomaly_scores: list[torch.Tensor] = [] + self.labels: list[torch.Tensor] = [] self.add_state("image_mean", torch.empty(0), persistent=True) self.add_state("image_std", torch.empty(0), persistent=True) @@ -75,6 +83,7 @@ def update( *args, anomaly_scores: torch.Tensor | None = None, anomaly_maps: torch.Tensor | None = None, + labels: torch.Tensor | None = None, **kwargs, ) -> None: """Update the internal state with new scores and maps. @@ -83,6 +92,7 @@ def update( *args: Unused positional arguments. anomaly_scores: Batch of image-level anomaly scores. anomaly_maps: Batch of pixel-level anomaly maps. + labels: Batch of binary labels. **kwargs: Unused keyword arguments. """ del args, kwargs # These variables are not used. @@ -91,6 +101,8 @@ def update( self.anomaly_maps.append(anomaly_maps) if anomaly_scores is not None: self.anomaly_scores.append(anomaly_scores) + if labels is not None: + self.labels.append(labels) def compute(self) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """Compute distribution statistics from accumulated scores and maps. @@ -116,3 +128,53 @@ def compute(self) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tenso self.pixel_std = anomaly_maps.std(dim=0).squeeze() return self.image_mean, self.image_std, self.pixel_mean, self.pixel_std + + def plot( + self, + bins: int = 30, + good_color: str = "skyblue", + bad_color: str = "salmon", + xlabel: str = "Score", + ylabel: str = "Relative Count", + title: str = "Score Histogram", + legend_labels: tuple[str, str] = ("Good", "Bad"), + ) -> tuple[Figure, str]: + """Generate a histogram of scores. + + Args: + bins (int, optional): Number of histogram bins. Defaults to 30. + good_color (str, optional): Color for good samples. Defaults to "skyblue". + bad_color (str, optional): Color for bad samples. Defaults to "salmon". + xlabel (str, optional): Label for the x-axis. Defaults to "Score". + ylabel (str, optional): Label for the y-axis. Defaults to "Relative Count". + title (str, optional): Title of the plot. Defaults to "Score Histogram". + legend_labels (tuple[str, str], optional): Legend labels for good and bad samples. + Defaults to ("Good", "Bad"). + + Returns: + tuple[Figure, str]: Tuple containing both the figure and the figure + title to be used for logging + + Raises: + ValueError: If no anomaly scores or labels are available. + """ + if len(self.anomaly_scores) == 0: + msg = "No anomaly scores available." + raise ValueError(msg) + if len(self.labels) == 0: + msg = "No labels available." + raise ValueError(msg) + + fig, _ = plot_score_histogram( + scores=torch.hstack(self.anomaly_scores), + labels=torch.hstack(self.labels), + bins=bins, + good_color=good_color, + bad_color=bad_color, + xlabel=xlabel, + ylabel=ylabel, + title=title, + legend_labels=legend_labels, + ) + + return fig, title diff --git a/src/anomalib/metrics/utils.py b/src/anomalib/metrics/utils.py index d4b3c5126e..1cefe4027a 100644 --- a/src/anomalib/metrics/utils.py +++ b/src/anomalib/metrics/utils.py @@ -1,7 +1,7 @@ # Copyright (C) 2022-2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -"""Helper functions to generate ROC-style plots of various metrics. +"""Helper functions to generate plots of various metrics. This module provides utility functions for generating ROC-style plots and other visualization helpers used by metrics in Anomalib. @@ -15,6 +15,64 @@ from anomalib.utils.deprecation import deprecate +def plot_score_histogram( + scores: torch.Tensor, + labels: torch.Tensor, + bins: int = 30, + good_color: str = "skyblue", + bad_color: str = "salmon", + xlabel: str = "Score", + ylabel: str = "Relative Count", + title: str = "Score Histogram", + legend_labels: tuple[str, str] = ("Good", "Bad"), +) -> tuple[Figure, Axes]: + """Plot a histogram of scores using two colors for good and bad images. + + Args: + scores (torch.Tensor): 1D tensor of scores for all images. + labels (torch.Tensor): 1D tensor of binary labels (0=good, 1=bad). + bins (int, optional): Number of histogram bins. Defaults to ``30``. + good_color (str, optional): Color for good images. Defaults to "skyblue". + bad_color (str, optional): Color for bad images. Defaults to "salmon". + xlabel (str, optional): Label for x-axis. Defaults to "Score". + ylabel (str, optional): Label for y-axis. Defaults to "Relative Count". + title (str, optional): Title of the plot. Defaults to "Score Histogram". + legend_labels (tuple[str, str], optional): Labels for legend. Defaults to ("Good", "Bad"). + + Returns: + tuple[Figure, Axes]: Tuple containing the figure and its main axis. + """ + fig, axis = plt.subplots() + scores = scores.detach().cpu().numpy() + labels = labels.detach().cpu().numpy() + + good_scores = scores[labels == 0] + bad_scores = scores[labels == 1] + + axis.hist( + good_scores, + bins=bins, + color=good_color, + alpha=0.7, + label=legend_labels[0], + density=True, + ) + axis.hist( + bad_scores, + bins=bins, + color=bad_color, + alpha=0.7, + label=legend_labels[1], + density=True, + ) + + axis.set_xlabel(xlabel) + axis.set_ylabel(ylabel) + axis.set_title(title) + axis.legend() + return fig, axis + + def plot_metric_curve( x_vals: torch.Tensor, y_vals: torch.Tensor, From 575632ffd965afe9467c3fd9dec615865e0cabff Mon Sep 17 00:00:00 2001 From: Aimira Baitieva <63813435+abc-125@users.noreply.github.com> Date: Fri, 5 Sep 2025 18:03:38 +0200 Subject: [PATCH 6/6] =?UTF-8?q?=F0=9F=9A=80=20feat(metric):=20added=20PGn,?= =?UTF-8?q?=20PBn=20metrics=20(#2889)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * added pg and pb metrics * fixed typos * Update __init__.py * Update __init__.py - removed duplicate * Update pg_pb.py * Fixed pre-commit checks --------- Co-authored-by: Samet Akcay Co-authored-by: Rajesh Gangireddy Signed-off-by: StarPlatinum7 <2732981250@qq.com> --- src/anomalib/metrics/__init__.py | 5 + src/anomalib/metrics/pg_pb.py | 219 +++++++++++++++++++++++++++++++ tests/unit/metrics/test_pg_pb.py | 68 ++++++++++ 3 files changed, 292 insertions(+) create mode 100644 src/anomalib/metrics/pg_pb.py create mode 100644 tests/unit/metrics/test_pg_pb.py diff --git a/src/anomalib/metrics/__init__.py b/src/anomalib/metrics/__init__.py index ab57b7fca7..819d54e6cd 100644 --- a/src/anomalib/metrics/__init__.py +++ b/src/anomalib/metrics/__init__.py @@ -25,6 +25,8 @@ - ``BinaryPrecisionRecallCurve``: Computes precision-recall curves - ``Evaluator``: Combines multiple metrics for evaluation - ``MinMax``: Normalizes scores to [0,1] range + - ``PBn``: Presorted bad with n% good samples misclassified + - ``PGn``: Presorted good with n% bad samples missed - ``PRO``: Per-Region Overlap score - ``PIMO``: Per-Image Missed Overlap score @@ -56,6 +58,7 @@ from .evaluator import Evaluator from .f1_score import F1Max, F1Score from .min_max import MinMax +from .pg_pb import PBn, PGn from .pimo import AUPIMO, PIMO from .precision_recall_curve import BinaryPrecisionRecallCurve from .pro import PRO @@ -75,6 +78,8 @@ "F1Score", "ManualThreshold", "MinMax", + "PGn", + "PBn", "PRO", "PIMO", "AUPIMO", diff --git a/src/anomalib/metrics/pg_pb.py b/src/anomalib/metrics/pg_pb.py new file mode 100644 index 0000000000..49114e4ac4 --- /dev/null +++ b/src/anomalib/metrics/pg_pb.py @@ -0,0 +1,219 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""PGn and PBn metrics for binary image-level classification tasks. + +This module provides two metrics for evaluating binary image-level classification performance +on the assumption that bad (anomalous) samples are considered to be the positive class: + +- ``PGn``: Presorted good with n% bad samples missed, can be interpreted as true negative rate +at a fixed false negative rate (TNR@nFNR). +- ``PBn``: Presorted bad with n% good samples misclassified, can be interpreted as true positive rate +at a fixed false positive rate (TPR@nFPR). + +These metrics emphasize the practical applications of anomaly detection models by showing their potential +to reduce human operator workload while maintaining an acceptable level of misclassification. + +Example: + >>> from anomalib.metrics import PGn, PBn + >>> from anomalib.data import ImageBatch + >>> import torch + >>> # Create sample batch + >>> batch = ImageBatch( + ... image=torch.rand(4, 3, 32, 32), + ... pred_score=torch.tensor([0.1, 0.4, 0.35, 0.8]), + ... gt_label=torch.tensor([0, 0, 1, 1]) + ... ) + >>> pg = PGn(fnr=0.2) + >>> # Print name of the metric + >>> print(pg.name) + PG20 + >>> # Compute PGn score + >>> pg.update(batch) + >>> pg.compute() + tensor(1.0) + >>> pb = PBn(fpr=0.2) + >>> # Print name of the metric + >>> print(pb.name) + PB20 + >>> # Compute PBn score + >>> pb.update(batch) + >>> pb.compute() + tensor(1.0) + +Note: + Scores for both metrics range from 0 to 1, with 1 indicating perfect separation + of the respective class with ``n``% or less of the other class misclassified. + +Reference: + Aimira Baitieva, Yacine Bouaouni, Alexandre Briot, Dick Ameln, Souhaiel Khalfaoui, + Samet Akcay; Beyond Academic Benchmarks: Critical Analysis and Best Practices + for Visual Industrial Anomaly Detection; in: Proceedings of the IEEE/CVF Conference + on Computer Vision and Pattern Recognition (CVPR) Workshops, 2025, pp. 4024-4034, + https://arxiv.org/abs/2503.23451 +""" + +import torch +from torchmetrics import Metric +from torchmetrics.utilities import dim_zero_cat + +from anomalib.metrics.base import AnomalibMetric + + +class _PGn(Metric): + """Presorted good metric. + + This class calculates the Presorted good (PGn) metric, which is the true negative rate + at a fixed false negative rate. + + Args: + **kwargs: Additional arguments passed to the parent ``Metric`` class. + + Attributes: + fnr (torch.Tensor): Fixed false negative rate (bad parts misclassified). + Defaults to ``0.05``. + + Example: + >>> from anomalib.metrics.pg_pb import _PGn + >>> import torch + >>> # Create sample data + >>> preds = torch.tensor([0.1, 0.4, 0.35, 0.8]) + >>> target = torch.tensor([0, 0, 1, 1]) + >>> # Compute PGn score + >>> pg = _PGn(fnr=0.2) + >>> pg.update(preds, target) + >>> pg.compute() + tensor(1.0) + """ + + def __init__(self, fnr: float = 0.05, **kwargs) -> None: + super().__init__(**kwargs) + if fnr < 0 or fnr > 1: + msg = f"False negative rate must be in the range between 0 and 1, got {fnr}." + raise ValueError(msg) + + self.fnr = torch.tensor(fnr, dtype=torch.float32) + self.name = "PG" + str(int(fnr * 100)) + + self.add_state("preds", default=[], dist_reduce_fx="cat") + self.add_state("target", default=[], dist_reduce_fx="cat") + + def update(self, preds: torch.Tensor, target: torch.Tensor) -> None: + """Update state with new values. + + Args: + preds (torch.Tensor): predictions of the model + target (torch.Tensor): ground truth targets + """ + self.target.append(target) + self.preds.append(preds) + + def compute(self) -> torch.Tensor: + """Compute the PGn score at a given false negative rate. + + Returns: + torch.Tensor: PGn score value. + + Raises: + ValueError: If no negative samples are found. + """ + preds = dim_zero_cat(self.preds) + target = dim_zero_cat(self.target) + + pos_scores = preds[target == 1] + thr_accept = torch.quantile(pos_scores, self.fnr) + + neg_scores = preds[target == 0] + if neg_scores.numel() == 0: + msg = "No negative samples found. Cannot compute PGn score." + raise ValueError(msg) + pg = neg_scores[neg_scores < thr_accept].numel() / neg_scores.numel() + + return torch.tensor(pg, dtype=preds.dtype) + + +class PGn(AnomalibMetric, _PGn): # type: ignore[misc] + """Wrapper to add AnomalibMetric functionality to PGn metric. + + This class wraps the internal ``_PGn`` metric to make it compatible with + Anomalib's batch processing capabilities. + """ + + default_fields = ("pred_score", "gt_label") + + +class _PBn(Metric): + """Presorted bad metric. + + This class calculates the Presorted bad (PBn) metric, which is the true positive rate + at a fixed false positive rate. + + Args: + fpr (float): Fixed false positive rate (good parts misclassified). Defaults to ``0.05``. + **kwargs: Additional arguments passed to the parent ``Metric`` class. + + Example: + >>> from anomalib.metrics import _PBn + >>> import torch + >>> preds = torch.tensor([0.1, 0.4, 0.35, 0.8]) + >>> target = torch.tensor([0, 0, 1, 1]) + >>> pb = _PBn(fpr=0.2) + >>> pb.update(preds, target) + >>> pb.compute() + tensor(1.0) + """ + + def __init__(self, fpr: float = 0.05, **kwargs) -> None: + super().__init__(**kwargs) + if fpr < 0 or fpr > 1: + msg = f"False positive rate must be in the range between 0 and 1, got {fpr}." + raise ValueError(msg) + + self.fpr = torch.tensor(fpr, dtype=torch.float32) + self.name = "PB" + str(int(fpr * 100)) + + self.add_state("preds", default=[], dist_reduce_fx="cat") + self.add_state("target", default=[], dist_reduce_fx="cat") + + def update(self, preds: torch.Tensor, target: torch.Tensor) -> None: + """Update state with new values. + + Args: + preds (torch.Tensor): predictions of the model + target (torch.Tensor): ground truth targets + """ + self.target.append(target) + self.preds.append(preds) + + def compute(self) -> torch.Tensor: + """Compute the PBn score at a given false positive rate. + + Returns: + torch.Tensor: PBn score value. + + Raises: + ValueError: If no positive samples are found. + """ + preds = dim_zero_cat(self.preds) + target = dim_zero_cat(self.target) + + neg_scores = preds[target == 0] + thr_accept = torch.quantile(neg_scores, 1 - self.fpr) + + pos_scores = preds[target == 1] + if pos_scores.numel() == 0: + msg = "No positive samples found. Cannot compute PBn score." + raise ValueError(msg) + pb = pos_scores[pos_scores > thr_accept].numel() / pos_scores.numel() + + return torch.tensor(pb, dtype=preds.dtype) + + +class PBn(AnomalibMetric, _PBn): # type: ignore[misc] + """Wrapper to add AnomalibMetric functionality to PBn metric. + + This class wraps the internal ``_PBn`` metric to make it compatible with + Anomalib's batch processing capabilities. + """ + + default_fields = ("pred_score", "gt_label") diff --git a/tests/unit/metrics/test_pg_pb.py b/tests/unit/metrics/test_pg_pb.py new file mode 100644 index 0000000000..a1164dd877 --- /dev/null +++ b/tests/unit/metrics/test_pg_pb.py @@ -0,0 +1,68 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""Test PGn and PBn metrics.""" + +import pytest +import torch + +from anomalib.metrics.pg_pb import _PBn as PBn +from anomalib.metrics.pg_pb import _PGn as PGn + + +def test_pg_basic() -> None: + """Test PGn metric with simple binary classification.""" + metric = PGn(fnr=0.2) + preds = torch.tensor([0.1, 0.4, 0.35, 0.8]) + labels = torch.tensor([0, 0, 1, 1]) + metric.update(preds, labels) + result = metric.compute() + assert result == torch.tensor(1.0) + assert metric.name == "PG20" + + +def test_pb_basic() -> None: + """Test PBn metric with simple binary classification.""" + metric = PBn(fpr=0.2) + preds = torch.tensor([0.1, 0.4, 0.35, 0.8]) + labels = torch.tensor([0, 0, 1, 1]) + metric.update(preds, labels) + result = metric.compute() + assert result == torch.tensor(1.0) + assert metric.name == "PB20" + + +def test_pg_invalid_fnr() -> None: + """Test PGn metric raises ValueError for invalid fnr.""" + with pytest.raises(ValueError, match="False negative rate must be in the range between 0 and 1"): + PGn(fnr=-0.1) + with pytest.raises(ValueError, match="False negative rate must be in the range between 0 and 1"): + PGn(fnr=1.1) + + +def test_pb_invalid_fpr() -> None: + """Test PBn metric raises ValueError for invalid fpr.""" + with pytest.raises(ValueError, match="False positive rate must be in the range between 0 and 1"): + PBn(fpr=-0.1) + with pytest.raises(ValueError, match="False positive rate must be in the range between 0 and 1"): + PBn(fpr=1.1) + + +def test_pg_no_negatives() -> None: + """Test PGn metric raises ValueError if no negative samples.""" + metric = PGn(fnr=0.1) + preds = torch.tensor([0.5, 0.7]) + labels = torch.tensor([1, 1]) + metric.update(preds, labels) + with pytest.raises(ValueError, match="No negative samples found"): + metric.compute() + + +def test_pb_no_positives() -> None: + """Test PBn metric raises ValueError if no positive samples.""" + metric = PBn(fpr=0.1) + preds = torch.tensor([0.2, 0.3]) + labels = torch.tensor([0, 0]) + metric.update(preds, labels) + with pytest.raises(ValueError, match="No positive samples found"): + metric.compute()