From e028449b0c8c9a383531f86b473321acbc73a1a1 Mon Sep 17 00:00:00 2001 From: alfieroddan <51797647+alfieroddan@users.noreply.github.com> Date: Thu, 28 Aug 2025 12:42:19 +0100 Subject: [PATCH 01/17] =?UTF-8?q?=F0=9F=90=9B=20fix(model):=20Reduce=20mem?= =?UTF-8?q?ory=20reserved=20for=20memory=20bank=20based=20models=20PatchCo?= =?UTF-8?q?re,=20Padim,=20Dfkde=20(#2913)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * refactor patchcore, use list instead of cat * typo in patchcore Signed-off-by: Alfie Roddan <51797647+alfieroddan@users.noreply.github.com> * padim model to now use cat instead of list Signed-off-by: Alfie Roddan <51797647+alfieroddan@users.noreply.github.com> * dfm and padim update cat to list Signed-off-by: Alfie Roddan <51797647+alfieroddan@users.noreply.github.com> * add dfkde cat to list changes Signed-off-by: Alfie Roddan <51797647+alfieroddan@users.noreply.github.com> * change assert from memory bank to list size, fix pre-commit issues Signed-off-by: Alfie Roddan <51797647+alfieroddan@users.noreply.github.com> --------- Signed-off-by: Alfie Roddan <51797647+alfieroddan@users.noreply.github.com> Signed-off-by: StarPlatinum7 <2732981250@qq.com> --- src/anomalib/models/image/dfkde/torch_model.py | 13 +++++-------- src/anomalib/models/image/dfm/torch_model.py | 11 ++++------- src/anomalib/models/image/padim/torch_model.py | 13 +++++-------- .../models/image/patchcore/torch_model.py | 18 +++++++++--------- 4 files changed, 23 insertions(+), 32 deletions(-) diff --git a/src/anomalib/models/image/dfkde/torch_model.py b/src/anomalib/models/image/dfkde/torch_model.py index 9af1ef4b65..43ac31bda6 100644 --- a/src/anomalib/models/image/dfkde/torch_model.py +++ b/src/anomalib/models/image/dfkde/torch_model.py @@ -89,7 +89,7 @@ def __init__( feature_scaling_method=feature_scaling_method, max_training_points=max_training_points, ) - self.memory_bank = torch.empty(0) + self.memory_bank: list[torch.tensor] = [] def get_features(self, batch: torch.Tensor) -> torch.Tensor: """Extract features from the pre-trained backbone network. @@ -142,11 +142,7 @@ def forward(self, batch: torch.Tensor) -> torch.Tensor | InferenceBatch: # 1. apply feature extraction features = self.get_features(batch) if self.training: - if self.memory_bank.size(0) == 0: - self.memory_bank = features - else: - new_bank = torch.cat((self.memory_bank, features), dim=0).to(self.memory_bank) - self.memory_bank = new_bank + self.memory_bank.append(features) return features # 2. apply density estimation @@ -164,12 +160,13 @@ def fit(self) -> None: Raises: ValueError: If the memory bank is empty. """ - if self.memory_bank.size(0) == 0: + if len(self.memory_bank) == 0: msg = "Memory bank is empty. Cannot perform coreset selection." raise ValueError(msg) + self.memory_bank = torch.vstack(self.memory_bank) # fit gaussian self.classifier.fit(self.memory_bank) # clear memory bank, redcues gpu size - self.memory_bank = torch.empty(0).to(self.memory_bank) + self.memory_bank = [] diff --git a/src/anomalib/models/image/dfm/torch_model.py b/src/anomalib/models/image/dfm/torch_model.py index 3d5a366780..87d8740bb9 100644 --- a/src/anomalib/models/image/dfm/torch_model.py +++ b/src/anomalib/models/image/dfm/torch_model.py @@ -153,17 +153,18 @@ def __init__( layers=[layer], ).eval() - self.memory_bank = torch.empty(0) + self.memory_bank: list[torch.tensor] = [] def fit(self) -> None: """Fit PCA and Gaussian model to dataset.""" + self.memory_bank = torch.vstack(self.memory_bank) self.pca_model.fit(self.memory_bank) if self.score_type == "nll": features_reduced = self.pca_model.transform(self.memory_bank) self.gaussian_model.fit(features_reduced.T) # clear memory bank, reduces GPU size - self.memory_bank = torch.empty(0).to(self.memory_bank) + self.memory_bank = [] def score(self, features: torch.Tensor, feature_shapes: tuple) -> torch.Tensor: """Compute anomaly scores. @@ -228,11 +229,7 @@ def forward(self, batch: torch.Tensor) -> torch.Tensor | InferenceBatch: feature_vector, feature_shapes = self.get_features(batch) if self.training: - if self.memory_bank.size(0) == 0: - self.memory_bank = feature_vector - else: - new_bank = torch.cat((self.memory_bank, feature_vector), dim=0).to(self.memory_bank) - self.memory_bank = new_bank + self.memory_bank.append(feature_vector) return feature_vector pred_score, anomaly_map = self.score(feature_vector.view(feature_vector.shape[:2]), feature_shapes) diff --git a/src/anomalib/models/image/padim/torch_model.py b/src/anomalib/models/image/padim/torch_model.py index d7367149a0..c0f017d9eb 100644 --- a/src/anomalib/models/image/padim/torch_model.py +++ b/src/anomalib/models/image/padim/torch_model.py @@ -147,7 +147,7 @@ def __init__( self.anomaly_map_generator = AnomalyMapGenerator() self.gaussian = MultiVariateGaussian() - self.memory_bank = torch.empty(0) + self.memory_bank: list[torch.tensor] = [] def forward(self, input_tensor: torch.Tensor) -> torch.Tensor | InferenceBatch: """Forward-pass image-batch (N, C, H, W) into model to extract features. @@ -183,11 +183,7 @@ def forward(self, input_tensor: torch.Tensor) -> torch.Tensor | InferenceBatch: embeddings = self.tiler.untile(embeddings) if self.training: - if self.memory_bank.size(0) == 0: - self.memory_bank = embeddings - else: - new_bank = torch.cat((self.memory_bank, embeddings), dim=0).to(self.memory_bank) - self.memory_bank = new_bank + self.memory_bank.append(embeddings) return embeddings anomaly_map = self.anomaly_map_generator( @@ -234,12 +230,13 @@ def fit(self) -> None: Raises: ValueError: If the memory bank is empty. """ - if self.memory_bank.size(0) == 0: + if len(self.memory_bank) == 0: msg = "Memory bank is empty. Cannot perform coreset selection." raise ValueError(msg) + self.memory_bank = torch.vstack(self.memory_bank) # fit gaussian self.gaussian.fit(self.memory_bank) # clear memory bank, redcues gpu usage - self.memory_bank = torch.empty(0).to(self.memory_bank) + self.memory_bank = [] diff --git a/src/anomalib/models/image/patchcore/torch_model.py b/src/anomalib/models/image/patchcore/torch_model.py index 1970e6dcc2..20b09de613 100644 --- a/src/anomalib/models/image/patchcore/torch_model.py +++ b/src/anomalib/models/image/patchcore/torch_model.py @@ -126,6 +126,7 @@ def __init__( self.anomaly_map_generator = AnomalyMapGenerator() self.memory_bank: torch.Tensor self.register_buffer("memory_bank", torch.empty(0)) + self.embedding_store: list[torch.tensor] = [] def forward(self, input_tensor: torch.Tensor) -> torch.Tensor | InferenceBatch: """Process input tensor through the model. @@ -169,11 +170,7 @@ def forward(self, input_tensor: torch.Tensor) -> torch.Tensor | InferenceBatch: embedding = self.reshape_embedding(embedding) if self.training: - if self.memory_bank.size(0) == 0: - self.memory_bank = embedding - else: - new_bank = torch.cat((self.memory_bank, embedding), dim=0).to(self.memory_bank) - self.memory_bank = new_bank + self.embedding_store.append(embedding) return embedding # Ensure memory bank is not empty @@ -272,13 +269,16 @@ def subsample_embedding(self, sampling_ratio: float, embeddings: torch.Tensor = if embeddings is not None: del embeddings - if self.memory_bank.size(0) == 0: - msg = "Memory bank is empty. Cannot perform coreset selection." + if len(self.embedding_store) == 0: + msg = "Embedding store is empty. Cannot perform coreset selection." raise ValueError(msg) + # Coreset Subsampling + self.memory_bank = torch.vstack(self.embedding_store) + self.embedding_store.clear() + sampler = KCenterGreedy(embedding=self.memory_bank, sampling_ratio=sampling_ratio) - coreset = sampler.sample_coreset().to(self.memory_bank) - self.memory_bank = coreset + self.memory_bank = sampler.sample_coreset() @staticmethod def euclidean_dist(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: From ce4097b3241b3302bec29edf73612e4da75c7932 Mon Sep 17 00:00:00 2001 From: StarPlatinum7 <2732981250@qq.com> Date: Wed, 3 Sep 2025 16:45:34 +0800 Subject: [PATCH 02/17] add mebin postprocessor Signed-off-by: StarPlatinum7 <2732981250@qq.com> --- CHANGELOG.md | 2 + src/anomalib/metrics/__init__.py | 3 +- src/anomalib/metrics/threshold/__init__.py | 3 +- src/anomalib/metrics/threshold/mebin.py | 250 ++++++++++++++++++ src/anomalib/post_processing/__init__.py | 3 +- .../post_processing/mebin_post_processor.py | 127 +++++++++ 6 files changed, 385 insertions(+), 3 deletions(-) create mode 100644 src/anomalib/metrics/threshold/mebin.py create mode 100644 src/anomalib/post_processing/mebin_post_processor.py diff --git a/CHANGELOG.md b/CHANGELOG.md index c4933a7d86..74cf439169 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). ## [Unreleased] ### Added +- ๐Ÿš€ Add MEBin post-processing method + ### Removed diff --git a/src/anomalib/metrics/__init__.py b/src/anomalib/metrics/__init__.py index 6d4d0b64e0..ab57b7fca7 100644 --- a/src/anomalib/metrics/__init__.py +++ b/src/anomalib/metrics/__init__.py @@ -59,7 +59,7 @@ from .pimo import AUPIMO, PIMO from .precision_recall_curve import BinaryPrecisionRecallCurve from .pro import PRO -from .threshold import F1AdaptiveThreshold, ManualThreshold +from .threshold import F1AdaptiveThreshold, ManualThreshold ,MEBin __all__ = [ "AUROC", @@ -78,4 +78,5 @@ "PRO", "PIMO", "AUPIMO", + "MEBin", ] diff --git a/src/anomalib/metrics/threshold/__init__.py b/src/anomalib/metrics/threshold/__init__.py index 1bf10854ec..0cebd87b82 100644 --- a/src/anomalib/metrics/threshold/__init__.py +++ b/src/anomalib/metrics/threshold/__init__.py @@ -24,5 +24,6 @@ from .base import BaseThreshold, Threshold from .f1_adaptive_threshold import F1AdaptiveThreshold from .manual_threshold import ManualThreshold +from .mebin import MEBin -__all__ = ["BaseThreshold", "Threshold", "F1AdaptiveThreshold", "ManualThreshold"] +__all__ = ["BaseThreshold", "Threshold", "F1AdaptiveThreshold", "ManualThreshold", "MEBin"] diff --git a/src/anomalib/metrics/threshold/mebin.py b/src/anomalib/metrics/threshold/mebin.py new file mode 100644 index 0000000000..c823ea3b5f --- /dev/null +++ b/src/anomalib/metrics/threshold/mebin.py @@ -0,0 +1,250 @@ +"""MEBin adaptive thresholding algorithm for anomaly detection. + +This module provides the ``MEBin`` class which automatically finds the optimal +threshold value by analyzing the stability of connected components across +multiple threshold levels. + +The threshold is computed by: +1. Sampling anomaly maps at configurable rates across threshold range +2. Counting connected components at each threshold level +3. Finding stable intervals where component count remains constant +4. Selecting threshold from the longest stable interval + +Example: + >>> from anomalib.metrics.threshold import MEBin + >>> import numpy as np + >>> # Create sample anomaly maps + >>> anomaly_maps = [np.random.rand(256, 256) * 255 for _ in range(10)] + >>> # Initialize and compute thresholds + >>> mebin = MEBin(anomaly_map_list=anomaly_maps, sample_rate=4) + >>> binarized_maps, thresholds = mebin.binarize_anomaly_maps() + >>> print(f"Computed {len(thresholds)} thresholds") + +Note: + The algorithm works best when anomaly maps contain clear separation between + normal and anomalous regions. The min_interval_len parameter should be tuned + based on the expected stability of anomaly score distributions. +""" + +import cv2 +import numpy as np +from tqdm import tqdm + + + +class MEBin: + """MEBin adaptive thresholding algorithm for anomaly detection. + + This class implements the MEBin (Minimum Entropy Binarization) algorithm + which automatically determines optimal thresholds for converting continuous + anomaly maps to binary masks by analyzing the stability of connected + component counts across different threshold levels. + + The algorithm works by: + - Sampling anomaly maps at configurable rates across threshold range + - Counting connected components at each threshold level + - Identifying stable intervals where component count remains constant + - Selecting the optimal threshold from the longest stable interval + - Optionally applying morphological erosion to reduce noise + + Args: + anomaly_map_path_list (list, optional): List of file paths to anomaly maps. + If provided, maps will be loaded as grayscale images. + Defaults to None. + anomaly_map_list (list, optional): List of anomaly map arrays. If provided, + maps should be numpy arrays. + Defaults to None. + sample_rate (int, optional): Sampling rate for threshold search. Higher + values reduce processing time but may affect accuracy. + Defaults to 4. + min_interval_len (int, optional): Minimum length of stable intervals. + Should be tuned based on the expected stability of anomaly score + distributions. + Defaults to 4. + erode (bool, optional): Whether to apply morphological erosion to + binarized results to reduce noise. + Defaults to True. + + Example: + >>> from anomalib.metrics.threshold import MEBin + >>> import numpy as np + >>> # Create sample anomaly maps + >>> anomaly_maps = [np.random.rand(256, 256) * 255 for _ in range(10)] + >>> # Initialize MEBin + >>> mebin = MEBin(anomaly_map_list=anomaly_maps, sample_rate=4) + >>> # Compute binary masks and thresholds + >>> binarized_maps, thresholds = mebin.binarize_anomaly_maps() + """ + + def __init__(self, anomaly_map_path_list=None, sample_rate=4, min_interval_len=4, erode=True): + + self.anomaly_map_path_list = anomaly_map_path_list + # Load anomaly maps as grayscale images if paths are provided + self.anomaly_map_list = [cv2.imread(x, cv2.IMREAD_GRAYSCALE) for x in self.anomaly_map_path_list] + + self.sample_rate = sample_rate + self.min_interval_len = min_interval_len + self.erode = erode + + # Adaptively determine the threshold search range + self.max_th, self.min_th = self.get_search_range() + + + def get_search_range(self): + """Determine the threshold search range adaptively. + + This method analyzes all anomaly maps to determine the minimum and maximum + threshold values for the binarization process. The search range is based + on the actual anomaly score distributions in the input maps. + + Returns: + max_th (int): Maximum threshold for binarization. + min_th (int): Minimum threshold for binarization. + """ + # Get the anomaly scores of all anomaly maps + anomaly_score_list = [np.max(x) for x in self.anomaly_map_list] + + # Select the maximum and minimum anomaly scores from images + max_score, min_score = max(anomaly_score_list), min(anomaly_score_list) + max_th, min_th = max_score, min_score + + print(f"Value range: {min_score} - {max_score}") + + return max_th, min_th + + + + def get_threshold(self, anomaly_num_sequence, min_interval_len): + """ + Find the 'stable interval' in the anomaly region number sequence. + Stable Interval: A continuous threshold range in which the number of connected components remains constant, + and the length of the threshold range is greater than or equal to the given length threshold (min_interval_len). + + Args: + anomaly_num_sequence (list): Sequence of connected component counts + at each threshold level, ordered from high to low threshold. + min_interval_len (int): Minimum length requirement for stable intervals. + Longer intervals indicate more robust threshold selection. + + Returns: + threshold (int): The final threshold for binarization. + est_anomaly_num (int): The estimated number of anomalies. + """ + interval_result = {} + current_index = 0 + while current_index < len(anomaly_num_sequence): + end = current_index + + start = end + + # Find the interval where the connected component count remains constant. + if len(set(anomaly_num_sequence[start:end+1])) == 1 and anomaly_num_sequence[start] != 0: + # Move the 'end' pointer forward until a different connected component number is encountered. + while end < len(anomaly_num_sequence)-1 and anomaly_num_sequence[end] == anomaly_num_sequence[end+1]: + end += 1 + current_index += 1 + # If the length of the current stable interval is greater than or equal to the given threshold (min_interval_len), record this interval. + if end - start + 1 >= min_interval_len: + if anomaly_num_sequence[start] not in interval_result: + interval_result[anomaly_num_sequence[start]] = [(start, end)] + else: + interval_result[anomaly_num_sequence[start]].append((start, end)) + current_index += 1 + + """ + If a 'stable interval' exists, calculate the final threshold based on the longest stable interval. + If no stable interval is found, it indicates that no anomaly regions exist, and 255 is returned. + """ + + if interval_result: + # Iterate through the stable intervals, calculating their lengths and corresponding number of connected component. + count_result = {} + for anomaly_num in interval_result: + count_result[anomaly_num] = max([x[1] - x[0] for x in interval_result[anomaly_num]]) + est_anomaly_num = max(count_result, key=count_result.get) + est_anomaly_num_interval_result = interval_result[est_anomaly_num] + + # Find the longest stable interval. + longest_interval = sorted(est_anomaly_num_interval_result, key=lambda x: x[1] - x[0])[-1] + + # Use the endpoint threshold of the longest stable interval as the final threshold. + index = longest_interval[1] + threshold = 255 - index * self.sample_rate + threshold = int(threshold*(self.max_th - self.min_th)/255 + self.min_th) + return threshold, est_anomaly_num + else: + return 255, 0 + + + def bin_and_erode(self, anomaly_map, threshold): + """Binarize anomaly map and optionally apply erosion. + + This method converts a continuous anomaly map to a binary mask using + the specified threshold, and optionally applies morphological erosion + to reduce noise and smooth the boundaries of anomaly regions. + + The binarization process: + 1. Pixels above threshold become 255 (anomalous) + 2. Pixels below threshold become 0 (normal) + 3. Optional erosion with 6x6 kernel to reduce noise + + Args: + anomaly_map (numpy.ndarray): Input anomaly map with continuous + anomaly scores to be binarized. + threshold (int): Threshold value for binarization. Pixels with + values above this threshold are considered anomalous. + + Returns: + numpy.ndarray: Binary mask where 255 indicates anomalous regions + and 0 indicates normal regions. The result is of type uint8. + + Note: + Erosion is applied with a 6x6 kernel and 1 iteration to balance + noise reduction with preservation of anomaly boundaries. + """ + bin_result = np.where(anomaly_map > threshold, 255, 0).astype(np.uint8) + + # Apply erosion operation to the binarized result + if self.erode: + kernel_size = 6 + iter_num = 1 + kernel = np.ones((kernel_size, kernel_size), np.uint8) + bin_result = cv2.erode(bin_result, kernel, iterations=iter_num) + return bin_result + + + def binarize_anomaly_maps(self): + """ + Perform binarization within the given threshold search range, + count the number of connected components in the binarized results. + Adaptively determine the threshold according to the count, + and perform binarization on the anomaly maps. + + Returns: + binarized_maps (list): List of binarized images. + thresholds (list): List of thresholds for each image. + """ + self.binarized_maps = [] + self.thresholds = [] + + for i, anomaly_map in enumerate(tqdm(self.anomaly_map_list)): + # Normalize the anomaly map within the given threshold search range. + anomaly_map_norm = np.where(anomaly_map < self.min_th, 0, ((anomaly_map - self.min_th) / (self.max_th - self.min_th)) * 255) + anomaly_num_sequence = [] + + # Search for the threshold from high to low within the given range using the specified sampling rate. + for score in range(255, 0, -self.sample_rate): + bin_result = self.bin_and_erode(anomaly_map_norm, score) + num_labels, *rest = cv2.connectedComponentsWithStats(bin_result, connectivity=8) + anomaly_num = num_labels - 1 + anomaly_num_sequence.append(anomaly_num) + + # Adaptively determine the threshold based on the anomaly connected component count sequence. + threshold, est_anomaly_num = self.get_threshold(anomaly_num_sequence, self.min_interval_len) + + # Binarize the anomaly image based on the determined threshold. + bin_result = self.bin_and_erode(anomaly_map, threshold) + self.binarized_maps.append(bin_result) + self.thresholds.append(threshold) + + return self.binarized_maps, self.thresholds \ No newline at end of file diff --git a/src/anomalib/post_processing/__init__.py b/src/anomalib/post_processing/__init__.py index 5ca7bf598b..7e57c6d82d 100644 --- a/src/anomalib/post_processing/__init__.py +++ b/src/anomalib/post_processing/__init__.py @@ -20,5 +20,6 @@ """ from .post_processor import PostProcessor +from .mebin_post_processor import MEBinPostProcessor -__all__ = ["PostProcessor"] +__all__ = ["PostProcessor", "MEBinPostProcessor"] diff --git a/src/anomalib/post_processing/mebin_post_processor.py b/src/anomalib/post_processing/mebin_post_processor.py new file mode 100644 index 0000000000..8b71fd68c4 --- /dev/null +++ b/src/anomalib/post_processing/mebin_post_processor.py @@ -0,0 +1,127 @@ +"""Post-processing module for MEBin-based anomaly detection results. + +This module provides post-processing functionality for anomaly detection +outputs through the :class:`MEBinPostProcessor` class. + +The MEBin post-processor handles: + - Converting anomaly maps to binary masks using MEBin algorithm + - Sampling anomaly maps at configurable rates for efficient processing + - Applying morphological operations (erosion) to refine binary masks + - Maintaining minimum interval lengths for consistent mask generation + - Formatting results for downstream use + +Example: + >>> from anomalib.post_processing import MEBinPostProcessor + >>> post_processor = MEBinPostProcessor(sample_rate=4, min_interval_len=4) + >>> predictions = post_processor(anomaly_maps=anomaly_maps) +""" + +import argparse +import yaml +import os +import json +import shutil +import cv2 +from tqdm import tqdm +import csv +import sys +sys.path.append(os.getcwd()) + +from anomalib.post_processing import PostProcessor +from anomalib.data import InferenceBatch +import torch +import numpy as np + +from anomalib.metrics import MEBin + + +class MEBinPostProcessor(PostProcessor): + """Post-processor for MEBin-based anomaly detection. + + This class handles post-processing of anomaly detection results by: + - Converting continuous anomaly maps to binary masks using MEBin algorithm + - Sampling anomaly maps at configurable rates for efficient processing + - Applying morphological operations (erosion) to refine binary masks + - Maintaining minimum interval lengths for consistent mask generation + - Formatting results for downstream use + + Args: + sample_rate (int, optional): Threshold sampling step size, + Default to 4 + min_interval_len (int, optional): Minimum length of the stable interval, + can be fine-tuned according to the interval between normal and abnormal score distributions in the anomaly score maps, + decrease if there are many false negatives, increase if there are many false positives. + Default to 4 + erode (bool, optional): Whether to perform erosion after binarization to eliminate noise, + this operation can smooth the change process of the number of abnormal + connected components. + Default to True + **kwargs: Additional keyword arguments passed to parent class. + + Example: + >>> from anomalib.post_processing import MEBinPostProcessor + >>> post_processor = MEBinPostProcessor(sample_rate=4, min_interval_len=4) + >>> predictions = post_processor(anomaly_maps=anomaly_maps) + """ + + def __init__( + self, + sample_rate: int = 4, + min_interval_len: int = 4, + erode: bool = True, + **kwargs, + ) -> None: + super().__init__(**kwargs) + + self.sample_rate = sample_rate + self.min_interval_len = min_interval_len + self.erode = erode + + """Custom post-processor using MEBin algorithm""" + def forward(self, predictions: InferenceBatch) -> InferenceBatch: + """Post-process model predictions using MEBin algorithm. + + This method converts continuous anomaly maps to binary masks using the MEBin + algorithm, which provides efficient and accurate binarization of anomaly + detection results. + + Args: + predictions (InferenceBatch): Batch containing model predictions with + anomaly maps to be processed. + + Returns: + InferenceBatch: Post-processed batch with binary masks generated from + anomaly maps using MEBin algorithm. + + Note: + The method automatically handles tensor-to-numpy conversion and back, + ensuring compatibility with the original tensor device and dtype. + """ + anomaly_maps = predictions.anomaly_map + if isinstance(anomaly_maps, torch.Tensor): + anomaly_maps = anomaly_maps.detach().cpu().numpy() + if anomaly_maps.ndim == 4: + anomaly_maps = anomaly_maps[:, 0, :, :] # Remove channel dimension + + # Normalize to 0-255 and convert to uint8 + norm_maps = [] + for amap in anomaly_maps: + amap_norm = (amap - amap.min()) / (amap.max() - amap.min() + 1e-8) * 255 + norm_maps.append(amap_norm.astype(np.uint8)) + + + mebin = MEBin(anomaly_map_list=norm_maps, sample_rate=self.sample_rate, min_interval_len=self.min_interval_len, erode=self.erode) + binarized_maps, thresholds = mebin.binarize_anomaly_maps() + + # Convert back to torch.Tensor and normalize to 0/1 + pred_masks = torch.stack([torch.from_numpy(bm).to(predictions.anomaly_map.device) for bm in binarized_maps]) + pred_masks = (pred_masks > 0).to(predictions.anomaly_map.dtype) + + return InferenceBatch( + pred_label=predictions.pred_label, + pred_score=predictions.pred_score, + pred_mask=pred_masks, + anomaly_map=predictions.anomaly_map, + ) + + \ No newline at end of file From 1e68949b1faca6913b27a9aad6efc0eb057e2fd0 Mon Sep 17 00:00:00 2001 From: StarPlatinum7 <2732981250@qq.com> Date: Mon, 8 Sep 2025 17:15:40 +0800 Subject: [PATCH 03/17] test mebin postprocessor Signed-off-by: StarPlatinum7 <2732981250@qq.com> --- tests/unit/post_processing/test_mebin_post_processor.py | 4 ++++ 1 file changed, 4 insertions(+) create mode 100644 tests/unit/post_processing/test_mebin_post_processor.py diff --git a/tests/unit/post_processing/test_mebin_post_processor.py b/tests/unit/post_processing/test_mebin_post_processor.py new file mode 100644 index 0000000000..88dc0d5a1b --- /dev/null +++ b/tests/unit/post_processing/test_mebin_post_processor.py @@ -0,0 +1,4 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""Test the MeBinPostProcessor class.""" \ No newline at end of file From 3603157878206f111a35ab9120d0affb4684fc79 Mon Sep 17 00:00:00 2001 From: StarPlatinum7 <2732981250@qq.com> Date: Mon, 8 Sep 2025 18:27:51 +0800 Subject: [PATCH 04/17] test mebin postprocessor Signed-off-by: StarPlatinum7 <2732981250@qq.com> --- .../test_mebin_post_processor.py | 213 +++++++++++++++++- 1 file changed, 212 insertions(+), 1 deletion(-) diff --git a/tests/unit/post_processing/test_mebin_post_processor.py b/tests/unit/post_processing/test_mebin_post_processor.py index 88dc0d5a1b..25925d40cd 100644 --- a/tests/unit/post_processing/test_mebin_post_processor.py +++ b/tests/unit/post_processing/test_mebin_post_processor.py @@ -1,4 +1,215 @@ # Copyright (C) 2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -"""Test the MeBinPostProcessor class.""" \ No newline at end of file +"""Test the MEBinPostProcessor class.""" + +import pytest +import torch +import numpy as np +from unittest.mock import Mock, patch + +from anomalib.data import InferenceBatch +from anomalib.post_processing import MEBinPostProcessor + + +class TestMEBinPostProcessor: + """Test the MEBinPostProcessor class.""" + + @staticmethod + def test_initialization_default_params() -> None: + """Test MEBinPostProcessor initialization with default parameters.""" + processor = MEBinPostProcessor() + + assert processor.sample_rate == 4 + assert processor.min_interval_len == 4 + assert processor.erode is True + + @staticmethod + @pytest.mark.parametrize( + ("sample_rate", "min_interval_len", "erode"), + [ + (2, 3, True), + (8, 6, False), + (1, 1, True), + ], + ) + def test_initialization_custom_params( + sample_rate: int, + min_interval_len: int, + erode: bool + ) -> None: + """Test MEBinPostProcessor initialization with custom parameters.""" + processor = MEBinPostProcessor( + sample_rate=sample_rate, + min_interval_len=min_interval_len, + erode=erode + ) + + assert processor.sample_rate == sample_rate + assert processor.min_interval_len == min_interval_len + assert processor.erode == erode + + @staticmethod + @patch('anomalib.post_processing.mebin_post_processor.MEBin') + def test_forward_single_anomaly_map(mock_mebin) -> None: + """Test forward method with single anomaly map.""" + # Setup mock + mock_mebin_instance = Mock() + mock_mebin_instance.binarize_anomaly_maps.return_value = ( + [np.array([[0, 0], [1, 1]], dtype=np.uint8)], + [0.5] + ) + mock_mebin.return_value = mock_mebin_instance + + # Create test data + anomaly_map = torch.rand(1, 1, 4, 4) + predictions = InferenceBatch( + pred_score=torch.tensor([0.8]), + pred_label=torch.tensor([1]), + anomaly_map=anomaly_map, + pred_mask=None + ) + + # Test forward pass + processor = MEBinPostProcessor() + result = processor.forward(predictions) + + # Verify results + assert isinstance(result, InferenceBatch) + assert result.pred_mask is not None + assert result.pred_mask.shape == (1, 2, 2) + assert result.pred_mask.dtype == anomaly_map.dtype + + @staticmethod + @patch('anomalib.post_processing.mebin_post_processor.MEBin') + def test_forward_batch_anomaly_maps(mock_mebin) -> None: + """Test forward method with batch of anomaly maps.""" + # Setup mock + mock_mebin_instance = Mock() + mock_mebin_instance.binarize_anomaly_maps.return_value = ( + [ + np.array([[0, 0], [1, 1]], dtype=np.uint8), + np.array([[1, 0], [0, 1]], dtype=np.uint8) + ], + [0.5, 0.6] + ) + mock_mebin.return_value = mock_mebin_instance + + # Create test data + anomaly_maps = torch.rand(2, 1, 4, 4) + predictions = InferenceBatch( + pred_score=torch.tensor([0.8, 0.9]), + pred_label=torch.tensor([1, 1]), + anomaly_map=anomaly_maps, + pred_mask=None + ) + + # Test forward pass + processor = MEBinPostProcessor() + result = processor.forward(predictions) + + # Verify results + assert isinstance(result, InferenceBatch) + assert result.pred_mask.shape == (2, 2, 2) + + @staticmethod + @patch('anomalib.post_processing.mebin_post_processor.MEBin') + def test_forward_normalization(mock_mebin) -> None: + """Test that anomaly maps are properly normalized to 0-255 range.""" + # Setup mock + mock_mebin_instance = Mock() + mock_mebin_instance.binarize_anomaly_maps.return_value = ( + [np.array([[0, 0], [1, 1]], dtype=np.uint8)], + [0.5] + ) + mock_mebin.return_value = mock_mebin_instance + + # Create test data with specific range + anomaly_maps = torch.tensor([[[[0.0, 0.5], [1.0, 0.2]]]]) + predictions = InferenceBatch( + pred_score=torch.tensor([0.8]), + pred_label=torch.tensor([1]), + anomaly_map=anomaly_maps, + pred_mask=None + ) + + # Test forward pass + processor = MEBinPostProcessor() + result = processor.forward(predictions) + + # Verify MEBin was called with normalized data + mock_mebin.assert_called_once() + call_args = mock_mebin.call_args + anomaly_map_list = call_args[1]['anomaly_map_list'] + + # Check that the data is normalized to 0-255 range + assert len(anomaly_map_list) == 1 + assert anomaly_map_list[0].dtype == np.uint8 + assert anomaly_map_list[0].min() >= 0 + assert anomaly_map_list[0].max() <= 255 + + @staticmethod + @patch('anomalib.post_processing.mebin_post_processor.MEBin') + def test_forward_mebin_parameters(mock_mebin) -> None: + """Test that MEBin is called with correct parameters.""" + # Setup mock + mock_mebin_instance = Mock() + mock_mebin_instance.binarize_anomaly_maps.return_value = ( + [np.array([[0, 0], [1, 1]], dtype=np.uint8)], + [0.5] + ) + mock_mebin.return_value = mock_mebin_instance + + # Create test data + anomaly_maps = torch.rand(1, 1, 4, 4) + predictions = InferenceBatch( + pred_score=torch.tensor([0.8]), + pred_label=torch.tensor([1]), + anomaly_map=anomaly_maps, + pred_mask=None + ) + + # Test with custom parameters + processor = MEBinPostProcessor( + sample_rate=8, + min_interval_len=6, + erode=False + ) + result = processor.forward(predictions) + + # Verify MEBin was called with correct parameters + mock_mebin.assert_called_once_with( + anomaly_map_list=mock_mebin.call_args[1]['anomaly_map_list'], + sample_rate=8, + min_interval_len=6, + erode=False + ) + + @staticmethod + @patch('anomalib.post_processing.mebin_post_processor.MEBin') + def test_forward_binary_mask_conversion(mock_mebin) -> None: + """Test that binary masks are properly converted to 0/1 values.""" + # Setup mock to return masks with values > 0 + mock_mebin_instance = Mock() + mock_mebin_instance.binarize_anomaly_maps.return_value = ( + [np.array([[0, 128], [255, 64]], dtype=np.uint8)], + [0.5] + ) + mock_mebin.return_value = mock_mebin_instance + + # Create test data + anomaly_maps = torch.rand(1, 1, 2, 2) + predictions = InferenceBatch( + pred_score=torch.tensor([0.8]), + pred_label=torch.tensor([1]), + anomaly_map=anomaly_maps, + pred_mask=None + ) + + # Test forward pass + processor = MEBinPostProcessor() + result = processor.forward(predictions) + + # Verify that all values are either 0 or 1 + unique_values = torch.unique(result.pred_mask) + assert torch.all((unique_values == 0) | (unique_values == 1)) From 075352025d1dac8e1f34cc8bfe8472f75bdff9ad Mon Sep 17 00:00:00 2001 From: Aimira Baitieva <63813435+abc-125@users.noreply.github.com> Date: Fri, 5 Sep 2025 14:11:03 +0200 Subject: [PATCH 05/17] =?UTF-8?q?=F0=9F=9A=80=20feat(metric):=20added=20hi?= =?UTF-8?q?stogram=20of=20anomaly=20scores=20(#2920)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * added histogram of anomaly scores * Update src/anomalib/metrics/anomaly_score_distribution.py Co-authored-by: Samet Akcay Signed-off-by: Aimira Baitieva <63813435+abc-125@users.noreply.github.com> * Update anomaly_score_distribution.py * Update anomaly_score_distribution.py * Fixed pre-commit checks --------- Signed-off-by: Aimira Baitieva <63813435+abc-125@users.noreply.github.com> Co-authored-by: Samet Akcay Signed-off-by: StarPlatinum7 <2732981250@qq.com> --- .../metrics/anomaly_score_distribution.py | 78 +++++++++++++++++-- src/anomalib/metrics/utils.py | 60 +++++++++++++- 2 files changed, 129 insertions(+), 9 deletions(-) diff --git a/src/anomalib/metrics/anomaly_score_distribution.py b/src/anomalib/metrics/anomaly_score_distribution.py index 05a271d063..c63ab8811f 100644 --- a/src/anomalib/metrics/anomaly_score_distribution.py +++ b/src/anomalib/metrics/anomaly_score_distribution.py @@ -3,9 +3,11 @@ """Compute statistics of anomaly score distributions. -This module provides the ``AnomalyScoreDistribution`` class which computes mean -and standard deviation statistics of anomaly scores from normal training data. +This module provides the ``AnomalyScoreDistribution`` class, which computes the mean +and standard deviation statistics of anomaly scores. Statistics are computed for both image-level and pixel-level scores. +The ``plot`` method generates a histogram of anomaly scores, +separated by label, to visualize score distributions for normal and abnormal samples. The class tracks: - Image-level statistics: Mean and std of image anomaly scores @@ -17,29 +19,34 @@ >>> # Create sample data >>> scores = torch.tensor([0.1, 0.2, 0.15]) # Image anomaly scores >>> maps = torch.tensor([[0.1, 0.2], [0.15, 0.25]]) # Pixel anomaly maps + >>> labels = torch.tensor([0, 1, 0]) # Binary labels >>> # Initialize and compute stats >>> dist = AnomalyScoreDistribution() - >>> dist.update(anomaly_scores=scores, anomaly_maps=maps) + >>> dist.update(anomaly_scores=scores, anomaly_maps=maps, labels=labels) >>> image_mean, image_std, pixel_mean, pixel_std = dist.compute() + >>> fig, title = dist.plot() Note: The input scores and maps are log-transformed before computing statistics. - Both image-level scores and pixel-level maps are optional inputs. + Image-level scores, pixel-level maps, and labels are optional inputs. """ import torch +from matplotlib.figure import Figure from torchmetrics import Metric +from .utils import plot_score_histogram + class AnomalyScoreDistribution(Metric): """Compute distribution statistics of anomaly scores. This class tracks and computes the mean and standard deviation of anomaly - scores from the normal samples in the training set. Statistics are computed - for both image-level scores and pixel-level anomaly maps. + scores. Statistics are computed for both image-level scores and pixel-level + anomaly maps. - The metric maintains internal state to accumulate scores and maps across - batches before computing final statistics. + The metric maintains internal state to accumulate scores, anomaly maps, + and labels across batches before computing final statistics. Example: >>> dist = AnomalyScoreDistribution() @@ -59,6 +66,7 @@ def __init__(self, **kwargs) -> None: super().__init__(**kwargs) self.anomaly_maps: list[torch.Tensor] = [] self.anomaly_scores: list[torch.Tensor] = [] + self.labels: list[torch.Tensor] = [] self.add_state("image_mean", torch.empty(0), persistent=True) self.add_state("image_std", torch.empty(0), persistent=True) @@ -75,6 +83,7 @@ def update( *args, anomaly_scores: torch.Tensor | None = None, anomaly_maps: torch.Tensor | None = None, + labels: torch.Tensor | None = None, **kwargs, ) -> None: """Update the internal state with new scores and maps. @@ -83,6 +92,7 @@ def update( *args: Unused positional arguments. anomaly_scores: Batch of image-level anomaly scores. anomaly_maps: Batch of pixel-level anomaly maps. + labels: Batch of binary labels. **kwargs: Unused keyword arguments. """ del args, kwargs # These variables are not used. @@ -91,6 +101,8 @@ def update( self.anomaly_maps.append(anomaly_maps) if anomaly_scores is not None: self.anomaly_scores.append(anomaly_scores) + if labels is not None: + self.labels.append(labels) def compute(self) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """Compute distribution statistics from accumulated scores and maps. @@ -116,3 +128,53 @@ def compute(self) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tenso self.pixel_std = anomaly_maps.std(dim=0).squeeze() return self.image_mean, self.image_std, self.pixel_mean, self.pixel_std + + def plot( + self, + bins: int = 30, + good_color: str = "skyblue", + bad_color: str = "salmon", + xlabel: str = "Score", + ylabel: str = "Relative Count", + title: str = "Score Histogram", + legend_labels: tuple[str, str] = ("Good", "Bad"), + ) -> tuple[Figure, str]: + """Generate a histogram of scores. + + Args: + bins (int, optional): Number of histogram bins. Defaults to 30. + good_color (str, optional): Color for good samples. Defaults to "skyblue". + bad_color (str, optional): Color for bad samples. Defaults to "salmon". + xlabel (str, optional): Label for the x-axis. Defaults to "Score". + ylabel (str, optional): Label for the y-axis. Defaults to "Relative Count". + title (str, optional): Title of the plot. Defaults to "Score Histogram". + legend_labels (tuple[str, str], optional): Legend labels for good and bad samples. + Defaults to ("Good", "Bad"). + + Returns: + tuple[Figure, str]: Tuple containing both the figure and the figure + title to be used for logging + + Raises: + ValueError: If no anomaly scores or labels are available. + """ + if len(self.anomaly_scores) == 0: + msg = "No anomaly scores available." + raise ValueError(msg) + if len(self.labels) == 0: + msg = "No labels available." + raise ValueError(msg) + + fig, _ = plot_score_histogram( + scores=torch.hstack(self.anomaly_scores), + labels=torch.hstack(self.labels), + bins=bins, + good_color=good_color, + bad_color=bad_color, + xlabel=xlabel, + ylabel=ylabel, + title=title, + legend_labels=legend_labels, + ) + + return fig, title diff --git a/src/anomalib/metrics/utils.py b/src/anomalib/metrics/utils.py index d4b3c5126e..1cefe4027a 100644 --- a/src/anomalib/metrics/utils.py +++ b/src/anomalib/metrics/utils.py @@ -1,7 +1,7 @@ # Copyright (C) 2022-2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -"""Helper functions to generate ROC-style plots of various metrics. +"""Helper functions to generate plots of various metrics. This module provides utility functions for generating ROC-style plots and other visualization helpers used by metrics in Anomalib. @@ -15,6 +15,64 @@ from anomalib.utils.deprecation import deprecate +def plot_score_histogram( + scores: torch.Tensor, + labels: torch.Tensor, + bins: int = 30, + good_color: str = "skyblue", + bad_color: str = "salmon", + xlabel: str = "Score", + ylabel: str = "Relative Count", + title: str = "Score Histogram", + legend_labels: tuple[str, str] = ("Good", "Bad"), +) -> tuple[Figure, Axes]: + """Plot a histogram of scores using two colors for good and bad images. + + Args: + scores (torch.Tensor): 1D tensor of scores for all images. + labels (torch.Tensor): 1D tensor of binary labels (0=good, 1=bad). + bins (int, optional): Number of histogram bins. Defaults to ``30``. + good_color (str, optional): Color for good images. Defaults to "skyblue". + bad_color (str, optional): Color for bad images. Defaults to "salmon". + xlabel (str, optional): Label for x-axis. Defaults to "Score". + ylabel (str, optional): Label for y-axis. Defaults to "Relative Count". + title (str, optional): Title of the plot. Defaults to "Score Histogram". + legend_labels (tuple[str, str], optional): Labels for legend. Defaults to ("Good", "Bad"). + + Returns: + tuple[Figure, Axes]: Tuple containing the figure and its main axis. + """ + fig, axis = plt.subplots() + scores = scores.detach().cpu().numpy() + labels = labels.detach().cpu().numpy() + + good_scores = scores[labels == 0] + bad_scores = scores[labels == 1] + + axis.hist( + good_scores, + bins=bins, + color=good_color, + alpha=0.7, + label=legend_labels[0], + density=True, + ) + axis.hist( + bad_scores, + bins=bins, + color=bad_color, + alpha=0.7, + label=legend_labels[1], + density=True, + ) + + axis.set_xlabel(xlabel) + axis.set_ylabel(ylabel) + axis.set_title(title) + axis.legend() + return fig, axis + + def plot_metric_curve( x_vals: torch.Tensor, y_vals: torch.Tensor, From 575632ffd965afe9467c3fd9dec615865e0cabff Mon Sep 17 00:00:00 2001 From: Aimira Baitieva <63813435+abc-125@users.noreply.github.com> Date: Fri, 5 Sep 2025 18:03:38 +0200 Subject: [PATCH 06/17] =?UTF-8?q?=F0=9F=9A=80=20feat(metric):=20added=20PG?= =?UTF-8?q?n,=20PBn=20metrics=20(#2889)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * added pg and pb metrics * fixed typos * Update __init__.py * Update __init__.py - removed duplicate * Update pg_pb.py * Fixed pre-commit checks --------- Co-authored-by: Samet Akcay Co-authored-by: Rajesh Gangireddy Signed-off-by: StarPlatinum7 <2732981250@qq.com> --- src/anomalib/metrics/__init__.py | 5 + src/anomalib/metrics/pg_pb.py | 219 +++++++++++++++++++++++++++++++ tests/unit/metrics/test_pg_pb.py | 68 ++++++++++ 3 files changed, 292 insertions(+) create mode 100644 src/anomalib/metrics/pg_pb.py create mode 100644 tests/unit/metrics/test_pg_pb.py diff --git a/src/anomalib/metrics/__init__.py b/src/anomalib/metrics/__init__.py index ab57b7fca7..819d54e6cd 100644 --- a/src/anomalib/metrics/__init__.py +++ b/src/anomalib/metrics/__init__.py @@ -25,6 +25,8 @@ - ``BinaryPrecisionRecallCurve``: Computes precision-recall curves - ``Evaluator``: Combines multiple metrics for evaluation - ``MinMax``: Normalizes scores to [0,1] range + - ``PBn``: Presorted bad with n% good samples misclassified + - ``PGn``: Presorted good with n% bad samples missed - ``PRO``: Per-Region Overlap score - ``PIMO``: Per-Image Missed Overlap score @@ -56,6 +58,7 @@ from .evaluator import Evaluator from .f1_score import F1Max, F1Score from .min_max import MinMax +from .pg_pb import PBn, PGn from .pimo import AUPIMO, PIMO from .precision_recall_curve import BinaryPrecisionRecallCurve from .pro import PRO @@ -75,6 +78,8 @@ "F1Score", "ManualThreshold", "MinMax", + "PGn", + "PBn", "PRO", "PIMO", "AUPIMO", diff --git a/src/anomalib/metrics/pg_pb.py b/src/anomalib/metrics/pg_pb.py new file mode 100644 index 0000000000..49114e4ac4 --- /dev/null +++ b/src/anomalib/metrics/pg_pb.py @@ -0,0 +1,219 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""PGn and PBn metrics for binary image-level classification tasks. + +This module provides two metrics for evaluating binary image-level classification performance +on the assumption that bad (anomalous) samples are considered to be the positive class: + +- ``PGn``: Presorted good with n% bad samples missed, can be interpreted as true negative rate +at a fixed false negative rate (TNR@nFNR). +- ``PBn``: Presorted bad with n% good samples misclassified, can be interpreted as true positive rate +at a fixed false positive rate (TPR@nFPR). + +These metrics emphasize the practical applications of anomaly detection models by showing their potential +to reduce human operator workload while maintaining an acceptable level of misclassification. + +Example: + >>> from anomalib.metrics import PGn, PBn + >>> from anomalib.data import ImageBatch + >>> import torch + >>> # Create sample batch + >>> batch = ImageBatch( + ... image=torch.rand(4, 3, 32, 32), + ... pred_score=torch.tensor([0.1, 0.4, 0.35, 0.8]), + ... gt_label=torch.tensor([0, 0, 1, 1]) + ... ) + >>> pg = PGn(fnr=0.2) + >>> # Print name of the metric + >>> print(pg.name) + PG20 + >>> # Compute PGn score + >>> pg.update(batch) + >>> pg.compute() + tensor(1.0) + >>> pb = PBn(fpr=0.2) + >>> # Print name of the metric + >>> print(pb.name) + PB20 + >>> # Compute PBn score + >>> pb.update(batch) + >>> pb.compute() + tensor(1.0) + +Note: + Scores for both metrics range from 0 to 1, with 1 indicating perfect separation + of the respective class with ``n``% or less of the other class misclassified. + +Reference: + Aimira Baitieva, Yacine Bouaouni, Alexandre Briot, Dick Ameln, Souhaiel Khalfaoui, + Samet Akcay; Beyond Academic Benchmarks: Critical Analysis and Best Practices + for Visual Industrial Anomaly Detection; in: Proceedings of the IEEE/CVF Conference + on Computer Vision and Pattern Recognition (CVPR) Workshops, 2025, pp. 4024-4034, + https://arxiv.org/abs/2503.23451 +""" + +import torch +from torchmetrics import Metric +from torchmetrics.utilities import dim_zero_cat + +from anomalib.metrics.base import AnomalibMetric + + +class _PGn(Metric): + """Presorted good metric. + + This class calculates the Presorted good (PGn) metric, which is the true negative rate + at a fixed false negative rate. + + Args: + **kwargs: Additional arguments passed to the parent ``Metric`` class. + + Attributes: + fnr (torch.Tensor): Fixed false negative rate (bad parts misclassified). + Defaults to ``0.05``. + + Example: + >>> from anomalib.metrics.pg_pb import _PGn + >>> import torch + >>> # Create sample data + >>> preds = torch.tensor([0.1, 0.4, 0.35, 0.8]) + >>> target = torch.tensor([0, 0, 1, 1]) + >>> # Compute PGn score + >>> pg = _PGn(fnr=0.2) + >>> pg.update(preds, target) + >>> pg.compute() + tensor(1.0) + """ + + def __init__(self, fnr: float = 0.05, **kwargs) -> None: + super().__init__(**kwargs) + if fnr < 0 or fnr > 1: + msg = f"False negative rate must be in the range between 0 and 1, got {fnr}." + raise ValueError(msg) + + self.fnr = torch.tensor(fnr, dtype=torch.float32) + self.name = "PG" + str(int(fnr * 100)) + + self.add_state("preds", default=[], dist_reduce_fx="cat") + self.add_state("target", default=[], dist_reduce_fx="cat") + + def update(self, preds: torch.Tensor, target: torch.Tensor) -> None: + """Update state with new values. + + Args: + preds (torch.Tensor): predictions of the model + target (torch.Tensor): ground truth targets + """ + self.target.append(target) + self.preds.append(preds) + + def compute(self) -> torch.Tensor: + """Compute the PGn score at a given false negative rate. + + Returns: + torch.Tensor: PGn score value. + + Raises: + ValueError: If no negative samples are found. + """ + preds = dim_zero_cat(self.preds) + target = dim_zero_cat(self.target) + + pos_scores = preds[target == 1] + thr_accept = torch.quantile(pos_scores, self.fnr) + + neg_scores = preds[target == 0] + if neg_scores.numel() == 0: + msg = "No negative samples found. Cannot compute PGn score." + raise ValueError(msg) + pg = neg_scores[neg_scores < thr_accept].numel() / neg_scores.numel() + + return torch.tensor(pg, dtype=preds.dtype) + + +class PGn(AnomalibMetric, _PGn): # type: ignore[misc] + """Wrapper to add AnomalibMetric functionality to PGn metric. + + This class wraps the internal ``_PGn`` metric to make it compatible with + Anomalib's batch processing capabilities. + """ + + default_fields = ("pred_score", "gt_label") + + +class _PBn(Metric): + """Presorted bad metric. + + This class calculates the Presorted bad (PBn) metric, which is the true positive rate + at a fixed false positive rate. + + Args: + fpr (float): Fixed false positive rate (good parts misclassified). Defaults to ``0.05``. + **kwargs: Additional arguments passed to the parent ``Metric`` class. + + Example: + >>> from anomalib.metrics import _PBn + >>> import torch + >>> preds = torch.tensor([0.1, 0.4, 0.35, 0.8]) + >>> target = torch.tensor([0, 0, 1, 1]) + >>> pb = _PBn(fpr=0.2) + >>> pb.update(preds, target) + >>> pb.compute() + tensor(1.0) + """ + + def __init__(self, fpr: float = 0.05, **kwargs) -> None: + super().__init__(**kwargs) + if fpr < 0 or fpr > 1: + msg = f"False positive rate must be in the range between 0 and 1, got {fpr}." + raise ValueError(msg) + + self.fpr = torch.tensor(fpr, dtype=torch.float32) + self.name = "PB" + str(int(fpr * 100)) + + self.add_state("preds", default=[], dist_reduce_fx="cat") + self.add_state("target", default=[], dist_reduce_fx="cat") + + def update(self, preds: torch.Tensor, target: torch.Tensor) -> None: + """Update state with new values. + + Args: + preds (torch.Tensor): predictions of the model + target (torch.Tensor): ground truth targets + """ + self.target.append(target) + self.preds.append(preds) + + def compute(self) -> torch.Tensor: + """Compute the PBn score at a given false positive rate. + + Returns: + torch.Tensor: PBn score value. + + Raises: + ValueError: If no positive samples are found. + """ + preds = dim_zero_cat(self.preds) + target = dim_zero_cat(self.target) + + neg_scores = preds[target == 0] + thr_accept = torch.quantile(neg_scores, 1 - self.fpr) + + pos_scores = preds[target == 1] + if pos_scores.numel() == 0: + msg = "No positive samples found. Cannot compute PBn score." + raise ValueError(msg) + pb = pos_scores[pos_scores > thr_accept].numel() / pos_scores.numel() + + return torch.tensor(pb, dtype=preds.dtype) + + +class PBn(AnomalibMetric, _PBn): # type: ignore[misc] + """Wrapper to add AnomalibMetric functionality to PBn metric. + + This class wraps the internal ``_PBn`` metric to make it compatible with + Anomalib's batch processing capabilities. + """ + + default_fields = ("pred_score", "gt_label") diff --git a/tests/unit/metrics/test_pg_pb.py b/tests/unit/metrics/test_pg_pb.py new file mode 100644 index 0000000000..a1164dd877 --- /dev/null +++ b/tests/unit/metrics/test_pg_pb.py @@ -0,0 +1,68 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""Test PGn and PBn metrics.""" + +import pytest +import torch + +from anomalib.metrics.pg_pb import _PBn as PBn +from anomalib.metrics.pg_pb import _PGn as PGn + + +def test_pg_basic() -> None: + """Test PGn metric with simple binary classification.""" + metric = PGn(fnr=0.2) + preds = torch.tensor([0.1, 0.4, 0.35, 0.8]) + labels = torch.tensor([0, 0, 1, 1]) + metric.update(preds, labels) + result = metric.compute() + assert result == torch.tensor(1.0) + assert metric.name == "PG20" + + +def test_pb_basic() -> None: + """Test PBn metric with simple binary classification.""" + metric = PBn(fpr=0.2) + preds = torch.tensor([0.1, 0.4, 0.35, 0.8]) + labels = torch.tensor([0, 0, 1, 1]) + metric.update(preds, labels) + result = metric.compute() + assert result == torch.tensor(1.0) + assert metric.name == "PB20" + + +def test_pg_invalid_fnr() -> None: + """Test PGn metric raises ValueError for invalid fnr.""" + with pytest.raises(ValueError, match="False negative rate must be in the range between 0 and 1"): + PGn(fnr=-0.1) + with pytest.raises(ValueError, match="False negative rate must be in the range between 0 and 1"): + PGn(fnr=1.1) + + +def test_pb_invalid_fpr() -> None: + """Test PBn metric raises ValueError for invalid fpr.""" + with pytest.raises(ValueError, match="False positive rate must be in the range between 0 and 1"): + PBn(fpr=-0.1) + with pytest.raises(ValueError, match="False positive rate must be in the range between 0 and 1"): + PBn(fpr=1.1) + + +def test_pg_no_negatives() -> None: + """Test PGn metric raises ValueError if no negative samples.""" + metric = PGn(fnr=0.1) + preds = torch.tensor([0.5, 0.7]) + labels = torch.tensor([1, 1]) + metric.update(preds, labels) + with pytest.raises(ValueError, match="No negative samples found"): + metric.compute() + + +def test_pb_no_positives() -> None: + """Test PBn metric raises ValueError if no positive samples.""" + metric = PBn(fpr=0.1) + preds = torch.tensor([0.2, 0.3]) + labels = torch.tensor([0, 0]) + metric.update(preds, labels) + with pytest.raises(ValueError, match="No positive samples found"): + metric.compute() From fa69cf61116fe94fdafbc2022421a01adb71379f Mon Sep 17 00:00:00 2001 From: Rajesh Gangireddy Date: Wed, 15 Oct 2025 16:05:41 +0200 Subject: [PATCH 07/17] =?UTF-8?q?=E2=9C=A8=20feat(post-processing):=20Make?= =?UTF-8?q?=20=20minor=20changes=20as=20per=20comments?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- CHANGELOG.md | 2 +- .../guides/how_to/models/post_processor.md | 46 ++++- .../guides/reference/post_processing/index.md | 20 +++ src/anomalib/metrics/__init__.py | 2 +- src/anomalib/metrics/threshold/mebin.py | 164 ++++++++++-------- src/anomalib/post_processing/__init__.py | 2 +- .../post_processing/mebin_post_processor.py | 68 ++++---- .../test_mebin_post_processor.py | 119 +++++++------ 8 files changed, 265 insertions(+), 158 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 933fd31243..545f6a2f1b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,8 +7,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). ## [Unreleased] ### Added -- ๐Ÿš€ Add MEBin post-processing method +- ๐Ÿš€ Add MEBin post-processing method ### Removed diff --git a/docs/source/markdown/guides/how_to/models/post_processor.md b/docs/source/markdown/guides/how_to/models/post_processor.md index d2b57aaa22..162164c358 100644 --- a/docs/source/markdown/guides/how_to/models/post_processor.md +++ b/docs/source/markdown/guides/how_to/models/post_processor.md @@ -83,7 +83,7 @@ Normalization and thresholding only works when your datamodule contains a valida ## Basic Usage -To use the `PostProcessor`, simply add it to any Anomalib model when creating the model: +Anomalib provides two main post-processors: the default `PostProcessor` and the `MEBinPostProcessor`. To use the default `PostProcessor`, simply add it to any Anomalib model when creating the model: ```python from anomalib.models import Padim @@ -93,6 +93,16 @@ post_processor = PostProcessor() model = Padim(post_processor=post_processor) ``` +For alternative thresholding approaches, you can use the `MEBinPostProcessor`: + +```python +from anomalib.models import Padim +from anomalib.post_processing import MEBinPostProcessor + +post_processor = MEBinPostProcessor() +model = Padim(post_processor=post_processor) +``` + The post-processor can be configured using its constructor arguments. In the case of the `PostProcessor`, the only configuration parameters are the sensitivity for the thresholding operation on the image- and pixel-level: ```python @@ -180,6 +190,40 @@ pred_labels = results[..., 2] # Already thresholded (0/1) pred_masks = results[..., 3] # Already thresholded masks (if applicable) ``` +## MEBinPostProcessor + +Anomalib provides the `MEBinPostProcessor` which implements the MEBin (Main Element Binarization) algorithm from AnomalyNCD. This post-processor is designed for industrial anomaly detection scenarios where traditional thresholding may not perform optimally. + +MEBin was introduced in "AnomalyNCD: Towards Novel Anomaly Class Discovery in Industrial Scenarios" ([arXiv:2410.14379](https://arxiv.org/abs/2410.14379)). + +### Basic Usage + +```python +from anomalib.models import Padim +from anomalib.post_processing import MEBinPostProcessor + +# Create MEBin post-processor with custom parameters +post_processor = MEBinPostProcessor( + sample_rate=4, # Threshold sampling step size + min_interval_len=4, # Minimum stable interval length + erode=True # Apply erosion to reduce noise +) + +model = Padim(post_processor=post_processor) +``` + +### Basic Usage + +```python +from anomalib.models import Padim +from anomalib.post_processing import MEBinPostProcessor + +post_processor = MEBinPostProcessor() +model = Padim(post_processor=post_processor) +``` + +For more details on the MEBin algorithm, see the [AnomalyNCD paper](https://arxiv.org/abs/2410.14379). + ## Creating Custom Post-processors Advanced users may want to define their own post-processing pipeline. This can be useful when the default post-processing behaviour of the `PostProcessor` is not suitable for the model and its predictions. To create a custom post-processor, inherit from `nn.Module` and `Callback`, and implement your post-processing logic using lightning hooks. Don't forget to also include the post-processing steps in the `forward` method of your class to ensure that the post-processing is included when exporting your model: diff --git a/docs/source/markdown/guides/reference/post_processing/index.md b/docs/source/markdown/guides/reference/post_processing/index.md index 02bdb4d638..493a293413 100644 --- a/docs/source/markdown/guides/reference/post_processing/index.md +++ b/docs/source/markdown/guides/reference/post_processing/index.md @@ -23,6 +23,16 @@ Post-processor for one-class anomaly detection. +++ [Learn more ยป](one-class-post-processor) ::: + +:::{grid-item-card} {octicon}`gear` MEBin Post-processor +:link: mebin-post-processor +:link-type: ref + +MEBin post-processor from AnomalyNCD. + ++++ +[Learn more ยป](mebin-post-processor) +::: :::: (base-post-processor)= @@ -44,3 +54,13 @@ Post-processor for one-class anomaly detection. :members: :show-inheritance: ``` + +(mebin-post-processor)= + +## MEBin Post-processor + +```{eval-rst} +.. automodule:: anomalib.post_processing.mebin_post_processor + :members: + :show-inheritance: +``` diff --git a/src/anomalib/metrics/__init__.py b/src/anomalib/metrics/__init__.py index 819d54e6cd..96017e6eeb 100644 --- a/src/anomalib/metrics/__init__.py +++ b/src/anomalib/metrics/__init__.py @@ -62,7 +62,7 @@ from .pimo import AUPIMO, PIMO from .precision_recall_curve import BinaryPrecisionRecallCurve from .pro import PRO -from .threshold import F1AdaptiveThreshold, ManualThreshold ,MEBin +from .threshold import F1AdaptiveThreshold, ManualThreshold, MEBin __all__ = [ "AUROC", diff --git a/src/anomalib/metrics/threshold/mebin.py b/src/anomalib/metrics/threshold/mebin.py index c823ea3b5f..7524877ce8 100644 --- a/src/anomalib/metrics/threshold/mebin.py +++ b/src/anomalib/metrics/threshold/mebin.py @@ -1,14 +1,27 @@ -"""MEBin adaptive thresholding algorithm for anomaly detection. +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +"""MEBin (Main Element Binarization) adaptive thresholding for anomaly detection. -This module provides the ``MEBin`` class which automatically finds the optimal -threshold value by analyzing the stability of connected components across -multiple threshold levels. +This module provides the ``MEBin`` class which implements the Main Element +Binarization algorithm designed to address the non-prominence of anomalies +in anomaly maps. MEBin obtains anomaly-centered images by analyzing the +stability of connected components across multiple threshold levels. + +The algorithm is particularly effective for: +- Industrial anomaly detection scenarios +- Multi-class anomaly classification tasks +- Cases where anomalies are non-prominent in anomaly maps +- Avoiding the impact of incorrect detections The threshold is computed by: -1. Sampling anomaly maps at configurable rates across threshold range -2. Counting connected components at each threshold level -3. Finding stable intervals where component count remains constant -4. Selecting threshold from the longest stable interval +1. Adaptively determining threshold search range from anomaly map statistics +2. Sampling anomaly maps at configurable rates across threshold range +3. Counting connected components at each threshold level +4. Finding stable intervals where component count remains constant +5. Selecting threshold from the longest stable interval + +MEBin was introduced in "AnomalyNCD: Towards Novel Anomaly Class Discovery +in Industrial Scenarios" (https://arxiv.org/abs/2410.14379). Example: >>> from anomalib.metrics.threshold import MEBin @@ -16,31 +29,34 @@ >>> # Create sample anomaly maps >>> anomaly_maps = [np.random.rand(256, 256) * 255 for _ in range(10)] >>> # Initialize and compute thresholds - >>> mebin = MEBin(anomaly_map_list=anomaly_maps, sample_rate=4) + >>> mebin = MEBin(anomaly_maps, sample_rate=4) >>> binarized_maps, thresholds = mebin.binarize_anomaly_maps() >>> print(f"Computed {len(thresholds)} thresholds") Note: - The algorithm works best when anomaly maps contain clear separation between - normal and anomalous regions. The min_interval_len parameter should be tuned - based on the expected stability of anomaly score distributions. + MEBin is designed for industrial scenarios where anomalies may be + non-prominent. The min_interval_len parameter should be tuned based + on the expected stability of connected component counts. """ +from __future__ import annotations + import cv2 import numpy as np from tqdm import tqdm - class MEBin: - """MEBin adaptive thresholding algorithm for anomaly detection. + """MEBin (Main Element Binarization) adaptive thresholding algorithm. - This class implements the MEBin (Minimum Entropy Binarization) algorithm - which automatically determines optimal thresholds for converting continuous - anomaly maps to binary masks by analyzing the stability of connected - component counts across different threshold levels. + This class implements the Main Element Binarization algorithm designed + to address non-prominent anomalies in industrial anomaly detection scenarios. + MEBin determines optimal thresholds by analyzing the stability of connected + component counts across different threshold levels to obtain anomaly-centered + binary representations. The algorithm works by: + - Adaptively determining threshold search ranges from anomaly statistics - Sampling anomaly maps at configurable rates across threshold range - Counting connected components at each threshold level - Identifying stable intervals where component count remains constant @@ -48,21 +64,16 @@ class MEBin: - Optionally applying morphological erosion to reduce noise Args: - anomaly_map_path_list (list, optional): List of file paths to anomaly maps. - If provided, maps will be loaded as grayscale images. - Defaults to None. - anomaly_map_list (list, optional): List of anomaly map arrays. If provided, - maps should be numpy arrays. - Defaults to None. + anomaly_map_list (list[np.ndarray]): List of anomaly map arrays as numpy arrays. sample_rate (int, optional): Sampling rate for threshold search. Higher - values reduce processing time but may affect accuracy. + values reduce processing time but may affect accuracy. Defaults to 4. min_interval_len (int, optional): Minimum length of stable intervals. Should be tuned based on the expected stability of anomaly score - distributions. + distributions. Defaults to 4. erode (bool, optional): Whether to apply morphological erosion to - binarized results to reduce noise. + binarized results to reduce noise. Defaults to True. Example: @@ -71,26 +82,28 @@ class MEBin: >>> # Create sample anomaly maps >>> anomaly_maps = [np.random.rand(256, 256) * 255 for _ in range(10)] >>> # Initialize MEBin - >>> mebin = MEBin(anomaly_map_list=anomaly_maps, sample_rate=4) + >>> mebin = MEBin(anomaly_maps, sample_rate=4) >>> # Compute binary masks and thresholds >>> binarized_maps, thresholds = mebin.binarize_anomaly_maps() """ - def __init__(self, anomaly_map_path_list=None, sample_rate=4, min_interval_len=4, erode=True): + def __init__( + self, + anomaly_map_list: list[np.ndarray], + sample_rate: int = 4, + min_interval_len: int = 4, + erode: bool = True, + ) -> None: + self.anomaly_map_list = anomaly_map_list - self.anomaly_map_path_list = anomaly_map_path_list - # Load anomaly maps as grayscale images if paths are provided - self.anomaly_map_list = [cv2.imread(x, cv2.IMREAD_GRAYSCALE) for x in self.anomaly_map_path_list] - self.sample_rate = sample_rate self.min_interval_len = min_interval_len self.erode = erode - + # Adaptively determine the threshold search range self.max_th, self.min_th = self.get_search_range() - - - def get_search_range(self): + + def get_search_range(self) -> tuple[float, float]: """Determine the threshold search range adaptively. This method analyzes all anomaly maps to determine the minimum and maximum @@ -112,14 +125,17 @@ def get_search_range(self): return max_th, min_th + def get_threshold( + self, + anomaly_num_sequence: list[int], + min_interval_len: int, + ) -> tuple[int, int]: + """Find the 'stable interval' in the anomaly region number sequence. + Stable Interval: A continuous threshold range in which the number of connected components remains constant, + and the length of the threshold range is greater than or equal to the given length threshold + (min_interval_len). - def get_threshold(self, anomaly_num_sequence, min_interval_len): - """ - Find the 'stable interval' in the anomaly region number sequence. - Stable Interval: A continuous threshold range in which the number of connected components remains constant, - and the length of the threshold range is greater than or equal to the given length threshold (min_interval_len). - Args: anomaly_num_sequence (list): Sequence of connected component counts at each threshold level, ordered from high to low threshold. @@ -133,17 +149,21 @@ def get_threshold(self, anomaly_num_sequence, min_interval_len): interval_result = {} current_index = 0 while current_index < len(anomaly_num_sequence): - end = current_index + end = current_index - start = end + start = end # Find the interval where the connected component count remains constant. - if len(set(anomaly_num_sequence[start:end+1])) == 1 and anomaly_num_sequence[start] != 0: + sequence_slice = anomaly_num_sequence[start : end + 1] + if len(set(sequence_slice)) == 1 and anomaly_num_sequence[start] != 0: # Move the 'end' pointer forward until a different connected component number is encountered. - while end < len(anomaly_num_sequence)-1 and anomaly_num_sequence[end] == anomaly_num_sequence[end+1]: + while ( + end < len(anomaly_num_sequence) - 1 and anomaly_num_sequence[end] == anomaly_num_sequence[end + 1] + ): end += 1 current_index += 1 - # If the length of the current stable interval is greater than or equal to the given threshold (min_interval_len), record this interval. + # If the length of the current stable interval is greater than or equal to the given + # threshold (min_interval_len), record this interval. if end - start + 1 >= min_interval_len: if anomaly_num_sequence[start] not in interval_result: interval_result[anomaly_num_sequence[start]] = [(start, end)] @@ -157,11 +177,12 @@ def get_threshold(self, anomaly_num_sequence, min_interval_len): """ if interval_result: - # Iterate through the stable intervals, calculating their lengths and corresponding number of connected component. + # Iterate through the stable intervals, calculating their lengths and corresponding + # number of connected component. count_result = {} for anomaly_num in interval_result: - count_result[anomaly_num] = max([x[1] - x[0] for x in interval_result[anomaly_num]]) - est_anomaly_num = max(count_result, key=count_result.get) + count_result[anomaly_num] = max(x[1] - x[0] for x in interval_result[anomaly_num]) + est_anomaly_num = max(count_result, key=lambda x: count_result[x]) est_anomaly_num_interval_result = interval_result[est_anomaly_num] # Find the longest stable interval. @@ -170,13 +191,11 @@ def get_threshold(self, anomaly_num_sequence, min_interval_len): # Use the endpoint threshold of the longest stable interval as the final threshold. index = longest_interval[1] threshold = 255 - index * self.sample_rate - threshold = int(threshold*(self.max_th - self.min_th)/255 + self.min_th) + threshold = int(threshold * (self.max_th - self.min_th) / 255 + self.min_th) return threshold, est_anomaly_num - else: - return 255, 0 - - - def bin_and_erode(self, anomaly_map, threshold): + return 255, 0 + + def bin_and_erode(self, anomaly_map: np.ndarray, threshold: int) -> np.ndarray: """Binarize anomaly map and optionally apply erosion. This method converts a continuous anomaly map to a binary mask using @@ -211,40 +230,47 @@ def bin_and_erode(self, anomaly_map, threshold): kernel = np.ones((kernel_size, kernel_size), np.uint8) bin_result = cv2.erode(bin_result, kernel, iterations=iter_num) return bin_result - - def binarize_anomaly_maps(self): - """ - Perform binarization within the given threshold search range, - count the number of connected components in the binarized results. + def binarize_anomaly_maps(self) -> tuple[list[np.ndarray], list[int]]: + """Perform binarization within the given threshold search range. + + Count the number of connected components in the binarized results. Adaptively determine the threshold according to the count, and perform binarization on the anomaly maps. - + Returns: binarized_maps (list): List of binarized images. thresholds (list): List of thresholds for each image. """ self.binarized_maps = [] self.thresholds = [] - - for i, anomaly_map in enumerate(tqdm(self.anomaly_map_list)): + + for anomaly_map in tqdm(self.anomaly_map_list): # Normalize the anomaly map within the given threshold search range. - anomaly_map_norm = np.where(anomaly_map < self.min_th, 0, ((anomaly_map - self.min_th) / (self.max_th - self.min_th)) * 255) + if self.max_th == self.min_th: + # Rare case where all anomaly maps have identical max values + anomaly_map_norm = np.where(anomaly_map < self.min_th, 0, 255) + else: + anomaly_map_norm = np.where( + anomaly_map < self.min_th, + 0, + ((anomaly_map - self.min_th) / (self.max_th - self.min_th)) * 255, + ) anomaly_num_sequence = [] # Search for the threshold from high to low within the given range using the specified sampling rate. for score in range(255, 0, -self.sample_rate): bin_result = self.bin_and_erode(anomaly_map_norm, score) - num_labels, *rest = cv2.connectedComponentsWithStats(bin_result, connectivity=8) + num_labels, *_rest = cv2.connectedComponentsWithStats(bin_result, connectivity=8) anomaly_num = num_labels - 1 anomaly_num_sequence.append(anomaly_num) # Adaptively determine the threshold based on the anomaly connected component count sequence. - threshold, est_anomaly_num = self.get_threshold(anomaly_num_sequence, self.min_interval_len) + threshold, _est_anomaly_num = self.get_threshold(anomaly_num_sequence, self.min_interval_len) # Binarize the anomaly image based on the determined threshold. bin_result = self.bin_and_erode(anomaly_map, threshold) self.binarized_maps.append(bin_result) self.thresholds.append(threshold) - return self.binarized_maps, self.thresholds \ No newline at end of file + return self.binarized_maps, self.thresholds diff --git a/src/anomalib/post_processing/__init__.py b/src/anomalib/post_processing/__init__.py index 7e57c6d82d..bc7a639eea 100644 --- a/src/anomalib/post_processing/__init__.py +++ b/src/anomalib/post_processing/__init__.py @@ -19,7 +19,7 @@ >>> predictions = post_processor(anomaly_maps=anomaly_maps) """ -from .post_processor import PostProcessor from .mebin_post_processor import MEBinPostProcessor +from .post_processor import PostProcessor __all__ = ["PostProcessor", "MEBinPostProcessor"] diff --git a/src/anomalib/post_processing/mebin_post_processor.py b/src/anomalib/post_processing/mebin_post_processor.py index 8b71fd68c4..f17e783601 100644 --- a/src/anomalib/post_processing/mebin_post_processor.py +++ b/src/anomalib/post_processing/mebin_post_processor.py @@ -1,8 +1,13 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + """Post-processing module for MEBin-based anomaly detection results. This module provides post-processing functionality for anomaly detection outputs through the :class:`MEBinPostProcessor` class. +MEBin was introduced in AnomalyNCD : https://arxiv.org/pdf/2410.14379 + The MEBin post-processor handles: - Converting anomaly maps to binary masks using MEBin algorithm - Sampling anomaly maps at configurable rates for efficient processing @@ -16,24 +21,14 @@ >>> predictions = post_processor(anomaly_maps=anomaly_maps) """ -import argparse -import yaml -import os -import json -import shutil -import cv2 -from tqdm import tqdm -import csv -import sys -sys.path.append(os.getcwd()) - -from anomalib.post_processing import PostProcessor -from anomalib.data import InferenceBatch -import torch import numpy as np +import torch +from anomalib.data import InferenceBatch from anomalib.metrics import MEBin +from .post_processor import PostProcessor + class MEBinPostProcessor(PostProcessor): """Post-processor for MEBin-based anomaly detection. @@ -46,16 +41,14 @@ class MEBinPostProcessor(PostProcessor): - Formatting results for downstream use Args: - sample_rate (int, optional): Threshold sampling step size, - Default to 4 - min_interval_len (int, optional): Minimum length of the stable interval, - can be fine-tuned according to the interval between normal and abnormal score distributions in the anomaly score maps, - decrease if there are many false negatives, increase if there are many false positives. - Default to 4 - erode (bool, optional): Whether to perform erosion after binarization to eliminate noise, - this operation can smooth the change process of the number of abnormal - connected components. - Default to True + sample_rate (int): Threshold sampling step size. + Defaults to 4 + min_interval_len (int): Minimum length of the stable interval. Can be adjusted based on the interval + between normal and abnormal score distributions in the anomaly score maps. + Decrease if there are many false negatives, increase if there are many false positives. + Defaults to 4 + erode (bool): Whether to perform erosion after binarization to eliminate noise. + Defaults to True **kwargs: Additional keyword arguments passed to parent class. Example: @@ -78,6 +71,7 @@ def __init__( self.erode = erode """Custom post-processor using MEBin algorithm""" + def forward(self, predictions: InferenceBatch) -> InferenceBatch: """Post-process model predictions using MEBin algorithm. @@ -97,9 +91,13 @@ def forward(self, predictions: InferenceBatch) -> InferenceBatch: The method automatically handles tensor-to-numpy conversion and back, ensuring compatibility with the original tensor device and dtype. """ - anomaly_maps = predictions.anomaly_map - if isinstance(anomaly_maps, torch.Tensor): - anomaly_maps = anomaly_maps.detach().cpu().numpy() + if predictions.anomaly_map is None: + msg = "Anomaly map is required for MEBin post-processing" + raise ValueError(msg) + + # Store the original tensor for device and dtype info + original_anomaly_map = predictions.anomaly_map + anomaly_maps = original_anomaly_map.detach().cpu().numpy() if anomaly_maps.ndim == 4: anomaly_maps = anomaly_maps[:, 0, :, :] # Remove channel dimension @@ -109,13 +107,17 @@ def forward(self, predictions: InferenceBatch) -> InferenceBatch: amap_norm = (amap - amap.min()) / (amap.max() - amap.min() + 1e-8) * 255 norm_maps.append(amap_norm.astype(np.uint8)) - - mebin = MEBin(anomaly_map_list=norm_maps, sample_rate=self.sample_rate, min_interval_len=self.min_interval_len, erode=self.erode) - binarized_maps, thresholds = mebin.binarize_anomaly_maps() + mebin = MEBin( + anomaly_map_list=norm_maps, + sample_rate=self.sample_rate, + min_interval_len=self.min_interval_len, + erode=self.erode, + ) + binarized_maps, _ = mebin.binarize_anomaly_maps() # Convert back to torch.Tensor and normalize to 0/1 - pred_masks = torch.stack([torch.from_numpy(bm).to(predictions.anomaly_map.device) for bm in binarized_maps]) - pred_masks = (pred_masks > 0).to(predictions.anomaly_map.dtype) + pred_masks = torch.stack([torch.from_numpy(bm).to(original_anomaly_map.device) for bm in binarized_maps]) + pred_masks = (pred_masks > 0).to(original_anomaly_map.dtype) return InferenceBatch( pred_label=predictions.pred_label, @@ -123,5 +125,3 @@ def forward(self, predictions: InferenceBatch) -> InferenceBatch: pred_mask=pred_masks, anomaly_map=predictions.anomaly_map, ) - - \ No newline at end of file diff --git a/tests/unit/post_processing/test_mebin_post_processor.py b/tests/unit/post_processing/test_mebin_post_processor.py index 25925d40cd..2d0813d320 100644 --- a/tests/unit/post_processing/test_mebin_post_processor.py +++ b/tests/unit/post_processing/test_mebin_post_processor.py @@ -3,10 +3,11 @@ """Test the MEBinPostProcessor class.""" +from unittest.mock import MagicMock, Mock, patch + +import numpy as np import pytest import torch -import numpy as np -from unittest.mock import Mock, patch from anomalib.data import InferenceBatch from anomalib.post_processing import MEBinPostProcessor @@ -19,7 +20,7 @@ class TestMEBinPostProcessor: def test_initialization_default_params() -> None: """Test MEBinPostProcessor initialization with default parameters.""" processor = MEBinPostProcessor() - + assert processor.sample_rate == 4 assert processor.min_interval_len == 4 assert processor.erode is True @@ -34,46 +35,46 @@ def test_initialization_default_params() -> None: ], ) def test_initialization_custom_params( - sample_rate: int, - min_interval_len: int, - erode: bool + sample_rate: int, + min_interval_len: int, + erode: bool, ) -> None: """Test MEBinPostProcessor initialization with custom parameters.""" processor = MEBinPostProcessor( sample_rate=sample_rate, min_interval_len=min_interval_len, - erode=erode + erode=erode, ) - + assert processor.sample_rate == sample_rate assert processor.min_interval_len == min_interval_len assert processor.erode == erode @staticmethod - @patch('anomalib.post_processing.mebin_post_processor.MEBin') - def test_forward_single_anomaly_map(mock_mebin) -> None: + @patch("anomalib.post_processing.mebin_post_processor.MEBin") + def test_forward_single_anomaly_map(mock_mebin: MagicMock) -> None: """Test forward method with single anomaly map.""" # Setup mock mock_mebin_instance = Mock() mock_mebin_instance.binarize_anomaly_maps.return_value = ( [np.array([[0, 0], [1, 1]], dtype=np.uint8)], - [0.5] + [0.5], ) mock_mebin.return_value = mock_mebin_instance - + # Create test data anomaly_map = torch.rand(1, 1, 4, 4) predictions = InferenceBatch( pred_score=torch.tensor([0.8]), pred_label=torch.tensor([1]), anomaly_map=anomaly_map, - pred_mask=None + pred_mask=None, ) - + # Test forward pass processor = MEBinPostProcessor() result = processor.forward(predictions) - + # Verify results assert isinstance(result, InferenceBatch) assert result.pred_mask is not None @@ -81,67 +82,67 @@ def test_forward_single_anomaly_map(mock_mebin) -> None: assert result.pred_mask.dtype == anomaly_map.dtype @staticmethod - @patch('anomalib.post_processing.mebin_post_processor.MEBin') - def test_forward_batch_anomaly_maps(mock_mebin) -> None: + @patch("anomalib.post_processing.mebin_post_processor.MEBin") + def test_forward_batch_anomaly_maps(mock_mebin: MagicMock) -> None: """Test forward method with batch of anomaly maps.""" # Setup mock mock_mebin_instance = Mock() mock_mebin_instance.binarize_anomaly_maps.return_value = ( [ np.array([[0, 0], [1, 1]], dtype=np.uint8), - np.array([[1, 0], [0, 1]], dtype=np.uint8) + np.array([[1, 0], [0, 1]], dtype=np.uint8), ], - [0.5, 0.6] + [0.5, 0.6], ) mock_mebin.return_value = mock_mebin_instance - + # Create test data anomaly_maps = torch.rand(2, 1, 4, 4) predictions = InferenceBatch( pred_score=torch.tensor([0.8, 0.9]), pred_label=torch.tensor([1, 1]), anomaly_map=anomaly_maps, - pred_mask=None + pred_mask=None, ) - + # Test forward pass processor = MEBinPostProcessor() result = processor.forward(predictions) - + # Verify results assert isinstance(result, InferenceBatch) assert result.pred_mask.shape == (2, 2, 2) @staticmethod - @patch('anomalib.post_processing.mebin_post_processor.MEBin') - def test_forward_normalization(mock_mebin) -> None: + @patch("anomalib.post_processing.mebin_post_processor.MEBin") + def test_forward_normalization(mock_mebin: MagicMock) -> None: """Test that anomaly maps are properly normalized to 0-255 range.""" # Setup mock mock_mebin_instance = Mock() mock_mebin_instance.binarize_anomaly_maps.return_value = ( [np.array([[0, 0], [1, 1]], dtype=np.uint8)], - [0.5] + [0.5], ) mock_mebin.return_value = mock_mebin_instance - + # Create test data with specific range anomaly_maps = torch.tensor([[[[0.0, 0.5], [1.0, 0.2]]]]) predictions = InferenceBatch( pred_score=torch.tensor([0.8]), pred_label=torch.tensor([1]), anomaly_map=anomaly_maps, - pred_mask=None + pred_mask=None, ) - + # Test forward pass processor = MEBinPostProcessor() - result = processor.forward(predictions) - + processor.forward(predictions) + # Verify MEBin was called with normalized data mock_mebin.assert_called_once() call_args = mock_mebin.call_args - anomaly_map_list = call_args[1]['anomaly_map_list'] - + anomaly_map_list = call_args[1]["anomaly_map_list"] + # Check that the data is normalized to 0-255 range assert len(anomaly_map_list) == 1 assert anomaly_map_list[0].dtype == np.uint8 @@ -149,67 +150,83 @@ def test_forward_normalization(mock_mebin) -> None: assert anomaly_map_list[0].max() <= 255 @staticmethod - @patch('anomalib.post_processing.mebin_post_processor.MEBin') - def test_forward_mebin_parameters(mock_mebin) -> None: + @patch("anomalib.post_processing.mebin_post_processor.MEBin") + def test_forward_mebin_parameters(mock_mebin: MagicMock) -> None: """Test that MEBin is called with correct parameters.""" # Setup mock mock_mebin_instance = Mock() mock_mebin_instance.binarize_anomaly_maps.return_value = ( [np.array([[0, 0], [1, 1]], dtype=np.uint8)], - [0.5] + [0.5], ) mock_mebin.return_value = mock_mebin_instance - + # Create test data anomaly_maps = torch.rand(1, 1, 4, 4) predictions = InferenceBatch( pred_score=torch.tensor([0.8]), pred_label=torch.tensor([1]), anomaly_map=anomaly_maps, - pred_mask=None + pred_mask=None, ) - + # Test with custom parameters processor = MEBinPostProcessor( sample_rate=8, min_interval_len=6, - erode=False + erode=False, ) - result = processor.forward(predictions) - + _ = processor.forward(predictions) + # Verify MEBin was called with correct parameters mock_mebin.assert_called_once_with( - anomaly_map_list=mock_mebin.call_args[1]['anomaly_map_list'], + anomaly_map_list=mock_mebin.call_args[1]["anomaly_map_list"], sample_rate=8, min_interval_len=6, - erode=False + erode=False, ) @staticmethod - @patch('anomalib.post_processing.mebin_post_processor.MEBin') - def test_forward_binary_mask_conversion(mock_mebin) -> None: + @patch("anomalib.post_processing.mebin_post_processor.MEBin") + def test_forward_binary_mask_conversion(mock_mebin: MagicMock) -> None: """Test that binary masks are properly converted to 0/1 values.""" # Setup mock to return masks with values > 0 mock_mebin_instance = Mock() mock_mebin_instance.binarize_anomaly_maps.return_value = ( [np.array([[0, 128], [255, 64]], dtype=np.uint8)], - [0.5] + [0.5], ) mock_mebin.return_value = mock_mebin_instance - + # Create test data anomaly_maps = torch.rand(1, 1, 2, 2) predictions = InferenceBatch( pred_score=torch.tensor([0.8]), pred_label=torch.tensor([1]), anomaly_map=anomaly_maps, - pred_mask=None + pred_mask=None, ) - + # Test forward pass processor = MEBinPostProcessor() result = processor.forward(predictions) - + # Verify that all values are either 0 or 1 unique_values = torch.unique(result.pred_mask) assert torch.all((unique_values == 0) | (unique_values == 1)) + + @staticmethod + def test_forward_missing_anomaly_map() -> None: + """Test that ValueError is raised when anomaly_map is None.""" + # Create test data without anomaly_map + predictions = InferenceBatch( + pred_score=torch.tensor([0.8]), + pred_label=torch.tensor([1]), + anomaly_map=None, + pred_mask=None, + ) + + # Test forward pass should raise ValueError + processor = MEBinPostProcessor() + with pytest.raises(ValueError, match="Anomaly map is required for MEBin post-processing"): + processor.forward(predictions) From ab8a5afa6b069e737b5fdfd12b7428d341701506 Mon Sep 17 00:00:00 2001 From: Rajesh Gangireddy Date: Wed, 15 Oct 2025 16:40:33 +0200 Subject: [PATCH 08/17] =?UTF-8?q?=F0=9F=93=A6=20docs(post-processor):=20re?= =?UTF-8?q?vert=20docs=20for=20MEBIN?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../guides/how_to/models/post_processor.md | 46 +------------------ 1 file changed, 1 insertion(+), 45 deletions(-) diff --git a/docs/source/markdown/guides/how_to/models/post_processor.md b/docs/source/markdown/guides/how_to/models/post_processor.md index 162164c358..d2b57aaa22 100644 --- a/docs/source/markdown/guides/how_to/models/post_processor.md +++ b/docs/source/markdown/guides/how_to/models/post_processor.md @@ -83,7 +83,7 @@ Normalization and thresholding only works when your datamodule contains a valida ## Basic Usage -Anomalib provides two main post-processors: the default `PostProcessor` and the `MEBinPostProcessor`. To use the default `PostProcessor`, simply add it to any Anomalib model when creating the model: +To use the `PostProcessor`, simply add it to any Anomalib model when creating the model: ```python from anomalib.models import Padim @@ -93,16 +93,6 @@ post_processor = PostProcessor() model = Padim(post_processor=post_processor) ``` -For alternative thresholding approaches, you can use the `MEBinPostProcessor`: - -```python -from anomalib.models import Padim -from anomalib.post_processing import MEBinPostProcessor - -post_processor = MEBinPostProcessor() -model = Padim(post_processor=post_processor) -``` - The post-processor can be configured using its constructor arguments. In the case of the `PostProcessor`, the only configuration parameters are the sensitivity for the thresholding operation on the image- and pixel-level: ```python @@ -190,40 +180,6 @@ pred_labels = results[..., 2] # Already thresholded (0/1) pred_masks = results[..., 3] # Already thresholded masks (if applicable) ``` -## MEBinPostProcessor - -Anomalib provides the `MEBinPostProcessor` which implements the MEBin (Main Element Binarization) algorithm from AnomalyNCD. This post-processor is designed for industrial anomaly detection scenarios where traditional thresholding may not perform optimally. - -MEBin was introduced in "AnomalyNCD: Towards Novel Anomaly Class Discovery in Industrial Scenarios" ([arXiv:2410.14379](https://arxiv.org/abs/2410.14379)). - -### Basic Usage - -```python -from anomalib.models import Padim -from anomalib.post_processing import MEBinPostProcessor - -# Create MEBin post-processor with custom parameters -post_processor = MEBinPostProcessor( - sample_rate=4, # Threshold sampling step size - min_interval_len=4, # Minimum stable interval length - erode=True # Apply erosion to reduce noise -) - -model = Padim(post_processor=post_processor) -``` - -### Basic Usage - -```python -from anomalib.models import Padim -from anomalib.post_processing import MEBinPostProcessor - -post_processor = MEBinPostProcessor() -model = Padim(post_processor=post_processor) -``` - -For more details on the MEBin algorithm, see the [AnomalyNCD paper](https://arxiv.org/abs/2410.14379). - ## Creating Custom Post-processors Advanced users may want to define their own post-processing pipeline. This can be useful when the default post-processing behaviour of the `PostProcessor` is not suitable for the model and its predictions. To create a custom post-processor, inherit from `nn.Module` and `Callback`, and implement your post-processing logic using lightning hooks. Don't forget to also include the post-processing steps in the `forward` method of your class to ensure that the post-processing is included when exporting your model: From 1b6e8455ac71bc749e57ddbcad57966bfbf219ab Mon Sep 17 00:00:00 2001 From: Rajesh Gangireddy Date: Wed, 15 Oct 2025 16:44:36 +0200 Subject: [PATCH 09/17] Update src/anomalib/post_processing/mebin_post_processor.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Rajesh Gangireddy --- src/anomalib/post_processing/mebin_post_processor.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/anomalib/post_processing/mebin_post_processor.py b/src/anomalib/post_processing/mebin_post_processor.py index f17e783601..e9b13bb43e 100644 --- a/src/anomalib/post_processing/mebin_post_processor.py +++ b/src/anomalib/post_processing/mebin_post_processor.py @@ -70,7 +70,6 @@ def __init__( self.min_interval_len = min_interval_len self.erode = erode - """Custom post-processor using MEBin algorithm""" def forward(self, predictions: InferenceBatch) -> InferenceBatch: """Post-process model predictions using MEBin algorithm. From 46c0ae82e4022870df39f8009aadc666d1506c0c Mon Sep 17 00:00:00 2001 From: Rajesh Gangireddy Date: Wed, 15 Oct 2025 16:44:48 +0200 Subject: [PATCH 10/17] Update src/anomalib/metrics/threshold/mebin.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Rajesh Gangireddy --- src/anomalib/metrics/threshold/mebin.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/anomalib/metrics/threshold/mebin.py b/src/anomalib/metrics/threshold/mebin.py index 7524877ce8..89833d2392 100644 --- a/src/anomalib/metrics/threshold/mebin.py +++ b/src/anomalib/metrics/threshold/mebin.py @@ -171,10 +171,8 @@ def get_threshold( interval_result[anomaly_num_sequence[start]].append((start, end)) current_index += 1 - """ - If a 'stable interval' exists, calculate the final threshold based on the longest stable interval. - If no stable interval is found, it indicates that no anomaly regions exist, and 255 is returned. - """ + # If a 'stable interval' exists, calculate the final threshold based on the longest stable interval. + # If no stable interval is found, it indicates that no anomaly regions exist, and 255 is returned. if interval_result: # Iterate through the stable intervals, calculating their lengths and corresponding From b757b271dabdc97000448191adf2cfee43b951a7 Mon Sep 17 00:00:00 2001 From: Rajesh Gangireddy Date: Wed, 15 Oct 2025 16:46:34 +0200 Subject: [PATCH 11/17] Update src/anomalib/post_processing/mebin_post_processor.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Rajesh Gangireddy --- src/anomalib/post_processing/mebin_post_processor.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/anomalib/post_processing/mebin_post_processor.py b/src/anomalib/post_processing/mebin_post_processor.py index e9b13bb43e..aaffbff365 100644 --- a/src/anomalib/post_processing/mebin_post_processor.py +++ b/src/anomalib/post_processing/mebin_post_processor.py @@ -70,7 +70,6 @@ def __init__( self.min_interval_len = min_interval_len self.erode = erode - def forward(self, predictions: InferenceBatch) -> InferenceBatch: """Post-process model predictions using MEBin algorithm. From 1f29065c7f7ba986ef63c05b896e5d12071ae80c Mon Sep 17 00:00:00 2001 From: Rajesh Gangireddy Date: Wed, 15 Oct 2025 16:50:03 +0200 Subject: [PATCH 12/17] =?UTF-8?q?=F0=9F=90=9B=20fix(mebin):=20remove=20deb?= =?UTF-8?q?ug=20print=20statement=20for=20value=20range?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/anomalib/metrics/threshold/mebin.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/anomalib/metrics/threshold/mebin.py b/src/anomalib/metrics/threshold/mebin.py index 89833d2392..02f9ef43f9 100644 --- a/src/anomalib/metrics/threshold/mebin.py +++ b/src/anomalib/metrics/threshold/mebin.py @@ -2,14 +2,14 @@ # SPDX-License-Identifier: Apache-2.0 """MEBin (Main Element Binarization) adaptive thresholding for anomaly detection. -This module provides the ``MEBin`` class which implements the Main Element -Binarization algorithm designed to address the non-prominence of anomalies -in anomaly maps. MEBin obtains anomaly-centered images by analyzing the +This module provides the ``MEBin`` class which implements the Main Element +Binarization algorithm designed to address the non-prominence of anomalies +in anomaly maps. MEBin obtains anomaly-centered images by analyzing the stability of connected components across multiple threshold levels. The algorithm is particularly effective for: - Industrial anomaly detection scenarios -- Multi-class anomaly classification tasks +- Multi-class anomaly classification tasks - Cases where anomalies are non-prominent in anomaly maps - Avoiding the impact of incorrect detections @@ -20,7 +20,7 @@ 4. Finding stable intervals where component count remains constant 5. Selecting threshold from the longest stable interval -MEBin was introduced in "AnomalyNCD: Towards Novel Anomaly Class Discovery +MEBin was introduced in "AnomalyNCD: Towards Novel Anomaly Class Discovery in Industrial Scenarios" (https://arxiv.org/abs/2410.14379). Example: @@ -34,8 +34,8 @@ >>> print(f"Computed {len(thresholds)} thresholds") Note: - MEBin is designed for industrial scenarios where anomalies may be - non-prominent. The min_interval_len parameter should be tuned based + MEBin is designed for industrial scenarios where anomalies may be + non-prominent. The min_interval_len parameter should be tuned based on the expected stability of connected component counts. """ @@ -49,14 +49,14 @@ class MEBin: """MEBin (Main Element Binarization) adaptive thresholding algorithm. - This class implements the Main Element Binarization algorithm designed + This class implements the Main Element Binarization algorithm designed to address non-prominent anomalies in industrial anomaly detection scenarios. MEBin determines optimal thresholds by analyzing the stability of connected component counts across different threshold levels to obtain anomaly-centered binary representations. The algorithm works by: - - Adaptively determining threshold search ranges from anomaly statistics + - Adaptively determining threshold search ranges from anomaly statistics - Sampling anomaly maps at configurable rates across threshold range - Counting connected components at each threshold level - Identifying stable intervals where component count remains constant From 188275bea480df412483b51e561a58cee0c81d94 Mon Sep 17 00:00:00 2001 From: Rajesh Gangireddy Date: Wed, 15 Oct 2025 16:54:23 +0200 Subject: [PATCH 13/17] =?UTF-8?q?=F0=9F=90=9B=20fix(mebin):=20update=20thr?= =?UTF-8?q?eshold=20return=20types=20to=20float=20and=20remove=20debug=20p?= =?UTF-8?q?rint=20statement?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/anomalib/metrics/threshold/mebin.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/anomalib/metrics/threshold/mebin.py b/src/anomalib/metrics/threshold/mebin.py index 02f9ef43f9..f3224194dd 100644 --- a/src/anomalib/metrics/threshold/mebin.py +++ b/src/anomalib/metrics/threshold/mebin.py @@ -111,8 +111,8 @@ def get_search_range(self) -> tuple[float, float]: on the actual anomaly score distributions in the input maps. Returns: - max_th (int): Maximum threshold for binarization. - min_th (int): Minimum threshold for binarization. + max_th (float): Maximum threshold for binarization. + min_th (float): Minimum threshold for binarization. """ # Get the anomaly scores of all anomaly maps anomaly_score_list = [np.max(x) for x in self.anomaly_map_list] @@ -121,8 +121,6 @@ def get_search_range(self) -> tuple[float, float]: max_score, min_score = max(anomaly_score_list), min(anomaly_score_list) max_th, min_th = max_score, min_score - print(f"Value range: {min_score} - {max_score}") - return max_th, min_th def get_threshold( @@ -151,7 +149,7 @@ def get_threshold( while current_index < len(anomaly_num_sequence): end = current_index - start = end + start = current_index # Find the interval where the connected component count remains constant. sequence_slice = anomaly_num_sequence[start : end + 1] From bf9ecfa2a63d6d443e0874e97bd6b4aa38d91c68 Mon Sep 17 00:00:00 2001 From: Rajesh Gangireddy Date: Wed, 15 Oct 2025 16:59:49 +0200 Subject: [PATCH 14/17] =?UTF-8?q?=F0=9F=90=9B=20fix(mebin):=20update=20ret?= =?UTF-8?q?urn=20types=20in=20MEBin=20methods=20to=20tuples=20for=20better?= =?UTF-8?q?=20clarity?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/anomalib/metrics/threshold/mebin.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/anomalib/metrics/threshold/mebin.py b/src/anomalib/metrics/threshold/mebin.py index f3224194dd..462ddff002 100644 --- a/src/anomalib/metrics/threshold/mebin.py +++ b/src/anomalib/metrics/threshold/mebin.py @@ -141,8 +141,9 @@ def get_threshold( Longer intervals indicate more robust threshold selection. Returns: - threshold (int): The final threshold for binarization. - est_anomaly_num (int): The estimated number of anomalies. + tuple[int, int]: A tuple containing (threshold, est_anomaly_num) where + threshold is the final threshold for binarization and est_anomaly_num + is the estimated number of anomalies. """ interval_result = {} current_index = 0 @@ -235,8 +236,9 @@ def binarize_anomaly_maps(self) -> tuple[list[np.ndarray], list[int]]: and perform binarization on the anomaly maps. Returns: - binarized_maps (list): List of binarized images. - thresholds (list): List of thresholds for each image. + tuple[list[np.ndarray], list[int]]: A tuple containing (binarized_maps, thresholds) + where binarized_maps is a list of binarized images and thresholds is a list + of thresholds for each image. """ self.binarized_maps = [] self.thresholds = [] From 473c6dbca742a34f424ca7eaf1059974428e8c3d Mon Sep 17 00:00:00 2001 From: Rajesh Gangireddy Date: Wed, 15 Oct 2025 17:06:15 +0200 Subject: [PATCH 15/17] =?UTF-8?q?=F0=9F=90=9B=20fix(mebin):=20simplify=20m?= =?UTF-8?q?ax=20calculation=20for=20connected=20component=20counts?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/anomalib/metrics/threshold/mebin.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/anomalib/metrics/threshold/mebin.py b/src/anomalib/metrics/threshold/mebin.py index 462ddff002..94c9c57e2e 100644 --- a/src/anomalib/metrics/threshold/mebin.py +++ b/src/anomalib/metrics/threshold/mebin.py @@ -178,8 +178,8 @@ def get_threshold( # number of connected component. count_result = {} for anomaly_num in interval_result: - count_result[anomaly_num] = max(x[1] - x[0] for x in interval_result[anomaly_num]) - est_anomaly_num = max(count_result, key=lambda x: count_result[x]) + count_result[anomaly_num] = max([x[1] - x[0] for x in interval_result[anomaly_num]]) + est_anomaly_num = max(count_result, key=count_result.get) est_anomaly_num_interval_result = interval_result[est_anomaly_num] # Find the longest stable interval. From b1ea898b05e620bb107cad5aaccb0337f2ce0964 Mon Sep 17 00:00:00 2001 From: Rajesh Gangireddy Date: Thu, 16 Oct 2025 14:10:27 +0200 Subject: [PATCH 16/17] =?UTF-8?q?=F0=9F=90=9B=20fix(mebin):=20update=20exa?= =?UTF-8?q?mples=20in=20MEBin=20and=20MEBinPostProcessor=20for=20clarity?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/anomalib/metrics/threshold/mebin.py | 90 +++++++++++-------- .../post_processing/mebin_post_processor.py | 19 +++- 2 files changed, 68 insertions(+), 41 deletions(-) diff --git a/src/anomalib/metrics/threshold/mebin.py b/src/anomalib/metrics/threshold/mebin.py index 94c9c57e2e..fa13f94709 100644 --- a/src/anomalib/metrics/threshold/mebin.py +++ b/src/anomalib/metrics/threshold/mebin.py @@ -24,14 +24,23 @@ in Industrial Scenarios" (https://arxiv.org/abs/2410.14379). Example: - >>> from anomalib.metrics.threshold import MEBin >>> import numpy as np - >>> # Create sample anomaly maps - >>> anomaly_maps = [np.random.rand(256, 256) * 255 for _ in range(10)] - >>> # Initialize and compute thresholds - >>> mebin = MEBin(anomaly_maps, sample_rate=4) + >>> from anomalib.metrics.threshold import MEBin + >>> + >>> # Create sample anomaly maps with simulated anomalous regions + >>> anomaly_maps = [] + >>> for i in range(5): + ... amap = np.random.rand(128, 128) * 50 # Background noise + ... amap[40:80, 40:80] = np.random.rand(40, 40) * 200 + 55 # Anomalous region + ... anomaly_maps.append(amap) + >>> + >>> # Initialize MEBin with appropriate parameters + >>> mebin = MEBin(anomaly_maps, sample_rate=8, min_interval_len=3) + >>> + >>> # Compute binary masks and thresholds >>> binarized_maps, thresholds = mebin.binarize_anomaly_maps() - >>> print(f"Computed {len(thresholds)} thresholds") + >>> print(f"Processed {len(binarized_maps)} maps, thresholds: {thresholds}") + Processed 5 maps, thresholds: [...] Note: MEBin is designed for industrial scenarios where anomalies may be @@ -77,14 +86,25 @@ class MEBin: Defaults to True. Example: - >>> from anomalib.metrics.threshold import MEBin >>> import numpy as np - >>> # Create sample anomaly maps - >>> anomaly_maps = [np.random.rand(256, 256) * 255 for _ in range(10)] - >>> # Initialize MEBin - >>> mebin = MEBin(anomaly_maps, sample_rate=4) - >>> # Compute binary masks and thresholds - >>> binarized_maps, thresholds = mebin.binarize_anomaly_maps() + >>> from anomalib.metrics.threshold import MEBin + >>> + >>> # Create sample anomaly maps with realistic structure + >>> anomaly_maps = [] + >>> for i in range(3): + ... # Background with low anomaly scores + ... amap = np.random.rand(64, 64) * 30 + ... # Add anomalous regions with higher scores + ... amap[20:40, 20:40] = np.random.rand(20, 20) * 150 + 100 + ... anomaly_maps.append(amap) + >>> + >>> # Initialize MEBin with custom parameters + >>> mebin = MEBin(anomaly_maps, sample_rate=4, min_interval_len=3, erode=True) + >>> + >>> # Binarize anomaly maps + >>> binary_masks, thresholds = mebin.binarize_anomaly_maps() + >>> print(f"Generated {len(binary_masks)} binary masks") + Generated 3 binary masks """ def __init__( @@ -111,8 +131,9 @@ def get_search_range(self) -> tuple[float, float]: on the actual anomaly score distributions in the input maps. Returns: - max_th (float): Maximum threshold for binarization. - min_th (float): Minimum threshold for binarization. + tuple[float, float]: A tuple containing: + - max_th: Maximum threshold for binarization. + - min_th: Minimum threshold for binarization. """ # Get the anomaly scores of all anomaly maps anomaly_score_list = [np.max(x) for x in self.anomaly_map_list] @@ -148,27 +169,22 @@ def get_threshold( interval_result = {} current_index = 0 while current_index < len(anomaly_num_sequence): - end = current_index - start = current_index - - # Find the interval where the connected component count remains constant. - sequence_slice = anomaly_num_sequence[start : end + 1] - if len(set(sequence_slice)) == 1 and anomaly_num_sequence[start] != 0: - # Move the 'end' pointer forward until a different connected component number is encountered. - while ( - end < len(anomaly_num_sequence) - 1 and anomaly_num_sequence[end] == anomaly_num_sequence[end + 1] - ): - end += 1 - current_index += 1 - # If the length of the current stable interval is greater than or equal to the given - # threshold (min_interval_len), record this interval. - if end - start + 1 >= min_interval_len: - if anomaly_num_sequence[start] not in interval_result: - interval_result[anomaly_num_sequence[start]] = [(start, end)] - else: - interval_result[anomaly_num_sequence[start]].append((start, end)) - current_index += 1 + value = anomaly_num_sequence[start] + end = start + # Move the 'end' pointer forward until a different connected component number is encountered. + while ( + end < len(anomaly_num_sequence) - 1 and anomaly_num_sequence[end] == anomaly_num_sequence[end + 1] + ): + end += 1 + # If the length of the current stable interval is greater than or equal to the given + # threshold (min_interval_len), and the value is not zero, record this interval. + if end - start + 1 >= min_interval_len and value != 0: + if value not in interval_result: + interval_result[value] = [(start, end)] + else: + interval_result[value].append((start, end)) + current_index = end + 1 # If a 'stable interval' exists, calculate the final threshold based on the longest stable interval. # If no stable interval is found, it indicates that no anomaly regions exist, and 255 is returned. @@ -178,8 +194,8 @@ def get_threshold( # number of connected component. count_result = {} for anomaly_num in interval_result: - count_result[anomaly_num] = max([x[1] - x[0] for x in interval_result[anomaly_num]]) - est_anomaly_num = max(count_result, key=count_result.get) + count_result[anomaly_num] = max(x[1] - x[0] for x in interval_result[anomaly_num]) + est_anomaly_num = max(count_result, key=lambda k: count_result[k]) est_anomaly_num_interval_result = interval_result[est_anomaly_num] # Find the longest stable interval. diff --git a/src/anomalib/post_processing/mebin_post_processor.py b/src/anomalib/post_processing/mebin_post_processor.py index aaffbff365..bfd695b2a7 100644 --- a/src/anomalib/post_processing/mebin_post_processor.py +++ b/src/anomalib/post_processing/mebin_post_processor.py @@ -16,9 +16,15 @@ - Formatting results for downstream use Example: - >>> from anomalib.post_processing import MEBinPostProcessor - >>> post_processor = MEBinPostProcessor(sample_rate=4, min_interval_len=4) - >>> predictions = post_processor(anomaly_maps=anomaly_maps) + Example: + >>> from anomalib.post_processing import MEBinPostProcessor + >>> from anomalib.data import InferenceBatch + >>> import torch + >>> # Create sample anomaly maps + >>> anomaly_maps = torch.rand(4, 1, 256, 256) + >>> predictions = InferenceBatch(anomaly_map=anomaly_maps) + >>> post_processor = MEBinPostProcessor(sample_rate=4, min_interval_len=4) + >>> results = post_processor(predictions) """ import numpy as np @@ -53,8 +59,13 @@ class MEBinPostProcessor(PostProcessor): Example: >>> from anomalib.post_processing import MEBinPostProcessor + >>> from anomalib.data import InferenceBatch + >>> import torch + >>> # Create sample predictions + >>> anomaly_maps = torch.rand(4, 1, 256, 256) + >>> predictions = InferenceBatch(anomaly_map=anomaly_maps) >>> post_processor = MEBinPostProcessor(sample_rate=4, min_interval_len=4) - >>> predictions = post_processor(anomaly_maps=anomaly_maps) + >>> results = post_processor(predictions) """ def __init__( From d91e68504b96786fc96478842144b7ff405b93a5 Mon Sep 17 00:00:00 2001 From: rajeshgangireddy Date: Tue, 21 Oct 2025 17:10:03 +0200 Subject: [PATCH 17/17] =?UTF-8?q?=F0=9F=90=9B=20fix(mebin):=20refactor=20M?= =?UTF-8?q?EBin=20post-processing=20.=20Works=20now.=20Need=20to=20re-chec?= =?UTF-8?q?k=20if=20implementation=20matches=20author's=20implementation?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../post_processing/mebin_post_processor.py | 114 ++++++++++++++++-- 1 file changed, 106 insertions(+), 8 deletions(-) diff --git a/src/anomalib/post_processing/mebin_post_processor.py b/src/anomalib/post_processing/mebin_post_processor.py index bfd695b2a7..ee6db455ff 100644 --- a/src/anomalib/post_processing/mebin_post_processor.py +++ b/src/anomalib/post_processing/mebin_post_processor.py @@ -29,8 +29,9 @@ import numpy as np import torch +from lightning import LightningModule, Trainer -from anomalib.data import InferenceBatch +from anomalib.data import Batch, InferenceBatch from anomalib.metrics import MEBin from .post_processor import PostProcessor @@ -110,14 +111,12 @@ def forward(self, predictions: InferenceBatch) -> InferenceBatch: if anomaly_maps.ndim == 4: anomaly_maps = anomaly_maps[:, 0, :, :] # Remove channel dimension - # Normalize to 0-255 and convert to uint8 - norm_maps = [] - for amap in anomaly_maps: - amap_norm = (amap - amap.min()) / (amap.max() - amap.min() + 1e-8) * 255 - norm_maps.append(amap_norm.astype(np.uint8)) + # Convert to proper format for MEBin (don't normalize individually) + # MEBin will handle normalization after determining the global min/max range + anomaly_maps_list = [amap.astype(np.float32) for amap in anomaly_maps] mebin = MEBin( - anomaly_map_list=norm_maps, + anomaly_map_list=anomaly_maps_list, sample_rate=self.sample_rate, min_interval_len=self.min_interval_len, erode=self.erode, @@ -128,9 +127,108 @@ def forward(self, predictions: InferenceBatch) -> InferenceBatch: pred_masks = torch.stack([torch.from_numpy(bm).to(original_anomaly_map.device) for bm in binarized_maps]) pred_masks = (pred_masks > 0).to(original_anomaly_map.dtype) - return InferenceBatch( + # Create result with MEBin pred_mask + result = InferenceBatch( pred_label=predictions.pred_label, pred_score=predictions.pred_score, pred_mask=pred_masks, anomaly_map=predictions.anomaly_map, ) + + # Apply parent class post-processing for normalization and thresholding + # This will compute pred_label from pred_score if needed + return super().forward(result) + + def on_test_batch_end( + self, + trainer: Trainer, + pl_module: LightningModule, + outputs: Batch, + *args, + **kwargs, + ) -> None: + """Apply MEBin post-processing to test batch predictions. + + Args: + trainer (Trainer): PyTorch Lightning trainer instance. + pl_module (LightningModule): PyTorch Lightning module instance. + outputs (Batch): Batch containing model predictions. + *args: Variable length argument list. + **kwargs: Arbitrary keyword arguments. + """ + del trainer, pl_module, args, kwargs # Unused arguments + self.post_process_batch(outputs) + + def on_predict_batch_end( + self, + trainer: Trainer, + pl_module: LightningModule, + outputs: Batch, + *args, + **kwargs, + ) -> None: + """Apply MEBin post-processing to prediction batch. + + Args: + trainer (Trainer): PyTorch Lightning trainer instance. + pl_module (LightningModule): PyTorch Lightning module instance. + outputs (Batch): Batch containing model predictions. + *args: Variable length argument list. + **kwargs: Arbitrary keyword arguments. + """ + del trainer, pl_module, args, kwargs # Unused arguments + self.post_process_batch(outputs) + + def post_process_batch(self, batch: Batch) -> None: + """Post-process a batch of predictions using MEBin algorithm. + + This method applies MEBin binarization to anomaly maps in the batch and + updates the pred_mask field with the binarized results. + + Args: + batch (Batch): Batch containing model predictions to be processed. + """ + if batch.anomaly_map is None: + return + + # Store the original tensor for device and dtype info + original_anomaly_map = batch.anomaly_map + anomaly_maps = original_anomaly_map.detach().cpu().numpy() + + # Handle different tensor shapes + if anomaly_maps.ndim == 4: + anomaly_maps = anomaly_maps[:, 0, :, :] # Remove channel dimension if present + elif anomaly_maps.ndim == 3: + # Already in correct format (batch, height, width) + pass + else: + msg = f"Unsupported anomaly map shape: {anomaly_maps.shape}" + raise ValueError(msg) + + # Convert to proper format for MEBin (don't normalize individually) + # MEBin will handle normalization after determining the global min/max range + anomaly_maps_list = [amap.astype(np.float32) for amap in anomaly_maps] + + # Apply MEBin binarization + mebin = MEBin( + anomaly_map_list=anomaly_maps_list, + sample_rate=self.sample_rate, + min_interval_len=self.min_interval_len, + erode=self.erode, + ) + binarized_maps, _ = mebin.binarize_anomaly_maps() + + # Convert back to torch.Tensor and normalize to 0/1 + pred_masks = torch.stack([torch.from_numpy(bm).to(original_anomaly_map.device) for bm in binarized_maps]) + pred_masks = (pred_masks > 0).to(original_anomaly_map.dtype) + + # Add channel dimension if original had one + if original_anomaly_map.ndim == 4: + pred_masks = pred_masks.unsqueeze(1) + + # Update the batch with binarized masks + batch.pred_mask = pred_masks + + # Apply parent class post-processing for normalization and thresholding + # This will compute pred_label from pred_score if needed + super().post_process_batch(batch)