Skip to content

Commit be7b3e9

Browse files
authored
Merge pull request #5 from funkelab/multi_hypothesis
Multi hypothesis
2 parents 968cb2f + ab75d9a commit be7b3e9

18 files changed

+1027
-610
lines changed

.pre-commit-config.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,11 @@ repos:
1414
- id: check-yaml
1515
- id: check-added-large-files
1616

17+
- repo: https://github.com/psf/black
18+
rev: 23.1.0
19+
hooks:
20+
- id: black
21+
1722
- repo: https://github.com/charliermarsh/ruff-pre-commit
1823
rev: v0.2.2
1924
hooks:

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ dev = [
3535
'pdoc',
3636
'pre-commit',
3737
'types-tqdm',
38+
'pytest-unordered'
3839
]
3940

4041
[project.urls]
Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from .compute_graph import get_candidate_graph
12
from .graph_attributes import EdgeAttr, NodeAttr
2-
from .graph_from_segmentation import graph_from_segmentation
33
from .graph_to_nx import graph_to_nx
4+
from .iou import add_iou
5+
from .utils import add_cand_edges, get_node_id, nodes_from_segmentation
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
import logging
2+
from typing import Any
3+
4+
import networkx as nx
5+
import numpy as np
6+
7+
from .conflict_sets import compute_conflict_sets
8+
from .iou import add_iou
9+
from .utils import add_cand_edges, nodes_from_segmentation
10+
11+
logger = logging.getLogger(__name__)
12+
13+
14+
def get_candidate_graph(
15+
segmentation: np.ndarray,
16+
max_edge_distance: float,
17+
iou: bool = False,
18+
multihypo: bool = False,
19+
) -> tuple[nx.DiGraph, list[set[Any]] | None]:
20+
"""Construct a candidate graph from a segmentation array. Nodes are placed at the
21+
centroid of each segmentation and edges are added for all nodes in adjacent frames
22+
within max_edge_distance. If segmentation contains multiple hypotheses, will also
23+
return a list of conflicting node ids that cannot be selected together.
24+
25+
Args:
26+
segmentation (np.ndarray): A numpy array with integer labels and dimensions
27+
(t, [h], [z], y, x), where h is the number of hypotheses.
28+
max_edge_distance (float): Maximum distance that objects can travel between
29+
frames. All nodes with centroids within this distance in adjacent frames
30+
will by connected with a candidate edge.
31+
iou (bool, optional): Whether to include IOU on the candidate graph.
32+
Defaults to False.
33+
multihypo (bool, optional): Whether the segmentation contains multiple
34+
hypotheses. Defaults to False.
35+
36+
Returns:
37+
tuple[nx.DiGraph, list[set[Any]] | None]: A candidate graph that can be passed
38+
to the motile solver, and a list of conflicting node ids.
39+
"""
40+
# add nodes
41+
if multihypo:
42+
cand_graph = nx.DiGraph()
43+
num_frames = segmentation.shape[0]
44+
node_frame_dict: dict[int, list[Any]] = {t: [] for t in range(num_frames)}
45+
num_hypotheses = segmentation.shape[1]
46+
for hypo_id in range(num_hypotheses):
47+
hypothesis = segmentation[:, hypo_id]
48+
node_graph, frame_dict = nodes_from_segmentation(
49+
hypothesis, hypo_id=hypo_id
50+
)
51+
cand_graph.update(node_graph)
52+
for t in range(num_frames):
53+
if t in frame_dict:
54+
node_frame_dict[t].extend(frame_dict[t])
55+
else:
56+
cand_graph, node_frame_dict = nodes_from_segmentation(segmentation)
57+
logger.info(f"Candidate nodes: {cand_graph.number_of_nodes()}")
58+
59+
# add edges
60+
add_cand_edges(
61+
cand_graph,
62+
max_edge_distance=max_edge_distance,
63+
node_frame_dict=node_frame_dict,
64+
)
65+
if iou:
66+
add_iou(cand_graph, segmentation, node_frame_dict, multihypo=multihypo)
67+
68+
logger.info(f"Candidate edges: {cand_graph.number_of_edges()}")
69+
70+
# Compute conflict sets between segmentations
71+
if multihypo:
72+
conflicts = []
73+
for time, segs in enumerate(segmentation):
74+
conflicts.extend(compute_conflict_sets(segs, time))
75+
else:
76+
conflicts = None
77+
78+
return cand_graph, conflicts
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
from itertools import combinations
2+
3+
import numpy as np
4+
5+
from .utils import (
6+
get_node_id,
7+
)
8+
9+
10+
def compute_conflict_sets(segmentation_frame: np.ndarray, time: int) -> list[set]:
11+
"""Compute all sets of node ids that conflict with each other.
12+
Note: Results might include redundant sets, for example {a, b, c} and {a, b}
13+
might both appear in the results.
14+
15+
Args:
16+
segmentation_frame (np.ndarray): One frame of the multiple hypothesis
17+
segmentation. Dimensions are (h, [z], y, x), where h is the number of
18+
hypotheses.
19+
time (int): Time frame, for computing node_ids.
20+
21+
Returns:
22+
list[set]: list of sets of node ids that overlap. Might include some sets
23+
that are subsets of others.
24+
"""
25+
flattened_segs = [seg.flatten() for seg in segmentation_frame]
26+
27+
# get locations where at least two hypotheses have labels
28+
# This approach may be inefficient, but likely doesn't matter compared to np.unique
29+
conflict_indices = np.zeros(flattened_segs[0].shape, dtype=bool)
30+
for seg1, seg2 in combinations(flattened_segs, 2):
31+
non_zero_indices = np.logical_and(seg1, seg2)
32+
conflict_indices = np.logical_or(conflict_indices, non_zero_indices)
33+
34+
flattened_stacked = np.array([seg[conflict_indices] for seg in flattened_segs])
35+
values = np.unique(flattened_stacked, axis=1)
36+
values = np.transpose(values)
37+
conflict_sets = []
38+
for conflicting_labels in values:
39+
id_set = set()
40+
for hypo_id, label in enumerate(conflicting_labels):
41+
if label != 0:
42+
id_set.add(get_node_id(time, label, hypo_id))
43+
conflict_sets.append(id_set)
44+
return conflict_sets

src/motile_toolbox/candidate_graph/graph_attributes.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,10 @@ class NodeAttr(Enum):
77
implementations of commonly used ones, listed here.
88
"""
99

10-
SEG_ID = "segmentation_id"
10+
POS = "pos"
11+
TIME = "time"
12+
SEG_ID = "seg_id"
13+
SEG_HYPO = "seg_hypo"
1114

1215

1316
class EdgeAttr(Enum):

0 commit comments

Comments
 (0)