Skip to content

Commit acc97fb

Browse files
committed
add mebin postprocessor
1 parent c43e552 commit acc97fb

File tree

6 files changed

+385
-3
lines changed

6 files changed

+385
-3
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
77
## [Unreleased]
88

99
### Added
10+
- 🚀 Add MEBin post-processing method
11+
1012

1113
### Removed
1214

src/anomalib/metrics/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@
5959
from .pimo import AUPIMO, PIMO
6060
from .precision_recall_curve import BinaryPrecisionRecallCurve
6161
from .pro import PRO
62-
from .threshold import F1AdaptiveThreshold, ManualThreshold
62+
from .threshold import F1AdaptiveThreshold, ManualThreshold ,MEBin
6363

6464
__all__ = [
6565
"AUROC",
@@ -78,4 +78,5 @@
7878
"PRO",
7979
"PIMO",
8080
"AUPIMO",
81+
"MEBin",
8182
]

src/anomalib/metrics/threshold/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,5 +24,6 @@
2424
from .base import BaseThreshold, Threshold
2525
from .f1_adaptive_threshold import F1AdaptiveThreshold
2626
from .manual_threshold import ManualThreshold
27+
from .mebin import MEBin
2728

28-
__all__ = ["BaseThreshold", "Threshold", "F1AdaptiveThreshold", "ManualThreshold"]
29+
__all__ = ["BaseThreshold", "Threshold", "F1AdaptiveThreshold", "ManualThreshold", "MEBin"]
Lines changed: 250 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,250 @@
1+
"""MEBin adaptive thresholding algorithm for anomaly detection.
2+
3+
This module provides the ``MEBin`` class which automatically finds the optimal
4+
threshold value by analyzing the stability of connected components across
5+
multiple threshold levels.
6+
7+
The threshold is computed by:
8+
1. Sampling anomaly maps at configurable rates across threshold range
9+
2. Counting connected components at each threshold level
10+
3. Finding stable intervals where component count remains constant
11+
4. Selecting threshold from the longest stable interval
12+
13+
Example:
14+
>>> from anomalib.metrics.threshold import MEBin
15+
>>> import numpy as np
16+
>>> # Create sample anomaly maps
17+
>>> anomaly_maps = [np.random.rand(256, 256) * 255 for _ in range(10)]
18+
>>> # Initialize and compute thresholds
19+
>>> mebin = MEBin(anomaly_map_list=anomaly_maps, sample_rate=4)
20+
>>> binarized_maps, thresholds = mebin.binarize_anomaly_maps()
21+
>>> print(f"Computed {len(thresholds)} thresholds")
22+
23+
Note:
24+
The algorithm works best when anomaly maps contain clear separation between
25+
normal and anomalous regions. The min_interval_len parameter should be tuned
26+
based on the expected stability of anomaly score distributions.
27+
"""
28+
29+
import cv2
30+
import numpy as np
31+
from tqdm import tqdm
32+
33+
34+
35+
class MEBin:
36+
"""MEBin adaptive thresholding algorithm for anomaly detection.
37+
38+
This class implements the MEBin (Minimum Entropy Binarization) algorithm
39+
which automatically determines optimal thresholds for converting continuous
40+
anomaly maps to binary masks by analyzing the stability of connected
41+
component counts across different threshold levels.
42+
43+
The algorithm works by:
44+
- Sampling anomaly maps at configurable rates across threshold range
45+
- Counting connected components at each threshold level
46+
- Identifying stable intervals where component count remains constant
47+
- Selecting the optimal threshold from the longest stable interval
48+
- Optionally applying morphological erosion to reduce noise
49+
50+
Args:
51+
anomaly_map_path_list (list, optional): List of file paths to anomaly maps.
52+
If provided, maps will be loaded as grayscale images.
53+
Defaults to None.
54+
anomaly_map_list (list, optional): List of anomaly map arrays. If provided,
55+
maps should be numpy arrays.
56+
Defaults to None.
57+
sample_rate (int, optional): Sampling rate for threshold search. Higher
58+
values reduce processing time but may affect accuracy.
59+
Defaults to 4.
60+
min_interval_len (int, optional): Minimum length of stable intervals.
61+
Should be tuned based on the expected stability of anomaly score
62+
distributions.
63+
Defaults to 4.
64+
erode (bool, optional): Whether to apply morphological erosion to
65+
binarized results to reduce noise.
66+
Defaults to True.
67+
68+
Example:
69+
>>> from anomalib.metrics.threshold import MEBin
70+
>>> import numpy as np
71+
>>> # Create sample anomaly maps
72+
>>> anomaly_maps = [np.random.rand(256, 256) * 255 for _ in range(10)]
73+
>>> # Initialize MEBin
74+
>>> mebin = MEBin(anomaly_map_list=anomaly_maps, sample_rate=4)
75+
>>> # Compute binary masks and thresholds
76+
>>> binarized_maps, thresholds = mebin.binarize_anomaly_maps()
77+
"""
78+
79+
def __init__(self, anomaly_map_path_list=None, sample_rate=4, min_interval_len=4, erode=True):
80+
81+
self.anomaly_map_path_list = anomaly_map_path_list
82+
# Load anomaly maps as grayscale images if paths are provided
83+
self.anomaly_map_list = [cv2.imread(x, cv2.IMREAD_GRAYSCALE) for x in self.anomaly_map_path_list]
84+
85+
self.sample_rate = sample_rate
86+
self.min_interval_len = min_interval_len
87+
self.erode = erode
88+
89+
# Adaptively determine the threshold search range
90+
self.max_th, self.min_th = self.get_search_range()
91+
92+
93+
def get_search_range(self):
94+
"""Determine the threshold search range adaptively.
95+
96+
This method analyzes all anomaly maps to determine the minimum and maximum
97+
threshold values for the binarization process. The search range is based
98+
on the actual anomaly score distributions in the input maps.
99+
100+
Returns:
101+
max_th (int): Maximum threshold for binarization.
102+
min_th (int): Minimum threshold for binarization.
103+
"""
104+
# Get the anomaly scores of all anomaly maps
105+
anomaly_score_list = [np.max(x) for x in self.anomaly_map_list]
106+
107+
# Select the maximum and minimum anomaly scores from images
108+
max_score, min_score = max(anomaly_score_list), min(anomaly_score_list)
109+
max_th, min_th = max_score, min_score
110+
111+
print(f"Value range: {min_score} - {max_score}")
112+
113+
return max_th, min_th
114+
115+
116+
117+
def get_threshold(self, anomaly_num_sequence, min_interval_len):
118+
"""
119+
Find the 'stable interval' in the anomaly region number sequence.
120+
Stable Interval: A continuous threshold range in which the number of connected components remains constant,
121+
and the length of the threshold range is greater than or equal to the given length threshold (min_interval_len).
122+
123+
Args:
124+
anomaly_num_sequence (list): Sequence of connected component counts
125+
at each threshold level, ordered from high to low threshold.
126+
min_interval_len (int): Minimum length requirement for stable intervals.
127+
Longer intervals indicate more robust threshold selection.
128+
129+
Returns:
130+
threshold (int): The final threshold for binarization.
131+
est_anomaly_num (int): The estimated number of anomalies.
132+
"""
133+
interval_result = {}
134+
current_index = 0
135+
while current_index < len(anomaly_num_sequence):
136+
end = current_index
137+
138+
start = end
139+
140+
# Find the interval where the connected component count remains constant.
141+
if len(set(anomaly_num_sequence[start:end+1])) == 1 and anomaly_num_sequence[start] != 0:
142+
# Move the 'end' pointer forward until a different connected component number is encountered.
143+
while end < len(anomaly_num_sequence)-1 and anomaly_num_sequence[end] == anomaly_num_sequence[end+1]:
144+
end += 1
145+
current_index += 1
146+
# If the length of the current stable interval is greater than or equal to the given threshold (min_interval_len), record this interval.
147+
if end - start + 1 >= min_interval_len:
148+
if anomaly_num_sequence[start] not in interval_result:
149+
interval_result[anomaly_num_sequence[start]] = [(start, end)]
150+
else:
151+
interval_result[anomaly_num_sequence[start]].append((start, end))
152+
current_index += 1
153+
154+
"""
155+
If a 'stable interval' exists, calculate the final threshold based on the longest stable interval.
156+
If no stable interval is found, it indicates that no anomaly regions exist, and 255 is returned.
157+
"""
158+
159+
if interval_result:
160+
# Iterate through the stable intervals, calculating their lengths and corresponding number of connected component.
161+
count_result = {}
162+
for anomaly_num in interval_result:
163+
count_result[anomaly_num] = max([x[1] - x[0] for x in interval_result[anomaly_num]])
164+
est_anomaly_num = max(count_result, key=count_result.get)
165+
est_anomaly_num_interval_result = interval_result[est_anomaly_num]
166+
167+
# Find the longest stable interval.
168+
longest_interval = sorted(est_anomaly_num_interval_result, key=lambda x: x[1] - x[0])[-1]
169+
170+
# Use the endpoint threshold of the longest stable interval as the final threshold.
171+
index = longest_interval[1]
172+
threshold = 255 - index * self.sample_rate
173+
threshold = int(threshold*(self.max_th - self.min_th)/255 + self.min_th)
174+
return threshold, est_anomaly_num
175+
else:
176+
return 255, 0
177+
178+
179+
def bin_and_erode(self, anomaly_map, threshold):
180+
"""Binarize anomaly map and optionally apply erosion.
181+
182+
This method converts a continuous anomaly map to a binary mask using
183+
the specified threshold, and optionally applies morphological erosion
184+
to reduce noise and smooth the boundaries of anomaly regions.
185+
186+
The binarization process:
187+
1. Pixels above threshold become 255 (anomalous)
188+
2. Pixels below threshold become 0 (normal)
189+
3. Optional erosion with 6x6 kernel to reduce noise
190+
191+
Args:
192+
anomaly_map (numpy.ndarray): Input anomaly map with continuous
193+
anomaly scores to be binarized.
194+
threshold (int): Threshold value for binarization. Pixels with
195+
values above this threshold are considered anomalous.
196+
197+
Returns:
198+
numpy.ndarray: Binary mask where 255 indicates anomalous regions
199+
and 0 indicates normal regions. The result is of type uint8.
200+
201+
Note:
202+
Erosion is applied with a 6x6 kernel and 1 iteration to balance
203+
noise reduction with preservation of anomaly boundaries.
204+
"""
205+
bin_result = np.where(anomaly_map > threshold, 255, 0).astype(np.uint8)
206+
207+
# Apply erosion operation to the binarized result
208+
if self.erode:
209+
kernel_size = 6
210+
iter_num = 1
211+
kernel = np.ones((kernel_size, kernel_size), np.uint8)
212+
bin_result = cv2.erode(bin_result, kernel, iterations=iter_num)
213+
return bin_result
214+
215+
216+
def binarize_anomaly_maps(self):
217+
"""
218+
Perform binarization within the given threshold search range,
219+
count the number of connected components in the binarized results.
220+
Adaptively determine the threshold according to the count,
221+
and perform binarization on the anomaly maps.
222+
223+
Returns:
224+
binarized_maps (list): List of binarized images.
225+
thresholds (list): List of thresholds for each image.
226+
"""
227+
self.binarized_maps = []
228+
self.thresholds = []
229+
230+
for i, anomaly_map in enumerate(tqdm(self.anomaly_map_list)):
231+
# Normalize the anomaly map within the given threshold search range.
232+
anomaly_map_norm = np.where(anomaly_map < self.min_th, 0, ((anomaly_map - self.min_th) / (self.max_th - self.min_th)) * 255)
233+
anomaly_num_sequence = []
234+
235+
# Search for the threshold from high to low within the given range using the specified sampling rate.
236+
for score in range(255, 0, -self.sample_rate):
237+
bin_result = self.bin_and_erode(anomaly_map_norm, score)
238+
num_labels, *rest = cv2.connectedComponentsWithStats(bin_result, connectivity=8)
239+
anomaly_num = num_labels - 1
240+
anomaly_num_sequence.append(anomaly_num)
241+
242+
# Adaptively determine the threshold based on the anomaly connected component count sequence.
243+
threshold, est_anomaly_num = self.get_threshold(anomaly_num_sequence, self.min_interval_len)
244+
245+
# Binarize the anomaly image based on the determined threshold.
246+
bin_result = self.bin_and_erode(anomaly_map, threshold)
247+
self.binarized_maps.append(bin_result)
248+
self.thresholds.append(threshold)
249+
250+
return self.binarized_maps, self.thresholds

src/anomalib/post_processing/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,5 +20,6 @@
2020
"""
2121

2222
from .post_processor import PostProcessor
23+
from .mebin_post_processor import MEBinPostProcessor
2324

24-
__all__ = ["PostProcessor"]
25+
__all__ = ["PostProcessor", "MEBinPostProcessor"]

0 commit comments

Comments
 (0)