From 4f066238d6809d940fc00b587f8c4a0e4d481e4f Mon Sep 17 00:00:00 2001 From: Jordao Bragantini Date: Mon, 24 Nov 2025 19:28:39 -0800 Subject: [PATCH 1/5] adding neighbors_per_frame option --- src/tracksdata/edges/_distance_edges.py | 243 ++++++++++++++---- .../edges/_test/test_distance_edges.py | 161 ++++++++++++ 2 files changed, 347 insertions(+), 57 deletions(-) diff --git a/src/tracksdata/edges/_distance_edges.py b/src/tracksdata/edges/_distance_edges.py index 8d27b0f3..77e46451 100644 --- a/src/tracksdata/edges/_distance_edges.py +++ b/src/tracksdata/edges/_distance_edges.py @@ -29,10 +29,14 @@ class DistanceEdges(BaseEdgesOperator): Maximum number of neighbors to consider for each node when adding edges. For each node at time t, edges will be created to at most n_neighbors closest nodes at time t-1 to t-delta_t. - delta_t : int, default 1 + delta_t : int The number of time points to consider for adding edges. For each node at time t, edges will be created to the closest n_neighbors nodes at time t-1 to t-delta_t. + neighbors_per_frame : bool, default False + Whether to consider the neighbors in the current frame as well as the previous frame. + If True, `n_neighbors` is the number of neighbors per frame. + If False, `n_neighbors` is the number of neighbors in all frames (from t-delta_t to t). output_key : str, default DEFAULT_ATTR_KEYS.EDGE_WEIGHT The attribute key to store the distance values in the edges. attr_keys : Sequence[str] | None, optional @@ -48,6 +52,12 @@ class DistanceEdges(BaseEdgesOperator): This in respect from the current to the previous frame. That means, a node in frame t will have edges to the closest n_neighbors nodes in frame t-1. + delta_t : int, default 1 + The number of time points to consider for adding edges. + For each node at time t, edges will be created to the closest + n_neighbors nodes at time t-delta_t to t. + neighbors_per_frame : bool, default False + Whether `n_neighbors` is the number of neighbors per frame or all frames (from t-delta_t to t). output_key : str The key used to store distance values in edges. attr_keys : Sequence[str] | None @@ -92,6 +102,7 @@ def __init__( distance_threshold: float, n_neighbors: int, delta_t: int = 1, + neighbors_per_frame: bool = False, output_key: str = DEFAULT_ATTR_KEYS.EDGE_DIST, attr_keys: Sequence[str] | None = None, ): @@ -103,6 +114,7 @@ def __init__( self.n_neighbors = n_neighbors self.delta_t = delta_t self.attr_keys = attr_keys + self.neighbors_per_frame = neighbors_per_frame def _init_edge_attrs(self, graph: BaseGraph) -> None: """ @@ -111,93 +123,210 @@ def _init_edge_attrs(self, graph: BaseGraph) -> None: if self.output_key not in graph.edge_attr_keys: graph.add_edge_attr_key(self.output_key, default_value=-99999.0) - def _add_edges_per_time( - self, - t: int, - *, - graph: BaseGraph, - ) -> list[dict[str, Any]]: + def _get_spatial_attr_keys(self, graph: BaseGraph) -> list[str]: """ - Add distance-based edges between nodes at consecutive time points. - - Finds nodes at time t-1 and t, computes pairwise distances using KDTree, - and creates edges between nearby nodes within the distance threshold. - Uses bulk edge insertion for efficiency. + Determine which spatial attribute keys to use for distance calculation. Parameters ---------- - t : int - The current time point. Edges will be created from nodes at - time t-1 to nodes at time t. graph : BaseGraph - The current time point. Edges will be created from nodes at - time t-1 to nodes at time t. + The graph containing node attributes. + + Returns + ------- + list[str] + List of attribute keys to use for spatial coordinates. """ if self.attr_keys is None: if "z" in graph.node_attr_keys: - attr_keys = ["z", "y", "x"] + return ["z", "y", "x"] else: - attr_keys = ["y", "x"] + return ["y", "x"] else: - attr_keys = self.attr_keys + return list(self.attr_keys) - if self.delta_t == 1: - # faster than the range filter - prev_filter = graph.filter(NodeAttr(DEFAULT_ATTR_KEYS.T) == t - 1) - else: - prev_filter = graph.filter( - NodeAttr(DEFAULT_ATTR_KEYS.T) >= t - self.delta_t, - NodeAttr(DEFAULT_ATTR_KEYS.T) < t, - ) + def _build_kdtree_data( + self, graph: BaseGraph, time_point: int, attr_keys: Sequence[str] + ) -> tuple[KDTree, Any, list]: + """ + Build KDTree for a specific time point. - if prev_filter.is_empty(): - LOG.warning( - "No nodes found for time point in range (%d <= t < %d)", - t - self.delta_t, - t, - ) - return [] + Parameters + ---------- + graph : BaseGraph + The graph to query. + time_point : int + The time point to build the KDTree for. + attr_keys : Sequence[str] + Attribute keys to use for spatial coordinates. + + Returns + ------- + tuple[KDTree, GraphView, list] + A tuple containing: + - KDTree built from node coordinates + - Node attributes as numpy array + - List of node IDs at this time point + """ + node_filter = graph.filter(NodeAttr(DEFAULT_ATTR_KEYS.T) == time_point) - current_filter = graph.filter(NodeAttr(DEFAULT_ATTR_KEYS.T) == t) + if node_filter.is_empty(): + return None, None, [] - if current_filter.is_empty(): - LOG.warning( - "No nodes found for time point %d", - t, - ) - return [] + node_attrs = node_filter.node_attrs(attr_keys=attr_keys) + node_ids = list(node_filter.node_ids()) + kdtree = KDTree(node_attrs.to_numpy()) - prev_attrs = prev_filter.node_attrs(attr_keys=attr_keys) - cur_attrs = current_filter.node_attrs(attr_keys=attr_keys) + return kdtree, node_attrs, node_ids - prev_kdtree = KDTree(prev_attrs.to_numpy()) + def _query_neighbors_single_kdtree( + self, + kdtree: KDTree, + source_node_ids: np.ndarray, + target_coords: np.ndarray, + target_node_ids: list, + ) -> list[dict[str, Any]]: + """ + Query neighbors from a single KDTree and create edge data. - distances, prev_neigh_ids = prev_kdtree.query( - cur_attrs.to_numpy(), + Parameters + ---------- + kdtree : KDTree + KDTree of source nodes to query. + source_node_ids : np.ndarray + Array of source node IDs corresponding to KDTree points. + target_coords : np.ndarray + Coordinates of target nodes to query for. + target_node_ids : list + List of target node IDs. + + Returns + ------- + list[dict[str, Any]] + List of edge dictionaries with source_id, target_id, and distance. + """ + distances, neighbor_indices = kdtree.query( + target_coords, k=self.n_neighbors, distance_upper_bound=self.distance_threshold, ) + is_valid = ~np.isinf(distances) - prev_node_ids = np.asarray(prev_filter.node_ids()) - # kdtree return from 0 to n-1 - # converting back to arbitrary indexing - prev_neigh_ids[is_valid] = prev_node_ids[prev_neigh_ids[is_valid]] + # Convert KDTree indices (0 to n-1) back to actual node IDs + neighbor_indices_copy = neighbor_indices.copy() + neighbor_indices_copy[is_valid] = source_node_ids[neighbor_indices_copy[is_valid]] edges_data = [] - for cur_id, neigh_ids, neigh_dist, neigh_valid in zip( - current_filter.node_ids(), prev_neigh_ids, distances, is_valid, strict=True + for target_id, neigh_ids, neigh_dist, neigh_valid in zip( + target_node_ids, neighbor_indices_copy, distances, is_valid, strict=True ): - for neigh_id, dist in zip(neigh_ids[neigh_valid].tolist(), neigh_dist[neigh_valid].tolist(), strict=True): + for source_id, dist in zip(neigh_ids[neigh_valid].tolist(), neigh_dist[neigh_valid].tolist(), strict=True): edges_data.append( { - "source_id": neigh_id, - "target_id": cur_id, + "source_id": source_id, + "target_id": target_id, self.output_key: dist, } ) + return edges_data + + def _add_edges_per_time( + self, + t: int, + *, + graph: BaseGraph, + ) -> list[dict[str, Any]]: + """ + Add distance-based edges between nodes at consecutive time points. + + Finds nodes at time t and previous time points (t-1 to t-delta_t), + computes pairwise distances using KDTree, and creates edges between + nearby nodes within the distance threshold. + + The behavior depends on the `neighbors_per_frame` parameter: + - If False (default): Queries all previous frames as one combined KDTree, + returning up to `n_neighbors` total connections. + - If True: Queries each previous frame separately, returning up to + `n_neighbors` connections per frame. + + Parameters + ---------- + t : int + The current time point. Edges will be created from nodes at + previous time points to nodes at time t. + graph : BaseGraph + The graph to add edges to. + + Returns + ------- + list[dict[str, Any]] + List of edge dictionaries to be added to the graph. + """ + attr_keys = self._get_spatial_attr_keys(graph) + + # Get current time point nodes + current_filter = graph.filter(NodeAttr(DEFAULT_ATTR_KEYS.T) == t) + + if current_filter.is_empty(): + LOG.warning("No nodes found for time point %d", t) + return [] + + cur_attrs = current_filter.node_attrs(attr_keys=attr_keys) + cur_coords = cur_attrs.to_numpy() + cur_node_ids = list(current_filter.node_ids()) + + edges_data = [] + + if self.neighbors_per_frame: + # Query each previous time frame separately + for prev_t in range(t - self.delta_t, t): + kdtree, _, prev_node_ids = self._build_kdtree_data(graph, prev_t, attr_keys) + + if kdtree is None: + LOG.warning("No nodes found for time point %d", prev_t) + continue + + frame_edges = self._query_neighbors_single_kdtree( + kdtree, + np.asarray(prev_node_ids), + cur_coords, + cur_node_ids, + ) + edges_data.extend(frame_edges) + else: + # Query all previous frames as one combined KDTree (original behavior) + if self.delta_t == 1: + # Faster path for single frame + prev_filter = graph.filter(NodeAttr(DEFAULT_ATTR_KEYS.T) == t - 1) + else: + # Range filter for multiple frames + prev_filter = graph.filter( + NodeAttr(DEFAULT_ATTR_KEYS.T) >= t - self.delta_t, + NodeAttr(DEFAULT_ATTR_KEYS.T) < t, + ) + + if prev_filter.is_empty(): + LOG.warning( + "No nodes found for time point in range (%d <= t < %d)", + t - self.delta_t, + t, + ) + return [] + + prev_attrs = prev_filter.node_attrs(attr_keys=attr_keys) + prev_node_ids = np.asarray(list(prev_filter.node_ids())) + prev_kdtree = KDTree(prev_attrs.to_numpy()) + + edges_data = self._query_neighbors_single_kdtree( + prev_kdtree, + prev_node_ids, + cur_coords, + cur_node_ids, + ) + if len(edges_data) == 0: - LOG.warning("No valid edges found for the pair of time point (%d, %d)", t, t - 1) + LOG.warning("No valid edges found for time point %d", t) return edges_data diff --git a/src/tracksdata/edges/_test/test_distance_edges.py b/src/tracksdata/edges/_test/test_distance_edges.py index b5e60756..2f07559a 100644 --- a/src/tracksdata/edges/_test/test_distance_edges.py +++ b/src/tracksdata/edges/_test/test_distance_edges.py @@ -312,3 +312,164 @@ def test_distance_edges_multiprocessing_isolation() -> None: """Test that multiprocessing options don't affect subsequent tests.""" # Verify default n_workers is 1 assert get_options().n_workers == 1 + + +def test_distance_edges_neighbors_per_frame_false() -> None: + """Test neighbors_per_frame=False behavior with delta_t=2.""" + graph = RustWorkXGraph() + + # Register attribute keys + graph.add_node_attr_key("x", 0.0) + graph.add_node_attr_key("y", 0.0) + + # Add nodes at t=0, t=1, t=2 + # At t=0: two nodes close to origin + graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 0.0, "y": 0.0}) + graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 0.5, "y": 0.0}) + + # At t=1: two nodes slightly further + node_1a = graph.add_node({DEFAULT_ATTR_KEYS.T: 1, "x": 1.0, "y": 0.0}) + node_1b = graph.add_node({DEFAULT_ATTR_KEYS.T: 1, "x": 1.5, "y": 0.0}) + + # At t=2: one node at (2, 0) + node_2 = graph.add_node({DEFAULT_ATTR_KEYS.T: 2, "x": 2.0, "y": 0.0}) + + # With neighbors_per_frame=False, delta_t=2, n_neighbors=2: + # Node at t=2 should connect to 2 closest nodes from combined t=0 and t=1 + # (which should be the t=1 nodes since they're closer) + operator = DistanceEdges( + distance_threshold=5.0, + n_neighbors=2, + delta_t=2, + neighbors_per_frame=False, + ) + + operator.add_edges(graph) + + edges_df = graph.edge_attrs() + t2_edges = edges_df.filter(edges_df[DEFAULT_ATTR_KEYS.EDGE_TARGET] == node_2) + + # Should have exactly 2 edges (n_neighbors=2 total) + assert len(t2_edges) == 2 + + # Both edges should be from t=1 nodes (closest to t=2 node) + sources = set(t2_edges[DEFAULT_ATTR_KEYS.EDGE_SOURCE].to_list()) + assert sources == {node_1a, node_1b} + + +def test_distance_edges_neighbors_per_frame_true() -> None: + """Test neighbors_per_frame=True behavior with delta_t=2.""" + graph = RustWorkXGraph() + + # Register attribute keys + graph.add_node_attr_key("x", 0.0) + graph.add_node_attr_key("y", 0.0) + + # Add nodes at t=0, t=1, t=2 + # At t=0: two nodes close to origin + node_0a = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 0.0, "y": 0.0}) + node_0b = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 0.5, "y": 0.0}) + + # At t=1: two nodes slightly further + node_1a = graph.add_node({DEFAULT_ATTR_KEYS.T: 1, "x": 1.0, "y": 0.0}) + node_1b = graph.add_node({DEFAULT_ATTR_KEYS.T: 1, "x": 1.5, "y": 0.0}) + + # At t=2: one node at (2, 0) + node_2 = graph.add_node({DEFAULT_ATTR_KEYS.T: 2, "x": 2.0, "y": 0.0}) + + # With neighbors_per_frame=True, delta_t=2, n_neighbors=2: + # Node at t=2 should connect to 2 closest from t=0 AND 2 closest from t=1 + # (4 edges total if all are within distance threshold) + operator = DistanceEdges( + distance_threshold=5.0, + n_neighbors=2, + delta_t=2, + neighbors_per_frame=True, + ) + + operator.add_edges(graph) + + edges_df = graph.edge_attrs() + t2_edges = edges_df.filter(edges_df[DEFAULT_ATTR_KEYS.EDGE_TARGET] == node_2) + + # Should have 4 edges (2 from t=0, 2 from t=1) + assert len(t2_edges) == 4 + + # Should include nodes from both t=0 and t=1 + sources = set(t2_edges[DEFAULT_ATTR_KEYS.EDGE_SOURCE].to_list()) + assert sources == {node_0a, node_0b, node_1a, node_1b} + + +def test_distance_edges_neighbors_per_frame_with_distance_threshold() -> None: + """Test neighbors_per_frame=True respects distance threshold per frame.""" + graph = RustWorkXGraph() + + # Register attribute keys + graph.add_node_attr_key("x", 0.0) + graph.add_node_attr_key("y", 0.0) + + # Add nodes at t=0 (far away), t=1 (close), t=2 + # At t=0: nodes very far from where t=2 node will be + graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 0.0, "y": 0.0}) + + # At t=1: node close to where t=2 node will be + node_1 = graph.add_node({DEFAULT_ATTR_KEYS.T: 1, "x": 2.0, "y": 0.0}) + + # At t=2: node at (2.5, 0) + node_2 = graph.add_node({DEFAULT_ATTR_KEYS.T: 2, "x": 2.5, "y": 0.0}) + + # With a tight distance threshold, only t=1 node should connect + operator = DistanceEdges( + distance_threshold=1.0, # Only t=1 node is within 1.0 units + n_neighbors=2, + delta_t=2, + neighbors_per_frame=True, + ) + + operator.add_edges(graph) + + edges_df = graph.edge_attrs() + t2_edges = edges_df.filter(edges_df[DEFAULT_ATTR_KEYS.EDGE_TARGET] == node_2) + + # Should have only 1 edge from t=1 (t=0 is too far) + assert len(t2_edges) == 1 + assert t2_edges[DEFAULT_ATTR_KEYS.EDGE_SOURCE][0] == node_1 + + +def test_distance_edges_neighbors_per_frame_single_delta_t() -> None: + """Test that neighbors_per_frame behaves same as False when delta_t=1.""" + graph = RustWorkXGraph() + + # Register attribute keys + graph.add_node_attr_key("x", 0.0) + graph.add_node_attr_key("y", 0.0) + + # Add nodes at t=0 and t=1 + for i in range(3): + graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": float(i), "y": 0.0}) + + graph.add_node({DEFAULT_ATTR_KEYS.T: 1, "x": 1.0, "y": 0.0}) + + # Test with neighbors_per_frame=False + operator_false = DistanceEdges( + distance_threshold=5.0, + n_neighbors=2, + delta_t=1, + neighbors_per_frame=False, + ) + + graph_copy = graph.copy() + operator_false.add_edges(graph) + + # Test with neighbors_per_frame=True + operator_true = DistanceEdges( + distance_threshold=5.0, + n_neighbors=2, + delta_t=1, + neighbors_per_frame=True, + ) + + operator_true.add_edges(graph_copy) + + # Both should produce the same number of edges for delta_t=1 + assert graph.num_edges == graph_copy.num_edges From ad99c6adc65088d10eb328a3ce47946379830ace Mon Sep 17 00:00:00 2001 From: Jordao Bragantini Date: Mon, 24 Nov 2025 19:37:56 -0800 Subject: [PATCH 2/5] fixing range for non-int values --- src/tracksdata/edges/_distance_edges.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tracksdata/edges/_distance_edges.py b/src/tracksdata/edges/_distance_edges.py index 77e46451..2c76bb27 100644 --- a/src/tracksdata/edges/_distance_edges.py +++ b/src/tracksdata/edges/_distance_edges.py @@ -281,7 +281,7 @@ def _add_edges_per_time( if self.neighbors_per_frame: # Query each previous time frame separately - for prev_t in range(t - self.delta_t, t): + for prev_t in range(int(t - self.delta_t), int(t)): kdtree, _, prev_node_ids = self._build_kdtree_data(graph, prev_t, attr_keys) if kdtree is None: From 4d6c41a21ef363d9a9b822703f695a1fffaffde4 Mon Sep 17 00:00:00 2001 From: Jordao Bragantini Date: Thu, 22 Jan 2026 10:05:35 -0800 Subject: [PATCH 3/5] updated docs --- src/tracksdata/edges/_distance_edges.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/tracksdata/edges/_distance_edges.py b/src/tracksdata/edges/_distance_edges.py index 2c76bb27..d3fad165 100644 --- a/src/tracksdata/edges/_distance_edges.py +++ b/src/tracksdata/edges/_distance_edges.py @@ -34,9 +34,12 @@ class DistanceEdges(BaseEdgesOperator): For each node at time t, edges will be created to the closest n_neighbors nodes at time t-1 to t-delta_t. neighbors_per_frame : bool, default False - Whether to consider the neighbors in the current frame as well as the previous frame. - If True, `n_neighbors` is the number of neighbors per frame. - If False, `n_neighbors` is the number of neighbors in all frames (from t-delta_t to t). + Whether to consider the `n_neighbors` as `per_frame` or `total`. + If True, `n_neighbors` is the number of neighbors per frame, meaning that + for each node at time t, edges will be created to the closest + n_neighbors per adjacent frame. + If False, `n_neighbors` is the number of neighbors in all frames (from t-delta_t to t) + considering all adjacent frames together. output_key : str, default DEFAULT_ATTR_KEYS.EDGE_WEIGHT The attribute key to store the distance values in the edges. attr_keys : Sequence[str] | None, optional From 66b005cb03443fc861a86b0e664fb5da1b10d1f7 Mon Sep 17 00:00:00 2001 From: Jordao Bragantini Date: Thu, 22 Jan 2026 10:10:23 -0800 Subject: [PATCH 4/5] fixing test --- src/tracksdata/edges/_test/test_distance_edges.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tracksdata/edges/_test/test_distance_edges.py b/src/tracksdata/edges/_test/test_distance_edges.py index e775d5f1..98be809e 100644 --- a/src/tracksdata/edges/_test/test_distance_edges.py +++ b/src/tracksdata/edges/_test/test_distance_edges.py @@ -472,4 +472,4 @@ def test_distance_edges_neighbors_per_frame_single_delta_t() -> None: operator_true.add_edges(graph_copy) # Both should produce the same number of edges for delta_t=1 - assert graph.num_edges == graph_copy.num_edges + assert graph.num_edges() == graph_copy.num_edges() From c165885296dea0fbf6a70853940c84ff6af3e0f7 Mon Sep 17 00:00:00 2001 From: Jordao Bragantini Date: Thu, 22 Jan 2026 10:13:21 -0800 Subject: [PATCH 5/5] improving testing --- src/tracksdata/edges/_test/test_distance_edges.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/tracksdata/edges/_test/test_distance_edges.py b/src/tracksdata/edges/_test/test_distance_edges.py index 98be809e..1232d999 100644 --- a/src/tracksdata/edges/_test/test_distance_edges.py +++ b/src/tracksdata/edges/_test/test_distance_edges.py @@ -390,6 +390,8 @@ def test_distance_edges_neighbors_per_frame_true() -> None: operator.add_edges(graph) edges_df = graph.edge_attrs() + assert len(edges_df) == 4 + 2 + 2 # 4 edges from node_2 and 2 edges from node_1a, node_1b + t2_edges = edges_df.filter(edges_df[DEFAULT_ATTR_KEYS.EDGE_TARGET] == node_2) # Should have 4 edges (2 from t=0, 2 from t=1)