Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 34 additions & 31 deletions src/tracksdata/array/_test/test_graph_array.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from collections.abc import Sequence

import numpy as np
import polars as pl
import pytest
from pytest import fixture

Expand Down Expand Up @@ -31,9 +32,9 @@ def test_chain_indices() -> None:
def test_graph_array_view_init(graph_backend: BaseGraph) -> None:
"""Test GraphArrayView initialization."""
# Add a attribute key
graph_backend.add_node_attr_key("label", 0)
graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, None)
graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.BBOX, np.asarray([0, 0, 0, 0, 0, 0]))
graph_backend.add_node_attr_key("label", dtype=pl.Int64)
graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, pl.Object)
graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.BBOX, pl.Array(pl.Int64, 6))

array_view = GraphArrayView(graph=graph_backend, shape=(10, 100, 100), attr_key="label", offset=0)

Expand All @@ -57,9 +58,9 @@ def test_graph_array_view_init_invalid_attr_key(graph_backend: BaseGraph) -> Non
def test_graph_array_view_getitem_empty_time(graph_backend: BaseGraph) -> None:
"""Test __getitem__ with empty time point (no nodes)."""

graph_backend.add_node_attr_key("label", 0)
graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, None)
graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.BBOX, np.asarray([0, 0, 0, 0, 0, 0]))
graph_backend.add_node_attr_key("label", dtype=pl.Int64)
graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, pl.Object)
graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.BBOX, pl.Array(pl.Int64, 6))

array_view = GraphArrayView(graph=graph_backend, shape=(10, 100, 100), attr_key="label")

Expand All @@ -76,11 +77,11 @@ def test_graph_array_view_getitem_with_nodes(graph_backend: BaseGraph) -> None:
"""Test __getitem__ with nodes at time point."""

# Add attribute keys
graph_backend.add_node_attr_key("label", 0)
graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, None)
graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.BBOX, np.asarray([0, 0, 0, 0]))
graph_backend.add_node_attr_key("y", 0)
graph_backend.add_node_attr_key("x", 0)
graph_backend.add_node_attr_key("label", dtype=pl.Int64)
graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, pl.Object)
graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.BBOX, pl.Array(pl.Int64, 4))
graph_backend.add_node_attr_key("y", dtype=pl.Int64)
graph_backend.add_node_attr_key("x", dtype=pl.Int64)

# Create a mask
mask_data = np.array([[True, True], [True, False]], dtype=bool)
Expand Down Expand Up @@ -126,11 +127,11 @@ def test_graph_array_view_getitem_multiple_nodes(graph_backend: BaseGraph) -> No
"""Test __getitem__ with multiple nodes at same time point."""

# Add attribute keys
graph_backend.add_node_attr_key("label", 0)
graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, None)
graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.BBOX, np.asarray([0, 0, 0, 0]))
graph_backend.add_node_attr_key("y", 0)
graph_backend.add_node_attr_key("x", 0)
graph_backend.add_node_attr_key("label", dtype=pl.Int64)
graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, pl.Object)
graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.BBOX, pl.Array(pl.Int64, 4))
graph_backend.add_node_attr_key("y", dtype=pl.Int64)
graph_backend.add_node_attr_key("x", dtype=pl.Int64)

# Create two masks at different locations
mask1_data = np.array([[True, True]], dtype=bool)
Expand Down Expand Up @@ -181,11 +182,11 @@ def test_graph_array_view_getitem_boolean_dtype(graph_backend: BaseGraph) -> Non
"""Test __getitem__ with boolean attribute values."""

# Add attribute keys
graph_backend.add_node_attr_key("is_active", False)
graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, None)
graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.BBOX, np.asarray([0, 0, 0, 0]))
graph_backend.add_node_attr_key("y", 0)
graph_backend.add_node_attr_key("x", 0)
graph_backend.add_node_attr_key("is_active", pl.Boolean)
graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, pl.Object)
graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.BBOX, pl.Array(pl.Int64, 4))
graph_backend.add_node_attr_key("y", dtype=pl.Int64)
graph_backend.add_node_attr_key("x", dtype=pl.Int64)

