diff --git a/CHANGELOG.md b/CHANGELOG.md index 56738df2c3..933fd31243 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). ## [Unreleased] ### Added +- 🚀 Add MEBin post-processing method + ### Removed diff --git a/src/anomalib/metrics/__init__.py b/src/anomalib/metrics/__init__.py index f38d2ffde1..819d54e6cd 100644 --- a/src/anomalib/metrics/__init__.py +++ b/src/anomalib/metrics/__init__.py @@ -62,7 +62,7 @@ from .pimo import AUPIMO, PIMO from .precision_recall_curve import BinaryPrecisionRecallCurve from .pro import PRO -from .threshold import F1AdaptiveThreshold, ManualThreshold +from .threshold import F1AdaptiveThreshold, ManualThreshold ,MEBin __all__ = [ "AUROC", @@ -83,4 +83,5 @@ "PRO", "PIMO", "AUPIMO", + "MEBin", ] diff --git a/src/anomalib/metrics/threshold/__init__.py b/src/anomalib/metrics/threshold/__init__.py index 1bf10854ec..0cebd87b82 100644 --- a/src/anomalib/metrics/threshold/__init__.py +++ b/src/anomalib/metrics/threshold/__init__.py @@ -24,5 +24,6 @@ from .base import BaseThreshold, Threshold from .f1_adaptive_threshold import F1AdaptiveThreshold from .manual_threshold import ManualThreshold +from .mebin import MEBin -__all__ = ["BaseThreshold", "Threshold", "F1AdaptiveThreshold", "ManualThreshold"] +__all__ = ["BaseThreshold", "Threshold", "F1AdaptiveThreshold", "ManualThreshold", "MEBin"] diff --git a/src/anomalib/metrics/threshold/mebin.py b/src/anomalib/metrics/threshold/mebin.py new file mode 100644 index 0000000000..c823ea3b5f --- /dev/null +++ b/src/anomalib/metrics/threshold/mebin.py @@ -0,0 +1,250 @@ +"""MEBin adaptive thresholding algorithm for anomaly detection. + +This module provides the ``MEBin`` class which automatically finds the optimal +threshold value by analyzing the stability of connected components across +multiple threshold levels. + +The threshold is computed by: +1. Sampling anomaly maps at configurable rates across threshold range +2. Counting connected components at each threshold level +3. Finding stable intervals where component count remains constant +4. Selecting threshold from the longest stable interval + +Example: + >>> from anomalib.metrics.threshold import MEBin + >>> import numpy as np + >>> # Create sample anomaly maps + >>> anomaly_maps = [np.random.rand(256, 256) * 255 for _ in range(10)] + >>> # Initialize and compute thresholds + >>> mebin = MEBin(anomaly_map_list=anomaly_maps, sample_rate=4) + >>> binarized_maps, thresholds = mebin.binarize_anomaly_maps() + >>> print(f"Computed {len(thresholds)} thresholds") + +Note: + The algorithm works best when anomaly maps contain clear separation between + normal and anomalous regions. The min_interval_len parameter should be tuned + based on the expected stability of anomaly score distributions. +""" + +import cv2 +import numpy as np +from tqdm import tqdm + + + +class MEBin: + """MEBin adaptive thresholding algorithm for anomaly detection. + + This class implements the MEBin (Minimum Entropy Binarization) algorithm + which automatically determines optimal thresholds for converting continuous + anomaly maps to binary masks by analyzing the stability of connected + component counts across different threshold levels. + + The algorithm works by: + - Sampling anomaly maps at configurable rates across threshold range + - Counting connected components at each threshold level + - Identifying stable intervals where component count remains constant + - Selecting the optimal threshold from the longest stable interval + - Optionally applying morphological erosion to reduce noise + + Args: + anomaly_map_path_list (list, optional): List of file paths to anomaly maps. + If provided, maps will be loaded as grayscale images. + Defaults to None. + anomaly_map_list (list, optional): List of anomaly map arrays. If provided, + maps should be numpy arrays. + Defaults to None. + sample_rate (int, optional): Sampling rate for threshold search. Higher + values reduce processing time but may affect accuracy. + Defaults to 4. + min_interval_len (int, optional): Minimum length of stable intervals. + Should be tuned based on the expected stability of anomaly score + distributions. + Defaults to 4. + erode (bool, optional): Whether to apply morphological erosion to + binarized results to reduce noise. + Defaults to True. + + Example: + >>> from anomalib.metrics.threshold import MEBin + >>> import numpy as np + >>> # Create sample anomaly maps + >>> anomaly_maps = [np.random.rand(256, 256) * 255 for _ in range(10)] + >>> # Initialize MEBin + >>> mebin = MEBin(anomaly_map_list=anomaly_maps, sample_rate=4) + >>> # Compute binary masks and thresholds + >>> binarized_maps, thresholds = mebin.binarize_anomaly_maps() + """ + + def __init__(self, anomaly_map_path_list=None, sample_rate=4, min_interval_len=4, erode=True): + + self.anomaly_map_path_list = anomaly_map_path_list + # Load anomaly maps as grayscale images if paths are provided + self.anomaly_map_list = [cv2.imread(x, cv2.IMREAD_GRAYSCALE) for x in self.anomaly_map_path_list] + + self.sample_rate = sample_rate + self.min_interval_len = min_interval_len + self.erode = erode + + # Adaptively determine the threshold search range + self.max_th, self.min_th = self.get_search_range() + + + def get_search_range(self): + """Determine the threshold search range adaptively. + + This method analyzes all anomaly maps to determine the minimum and maximum + threshold values for the binarization process. The search range is based + on the actual anomaly score distributions in the input maps. + + Returns: + max_th (int): Maximum threshold for binarization. + min_th (int): Minimum threshold for binarization. + """ + # Get the anomaly scores of all anomaly maps + anomaly_score_list = [np.max(x) for x in self.anomaly_map_list] + + # Select the maximum and minimum anomaly scores from images + max_score, min_score = max(anomaly_score_list), min(anomaly_score_list) + max_th, min_th = max_score, min_score + + print(f"Value range: {min_score} - {max_score}") + + return max_th, min_th + + + + def get_threshold(self, anomaly_num_sequence, min_interval_len): + """ + Find the 'stable interval' in the anomaly region number sequence. + Stable Interval: A continuous threshold range in which the number of connected components remains constant, + and the length of the threshold range is greater than or equal to the given length threshold (min_interval_len). + + Args: + anomaly_num_sequence (list): Sequence of connected component counts + at each threshold level, ordered from high to low threshold. + min_interval_len (int): Minimum length requirement for stable intervals. + Longer intervals indicate more robust threshold selection. + + Returns: + threshold (int): The final threshold for binarization. + est_anomaly_num (int): The estimated number of anomalies. + """ + interval_result = {} + current_index = 0 + while current_index < len(anomaly_num_sequence): + end = current_index + + start = end + + # Find the interval where the connected component count remains constant. + if len(set(anomaly_num_sequence[start:end+1])) == 1 and anomaly_num_sequence[start] != 0: + # Move the 'end' pointer forward until a different connected component number is encountered. + while end < len(anomaly_num_sequence)-1 and anomaly_num_sequence[end] == anomaly_num_sequence[end+1]: + end += 1 + current_index += 1 + # If the length of the current stable interval is greater than or equal to the given threshold (min_interval_len), record this interval. + if end - start + 1 >= min_interval_len: + if anomaly_num_sequence[start] not in interval_result: + interval_result[anomaly_num_sequence[start]] = [(start, end)] + else: + interval_result[anomaly_num_sequence[start]].append((start, end)) + current_index += 1 + + """ + If a 'stable interval' exists, calculate the final threshold based on the longest stable interval. + If no stable interval is found, it indicates that no anomaly regions exist, and 255 is returned. + """ + + if interval_result: + # Iterate through the stable intervals, calculating their lengths and corresponding number of connected component. + count_result = {} + for anomaly_num in interval_result: + count_result[anomaly_num] = max([x[1] - x[0] for x in interval_result[anomaly_num]]) + est_anomaly_num = max(count_result, key=count_result.get) + est_anomaly_num_interval_result = interval_result[est_anomaly_num] + + # Find the longest stable interval. + longest_interval = sorted(est_anomaly_num_interval_result, key=lambda x: x[1] - x[0])[-1] + + # Use the endpoint threshold of the longest stable interval as the final threshold. + index = longest_interval[1] + threshold = 255 - index * self.sample_rate + threshold = int(threshold*(self.max_th - self.min_th)/255 + self.min_th) + return threshold, est_anomaly_num + else: + return 255, 0 + + + def bin_and_erode(self, anomaly_map, threshold): + """Binarize anomaly map and optionally apply erosion. + + This method converts a continuous anomaly map to a binary mask using + the specified threshold, and optionally applies morphological erosion + to reduce noise and smooth the boundaries of anomaly regions. + + The binarization process: + 1. Pixels above threshold become 255 (anomalous) + 2. Pixels below threshold become 0 (normal) + 3. Optional erosion with 6x6 kernel to reduce noise + + Args: + anomaly_map (numpy.ndarray): Input anomaly map with continuous + anomaly scores to be binarized. + threshold (int): Threshold value for binarization. Pixels with + values above this threshold are considered anomalous. + + Returns: + numpy.ndarray: Binary mask where 255 indicates anomalous regions + and 0 indicates normal regions. The result is of type uint8. + + Note: + Erosion is applied with a 6x6 kernel and 1 iteration to balance + noise reduction with preservation of anomaly boundaries. + """ + bin_result = np.where(anomaly_map > threshold, 255, 0).astype(np.uint8) + + # Apply erosion operation to the binarized result + if self.erode: + kernel_size = 6 + iter_num = 1 + kernel = np.ones((kernel_size, kernel_size), np.uint8) + bin_result = cv2.erode(bin_result, kernel, iterations=iter_num) + return bin_result + + + def binarize_anomaly_maps(self): + """ + Perform binarization within the given threshold search range, + count the number of connected components in the binarized results. + Adaptively determine the threshold according to the count, + and perform binarization on the anomaly maps. + + Returns: + binarized_maps (list): List of binarized images. + thresholds (list): List of thresholds for each image. + """ + self.binarized_maps = [] + self.thresholds = [] + + for i, anomaly_map in enumerate(tqdm(self.anomaly_map_list)): + # Normalize the anomaly map within the given threshold search range. + anomaly_map_norm = np.where(anomaly_map < self.min_th, 0, ((anomaly_map - self.min_th) / (self.max_th - self.min_th)) * 255) + anomaly_num_sequence = [] + + # Search for the threshold from high to low within the given range using the specified sampling rate. + for score in range(255, 0, -self.sample_rate): + bin_result = self.bin_and_erode(anomaly_map_norm, score) + num_labels, *rest = cv2.connectedComponentsWithStats(bin_result, connectivity=8) + anomaly_num = num_labels - 1 + anomaly_num_sequence.append(anomaly_num) + + # Adaptively determine the threshold based on the anomaly connected component count sequence. + threshold, est_anomaly_num = self.get_threshold(anomaly_num_sequence, self.min_interval_len) + + # Binarize the anomaly image based on the determined threshold. + bin_result = self.bin_and_erode(anomaly_map, threshold) + self.binarized_maps.append(bin_result) + self.thresholds.append(threshold) + + return self.binarized_maps, self.thresholds \ No newline at end of file diff --git a/src/anomalib/post_processing/__init__.py b/src/anomalib/post_processing/__init__.py index 5ca7bf598b..7e57c6d82d 100644 --- a/src/anomalib/post_processing/__init__.py +++ b/src/anomalib/post_processing/__init__.py @@ -20,5 +20,6 @@ """ from .post_processor import PostProcessor +from .mebin_post_processor import MEBinPostProcessor -__all__ = ["PostProcessor"] +__all__ = ["PostProcessor", "MEBinPostProcessor"] diff --git a/src/anomalib/post_processing/mebin_post_processor.py b/src/anomalib/post_processing/mebin_post_processor.py new file mode 100644 index 0000000000..8b71fd68c4 --- /dev/null +++ b/src/anomalib/post_processing/mebin_post_processor.py @@ -0,0 +1,127 @@ +"""Post-processing module for MEBin-based anomaly detection results. + +This module provides post-processing functionality for anomaly detection +outputs through the :class:`MEBinPostProcessor` class. + +The MEBin post-processor handles: + - Converting anomaly maps to binary masks using MEBin algorithm + - Sampling anomaly maps at configurable rates for efficient processing + - Applying morphological operations (erosion) to refine binary masks + - Maintaining minimum interval lengths for consistent mask generation + - Formatting results for downstream use + +Example: + >>> from anomalib.post_processing import MEBinPostProcessor + >>> post_processor = MEBinPostProcessor(sample_rate=4, min_interval_len=4) + >>> predictions = post_processor(anomaly_maps=anomaly_maps) +""" + +import argparse +import yaml +import os +import json +import shutil +import cv2 +from tqdm import tqdm +import csv +import sys +sys.path.append(os.getcwd()) + +from anomalib.post_processing import PostProcessor +from anomalib.data import InferenceBatch +import torch +import numpy as np + +from anomalib.metrics import MEBin + + +class MEBinPostProcessor(PostProcessor): + """Post-processor for MEBin-based anomaly detection. + + This class handles post-processing of anomaly detection results by: + - Converting continuous anomaly maps to binary masks using MEBin algorithm + - Sampling anomaly maps at configurable rates for efficient processing + - Applying morphological operations (erosion) to refine binary masks + - Maintaining minimum interval lengths for consistent mask generation + - Formatting results for downstream use + + Args: + sample_rate (int, optional): Threshold sampling step size, + Default to 4 + min_interval_len (int, optional): Minimum length of the stable interval, + can be fine-tuned according to the interval between normal and abnormal score distributions in the anomaly score maps, + decrease if there are many false negatives, increase if there are many false positives. + Default to 4 + erode (bool, optional): Whether to perform erosion after binarization to eliminate noise, + this operation can smooth the change process of the number of abnormal + connected components. + Default to True + **kwargs: Additional keyword arguments passed to parent class. + + Example: + >>> from anomalib.post_processing import MEBinPostProcessor + >>> post_processor = MEBinPostProcessor(sample_rate=4, min_interval_len=4) + >>> predictions = post_processor(anomaly_maps=anomaly_maps) + """ + + def __init__( + self, + sample_rate: int = 4, + min_interval_len: int = 4, + erode: bool = True, + **kwargs, + ) -> None: + super().__init__(**kwargs) + + self.sample_rate = sample_rate + self.min_interval_len = min_interval_len + self.erode = erode + + """Custom post-processor using MEBin algorithm""" + def forward(self, predictions: InferenceBatch) -> InferenceBatch: + """Post-process model predictions using MEBin algorithm. + + This method converts continuous anomaly maps to binary masks using the MEBin + algorithm, which provides efficient and accurate binarization of anomaly + detection results. + + Args: + predictions (InferenceBatch): Batch containing model predictions with + anomaly maps to be processed. + + Returns: + InferenceBatch: Post-processed batch with binary masks generated from + anomaly maps using MEBin algorithm. + + Note: + The method automatically handles tensor-to-numpy conversion and back, + ensuring compatibility with the original tensor device and dtype. + """ + anomaly_maps = predictions.anomaly_map + if isinstance(anomaly_maps, torch.Tensor): + anomaly_maps = anomaly_maps.detach().cpu().numpy() + if anomaly_maps.ndim == 4: + anomaly_maps = anomaly_maps[:, 0, :, :] # Remove channel dimension + + # Normalize to 0-255 and convert to uint8 + norm_maps = [] + for amap in anomaly_maps: + amap_norm = (amap - amap.min()) / (amap.max() - amap.min() + 1e-8) * 255 + norm_maps.append(amap_norm.astype(np.uint8)) + + + mebin = MEBin(anomaly_map_list=norm_maps, sample_rate=self.sample_rate, min_interval_len=self.min_interval_len, erode=self.erode) + binarized_maps, thresholds = mebin.binarize_anomaly_maps() + + # Convert back to torch.Tensor and normalize to 0/1 + pred_masks = torch.stack([torch.from_numpy(bm).to(predictions.anomaly_map.device) for bm in binarized_maps]) + pred_masks = (pred_masks > 0).to(predictions.anomaly_map.dtype) + + return InferenceBatch( + pred_label=predictions.pred_label, + pred_score=predictions.pred_score, + pred_mask=pred_masks, + anomaly_map=predictions.anomaly_map, + ) + + \ No newline at end of file diff --git a/tests/unit/post_processing/test_mebin_post_processor.py b/tests/unit/post_processing/test_mebin_post_processor.py new file mode 100644 index 0000000000..25925d40cd --- /dev/null +++ b/tests/unit/post_processing/test_mebin_post_processor.py @@ -0,0 +1,215 @@ +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""Test the MEBinPostProcessor class.""" + +import pytest +import torch +import numpy as np +from unittest.mock import Mock, patch + +from anomalib.data import InferenceBatch +from anomalib.post_processing import MEBinPostProcessor + + +class TestMEBinPostProcessor: + """Test the MEBinPostProcessor class.""" + + @staticmethod + def test_initialization_default_params() -> None: + """Test MEBinPostProcessor initialization with default parameters.""" + processor = MEBinPostProcessor() + + assert processor.sample_rate == 4 + assert processor.min_interval_len == 4 + assert processor.erode is True + + @staticmethod + @pytest.mark.parametrize( + ("sample_rate", "min_interval_len", "erode"), + [ + (2, 3, True), + (8, 6, False), + (1, 1, True), + ], + ) + def test_initialization_custom_params( + sample_rate: int, + min_interval_len: int, + erode: bool + ) -> None: + """Test MEBinPostProcessor initialization with custom parameters.""" + processor = MEBinPostProcessor( + sample_rate=sample_rate, + min_interval_len=min_interval_len, + erode=erode + ) + + assert processor.sample_rate == sample_rate + assert processor.min_interval_len == min_interval_len + assert processor.erode == erode + + @staticmethod + @patch('anomalib.post_processing.mebin_post_processor.MEBin') + def test_forward_single_anomaly_map(mock_mebin) -> None: + """Test forward method with single anomaly map.""" + # Setup mock + mock_mebin_instance = Mock() + mock_mebin_instance.binarize_anomaly_maps.return_value = ( + [np.array([[0, 0], [1, 1]], dtype=np.uint8)], + [0.5] + ) + mock_mebin.return_value = mock_mebin_instance + + # Create test data + anomaly_map = torch.rand(1, 1, 4, 4) + predictions = InferenceBatch( + pred_score=torch.tensor([0.8]), + pred_label=torch.tensor([1]), + anomaly_map=anomaly_map, + pred_mask=None + ) + + # Test forward pass + processor = MEBinPostProcessor() + result = processor.forward(predictions) + + # Verify results + assert isinstance(result, InferenceBatch) + assert result.pred_mask is not None + assert result.pred_mask.shape == (1, 2, 2) + assert result.pred_mask.dtype == anomaly_map.dtype + + @staticmethod + @patch('anomalib.post_processing.mebin_post_processor.MEBin') + def test_forward_batch_anomaly_maps(mock_mebin) -> None: + """Test forward method with batch of anomaly maps.""" + # Setup mock + mock_mebin_instance = Mock() + mock_mebin_instance.binarize_anomaly_maps.return_value = ( + [ + np.array([[0, 0], [1, 1]], dtype=np.uint8), + np.array([[1, 0], [0, 1]], dtype=np.uint8) + ], + [0.5, 0.6] + ) + mock_mebin.return_value = mock_mebin_instance + + # Create test data + anomaly_maps = torch.rand(2, 1, 4, 4) + predictions = InferenceBatch( + pred_score=torch.tensor([0.8, 0.9]), + pred_label=torch.tensor([1, 1]), + anomaly_map=anomaly_maps, + pred_mask=None + ) + + # Test forward pass + processor = MEBinPostProcessor() + result = processor.forward(predictions) + + # Verify results + assert isinstance(result, InferenceBatch) + assert result.pred_mask.shape == (2, 2, 2) + + @staticmethod + @patch('anomalib.post_processing.mebin_post_processor.MEBin') + def test_forward_normalization(mock_mebin) -> None: + """Test that anomaly maps are properly normalized to 0-255 range.""" + # Setup mock + mock_mebin_instance = Mock() + mock_mebin_instance.binarize_anomaly_maps.return_value = ( + [np.array([[0, 0], [1, 1]], dtype=np.uint8)], + [0.5] + ) + mock_mebin.return_value = mock_mebin_instance + + # Create test data with specific range + anomaly_maps = torch.tensor([[[[0.0, 0.5], [1.0, 0.2]]]]) + predictions = InferenceBatch( + pred_score=torch.tensor([0.8]), + pred_label=torch.tensor([1]), + anomaly_map=anomaly_maps, + pred_mask=None + ) + + # Test forward pass + processor = MEBinPostProcessor() + result = processor.forward(predictions) + + # Verify MEBin was called with normalized data + mock_mebin.assert_called_once() + call_args = mock_mebin.call_args + anomaly_map_list = call_args[1]['anomaly_map_list'] + + # Check that the data is normalized to 0-255 range + assert len(anomaly_map_list) == 1 + assert anomaly_map_list[0].dtype == np.uint8 + assert anomaly_map_list[0].min() >= 0 + assert anomaly_map_list[0].max() <= 255 + + @staticmethod + @patch('anomalib.post_processing.mebin_post_processor.MEBin') + def test_forward_mebin_parameters(mock_mebin) -> None: + """Test that MEBin is called with correct parameters.""" + # Setup mock + mock_mebin_instance = Mock() + mock_mebin_instance.binarize_anomaly_maps.return_value = ( + [np.array([[0, 0], [1, 1]], dtype=np.uint8)], + [0.5] + ) + mock_mebin.return_value = mock_mebin_instance + + # Create test data + anomaly_maps = torch.rand(1, 1, 4, 4) + predictions = InferenceBatch( + pred_score=torch.tensor([0.8]), + pred_label=torch.tensor([1]), + anomaly_map=anomaly_maps, + pred_mask=None + ) + + # Test with custom parameters + processor = MEBinPostProcessor( + sample_rate=8, + min_interval_len=6, + erode=False + ) + result = processor.forward(predictions) + + # Verify MEBin was called with correct parameters + mock_mebin.assert_called_once_with( + anomaly_map_list=mock_mebin.call_args[1]['anomaly_map_list'], + sample_rate=8, + min_interval_len=6, + erode=False + ) + + @staticmethod + @patch('anomalib.post_processing.mebin_post_processor.MEBin') + def test_forward_binary_mask_conversion(mock_mebin) -> None: + """Test that binary masks are properly converted to 0/1 values.""" + # Setup mock to return masks with values > 0 + mock_mebin_instance = Mock() + mock_mebin_instance.binarize_anomaly_maps.return_value = ( + [np.array([[0, 128], [255, 64]], dtype=np.uint8)], + [0.5] + ) + mock_mebin.return_value = mock_mebin_instance + + # Create test data + anomaly_maps = torch.rand(1, 1, 2, 2) + predictions = InferenceBatch( + pred_score=torch.tensor([0.8]), + pred_label=torch.tensor([1]), + anomaly_map=anomaly_maps, + pred_mask=None + ) + + # Test forward pass + processor = MEBinPostProcessor() + result = processor.forward(predictions) + + # Verify that all values are either 0 or 1 + unique_values = torch.unique(result.pred_mask) + assert torch.all((unique_values == 0) | (unique_values == 1))