Skip to content

Commit

Permalink
Add support for filtering objets by size (#333)
Browse files Browse the repository at this point in the history
  • Loading branch information
giswqs authored Oct 6, 2024
1 parent 837f5a5 commit db8a155
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 2 deletions.
37 changes: 36 additions & 1 deletion samgeo/samgeo.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,8 @@ def generate(
erosion_kernel=None,
mask_multiplier=255,
unique=True,
min_size=0,
max_size=None,
**kwargs,
):
"""Generate masks for the input image.
Expand All @@ -180,6 +182,9 @@ def generate(
The parameter is ignored if unique is True.
unique (bool, optional): Whether to assign a unique value to each object. Defaults to True.
The unique value increases from 1 to the number of objects. The larger the number, the larger the object area.
min_size (int, optional): The minimum size of the objects. Defaults to 0.
max_size (int, optional): The maximum size of the objects. Defaults to None.
**kwargs: Other arguments for save_masks().
"""

Expand Down Expand Up @@ -221,10 +226,19 @@ def generate(
masks = mask_generator.generate(image) # Segment the input image
self.masks = masks # Store the masks as a list of dictionaries
self.batch = False
self._min_size = min_size
self._max_size = max_size

# Save the masks to the output path. The output is either a binary mask or a mask of objects with unique values.
self.save_masks(
output, foreground, unique, erosion_kernel, mask_multiplier, **kwargs
output,
foreground,
unique,
erosion_kernel,
mask_multiplier,
min_size,
max_size,
**kwargs,
)

def save_masks(
Expand All @@ -234,6 +248,8 @@ def save_masks(
unique=True,
erosion_kernel=None,
mask_multiplier=255,
min_size=0,
max_size=None,
**kwargs,
):
"""Save the masks to the output path. The output is either a binary mask or a mask of objects with unique values.
Expand All @@ -246,6 +262,9 @@ def save_masks(
Such as (3, 3) or (5, 5). Set to None to disable it. Defaults to None.
mask_multiplier (int, optional): The mask multiplier for the output mask, which is usually a binary mask [0, 1].
You can use this parameter to scale the mask to a larger range, for example [0, 255]. Defaults to 255.
min_size (int, optional): The minimum size of the objects. Defaults to 0.
max_size (int, optional): The maximum size of the objects. Defaults to None.
**kwargs: Other arguments for array_to_image().
"""

Expand Down Expand Up @@ -279,6 +298,10 @@ def save_masks(
count = len(sorted_masks)
for index, ann in enumerate(sorted_masks):
m = ann["segmentation"]
if min_size > 0 and ann["area"] < min_size:
continue
if max_size is not None and ann["area"] > max_size:
continue
objects[m] = count - index

# Generate a binary mask
Expand All @@ -290,6 +313,10 @@ def save_masks(
resulting_borders = np.zeros((h, w), dtype=dtype)

for m in masks:
if min_size > 0 and m["area"] < min_size:
continue
if max_size is not None and m["area"] > max_size:
continue
mask = (m["segmentation"] > 0).astype(dtype)
resulting_mask += mask

Expand Down Expand Up @@ -384,6 +411,14 @@ def show_anns(
)
img[:, :, 3] = 0
for ann in sorted_anns:
if hasattr(self, "_min_size") and (ann["area"] < self._min_size):
continue
if (
hasattr(self, "_max_size")
and isinstance(self._max_size, int)
and ann["area"] > self._max_size
):
continue
m = ann["segmentation"]
color_mask = np.concatenate([np.random.random(3), [alpha]])
img[m] = color_mask
Expand Down
36 changes: 35 additions & 1 deletion samgeo/samgeo2.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,8 @@ def generate(
erosion_kernel: Optional[Tuple[int, int]] = None,
mask_multiplier: int = 255,
unique: bool = True,
min_size: int = 0,
max_size: int = None,
**kwargs: Any,
) -> List[Dict[str, Any]]:
"""
Expand All @@ -215,6 +217,8 @@ def generate(
Defaults to True.
The unique value increases from 1 to the number of objects. The
larger the number, the larger the object area.
min_size (int): The minimum size of the object. Defaults to 0.
max_size (int): The maximum size of the object. Defaults to None.
**kwargs (Any): Additional keyword arguments.
Returns:
Expand All @@ -241,11 +245,20 @@ def generate(
mask_generator = self.mask_generator # The automatic mask generator
masks = mask_generator.generate(image) # Segment the input image
self.masks = masks # Store the masks as a list of dictionaries
self._min_size = min_size
self._max_size = max_size

if output is not None:
# Save the masks to the output path. The output is either a binary mask or a mask of objects with unique values.
self.save_masks(
output, foreground, unique, erosion_kernel, mask_multiplier, **kwargs
output,
foreground,
unique,
erosion_kernel,
mask_multiplier,
min_size,
max_size,
**kwargs,
)

def save_masks(
Expand All @@ -255,6 +268,8 @@ def save_masks(
unique: bool = True,
erosion_kernel: Optional[Tuple[int, int]] = None,
mask_multiplier: int = 255,
min_size: int = 0,
max_size: int = None,
**kwargs: Any,
) -> None:
"""Save the masks to the output path. The output is either a binary mask
Expand All @@ -275,6 +290,9 @@ def save_masks(
mask, which is usually a binary mask [0, 1]. You can use this
parameter to scale the mask to a larger range, for example
[0, 255]. Defaults to 255.
min_size (int, optional): The minimum size of the object. Defaults to 0.
max_size (int, optional): The maximum size of the object. Defaults to None.
**kwargs: Additional keyword arguments for common.array_to_image().
"""

if self.masks is None:
Expand Down Expand Up @@ -307,6 +325,10 @@ def save_masks(
count = len(sorted_masks)
for index, ann in enumerate(sorted_masks):
m = ann["segmentation"]
if min_size > 0 and ann["area"] < min_size:
continue
if max_size is not None and ann["area"] > max_size:
continue
objects[m] = count - index

# Generate a binary mask
Expand All @@ -318,6 +340,10 @@ def save_masks(
resulting_borders = np.zeros((h, w), dtype=dtype)

for m in masks:
if min_size > 0 and m["area"] < min_size:
continue
if max_size is not None and m["area"] > max_size:
continue
mask = (m["segmentation"] > 0).astype(dtype)
resulting_mask += mask

Expand Down Expand Up @@ -415,6 +441,14 @@ def show_anns(
)
img[:, :, 3] = 0
for ann in sorted_anns:
if hasattr(self, "_min_size") and (ann["area"] < self._min_size):
continue
if (
hasattr(self, "_max_size")
and isinstance(self._max_size, int)
and ann["area"] > self._max_size
):
continue
m = ann["segmentation"]
color_mask = np.concatenate([np.random.random(3), [alpha]])
img[m] = color_mask
Expand Down

0 comments on commit db8a155

Please sign in to comment.