Skip to content

Commit 6c337e4

Browse files
authored
Merge pull request #25 from funkelab/13-unique-label-ids
13 unique label ids
2 parents 4986526 + 3a8489b commit 6c337e4

File tree

13 files changed

+253
-233
lines changed

13 files changed

+253
-233
lines changed

src/motile_toolbox/candidate_graph/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,4 @@
66
from .graph_attributes import EdgeAttr, NodeAttr
77
from .graph_to_nx import graph_to_nx
88
from .iou import add_iou
9-
from .utils import add_cand_edges, get_node_id, nodes_from_segmentation
9+
from .utils import add_cand_edges, nodes_from_segmentation

src/motile_toolbox/candidate_graph/compute_graph.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def compute_graph_from_multiseg(
114114
conflicts = []
115115
for time in range(segmentations.shape[1]):
116116
segs = segmentations[:, time]
117-
conflicts.extend(compute_conflict_sets(segs, time))
117+
conflicts.extend(compute_conflict_sets(segs))
118118

119119
return cand_graph, conflicts
120120

src/motile_toolbox/candidate_graph/conflict_sets.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,8 @@
22

33
import numpy as np
44

5-
from .utils import (
6-
get_node_id,
7-
)
85

9-
10-
def compute_conflict_sets(segmentation_frame: np.ndarray, time: int) -> list[set]:
6+
def compute_conflict_sets(segmentation_frame: np.ndarray) -> list[set]:
117
"""Compute all sets of node ids that conflict with each other.
128
Note: Results might include redundant sets, for example {a, b, c} and {a, b}
139
might both appear in the results.
@@ -36,9 +32,6 @@ def compute_conflict_sets(segmentation_frame: np.ndarray, time: int) -> list[set
3632
values = np.transpose(values)
3733
conflict_sets = []
3834
for conflicting_labels in values:
39-
id_set = set()
40-
for hypo_id, label in enumerate(conflicting_labels):
41-
if label != 0:
42-
id_set.add(get_node_id(time, label, hypo_id))
35+
id_set = {label for label in conflicting_labels if label != 0}
4336
conflict_sets.append(id_set)
4437
return conflict_sets

src/motile_toolbox/candidate_graph/iou.py

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
from itertools import product
2-
from typing import Any
32

43
import networkx as nx
54
import numpy as np
65
from tqdm import tqdm
76

87
from .graph_attributes import EdgeAttr
9-
from .utils import _compute_node_frame_dict, get_node_id
8+
from .utils import _compute_node_frame_dict
109

1110

1211
def _compute_ious(
@@ -45,7 +44,7 @@ def _compute_ious(
4544
return ious
4645

4746

48-
def _get_iou_dict(segmentation, multiseg=False) -> dict[str, dict[str, float]]:
47+
def _get_iou_dict(segmentation, multiseg=False) -> dict[int, dict[int, float]]:
4948
"""Get all ious values for the provided segmentations (all frames).
5049
Will return as map from node_id -> dict[node_id] -> iou for easy
5150
navigation when adding to candidate graph.
@@ -58,10 +57,10 @@ def _get_iou_dict(segmentation, multiseg=False) -> dict[str, dict[str, float]]:
5857
multiple hypothesis segmentations. Defaults to False.
5958
6059
Returns:
61-
dict[str, dict[str, float]]: A map from node id to another dictionary, which
60+
dict[int, dict[int, float]]: A map from node id to another dictionary, which
6261
contains node_ids to iou values.
6362
"""
64-
iou_dict: dict[str, dict[str, float]] = {}
63+
iou_dict: dict[int, dict[int, float]] = {}
6564
hypo_pairs: list[tuple[int, ...]] = [(0, 0)]
6665
if multiseg:
6766
num_hypotheses = segmentation.shape[0]
@@ -76,23 +75,16 @@ def _get_iou_dict(segmentation, multiseg=False) -> dict[str, dict[str, float]]:
7675
seg2 = segmentation[hypo2][frame + 1]
7776
ious = _compute_ious(seg1, seg2)
7877
for label1, label2, iou in ious:
79-
if multiseg:
80-
node_id1 = get_node_id(frame, label1, hypo1)
81-
node_id2 = get_node_id(frame + 1, label2, hypo2)
82-
else:
83-
node_id1 = get_node_id(frame, label1)
84-
node_id2 = get_node_id(frame + 1, label2)
85-
86-
if node_id1 not in iou_dict:
87-
iou_dict[node_id1] = {}
88-
iou_dict[node_id1][node_id2] = iou
78+
if label1 not in iou_dict:
79+
iou_dict[label1] = {}
80+
iou_dict[label1][label2] = iou
8981
return iou_dict
9082

9183

9284
def add_iou(
9385
cand_graph: nx.DiGraph,
9486
segmentation: np.ndarray,
95-
node_frame_dict: dict[int, list[Any]] | None = None,
87+
node_frame_dict: dict[int, list[int]] | None = None,
9688
multiseg=False,
9789
) -> None:
9890
"""Add IOU to the candidate graph.

src/motile_toolbox/candidate_graph/utils.py

Lines changed: 4 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -13,28 +13,6 @@
1313
logger = logging.getLogger(__name__)
1414

1515

16-
def get_node_id(time: int, label_id: int, hypothesis_id: int | None = None) -> str:
17-
"""Construct a node id given the time frame, segmentation label id, and
18-
optionally the hypothesis id. This function is not designed for candidate graphs
19-
that do not come from segmentations, but could be used if there is a similar
20-
"detection id" that is unique for all cells detected in a given frame.
21-
22-
Args:
23-
time (int): The time frame the node is in
24-
label_id (int): The label the node has in the segmentation.
25-
hypothesis_id (int | None, optional): An integer representing which hypothesis
26-
the segmentation came from, if applicable. Defaults to None.
27-
28-
Returns:
29-
str: A string to use as the node id in the candidate graph. Assuming that label
30-
ids are not repeated in the same time frame and hypothesis, it is unique.
31-
"""
32-
if hypothesis_id is not None:
33-
return f"{time}_{hypothesis_id}_{label_id}"
34-
else:
35-
return f"{time}_{label_id}"
36-
37-
3816
def nodes_from_segmentation(
3917
segmentation: np.ndarray,
4018
scale: list[float] | None = None,
@@ -52,7 +30,9 @@ def nodes_from_segmentation(
5230
5331
Args:
5432
segmentation (np.ndarray): A numpy array with integer labels and dimensions
55-
(t, [z], y, x).
33+
(t, [z], y, x). Labels must be unique across time, and the label
34+
will be used as the node id. If the labels are not unique, preprocess
35+
with motile_toolbox.utils.ensure_unqiue_ids before calling this function.
5636
scale (list[float] | None, optional): The scale of the segmentation data in all
5737
dimensions (including time, which should have a dummy 1 value).
5838
Will be used to rescale the point locations and attribute computations.
@@ -82,7 +62,7 @@ def nodes_from_segmentation(
8262
nodes_in_frame = []
8363
props = regionprops(segs, spacing=tuple(scale[1:]))
8464
for regionprop in props:
85-
node_id = get_node_id(t, regionprop.label, hypothesis_id=seg_hypo)
65+
node_id = regionprop.label
8666
attrs = {NodeAttr.TIME.value: t, NodeAttr.AREA.value: regionprop.area}
8767
attrs[NodeAttr.SEG_ID.value] = regionprop.label
8868
if seg_hypo:

src/motile_toolbox/utils/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,4 @@
1-
from .relabel_segmentation import relabel_segmentation
1+
from .relabel_segmentation import (
2+
ensure_unique_labels,
3+
relabel_segmentation_with_track_id,
4+
)

src/motile_toolbox/utils/relabel_segmentation.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from motile_toolbox.candidate_graph import NodeAttr
55

66

7-
def relabel_segmentation(
7+
def relabel_segmentation_with_track_id(
88
solution_nx_graph: nx.DiGraph,
99
segmentation: np.ndarray,
1010
) -> np.ndarray:
@@ -37,3 +37,32 @@ def relabel_segmentation(
3737
tracked_masks[time_frame][previous_seg_mask] = id_counter
3838
id_counter += 1
3939
return tracked_masks
40+
41+
42+
def ensure_unique_labels(
43+
segmentation: np.ndarray,
44+
multiseg: bool = False,
45+
) -> np.ndarray:
46+
"""Relabels the segmentation in place to ensure that label ids are unique across
47+
time. This means that every detection will have a unique label id.
48+
Useful for combining predictions made in each frame independently, or multiple
49+
segmentation outputs that repeat label IDs.
50+
51+
Args:
52+
segmentation (np.ndarray): Segmentation with dimensions ([h], t, [z], y, x).
53+
multiseg (bool, optional): Flag indicating if the segmentation contains
54+
multiple hypotheses in the first dimension. Defaults to False.
55+
"""
56+
segmentation = segmentation.astype(np.uint64)
57+
orig_shape = segmentation.shape
58+
if multiseg:
59+
segmentation = segmentation.reshape((-1, *orig_shape[2:]))
60+
curr_max = 0
61+
for idx in range(segmentation.shape[0]):
62+
frame = segmentation[idx]
63+
frame[frame != 0] += curr_max
64+
curr_max = int(np.max(frame))
65+
segmentation[idx] = frame
66+
if multiseg:
67+
segmentation = segmentation.reshape(orig_shape)
68+
return segmentation

0 commit comments

Comments
 (0)