Skip to content

Commit c7e9548

Browse files
authored
Merge pull request #2008 from roboflow/feature/filter_segments_by_distance
initial version of `filter_segments_by_distance`
2 parents 0bd6087 + 65f6e41 commit c7e9548

File tree

4 files changed

+372
-0
lines changed

4 files changed

+372
-0
lines changed

docs/detection/utils/masks.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,9 @@ status: new
2222
</div>
2323

2424
:::supervision.detection.utils.masks.contains_multiple_segments
25+
26+
<div class="md-typeset">
27+
<h2><a href="#supervision.detection.utils.masks.filter_segments_by_distance">filter_segments_by_distance</a></h2>
28+
</div>
29+
30+
:::supervision.detection.utils.masks.filter_segments_by_distance

supervision/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@
8787
calculate_masks_centroids,
8888
contains_holes,
8989
contains_multiple_segments,
90+
filter_segments_by_distance,
9091
move_masks,
9192
)
9293
from supervision.detection.utils.polygons import (
@@ -219,6 +220,7 @@
219220
"draw_text",
220221
"edit_distance",
221222
"filter_polygons_by_area",
223+
"filter_segments_by_distance",
222224
"fuzzy_match_index",
223225
"get_coco_class_index_mapping",
224226
"get_polygon_center",

supervision/detection/utils/masks.py

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from __future__ import annotations
22

3+
from typing import Literal
4+
35
import cv2
46
import numpy as np
57
import numpy.typing as npt
@@ -260,3 +262,139 @@ def resize_masks(masks: np.ndarray, max_dimension: int = 640) -> np.ndarray:
260262
resized_masks = masks[:, yv, xv]
261263

262264
return resized_masks.reshape(masks.shape[0], new_height, new_width)
265+
266+
267+
def filter_segments_by_distance(
268+
mask: npt.NDArray[np.bool_],
269+
absolute_distance: float | None = 100.0,
270+
relative_distance: float | None = None,
271+
connectivity: int = 8,
272+
mode: Literal["edge", "centroid"] = "edge",
273+
) -> npt.NDArray[np.bool_]:
274+
"""
275+
Keep the largest connected component and any other components within a distance
276+
threshold.
277+
278+
Distance can be absolute in pixels or relative to the image diagonal.
279+
280+
Args:
281+
mask: Boolean mask HxW.
282+
absolute_distance: Max allowed distance in pixels to the main component.
283+
Ignored if `relative_distance` is provided.
284+
relative_distance: Fraction of the diagonal. If set, threshold = fraction * sqrt(H^2 + W^2).
285+
connectivity: Defines which neighboring pixels are considered connected.
286+
- 4-connectedness: Only orthogonal neighbors.
287+
```
288+
[ ][X][ ]
289+
[X][O][X]
290+
[ ][X][ ]
291+
```
292+
- 8-connectedness: Includes diagonal neighbors.
293+
```
294+
[X][X][X]
295+
[X][O][X]
296+
[X][X][X]
297+
```
298+
Default is 8.
299+
mode: Defines how distance between components is measured.
300+
- "edge": Uses distance between nearest edges (via distance transform).
301+
- "centroid": Uses distance between component centroids.
302+
303+
Returns:
304+
Boolean mask after filtering.
305+
306+
Examples:
307+
```python
308+
import numpy as np
309+
import supervision as sv
310+
311+
mask = np.array([
312+
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
313+
[0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
314+
[0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
315+
[0, 1, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0],
316+
[0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0],
317+
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
318+
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
319+
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
320+
[0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0],
321+
[0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0],
322+
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
323+
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
324+
], dtype=bool)
325+
326+
sv.filter_segments_by_distance(
327+
mask,
328+
absolute_distance=2,
329+
mode="edge",
330+
connectivity=8
331+
).astype(int)
332+
333+
# np.array([
334+
# [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
335+
# [0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
336+
# [0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
337+
# [0, 1, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0],
338+
# [0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0],
339+
# [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
340+
# [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
341+
# [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
342+
# [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
343+
# [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
344+
# [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
345+
# [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
346+
# ], dtype=bool)
347+
348+
# The nearby 2×2 block at columns 6–7 is kept because its edge distance
349+
# is within 2 pixels. The distant block at columns 9-10 is removed.
350+
```
351+
""" # noqa E501 // docs
352+
if mask.dtype != bool:
353+
raise TypeError("mask must be boolean")
354+
355+
height, width = mask.shape
356+
if not np.any(mask):
357+
return mask.copy()
358+
359+
image = mask.astype(np.uint8)
360+
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(
361+
image, connectivity=connectivity
362+
)
363+
364+
if num_labels <= 1:
365+
return mask.copy()
366+
367+
areas = stats[1:, cv2.CC_STAT_AREA]
368+
main_label = 1 + int(np.argmax(areas))
369+
370+
if relative_distance is not None:
371+
diagonal = float(np.hypot(height, width))
372+
threshold = float(relative_distance) * diagonal
373+
else:
374+
threshold = float(absolute_distance)
375+
376+
keep_labels = np.zeros(num_labels, dtype=bool)
377+
keep_labels[main_label] = True
378+
379+
if mode == "centroid":
380+
differences = centroids[1:] - centroids[main_label]
381+
distances = np.sqrt(np.sum(differences**2, axis=1))
382+
nearby = 1 + np.where(distances <= threshold)[0]
383+
keep_labels[nearby] = True
384+
elif mode == "edge":
385+
main_mask = (labels == main_label).astype(np.uint8)
386+
inverse = 1 - main_mask
387+
distance_transform = cv2.distanceTransform(inverse, cv2.DIST_L2, 3)
388+
for label in range(1, num_labels):
389+
if label == main_label:
390+
continue
391+
component = labels == label
392+
if not np.any(component):
393+
continue
394+
min_distance = float(distance_transform[component].min())
395+
if min_distance <= threshold:
396+
keep_labels[label] = True
397+
else:
398+
raise ValueError("mode must be 'edge' or 'centroid'")
399+
400+
return keep_labels[labels]

0 commit comments

Comments
 (0)