Skip to content

Commit cc36034

Browse files
authored
Merge pull request #22 from funkelab/scale_points
Scale points
2 parents bf5c52b + f588698 commit cc36034

File tree

4 files changed

+42
-10
lines changed

4 files changed

+42
-10
lines changed

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ requires = ["setuptools>=64", "wheel", "setuptools_scm>=8"]
66
name = "motile_toolbox"
77
description = "A toolbox for tracking with (motile)[https://github.com/funkelab/motile]."
88
readme = "README.md"
9-
requires-python = ">=3.7"
9+
requires-python = ">=3.10"
1010
classifiers = [
1111
"Programming Language :: Python :: 3",
1212
]
@@ -51,7 +51,7 @@ omit = ["src/motile_toolbox/visualization/*"]
5151
# https://github.com/charliermarsh/ruff
5252
[tool.ruff]
5353
line-length = 88
54-
target-version = "py38"
54+
target-version = "py310"
5555

5656
[tool.ruff.lint]
5757
select = [

src/motile_toolbox/candidate_graph/compute_graph.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,7 @@ def get_candidate_graph(
5252
)
5353
if iou:
5454
# Scale does not matter to IOU, because both numerator and denominator
55-
# are scaled by the anisotropy. It would matter to compare IOUs across
56-
# multiple scales of data, but this is not the current use case.
55+
# are scaled by the anisotropy.
5756
add_iou(cand_graph, segmentation, node_frame_dict)
5857

5958
logger.info(f"Candidate edges: {cand_graph.number_of_edges()}")
@@ -70,6 +69,7 @@ def get_candidate_graph(
7069
def get_candidate_graph_from_points_list(
7170
points_list: np.ndarray,
7271
max_edge_distance: float,
72+
scale: list[float] | None = None,
7373
) -> nx.DiGraph:
7474
"""Construct a candidate graph from a points list.
7575
@@ -79,13 +79,17 @@ def get_candidate_graph_from_points_list(
7979
max_edge_distance (float): Maximum distance that objects can travel between
8080
frames. All nodes with centroids within this distance in adjacent frames
8181
will by connected with a candidate edge.
82+
scale (list[float] | None, optional): Amount to scale the points in each
83+
dimension. Only needed if the provided points are in "voxel" coordinates
84+
instead of world coordinates. Defaults to None, which implies the data is
85+
isotropic.
8286
8387
Returns:
8488
nx.DiGraph: A candidate graph that can be passed to the motile solver.
8589
Multiple hypotheses not supported for points input.
8690
"""
8791
# add nodes
88-
cand_graph, node_frame_dict = nodes_from_points_list(points_list)
92+
cand_graph, node_frame_dict = nodes_from_points_list(points_list, scale=scale)
8993
logger.info(f"Candidate nodes: {cand_graph.number_of_nodes()}")
9094
# add edges
9195
add_cand_edges(

src/motile_toolbox/candidate_graph/utils.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import logging
2-
from typing import Any, Iterable
2+
from collections.abc import Iterable
3+
from typing import Any
34

45
import networkx as nx
56
import numpy as np
@@ -64,7 +65,7 @@ def nodes_from_segmentation(
6465
cand_graph = nx.DiGraph()
6566
# also construct a dictionary from time frame to node_id for efficiency
6667
node_frame_dict: dict[int, list[Any]] = {}
67-
print("Extracting nodes from segmentation")
68+
logger.info("Extracting nodes from segmentation")
6869
num_hypotheses = segmentation.shape[1]
6970
if scale is None:
7071
scale = [
@@ -101,6 +102,7 @@ def nodes_from_segmentation(
101102

102103
def nodes_from_points_list(
103104
points_list: np.ndarray,
105+
scale: list[float] | None = None,
104106
) -> tuple[nx.DiGraph, dict[int, list[Any]]]:
105107
"""Extract candidate nodes from a list of points. Uses the index of the
106108
point in the list as its unique id.
@@ -110,6 +112,10 @@ def nodes_from_points_list(
110112
Args:
111113
points_list (np.ndarray): An NxD numpy array with N points and D
112114
(3 or 4) dimensions. Dimensions should be in order (t, [z], y, x).
115+
scale (list[float] | None, optional): Amount to scale the points in each
116+
dimension. Only needed if the provided points are in "voxel" coordinates
117+
instead of world coordinates. Defaults to None, which implies the data is
118+
isotropic.
113119
114120
Returns:
115121
tuple[nx.DiGraph, dict[int, list[Any]]]: A candidate graph with only nodes,
@@ -118,7 +124,16 @@ def nodes_from_points_list(
118124
cand_graph = nx.DiGraph()
119125
# also construct a dictionary from time frame to node_id for efficiency
120126
node_frame_dict: dict[int, list[Any]] = {}
121-
print("Extracting nodes from points list")
127+
logger.info("Extracting nodes from points list")
128+
129+
# scale points
130+
if scale is not None:
131+
assert (
132+
len(scale) == points_list.shape[1]
133+
), f"Cannot scale points with {points_list.shape[1]} dims by factor {scale}"
134+
points_list = points_list * np.array(scale)
135+
136+
# add points to graph
122137
for i, point in enumerate(points_list):
123138
# assume t, [z], y, x
124139
t = point[0]
@@ -187,7 +202,7 @@ def add_cand_edges(
187202
to node ids. If not provided, it will be computed from cand_graph. Defaults
188203
to None.
189204
"""
190-
print("Extracting candidate edges")
205+
logger.info("Extracting candidate edges")
191206
if not node_frame_dict:
192207
node_frame_dict = _compute_node_frame_dict(cand_graph)
193208

@@ -202,7 +217,9 @@ def add_cand_edges(
202217

203218
matched_indices = prev_kdtree.query_ball_tree(next_kdtree, max_edge_distance)
204219

205-
for prev_node_id, next_node_indices in zip(prev_node_ids, matched_indices):
220+
for prev_node_id, next_node_indices in zip(
221+
prev_node_ids, matched_indices, strict=False
222+
):
206223
for next_node_index in next_node_indices:
207224
next_node_id = next_node_ids[next_node_index]
208225
cand_graph.add_edge(prev_node_id, next_node_id)

tests/test_candidate_graph/test_compute_graph.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from motile_toolbox.candidate_graph.compute_graph import (
77
get_candidate_graph_from_points_list,
88
)
9+
from motile_toolbox.candidate_graph.graph_attributes import NodeAttr
910

1011

1112
def test_graph_from_segmentation_2d(segmentation_2d, graph_2d):
@@ -91,6 +92,7 @@ def test_graph_from_multi_segmentation_2d(
9192
def test_graph_from_points_list():
9293
points_list = np.array(
9394
[
95+
# t, z, y, x
9496
[0, 1, 1, 1],
9597
[2, 3, 3, 3],
9698
[1, 2, 2, 2],
@@ -101,3 +103,12 @@ def test_graph_from_points_list():
101103
cand_graph = get_candidate_graph_from_points_list(points_list, max_edge_distance=3)
102104
assert cand_graph.number_of_edges() == 3
103105
assert len(cand_graph.in_edges(3)) == 0
106+
107+
# test scale
108+
cand_graph = get_candidate_graph_from_points_list(
109+
points_list, max_edge_distance=3, scale=[1, 1, 1, 5]
110+
)
111+
assert cand_graph.number_of_edges() == 0
112+
assert len(cand_graph.in_edges(3)) == 0
113+
assert cand_graph.nodes[0][NodeAttr.POS.value] == [1, 1, 5]
114+
assert cand_graph.nodes[0][NodeAttr.TIME.value] == 0

0 commit comments

Comments
 (0)