1
1
import logging
2
- from typing import Any , Iterable
2
+ from collections .abc import Iterable
3
+ from typing import Any
3
4
4
5
import networkx as nx
5
6
import numpy as np
@@ -64,7 +65,7 @@ def nodes_from_segmentation(
64
65
cand_graph = nx .DiGraph ()
65
66
# also construct a dictionary from time frame to node_id for efficiency
66
67
node_frame_dict : dict [int , list [Any ]] = {}
67
- print ("Extracting nodes from segmentation" )
68
+ logger . info ("Extracting nodes from segmentation" )
68
69
num_hypotheses = segmentation .shape [1 ]
69
70
if scale is None :
70
71
scale = [
@@ -101,6 +102,7 @@ def nodes_from_segmentation(
101
102
102
103
def nodes_from_points_list (
103
104
points_list : np .ndarray ,
105
+ scale : list [float ] | None = None ,
104
106
) -> tuple [nx .DiGraph , dict [int , list [Any ]]]:
105
107
"""Extract candidate nodes from a list of points. Uses the index of the
106
108
point in the list as its unique id.
@@ -110,6 +112,10 @@ def nodes_from_points_list(
110
112
Args:
111
113
points_list (np.ndarray): An NxD numpy array with N points and D
112
114
(3 or 4) dimensions. Dimensions should be in order (t, [z], y, x).
115
+ scale (list[float] | None, optional): Amount to scale the points in each
116
+ dimension. Only needed if the provided points are in "voxel" coordinates
117
+ instead of world coordinates. Defaults to None, which implies the data is
118
+ isotropic.
113
119
114
120
Returns:
115
121
tuple[nx.DiGraph, dict[int, list[Any]]]: A candidate graph with only nodes,
@@ -118,7 +124,16 @@ def nodes_from_points_list(
118
124
cand_graph = nx .DiGraph ()
119
125
# also construct a dictionary from time frame to node_id for efficiency
120
126
node_frame_dict : dict [int , list [Any ]] = {}
121
- print ("Extracting nodes from points list" )
127
+ logger .info ("Extracting nodes from points list" )
128
+
129
+ # scale points
130
+ if scale is not None :
131
+ assert (
132
+ len (scale ) == points_list .shape [1 ]
133
+ ), f"Cannot scale points with { points_list .shape [1 ]} dims by factor { scale } "
134
+ points_list = points_list * np .array (scale )
135
+
136
+ # add points to graph
122
137
for i , point in enumerate (points_list ):
123
138
# assume t, [z], y, x
124
139
t = point [0 ]
@@ -187,7 +202,7 @@ def add_cand_edges(
187
202
to node ids. If not provided, it will be computed from cand_graph. Defaults
188
203
to None.
189
204
"""
190
- print ("Extracting candidate edges" )
205
+ logger . info ("Extracting candidate edges" )
191
206
if not node_frame_dict :
192
207
node_frame_dict = _compute_node_frame_dict (cand_graph )
193
208
@@ -202,7 +217,9 @@ def add_cand_edges(
202
217
203
218
matched_indices = prev_kdtree .query_ball_tree (next_kdtree , max_edge_distance )
204
219
205
- for prev_node_id , next_node_indices in zip (prev_node_ids , matched_indices ):
220
+ for prev_node_id , next_node_indices in zip (
221
+ prev_node_ids , matched_indices , strict = False
222
+ ):
206
223
for next_node_index in next_node_indices :
207
224
next_node_id = next_node_ids [next_node_index ]
208
225
cand_graph .add_edge (prev_node_id , next_node_id )
0 commit comments