# Create a mask
mask_data = np.array([[True]], dtype=bool)
Expand Down Expand Up @@ -218,11 +219,11 @@ def test_graph_array_view_dtype_inference(graph_backend: BaseGraph) -> None:
"""Test that dtype is properly inferred from data."""

# Add attribute keys
graph_backend.add_node_attr_key("float_label", 0.0)
graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, None)
graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.BBOX, np.asarray([0, 0, 0, 0]))
graph_backend.add_node_attr_key("y", 0)
graph_backend.add_node_attr_key("x", 0)
graph_backend.add_node_attr_key("float_label", dtype=pl.Float64)
graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, pl.Object)
graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.BBOX, pl.Array(pl.Int64, 4))
graph_backend.add_node_attr_key("y", dtype=pl.Int64)
graph_backend.add_node_attr_key("x", dtype=pl.Int64)

# Create a mask
mask_data = np.array([[True]], dtype=bool)
Expand Down Expand Up @@ -340,9 +341,9 @@ def test_graph_array_view_getitem_time_index_nested(multi_node_graph_from_image,

def test_graph_array_set_options(graph_backend: BaseGraph) -> None:
with Options(gav_chunk_shape=(512, 512), gav_default_dtype=np.int16):
graph_backend.add_node_attr_key("label", 0)
graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, None)
graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.BBOX, np.asarray([0, 0, 0, 0]))
graph_backend.add_node_attr_key("label", dtype=pl.Int64)
graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, pl.Object)
graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.BBOX, pl.Array(pl.Int64, 4))
array_view = GraphArrayView(graph=graph_backend, shape=(10, 100, 100), attr_key="label")
assert array_view.chunk_shape == (512, 512)
assert array_view.dtype == np.int16
Expand All @@ -363,8 +364,10 @@ def test_graph_array_raise_error_on_non_scalar_attr_key(graph_backend: BaseGraph
"""Test that GraphArrayView raises error if attr_key values are non-scalar."""

# Add a attribute key
graph_backend.add_node_attr_key("label", np.array([0, 1])) # Non-scalar default value
graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, None)
graph_backend.add_node_attr_key(
"label", dtype=pl.Object, default_value=np.array([0, 1])
) # Non-scalar default value
graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, pl.Object)
graph_backend.add_node(
{
DEFAULT_ATTR_KEYS.T: 0,
Expand Down
3 changes: 2 additions & 1 deletion src/tracksdata/edges/_distance_edges.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Any

import numpy as np
import polars as pl
from scipy.spatial import KDTree

from tracksdata.attrs import NodeAttr
Expand Down Expand Up @@ -125,7 +126,7 @@ def _init_edge_attrs(self, graph: BaseGraph) -> None:
Initialize the edge attributes for the graph.
"""
if self.output_key not in graph.edge_attr_keys():
graph.add_edge_attr_key(self.output_key, default_value=-99999.0)
graph.add_edge_attr_key(self.output_key, pl.Float64, default_value=-99999.0)

def _get_spatial_attr_keys(self, graph: BaseGraph) -> list[str]:
"""
Expand Down
3 changes: 2 additions & 1 deletion src/tracksdata/edges/_generic_edges.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Any

import numpy as np
import polars as pl

from tracksdata.attrs import NodeAttr
from tracksdata.constants import DEFAULT_ATTR_KEYS
Expand Down Expand Up @@ -54,7 +55,7 @@ def _init_edge_attrs(self, graph: BaseGraph) -> None:
Initialize the edge attributes for the graph.
"""
if self.output_key not in graph.edge_attr_keys():
graph.add_edge_attr_key(self.output_key, default_value=-99999.0)
graph.add_edge_attr_key(self.output_key, pl.Float64, default_value=-99999.0)

def _edge_attrs_per_time(
self,
Expand Down
43 changes: 22 additions & 21 deletions src/tracksdata/edges/_test/test_distance_edges.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import polars as pl
import pytest

from tracksdata.constants import DEFAULT_ATTR_KEYS
Expand Down Expand Up @@ -48,8 +49,8 @@ def test_distance_edges_add_edges_single_timepoint_no_previous() -> None:
graph = RustWorkXGraph()

# Register attribute keys
graph.add_node_attr_key("x", 0.0)
graph.add_node_attr_key("y", 0.0)
graph.add_node_attr_key("x", dtype=pl.Float64)
graph.add_node_attr_key("y", dtype=pl.Float64)

# Add nodes only at t=1 (no t=0)
graph.add_node({DEFAULT_ATTR_KEYS.T: 1, "x": 0.0, "y": 0.0})
Expand All @@ -67,8 +68,8 @@ def test_distance_edges_add_edges_single_timepoint_no_current() -> None:
graph = RustWorkXGraph()

# Register attribute keys
graph.add_node_attr_key("x", 0.0)
graph.add_node_attr_key("y", 0.0)
graph.add_node_attr_key("x", dtype=pl.Float64)
graph.add_node_attr_key("y", dtype=pl.Float64)

# Add nodes only at t=0 (no t=1)
graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 0.0, "y": 0.0})
Expand All @@ -86,8 +87,8 @@ def test_distance_edges_add_edges_2d_coordinates() -> None:
graph = RustWorkXGraph()

# Register attribute keys
graph.add_node_attr_key("x", 0.0)
graph.add_node_attr_key("y", 0.0)
graph.add_node_attr_key("x", dtype=pl.Float64)
graph.add_node_attr_key("y", dtype=pl.Float64)

# Add nodes at t=0
_ = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 0.0, "y": 0.0})
Expand All @@ -114,9 +115,9 @@ def test_distance_edges_add_edges_3d_coordinates() -> None:
graph = RustWorkXGraph()

# Register attribute keys
graph.add_node_attr_key("x", 0.0)
graph.add_node_attr_key("y", 0.0)
graph.add_node_attr_key("z", 0.0)
graph.add_node_attr_key("x", dtype=pl.Float64)
graph.add_node_attr_key("y", dtype=pl.Float64)
graph.add_node_attr_key("z", dtype=pl.Float64)

# Add nodes at t=0
_ = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 0.0, "y": 0.0, "z": 0.0})
Expand All @@ -139,8 +140,8 @@ def test_distance_edges_add_edges_custom_attr_keys() -> None:
graph = RustWorkXGraph()

# Register attribute keys
graph.add_node_attr_key("pos_x", 0.0)
graph.add_node_attr_key("pos_y", 0.0)
graph.add_node_attr_key("pos_x", dtype=pl.Float64)
graph.add_node_attr_key("pos_y", dtype=pl.Float64)

# Add nodes at t=0
_ = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "pos_x": 0.0, "pos_y": 0.0})
Expand All @@ -163,8 +164,8 @@ def test_distance_edges_add_edges_distance_threshold() -> None:
graph = RustWorkXGraph()

# Register attribute keys
graph.add_node_attr_key("x", 0.0)
graph.add_node_attr_key("y", 0.0)
graph.add_node_attr_key("x", dtype=pl.Float64)
graph.add_node_attr_key("y", dtype=pl.Float64)

# Add nodes at t=0
_ = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 0.0, "y": 0.0})
Expand All @@ -186,8 +187,8 @@ def test_distance_edges_add_edges_multiple_timepoints(n_workers: int) -> None:
graph = RustWorkXGraph()

# Register attribute keys
graph.add_node_attr_key("x", 0.0)
graph.add_node_attr_key("y", 0.0)
graph.add_node_attr_key("x", dtype=pl.Float64)
graph.add_node_attr_key("y", dtype=pl.Float64)

# Add nodes at multiple timepoints
for t in range(3):
Expand All @@ -209,8 +210,8 @@ def test_distance_edges_add_edges_custom_weight_key() -> None:
graph = RustWorkXGraph()

# Register attribute keys
graph.add_node_attr_key("x", 0.0)
graph.add_node_attr_key("y", 0.0)
graph.add_node_attr_key("x", dtype=pl.Float64)
graph.add_node_attr_key("y", dtype=pl.Float64)

# Add nodes at t=0
_ = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 0.0, "y": 0.0})
Expand All @@ -237,8 +238,8 @@ def test_distance_edges_n_neighbors_limit() -> None:
graph = RustWorkXGraph()

# Register attribute keys
graph.add_node_attr_key("x", 0.0)
graph.add_node_attr_key("y", 0.0)
graph.add_node_attr_key("x", dtype=pl.Float64)
graph.add_node_attr_key("y", dtype=pl.Float64)

# Add many nodes at t=0
for i in range(5):
Expand All @@ -264,8 +265,8 @@ def test_distance_edges_add_edges_with_delta_t(n_workers: int) -> None:
graph = RustWorkXGraph()

# Register attribute keys
graph.add_node_attr_key("x", 0.0)
graph.add_node_attr_key("y", 0.0)
graph.add_node_attr_key("x", dtype=pl.Float64)
graph.add_node_attr_key("y", dtype=pl.Float64)

# Add nodes at t=0, t=1, t=2
node_0 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 0.0, "y": 0.0})
Expand Down
27 changes: 14 additions & 13 deletions src/tracksdata/edges/_test/test_generic_edges.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
import polars as pl

from tracksdata.constants import DEFAULT_ATTR_KEYS
from tracksdata.edges import GenericFuncEdgeAttrs
Expand Down Expand Up @@ -41,8 +42,8 @@ def test_generic_edges_add_weights_single_attr_key() -> None:
graph = RustWorkXGraph()

# Register attribute keys
graph.add_node_attr_key("x", 0.0)
graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, 0.0)
graph.add_node_attr_key("x", dtype=pl.Float64)
graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, dtype=pl.Float64)

# Add nodes at time 0
node0 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 1.0})
Expand Down Expand Up @@ -78,9 +79,9 @@ def test_generic_edges_add_weights_multiple_attr_keys() -> None:
graph = RustWorkXGraph()

# Register attribute keys
graph.add_node_attr_key("x", 0.0)
graph.add_node_attr_key("y", 0.0)
graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, 0.0)
graph.add_node_attr_key("x", dtype=pl.Float64)
graph.add_node_attr_key("y", dtype=pl.Float64)
graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, dtype=pl.Float64)

# Add nodes at time 0
node0 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 0.0, "y": 0.0})
Expand Down Expand Up @@ -112,8 +113,8 @@ def test_generic_edges_add_weights_all_time_points() -> None:
graph = RustWorkXGraph()

# Register attribute keys
graph.add_node_attr_key("x", 0.0)
graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, 0.0)
graph.add_node_attr_key("x", dtype=pl.Float64)
graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, dtype=pl.Float64)

# Add nodes at different time points
node0_t0 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 1.0})
Expand Down Expand Up @@ -142,7 +143,7 @@ def test_generic_edges_no_edges_at_time_point() -> None:
graph = RustWorkXGraph()

# Register attribute keys
graph.add_node_attr_key("x", 0.0)
graph.add_node_attr_key("x", dtype=pl.Float64)

# Add nodes but no edges at time 0
graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 1.0})
Expand All @@ -163,8 +164,8 @@ def test_generic_edges_creates_output_key() -> None:
graph = RustWorkXGraph()

# Register attribute keys
graph.add_node_attr_key("x", 0.0)
graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, 0.0)
graph.add_node_attr_key("x", dtype=pl.Float64)
graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, dtype=pl.Float64)

# Add nodes and edge
node0 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 1.0})
Expand All @@ -191,9 +192,9 @@ def test_generic_edges_dict_input_function() -> None:
graph = RustWorkXGraph()

# Register attribute keys
graph.add_node_attr_key("value", 0.0)
graph.add_node_attr_key("weight", 0.0)
graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, 0.0)
graph.add_node_attr_key("value", dtype=pl.Float64)
graph.add_node_attr_key("weight", dtype=pl.Float64)
graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, dtype=pl.Float64)

# Add nodes
node0 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "value": 10.0, "weight": 2.0})
Expand Down
Loading
Loading