Skip to content

Commit

Permalink
Enable type checks for the ffn.inference.segmentation module.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 568984962
  • Loading branch information
mjanusz authored and copybara-github committed Sep 27, 2023
1 parent 0b30abb commit a6d91d1
Showing 1 changed file with 72 additions and 17 deletions.
89 changes: 72 additions & 17 deletions ffn/inference/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import skimage.measure


def clear_dust(data, min_size=10):
def clear_dust(data: np.ndarray, min_size: int = 10):
"""Removes small objects from a segmentation array.
Replaces objects smaller than `min_size` with 0 (background).
Expand All @@ -39,7 +39,7 @@ def clear_dust(data, min_size=10):
return data


def reduce_id_bits(segmentation):
def reduce_id_bits(segmentation: np.ndarray):
"""Reduces the number of bits used for IDs.
Assumes that one additional ID beyond the max of 'segmentation' is necessary
Expand All @@ -59,9 +59,10 @@ def reduce_id_bits(segmentation):
return segmentation.astype(np.uint16)
elif max_id <= np.iinfo(np.uint32).max:
return segmentation.astype(np.uint32)
return segmentation


def split_disconnected_components(labels):
def split_disconnected_components(labels: np.ndarray, connectivity=1):
"""Relabels the connected components of a 3-D integer array.
Connected components are determined based on 6-connectivity, where two
Expand All @@ -76,50 +77,105 @@ def split_disconnected_components(labels):
Args:
labels: 3-D integer numpy array.
connectivity: 1, 2, or 3; for 6-, 18-, or 26-connectivity respectively.
Returns:
The relabeled numpy array, same dtype as `labels`.
"""
has_zero = 0 in labels
fixed_labels = skimage.measure.label(labels, connectivity=1, background=0)
fixed_labels = skimage.measure.label(
labels, connectivity=connectivity, background=0)
if has_zero or (not has_zero and 0 in fixed_labels):
if np.any((fixed_labels == 0) != (labels == 0)):
fixed_labels[...] += 1
fixed_labels[labels == 0] = 0
return np.cast[labels.dtype](fixed_labels)


def clean_up(seg, split_cc=True, min_size=0, return_id_map=False): # pylint: disable=invalid-name
def clean_up(seg: np.ndarray,
split_cc=True,
connectivity=1,
min_size=0,
return_id_map=False):
"""Runs connected components and removes small objects.
Args:
seg: segmentation to clean as a uint64 ndarray
split_cc: whether to recompute connected components
min_size: connected components smaller that this value get
removed from the segmentation; if 0, no filtering by size is done
return_id_map: whether to compute and return a map from new IDs
to original IDs
connectivity: used for split_cc; 1, 2, or 3; for 6-, 18-, or 26-connectivity
respectively.
min_size: connected components smaller that this value get removed from the
segmentation; if 0, no filtering by size is done
return_id_map: whether to compute and return a map from new IDs to original
IDs
Returns:
None if not return_id_map, otherwise a dictionary mapping
new IDs to original IDs. `seg` is modified in place.
"""
cc_to_orig, _ = clean_up_and_count(
seg,
split_cc,
connectivity,
min_size,
compute_id_map=return_id_map,
compute_counts=False)
if return_id_map:
return cc_to_orig


def clean_up_and_count(seg: np.ndarray,
split_cc=True,
connectivity=1,
min_size=0,
compute_id_map=True,
compute_counts=True):
"""Runs connected components and removes small objects, returns metadata.
Args:
seg: segmentation to clean as a uint64 ndarray. Mutated in place.
split_cc: whether to recompute connected components
connectivity: used for split_cc; 1, 2, or 3; for 6-, 18-, or 26-connectivity
respectively.
min_size: connected components smaller that this value get removed from the
segmentation; if 0, no filtering by size is done
compute_id_map: whether to compute a mapping of new CC ID to old ID. If
False, None is returned instead.
compute_counts: whether to compute a mapping of new CC ID to voxel count. If
False, None is returned instead.
Returns:
tuple of (dict of new ID to original ID, dict of new ID to voxel count). If
compute_id_map or compute_counts is False, the respective returned tuple
member will be None.
"""
if compute_id_map:
seg_orig = seg.copy()

if split_cc:
seg[...] = split_disconnected_components(seg)
seg[...] = split_disconnected_components(seg, connectivity)
if min_size > 0:
clear_dust(seg, min_size)

if return_id_map:
cc_ids, cc_idx = np.unique(seg.ravel(), return_index=True)
cc_to_orig, cc_to_count = None, None

if compute_id_map or compute_counts:
unique_result_tuple = np.unique(
seg.ravel(), return_index=compute_id_map, return_counts=compute_counts)
cc_ids = unique_result_tuple[0]
if compute_id_map:
cc_idx = unique_result_tuple[1]
orig_ids = seg_orig.ravel()[cc_idx]
cc_to_orig = dict(zip(cc_ids, orig_ids))
return cc_to_orig
if compute_counts:
cc_counts = unique_result_tuple[-1]
cc_to_count = dict(zip(cc_ids, cc_counts))

return cc_to_orig, cc_to_count


def split_segmentation_by_intersection(a, b, min_size):
def split_segmentation_by_intersection(a: np.ndarray, b: np.ndarray,
min_size: int):
"""Computes the intersection of two segmentations.
Intersects two spatially overlapping segmentations and assigns a new ID to
Expand Down Expand Up @@ -213,9 +269,8 @@ def remap_input(x):

# Relabel map to apply to remapped_joint_labels to obtain the output ids.
new_labels = np.zeros(len(unique_joint_labels), np.uint64)
for i, (label_a, label_b, count) in enumerate(zip(unique_joint_labels_a,
unique_joint_labels_b,
joint_counts)):
for i, (label_a, label_b, count) in enumerate(
zip(unique_joint_labels_a, unique_joint_labels_b, joint_counts)):
if count < min_size or label_a == 0:
new_label = 0
elif label_b == max_overlap_ids[label_a][0]:
Expand Down

0 comments on commit a6d91d1

Please sign in to comment.