From 715d80f37ee8a28fa0bdef4f5ccc1dee8a672dd7 Mon Sep 17 00:00:00 2001 From: Jordao Bragantini Date: Tue, 20 Jan 2026 09:47:37 -0800 Subject: [PATCH 01/14] Refactor add_node_attr_key/add_edge_attr_key with dtype parameter - Add AttrSchema dataclass with defensive copying - Create process_attr_key_args() helper to eliminate duplication - Refactor dtype inference and SQLAlchemy mapping to use dictionaries - Add overloaded method signatures for convenience and schema modes - Replace _node/edge_attr_keys lists with _attr_schemas dicts - Add comprehensive test suite (16 tests passing) Breaking change: dtype parameter now required. Next phase will update all calling sites throughout the codebase. --- src/tracksdata/graph/_base_graph.py | 159 ++++++++- src/tracksdata/graph/_graph_view.py | 51 ++- src/tracksdata/graph/_rustworkx_graph.py | 115 +++++-- src/tracksdata/graph/_sql_graph.py | 163 ++++++--- .../graph/_test/test_attr_key_dtype.py | 225 +++++++++++++ src/tracksdata/utils/_dtypes.py | 308 ++++++++++++++++++ 6 files changed, 939 insertions(+), 82 deletions(-) create mode 100644 src/tracksdata/graph/_test/test_attr_key_dtype.py diff --git a/src/tracksdata/graph/_base_graph.py b/src/tracksdata/graph/_base_graph.py index 3bb44411..7237047a 100644 --- a/src/tracksdata/graph/_base_graph.py +++ b/src/tracksdata/graph/_base_graph.py @@ -19,6 +19,7 @@ from tracksdata.constants import DEFAULT_ATTR_KEYS from tracksdata.utils._cache import cache_method from tracksdata.utils._dtypes import ( + AttrSchema, column_to_numpy, infer_default_value, polars_dtype_to_numpy_dtype, @@ -637,11 +638,93 @@ def edge_attr_keys(self) -> list[str]: Get the keys of the attributes of the edges. """ + @overload + @abc.abstractmethod + def add_node_attr_key(self, schema: AttrSchema) -> None: ... + + @overload + @abc.abstractmethod + def add_node_attr_key( + self, + key: str, + dtype: pl.DataType, + default_value: Any = None, + ) -> None: ... + @abc.abstractmethod - def add_node_attr_key(self, key: str, default_value: Any) -> None: + def add_node_attr_key( + self, + key_or_schema: str | AttrSchema, + dtype: pl.DataType | None = None, + default_value: Any = None, + ) -> None: """ Add a new attribute key to the graph. - All existing nodes will have the default value for the new attribute key. + + All existing nodes will have the specified default value for the new attribute key. + + Parameters + ---------- + key_or_schema : str | AttrSchema + Either the attribute key name (str) or an AttrSchema object containing + the key, dtype, and default value. + dtype : pl.DataType, optional + The polars data type for this attribute. Required when key_or_schema is a string. + If provided with default_value, compatibility will be validated. + default_value : Any, optional + The default value for existing nodes. If None and dtype is provided, + a default will be inferred from the dtype based on these rules: + - Unsigned integers (pl.UInt*) → 0 + - Signed integers (pl.Int*) → -1 + - Floats (pl.Float*) → -1.0 + - Boolean → False + - String → "" + - Arrays (pl.Array) → np.zeros with correct shape and dtype + - Objects/Lists → None + + Raises + ------ + TypeError + If dtype is not provided when using string key. + ValueError + If the attribute key already exists or if default_value and dtype are incompatible. + + Examples + -------- + Add attribute with dtype only (default inferred): + + ```python + import polars as pl + + graph.add_node_attr_key("count", pl.UInt32) # default=0 + ``` + + Add attribute with dtype and custom default: + + ```python + graph.add_node_attr_key("score", pl.Float64, default_value=-99.0) + ``` + + Add array attribute (zeros default): + + ```python + graph.add_node_attr_key("bbox", pl.Array(pl.Float64, 4)) + # default=np.zeros(4, dtype=float64) + ``` + + Using AttrSchema for reusability: + + ```python + from tracksdata.utils import AttrSchema + + schema = AttrSchema(key="intensity", dtype=pl.Float64) + graph.add_node_attr_key(schema) + ``` + + See Also + -------- + remove_node_attr_key : Remove an attribute key from the graph + add_edge_attr_key : Add an edge attribute key """ @abc.abstractmethod @@ -655,11 +738,79 @@ def remove_node_attr_key(self, key: str) -> None: The attribute key to remove. """ + @overload + @abc.abstractmethod + def add_edge_attr_key(self, schema: AttrSchema) -> None: ... + + @overload + @abc.abstractmethod + def add_edge_attr_key( + self, + key: str, + dtype: pl.DataType, + default_value: Any = None, + ) -> None: ... + @abc.abstractmethod - def add_edge_attr_key(self, key: str, default_value: Any) -> None: + def add_edge_attr_key( + self, + key_or_schema: str | AttrSchema, + dtype: pl.DataType | None = None, + default_value: Any = None, + ) -> None: """ Add a new attribute key to the graph. - All existing edges will have the default value for the new attribute key. + + All existing edges will have the specified default value for the new attribute key. + + Parameters + ---------- + key_or_schema : str | AttrSchema + Either the attribute key name (str) or an AttrSchema object containing + the key, dtype, and default value. + dtype : pl.DataType, optional + The polars data type for this attribute. Required when key_or_schema is a string. + If provided with default_value, compatibility will be validated. + default_value : Any, optional + The default value for existing edges. If None and dtype is provided, + a default will be inferred from the dtype. See add_node_attr_key for inference rules. + + Raises + ------ + TypeError + If dtype is not provided when using string key. + ValueError + If the attribute key already exists or if default_value and dtype are incompatible. + + Examples + -------- + Add edge attribute with dtype only: + + ```python + import polars as pl + + graph.add_edge_attr_key("weight", pl.Float64) # default=-1.0 + ``` + + Add edge attribute with custom default: + + ```python + graph.add_edge_attr_key("distance", pl.Float64, default_value=0.0) + ``` + + Using AttrSchema: + + ```python + from tracksdata.utils import AttrSchema + + schema = AttrSchema(key="cost", dtype=pl.Float64, default_value=1.0) + graph.add_edge_attr_key(schema) + ``` + + See Also + -------- + remove_edge_attr_key : Remove an edge attribute key from the graph + add_node_attr_key : Add a node attribute key """ @abc.abstractmethod diff --git a/src/tracksdata/graph/_graph_view.py b/src/tracksdata/graph/_graph_view.py index 33ec5332..3e35978f 100644 --- a/src/tracksdata/graph/_graph_view.py +++ b/src/tracksdata/graph/_graph_view.py @@ -11,6 +11,7 @@ from tracksdata.graph._mapped_graph_mixin import MappedGraphMixin from tracksdata.graph._rustworkx_graph import IndexedRXGraph, RustWorkXGraph, RXFilter from tracksdata.graph.filters._indexed_filter import IndexRXFilter +from tracksdata.utils._dtypes import AttrSchema from tracksdata.utils._signal import is_signal_on @@ -235,16 +236,33 @@ def node_attr_keys(self) -> list[str]: def edge_attr_keys(self) -> list[str]: return self._root.edge_attr_keys() if self._edge_attr_keys is None else self._edge_attr_keys - def add_node_attr_key(self, key: str, default_value: Any) -> None: - self._root.add_node_attr_key(key, default_value) + def add_node_attr_key( + self, + key_or_schema: str | AttrSchema, + dtype: pl.DataType | None = None, + default_value: Any = None, + ) -> None: + # Delegate to root with all parameters (root handles overloading) + self._root.add_node_attr_key(key_or_schema, dtype, default_value) + + # Extract key for local tracking + if isinstance(key_or_schema, AttrSchema): + key = key_or_schema.key + else: + key = key_or_schema + if self._node_attr_keys is not None: self._node_attr_keys.append(key) - # because attributes are passed by reference, we need don't need if both are rustworkx graphs + + # Sync logic if not self._is_root_rx_graph: if self.sync: + # Get the schema from root to get the actual default value used + schema = self._root._node_attr_schemas[key] + # Apply to local rx_graph rx_graph = self.rx_graph for node_id in rx_graph.node_indices(): - rx_graph[node_id][key] = default_value + rx_graph[node_id][key] = schema.default_value else: self._out_of_sync = True @@ -260,15 +278,32 @@ def remove_node_attr_key(self, key: str) -> None: else: self._out_of_sync = True - def add_edge_attr_key(self, key: str, default_value: Any) -> None: - self._root.add_edge_attr_key(key, default_value) + def add_edge_attr_key( + self, + key_or_schema: str | AttrSchema, + dtype: pl.DataType | None = None, + default_value: Any = None, + ) -> None: + # Delegate to root with all parameters (root handles overloading) + self._root.add_edge_attr_key(key_or_schema, dtype, default_value) + + # Extract key for local tracking + if isinstance(key_or_schema, AttrSchema): + key = key_or_schema.key + else: + key = key_or_schema + if self._edge_attr_keys is not None: self._edge_attr_keys.append(key) - # because attributes are passed by reference, we need don't need if both are rustworkx graphs + + # Sync logic if not self._is_root_rx_graph: if self.sync: + # Get the schema from root to get the actual default value used + schema = self._root._edge_attr_schemas[key] + # Apply to local rx_graph for _, _, edge_attr in self.rx_graph.weighted_edge_list(): - edge_attr[key] = default_value + edge_attr[key] = schema.default_value else: self._out_of_sync = True diff --git a/src/tracksdata/graph/_rustworkx_graph.py b/src/tracksdata/graph/_rustworkx_graph.py index c379ab1f..11c792e0 100644 --- a/src/tracksdata/graph/_rustworkx_graph.py +++ b/src/tracksdata/graph/_rustworkx_graph.py @@ -14,6 +14,7 @@ from tracksdata.graph.filters._base_filter import BaseFilter from tracksdata.utils._cache import cache_method from tracksdata.utils._dataframe import unpack_array_attrs +from tracksdata.utils._dtypes import AttrSchema, process_attr_key_args from tracksdata.utils._logging import LOG from tracksdata.utils._signal import is_signal_on @@ -317,15 +318,22 @@ def __init__(self, rx_graph: rx.PyDiGraph | None = None) -> None: super().__init__() self._time_to_nodes: dict[int, list[int]] = {} - self._node_attr_keys: list[str] = [] - self._edge_attr_keys: list[str] = [] + self._node_attr_schemas: dict[str, AttrSchema] = {} + self._edge_attr_schemas: dict[str, AttrSchema] = {} self._overlaps: list[list[int, 2]] = [] + # Add default node attributes with inferred schemas + self._node_attr_schemas[DEFAULT_ATTR_KEYS.NODE_ID] = AttrSchema( + key=DEFAULT_ATTR_KEYS.NODE_ID, + dtype=pl.Int64, + ) + self._node_attr_schemas[DEFAULT_ATTR_KEYS.T] = AttrSchema( + key=DEFAULT_ATTR_KEYS.T, + dtype=pl.Int64, + ) + if rx_graph is None: self._graph = rx.PyDiGraph(attrs={}) - self._node_attr_keys.append(DEFAULT_ATTR_KEYS.NODE_ID) - self._node_attr_keys.append(DEFAULT_ATTR_KEYS.T) - else: self._graph = rx_graph @@ -341,9 +349,8 @@ def __init__(self, rx_graph: rx.PyDiGraph | None = None) -> None: "old_attrs": self._graph.attrs, } - unique_node_attr_keys = set() - unique_edge_attr_keys = set() - + # Process nodes: build time index and infer schemas + first_node_attrs = None for node_id in self._graph.node_indices(): node_attrs = self._graph[node_id] try: @@ -355,15 +362,35 @@ def __init__(self, rx_graph: rx.PyDiGraph | None = None) -> None: self._time_to_nodes.setdefault(int(t), []).append(node_id) - unique_node_attr_keys.update(node_attrs.keys()) + # Store first node's attrs to infer schemas + if first_node_attrs is None: + first_node_attrs = node_attrs + + # Infer node schemas from first node + if first_node_attrs is not None: + for key, value in first_node_attrs.items(): + if key == DEFAULT_ATTR_KEYS.NODE_ID: + continue + dtype = pl.Series([value]).dtype + self._node_attr_schemas[key] = AttrSchema(key=key, dtype=dtype) + # Process edges: set edge IDs and infer schemas edge_idx_map = self._graph.edge_index_map() + first_edge_attrs = None for edge_idx, (_, _, attr) in edge_idx_map.items(): - unique_edge_attr_keys.update(attr.keys()) attr[DEFAULT_ATTR_KEYS.EDGE_ID] = edge_idx - - self._node_attr_keys = [DEFAULT_ATTR_KEYS.NODE_ID, *unique_node_attr_keys] - self._edge_attr_keys = list(unique_edge_attr_keys) + # Store first edge's attrs to infer schemas + if first_edge_attrs is None: + first_edge_attrs = attr + + # Infer edge schemas from first edge + if first_edge_attrs is not None: + for key, value in first_edge_attrs.items(): + # TODO: check if EDGE_SOURCE and EDGE_TARGET should be also ignored or in the schema + if key == DEFAULT_ATTR_KEYS.EDGE_ID: + continue + dtype = pl.Series([value]).dtype + self._edge_attr_schemas[key] = AttrSchema(key=key, dtype=dtype) @property def rx_graph(self) -> rx.PyDiGraph: @@ -864,33 +891,44 @@ def node_attr_keys(self) -> list[str]: """ Get the keys of the attributes of the nodes. """ - return self._node_attr_keys.copy() + return list(self._node_attr_schemas.keys()) def edge_attr_keys(self) -> list[str]: """ Get the keys of the attributes of the edges. """ - return self._edge_attr_keys.copy() + return list(self._edge_attr_schemas.keys()) - def add_node_attr_key(self, key: str, default_value: Any) -> None: + def add_node_attr_key( + self, + key_or_schema: str | AttrSchema, + dtype: pl.DataType | None = None, + default_value: Any = None, + ) -> None: """ Add a new attribute key to the graph. All existing nodes will have the default value for the new attribute key. Parameters ---------- - key : str - The key of the new attribute. - default_value : Any + key_or_schema : str | AttrSchema + Either the key name (str) or an AttrSchema object containing key, dtype, and default_value. + dtype : pl.DataType | None + The polars data type for this attribute. Required when key_or_schema is a string. + default_value : Any, optional The default value for existing nodes for the new attribute key. + If None, will be inferred from dtype. """ - if key in self.node_attr_keys(): - raise ValueError(f"Attribute key {key} already exists") + # Process arguments and create validated schema + schema = process_attr_key_args(key_or_schema, dtype, default_value, self._node_attr_schemas) + + # Store schema + self._node_attr_schemas[schema.key] = schema - self._node_attr_keys.append(key) + # Add to all existing nodes rx_graph = self.rx_graph for node_id in rx_graph.node_indices(): - rx_graph[node_id][key] = default_value + rx_graph[node_id][schema.key] = schema.default_value def remove_node_attr_key(self, key: str) -> None: """ @@ -902,28 +940,39 @@ def remove_node_attr_key(self, key: str) -> None: if key in (DEFAULT_ATTR_KEYS.NODE_ID, DEFAULT_ATTR_KEYS.T): raise ValueError(f"Cannot remove required node attribute key {key}") - self._node_attr_keys.remove(key) + del self._node_attr_schemas[key] for node_attr in self.rx_graph.nodes(): node_attr.pop(key, None) - def add_edge_attr_key(self, key: str, default_value: Any) -> None: + def add_edge_attr_key( + self, + key_or_schema: str | AttrSchema, + dtype: pl.DataType | None = None, + default_value: Any = None, + ) -> None: """ Add a new attribute key to the graph. All existing edges will have the default value for the new attribute key. Parameters ---------- - key : str - The key of the new attribute. - default_value : Any + key_or_schema : str | AttrSchema + Either the key name (str) or an AttrSchema object containing key, dtype, and default_value. + dtype : pl.DataType | None + The polars data type for this attribute. Required when key_or_schema is a string. + default_value : Any, optional The default value for existing edges for the new attribute key. + If None, will be inferred from dtype. """ - if key in self.edge_attr_keys(): - raise ValueError(f"Attribute key {key} already exists") + # Process arguments and create validated schema + schema = process_attr_key_args(key_or_schema, dtype, default_value, self._edge_attr_schemas) + + # Store schema + self._edge_attr_schemas[schema.key] = schema - self._edge_attr_keys.append(key) + # Add to all existing edges for edge_attr in self.rx_graph.edges(): - edge_attr[key] = default_value + edge_attr[schema.key] = schema.default_value def remove_edge_attr_key(self, key: str) -> None: """ @@ -932,7 +981,7 @@ def remove_edge_attr_key(self, key: str) -> None: if key not in self.edge_attr_keys(): raise ValueError(f"Edge attribute key {key} does not exist") - self._edge_attr_keys.remove(key) + del self._edge_attr_schemas[key] for edge_attr in self.rx_graph.edges(): edge_attr.pop(key, None) diff --git a/src/tracksdata/graph/_sql_graph.py b/src/tracksdata/graph/_sql_graph.py index 0fdc22f7..859f5136 100644 --- a/src/tracksdata/graph/_sql_graph.py +++ b/src/tracksdata/graph/_sql_graph.py @@ -9,7 +9,6 @@ import rustworkx as rx import sqlalchemy as sa from polars._typing import SchemaDict -from polars.datatypes.convert import numpy_char_code_to_dtype from sqlalchemy.orm import DeclarativeBase, Session, aliased, load_only from sqlalchemy.sql.type_api import TypeEngine @@ -19,6 +18,12 @@ from tracksdata.graph.filters._base_filter import BaseFilter from tracksdata.utils._cache import cache_method from tracksdata.utils._dataframe import unpack_array_attrs, unpickle_bytes_columns +from tracksdata.utils._dtypes import ( + AttrSchema, + infer_default_value_from_dtype, + polars_dtype_to_sqlalchemy_type, + process_attr_key_args, +) from tracksdata.utils._logging import LOG from tracksdata.utils._signal import is_signal_on @@ -464,14 +469,17 @@ def __init__( # Create unique classes for this instance self._define_schema(overwrite=overwrite) - self._boolean_columns: dict[str, SchemaDict] = {self.Node.__tablename__: {}, self.Edge.__tablename__: {}} - self._array_columns: dict[str, SchemaDict] = {self.Node.__tablename__: {}, self.Edge.__tablename__: {}} + self._node_attr_schemas: dict[str, AttrSchema] = {} + self._edge_attr_schemas: dict[str, AttrSchema] = {} if overwrite: self.Base.metadata.drop_all(self._engine) self.Base.metadata.create_all(self._engine) + # Initialize schemas from existing table columns + self._init_schemas_from_tables() + self._max_id_per_time = {} self._update_max_id_per_time() @@ -543,22 +551,92 @@ class Metadata(Base): self.Overlap = Overlap self.Metadata = Metadata + def _init_schemas_from_tables(self) -> None: + """ + Initialize AttrSchema objects from existing database table columns. + This is used when loading an existing graph from the database. + """ + # Initialize node schemas from Node table columns + for column_name in self.Node.__table__.columns.keys(): + if column_name not in self._node_attr_schemas: + column = self.Node.__table__.columns[column_name] + # Infer polars dtype from SQLAlchemy type + pl_dtype = self._sqlalchemy_type_to_polars_dtype(column.type) + default_value = infer_default_value_from_dtype(pl_dtype) + self._node_attr_schemas[column_name] = AttrSchema( + key=column_name, + dtype=pl_dtype, + default_value=default_value, + ) + + # Initialize edge schemas from Edge table columns + for column_name in self.Edge.__table__.columns.keys(): + # Skip internal edge columns + if column_name in [DEFAULT_ATTR_KEYS.EDGE_ID, DEFAULT_ATTR_KEYS.EDGE_SOURCE, DEFAULT_ATTR_KEYS.EDGE_TARGET]: + continue + if column_name not in self._edge_attr_schemas: + column = self.Edge.__table__.columns[column_name] + # Infer polars dtype from SQLAlchemy type + pl_dtype = self._sqlalchemy_type_to_polars_dtype(column.type) + default_value = infer_default_value_from_dtype(pl_dtype) + self._edge_attr_schemas[column_name] = AttrSchema( + key=column_name, + dtype=pl_dtype, + default_value=default_value, + ) + + def _sqlalchemy_type_to_polars_dtype(self, sa_type: TypeEngine) -> pl.DataType: + """ + Convert a SQLAlchemy type to a polars dtype. + This is a best-effort conversion for loading existing schemas. + """ + if isinstance(sa_type, sa.Boolean): + return pl.Boolean + elif isinstance(sa_type, sa.SmallInteger): + return pl.Int16 + elif isinstance(sa_type, sa.Integer): + return pl.Int32 + elif isinstance(sa_type, sa.BigInteger): + return pl.Int64 + elif isinstance(sa_type, sa.Float): + return pl.Float64 + elif isinstance(sa_type, sa.String | sa.Text): + return pl.String + elif isinstance(sa_type, sa.LargeBinary | sa.PickleType): + # For pickled/binary types, default to Object + # Array types will need to be re-added explicitly + return pl.Object + else: + # Fallback to Object for unknown types + return pl.Object + def _restore_pickled_column_types(self, table: sa.Table) -> None: for column in table.columns: if isinstance(column.type, sa.LargeBinary): column.type = sa.PickleType() def _polars_schema_override(self, table_class: type[DeclarativeBase]) -> SchemaDict: - return { - **self._boolean_columns[table_class.__tablename__], - } + # Get the appropriate schema dict based on table class + if table_class.__tablename__ == self.Node.__tablename__: + schemas = self._node_attr_schemas + else: + schemas = self._edge_attr_schemas + + # Return schema overrides for special types that need explicit casting + return {key: schema.dtype for key, schema in schemas.items() if isinstance(schema.dtype, pl.Boolean)} def _cast_array_columns(self, table_class: type[DeclarativeBase], df: pl.DataFrame) -> pl.DataFrame: - # this operation cannot be done with schema_overrides because they are blobs at the database level + # Get the appropriate schema dict based on table class + if table_class.__tablename__ == self.Node.__tablename__: + schemas = self._node_attr_schemas + else: + schemas = self._edge_attr_schemas + + # Cast array columns (stored as blobs in database) df = df.with_columns( - pl.Series(col, df[col].to_list(), dtype=pl_dtype) - for col, pl_dtype in self._array_columns[table_class.__tablename__].items() - if col in df.columns + pl.Series(key, df[key].to_list(), dtype=schema.dtype) + for key, schema in schemas.items() + if isinstance(schema.dtype, pl.Array) and key in df.columns ) return df @@ -1480,23 +1558,18 @@ def _sqlalchemy_type_inference(self, default_value: Any) -> TypeEngine: def _add_new_column( self, table_class: type[DeclarativeBase], - key: str, - default_value: Any, + schema: AttrSchema, ) -> None: - sa_type = self._sqlalchemy_type_inference(default_value) + # Convert polars dtype to SQLAlchemy type + sa_type = polars_dtype_to_sqlalchemy_type(schema.dtype) - if sa_type == sa.Boolean: - self._boolean_columns[table_class.__tablename__][key] = pl.Boolean + # Handle special cases for default value encoding + default_value = schema.default_value + if isinstance(sa_type, sa.PickleType) and default_value is not None: + # Pickle complex types for database storage + default_value = blob_default(self._engine, cloudpickle.dumps(default_value)) - if sa_type == sa.PickleType and default_value is not None: - if isinstance(default_value, np.ndarray): - self._array_columns[table_class.__tablename__][key] = pl.Array( - numpy_char_code_to_dtype(default_value.dtype.char), default_value.shape - ) - # The following is required for all non-None PickleType columns - default_value = blob_default(self._engine, cloudpickle.dumps(default_value)) # None - - sa_column = sa.Column(key, sa_type, default=default_value) + sa_column = sa.Column(schema.key, sa_type, default=default_value) str_dialect_type = sa_column.type.compile(dialect=self._engine.dialect) @@ -1521,7 +1594,7 @@ def _add_new_column( session.commit() # register the new column in the Node class - setattr(table_class, key, sa_column) + setattr(table_class, schema.key, sa_column) table_class.__table__.append_column(sa_column) def _drop_column(self, table_class: type[DeclarativeBase], key: str) -> None: @@ -1535,11 +1608,20 @@ def _drop_column(self, table_class: type[DeclarativeBase], key: str) -> None: # refresh ORM schema to reflect database changes self._define_schema(overwrite=False) - def add_node_attr_key(self, key: str, default_value: Any) -> None: - if key in self.node_attr_keys(): - raise ValueError(f"Node attribute key {key} already exists") + def add_node_attr_key( + self, + key_or_schema: str | AttrSchema, + dtype: pl.DataType | None = None, + default_value: Any = None, + ) -> None: + # Process arguments and create validated schema + schema = process_attr_key_args(key_or_schema, dtype, default_value, self._node_attr_schemas) + + # Store schema + self._node_attr_schemas[schema.key] = schema - self._add_new_column(self.Node, key, default_value) + # Add column to database + self._add_new_column(self.Node, schema) def remove_node_attr_key(self, key: str) -> None: if key not in self.node_attr_keys(): @@ -1548,23 +1630,30 @@ def remove_node_attr_key(self, key: str) -> None: if key in (DEFAULT_ATTR_KEYS.NODE_ID, DEFAULT_ATTR_KEYS.T): raise ValueError(f"Cannot remove required node attribute key {key}") - self._boolean_columns[self.Node.__tablename__].pop(key, None) - self._array_columns[self.Node.__tablename__].pop(key, None) self._drop_column(self.Node, key) + self._node_attr_schemas.pop(key, None) + + def add_edge_attr_key( + self, + key_or_schema: str | AttrSchema, + dtype: pl.DataType | None = None, + default_value: Any = None, + ) -> None: + # Process arguments and create validated schema + schema = process_attr_key_args(key_or_schema, dtype, default_value, self._edge_attr_schemas) - def add_edge_attr_key(self, key: str, default_value: Any) -> None: - if key in self.edge_attr_keys(): - raise ValueError(f"Edge attribute key {key} already exists") + # Store schema + self._edge_attr_schemas[schema.key] = schema - self._add_new_column(self.Edge, key, default_value) + # Add column to database + self._add_new_column(self.Edge, schema) def remove_edge_attr_key(self, key: str) -> None: if key not in self.edge_attr_keys(): raise ValueError(f"Edge attribute key {key} does not exist") - self._boolean_columns[self.Edge.__tablename__].pop(key, None) - self._array_columns[self.Edge.__tablename__].pop(key, None) self._drop_column(self.Edge, key) + self._edge_attr_schemas.pop(key, None) def num_edges(self) -> int: with Session(self._engine) as session: diff --git a/src/tracksdata/graph/_test/test_attr_key_dtype.py b/src/tracksdata/graph/_test/test_attr_key_dtype.py new file mode 100644 index 00000000..5c0bff70 --- /dev/null +++ b/src/tracksdata/graph/_test/test_attr_key_dtype.py @@ -0,0 +1,225 @@ +"""Tests for dtype parameter in add_node_attr_key and add_edge_attr_key.""" + +import numpy as np +import polars as pl +import pytest + +from tracksdata.graph import RustWorkXGraph +from tracksdata.utils._dtypes import AttrSchema + + +class TestRustWorkXGraphDtype: + """Test dtype functionality in RustWorkXGraph.""" + + def test_add_node_attr_key_with_dtype_only(self): + """Test adding node attribute with dtype only (default value inferred).""" + graph = RustWorkXGraph() + + # Add various dtypes + graph.add_node_attr_key("count", pl.UInt32) + graph.add_node_attr_key("score", pl.Float64) + graph.add_node_attr_key("flag", pl.Boolean) + graph.add_node_attr_key("name", pl.String) + + # Verify schemas are stored + assert "count" in graph._node_attr_schemas + assert graph._node_attr_schemas["count"].dtype == pl.UInt32 + assert graph._node_attr_schemas["count"].default_value == 0 + + assert graph._node_attr_schemas["score"].dtype == pl.Float64 + assert graph._node_attr_schemas["score"].default_value == -1.0 + + assert graph._node_attr_schemas["flag"].dtype == pl.Boolean + assert graph._node_attr_schemas["flag"].default_value is False + + assert graph._node_attr_schemas["name"].dtype == pl.String + assert graph._node_attr_schemas["name"].default_value == "" + + def test_add_node_attr_key_with_dtype_and_default(self): + """Test adding node attribute with both dtype and default value.""" + graph = RustWorkXGraph() + + graph.add_node_attr_key("score", pl.Float64, default_value=0.0) + + assert graph._node_attr_schemas["score"].dtype == pl.Float64 + assert graph._node_attr_schemas["score"].default_value == 0.0 + + def test_add_node_attr_key_with_array_dtype(self): + """Test adding node attribute with array dtype (zeros default).""" + graph = RustWorkXGraph() + + graph.add_node_attr_key("bbox", pl.Array(pl.Float64, 4)) + + assert graph._node_attr_schemas["bbox"].dtype == pl.Array(pl.Float64, 4) + default = graph._node_attr_schemas["bbox"].default_value + assert isinstance(default, np.ndarray) + assert default.shape == (4,) + assert default.dtype == np.float64 + np.testing.assert_array_equal(default, np.zeros(4, dtype=np.float64)) + + def test_add_node_attr_key_with_schema_object(self): + """Test adding node attribute using AttrSchema object.""" + graph = RustWorkXGraph() + + schema = AttrSchema(key="intensity", dtype=pl.Float64) + graph.add_node_attr_key(schema) + + assert "intensity" in graph._node_attr_schemas + assert graph._node_attr_schemas["intensity"].dtype == pl.Float64 + assert graph._node_attr_schemas["intensity"].default_value == -1.0 + + def test_add_node_attr_key_missing_dtype_raises(self): + """Test that missing dtype raises TypeError.""" + graph = RustWorkXGraph() + + with pytest.raises(TypeError, match="dtype is required"): + graph.add_node_attr_key("score") + + def test_add_node_attr_key_duplicate_raises(self): + """Test that adding duplicate key raises ValueError.""" + graph = RustWorkXGraph() + + graph.add_node_attr_key("score", pl.Float64) + + with pytest.raises(ValueError, match="already exists"): + graph.add_node_attr_key("score", pl.Float64) + + def test_add_node_attr_key_incompatible_default_raises(self): + """Test that incompatible dtype and default raises ValueError.""" + graph = RustWorkXGraph() + + with pytest.raises(ValueError, match="incompatible"): + graph.add_node_attr_key("score", pl.Int64, default_value="string") + + def test_defaults_applied_to_existing_nodes(self): + """Test that defaults are applied to existing nodes.""" + graph = RustWorkXGraph() + + # Add a node + node_id = graph.add_node({"t": 0}) + + # Add new attribute + graph.add_node_attr_key("score", pl.Float64) + + # Verify the default was applied + node_attrs = graph.rx_graph[node_id] + assert "score" in node_attrs + assert node_attrs["score"] == -1.0 + + def test_add_edge_attr_key_with_dtype_only(self): + """Test adding edge attribute with dtype only.""" + graph = RustWorkXGraph() + + graph.add_edge_attr_key("weight", pl.Float64) + + assert "weight" in graph._edge_attr_schemas + assert graph._edge_attr_schemas["weight"].dtype == pl.Float64 + assert graph._edge_attr_schemas["weight"].default_value == -1.0 + + def test_add_edge_attr_key_with_schema(self): + """Test adding edge attribute using AttrSchema.""" + graph = RustWorkXGraph() + + schema = AttrSchema(key="distance", dtype=pl.Float64, default_value=0.0) + graph.add_edge_attr_key(schema) + + assert "distance" in graph._edge_attr_schemas + assert graph._edge_attr_schemas["distance"].default_value == 0.0 + + def test_defaults_applied_to_existing_edges(self): + """Test that defaults are applied to existing edges.""" + graph = RustWorkXGraph() + + # Add nodes and edge + n1 = graph.add_node({"t": 0}) + n2 = graph.add_node({"t": 1}) + graph.add_edge(n1, n2, {}) + + # Add new edge attribute + graph.add_edge_attr_key("weight", pl.Float64, default_value=1.0) + + # Verify the default was applied + edge_attrs = graph.rx_graph.get_edge_data(n1, n2) + assert "weight" in edge_attrs + assert edge_attrs["weight"] == 1.0 + + def test_node_attr_keys_returns_keys(self): + """Test that node_attr_keys returns the correct keys.""" + graph = RustWorkXGraph() + + graph.add_node_attr_key("score", pl.Float64) + graph.add_node_attr_key("count", pl.UInt32) + + keys = graph.node_attr_keys() + assert "node_id" in keys + assert "t" in keys + assert "score" in keys + assert "count" in keys + + def test_edge_attr_keys_returns_keys(self): + """Test that edge_attr_keys returns the correct keys.""" + graph = RustWorkXGraph() + + graph.add_edge_attr_key("weight", pl.Float64) + graph.add_edge_attr_key("distance", pl.Float64) + + keys = graph.edge_attr_keys() + assert "weight" in keys + assert "distance" in keys + + def test_signed_vs_unsigned_int_defaults(self): + """Test that signed and unsigned integers get different defaults.""" + graph = RustWorkXGraph() + + graph.add_node_attr_key("unsigned", pl.UInt32) + graph.add_node_attr_key("signed", pl.Int32) + + assert graph._node_attr_schemas["unsigned"].default_value == 0 + assert graph._node_attr_schemas["signed"].default_value == -1 + + def test_schema_defensive_copy(self): + """Test that passing AttrSchema creates a defensive copy to prevent mutation.""" + graph = RustWorkXGraph() + + # Create a schema object + original_schema = AttrSchema(key="score", dtype=pl.Float64, default_value=1.0) + + # Add it to the graph + graph.add_node_attr_key(original_schema) + + # Verify the schema was stored + stored_schema = graph._node_attr_schemas["score"] + assert stored_schema.key == "score" + assert stored_schema.dtype == pl.Float64 + assert stored_schema.default_value == 1.0 + + # Mutate the original schema + original_schema.default_value = 999.0 + + # Verify the stored schema wasn't affected (defensive copy worked) + assert graph._node_attr_schemas["score"].default_value == 1.0 + assert stored_schema.default_value == 1.0 + + # Verify the original was changed + assert original_schema.default_value == 999.0 + + def test_schema_copy_method(self): + """Test that AttrSchema.copy() creates an independent copy.""" + original = AttrSchema(key="value", dtype=pl.Int32, default_value=42) + + # Create a copy + copied = original.copy() + + # Verify the copy has the same values + assert copied.key == "value" + assert copied.dtype == pl.Int32 + assert copied.default_value == 42 + + # Verify they are different objects + assert copied is not original + + # Mutate the original + original.default_value = -999 + + # Verify the copy wasn't affected + assert copied.default_value == 42 diff --git a/src/tracksdata/utils/_dtypes.py b/src/tracksdata/utils/_dtypes.py index a79b10af..34588da3 100644 --- a/src/tracksdata/utils/_dtypes.py +++ b/src/tracksdata/utils/_dtypes.py @@ -1,7 +1,11 @@ +from __future__ import annotations + +from dataclasses import dataclass from typing import Any import numpy as np import polars as pl +import sqlalchemy as sa from cloudpickle import dumps, loads from polars.datatypes.classes import ( Boolean, @@ -19,6 +23,7 @@ UInt32, UInt64, ) +from sqlalchemy.sql.type_api import TypeEngine _POLARS_DTYPE_TO_NUMPY_DTYPE = { Datetime: np.datetime64, @@ -141,3 +146,306 @@ def infer_default_value(sample_value: Any) -> Any: if isinstance(sample_value, float | np.floating): return -1.0 return None + + +@dataclass +class AttrSchema: + """ + Schema information for a graph attribute. + + Stores both the polars dtype and the default value for an attribute key. + This is used to maintain consistent type information across graph operations + and when converting to polars DataFrames. + + Parameters + ---------- + key : str + The attribute key name. + dtype : pl.DataType + The polars data type for this attribute. + default_value : Any, optional + The default value for this attribute. If None, will be inferred from dtype. + + Examples + -------- + Create a schema with inferred default: + + ```python + schema = AttrSchema(key="count", dtype=pl.UInt32) + # default_value will be 0 + ``` + + Create a schema with custom default: + + ```python + schema = AttrSchema(key="score", dtype=pl.Float64, default_value=-99.0) + ``` + + Create an array schema: + + ```python + schema = AttrSchema(key="bbox", dtype=pl.Array(pl.Float64, 4)) + # default_value will be np.zeros(4, dtype=np.float64) + ``` + """ + + key: str + dtype: pl.DataType + default_value: Any = None + + def __post_init__(self): + """Infer default value if not provided and validate compatibility.""" + if self.default_value is None: + self.default_value = infer_default_value_from_dtype(self.dtype) + else: + validate_default_value_dtype_compatibility(self.default_value, self.dtype) + + def copy(self) -> AttrSchema: + """ + Create a defensive copy of this AttrSchema. + + Returns + ------- + AttrSchema + A new AttrSchema instance with the same key, dtype, and default_value. + + Examples + -------- + ```python + original = AttrSchema(key="score", dtype=pl.Float64, default_value=1.0) + copied = original.copy() + + # Mutating original doesn't affect the copy + original.default_value = 999.0 + assert copied.default_value == 1.0 + ``` + """ + return AttrSchema(key=self.key, dtype=self.dtype, default_value=self.default_value) + + +def process_attr_key_args( + key_or_schema: str | AttrSchema, + dtype: pl.DataType | None, + default_value: Any, + attr_schemas: dict[str, AttrSchema], +) -> AttrSchema: + """ + Process arguments for add_node_attr_key/add_edge_attr_key and return a validated schema. + + This helper function handles both calling patterns (convenience and schema mode), + validates arguments, and ensures the key doesn't already exist. + + Parameters + ---------- + key_or_schema : str | AttrSchema + Either a key name or an AttrSchema object. + dtype : pl.DataType | None + The polars data type (required when key_or_schema is a string). + default_value : Any + The default value (will be inferred from dtype if None). + attr_schemas : dict[str, AttrSchema] + The dictionary of existing attribute schemas (for duplicate check). + + Returns + ------- + AttrSchema + A validated AttrSchema ready to be stored. + + Raises + ------ + TypeError + If dtype is not provided when using string key. + ValueError + If the key already exists or if default_value and dtype are incompatible. + + Examples + -------- + ```python + # Convenience mode + schema = process_attr_key_args("count", pl.UInt32, None, {}) + assert schema.default_value == 0 + + # Schema mode + original = AttrSchema(key="score", dtype=pl.Float64, default_value=1.0) + schema = process_attr_key_args(original, None, None, {}) + assert schema is not original # Defensive copy + ``` + """ + # Handle both calling patterns + if isinstance(key_or_schema, AttrSchema): + # Schema mode: create a defensive copy to avoid mutation bugs + schema = key_or_schema.copy() + key = schema.key + else: + # Convenience mode: build schema from parameters + key = key_or_schema + if dtype is None: + raise TypeError("dtype is required when not using AttrSchema") + + # Determine default_value if not provided + if default_value is None: + default_value = infer_default_value_from_dtype(dtype) + else: + # Validate compatibility if both provided + validate_default_value_dtype_compatibility(default_value, dtype) + + # Create schema + schema = AttrSchema(key=key, dtype=dtype, default_value=default_value) + + # Check key doesn't exist + if key in attr_schemas: + raise ValueError(f"Attribute key {key} already exists") + + return schema + + +# Default value mapping for polars dtypes +DTYPE_DEFAULT_MAP = { + pl.Boolean: False, + pl.UInt8: 0, + pl.UInt16: 0, + pl.UInt32: 0, + pl.UInt64: 0, + pl.Int8: -1, + pl.Int16: -1, + pl.Int32: -1, + pl.Int64: -1, + pl.Float32: -1.0, + pl.Float64: -1.0, + pl.String: "", + pl.Utf8: "", +} + + +def infer_default_value_from_dtype(dtype: pl.DataType) -> Any: + """ + Infer a sensible default value from a polars dtype. + + Parameters + ---------- + dtype : pl.DataType + The polars data type. + + Returns + ------- + Any + A sensible default value for the type. + + Examples + -------- + >>> infer_default_value_from_dtype(pl.Int64) + -1 + >>> infer_default_value_from_dtype(pl.UInt32) + 0 + >>> infer_default_value_from_dtype(pl.Boolean) + False + >>> infer_default_value_from_dtype(pl.Array(pl.Float64, 3)) + array([0., 0., 0.]) + """ + # Handle array types - create zeros with correct shape and dtype + if isinstance(dtype, pl.Array): + inner_dtype = dtype.inner + shape = dtype.size # Use size instead of width (deprecated) + numpy_dtype = polars_dtype_to_numpy_dtype(inner_dtype, allow_sequence=False) + return np.zeros(shape, dtype=numpy_dtype) + + # Handle list types + if isinstance(dtype, pl.List): + return None + + # Use dictionary lookup for standard types + return DTYPE_DEFAULT_MAP.get(dtype, None) + + +# SQLAlchemy type mapping for polars dtypes +_POLARS_TO_SQLALCHEMY_TYPE_MAP = { + # Boolean + pl.Boolean: sa.Boolean, + # Small integer types + pl.Int8: sa.SmallInteger, + pl.UInt8: sa.SmallInteger, + pl.Int16: sa.SmallInteger, + pl.UInt16: sa.SmallInteger, + # Integer types + pl.Int32: sa.Integer, + pl.UInt32: sa.Integer, + # Big integer types + pl.Int64: sa.BigInteger, + pl.UInt64: sa.BigInteger, + # Float types + pl.Float32: sa.Float, + pl.Float64: sa.Float, + # String types + pl.String: sa.String, + pl.Utf8: sa.String, +} + + +def polars_dtype_to_sqlalchemy_type(dtype: pl.DataType) -> TypeEngine: + """ + Convert a polars dtype to SQLAlchemy type. + + Parameters + ---------- + dtype : pl.DataType + The polars data type. + + Returns + ------- + sa.TypeEngine + The corresponding SQLAlchemy type. + + Examples + -------- + >>> polars_dtype_to_sqlalchemy_type(pl.Int64) + + >>> polars_dtype_to_sqlalchemy_type(pl.Boolean) + + """ + # Handle sequence types - use PickleType for storage + if isinstance(dtype, pl.Array | pl.List): + return sa.PickleType() + + # Use dictionary lookup for standard types + sa_type_class = _POLARS_TO_SQLALCHEMY_TYPE_MAP.get(dtype) + if sa_type_class is not None: + return sa_type_class() + + # Object and fallback + return sa.PickleType() + + +def validate_default_value_dtype_compatibility(default_value: Any, dtype: pl.DataType) -> None: + """ + Validate that a default value is compatible with a polars dtype. + + Parameters + ---------- + default_value : Any + The default value to validate. + dtype : pl.DataType + The polars dtype to validate against. + + Raises + ------ + ValueError + If the default value is incompatible with the dtype. + + Examples + -------- + >>> validate_default_value_dtype_compatibility(42, pl.Int64) + # No error + + >>> validate_default_value_dtype_compatibility("string", pl.Int64) + ValueError: default_value 'string' (type: str) is incompatible with dtype Int64... + """ + try: + # Try to create a polars series and cast + s = pl.Series([default_value]) + s.cast(dtype) + except Exception as e: + raise ValueError( + f"default_value {default_value!r} (type: {type(default_value).__name__}) " + f"is incompatible with dtype {dtype}. " + f"Cannot cast to specified type. Error: {e}" + ) from e From e7aa9c7d65d8489455ad8623824b87a7102d4a25 Mon Sep 17 00:00:00 2001 From: Jordao Bragantini Date: Tue, 20 Jan 2026 10:25:21 -0800 Subject: [PATCH 02/14] Refactor: require explicit dtype parameter for add_node_attr_key and add_edge_attr_key This is a breaking change that requires dtype (polars.DataType) to be explicitly specified when adding attribute keys to graphs. Core changes: - Add AttrSchema dataclass to store key, dtype, and default_value together - Add utility functions: infer_default_value_from_dtype, polars_dtype_to_sqlalchemy_type, validate_default_value_dtype_compatibility, and process_attr_key_args - Update BaseGraph abstract methods with overloaded signatures supporting both direct parameters and AttrSchema objects - Replace _node_attr_keys/_edge_attr_keys lists with _node_attr_schemas dicts in RustWorkXGraph and SQLGraph - Remove _boolean_columns tracking in SQLGraph (now uses AttrSchema) - Fix _polars_schema_override to use == instead of isinstance for pl.Boolean check Backend implementations: - RustWorkXGraph: Use process_attr_key_args helper, store AttrSchema objects - SQLGraph: Use process_attr_key_args helper, pass AttrSchema to _add_new_column - GraphView: Delegate to root graph with all parameters preserved - Add dtype inference fallback for complex objects that polars can't parse Updated all callers: - Graph methods: from_other(), match() - Node operators: RandomNodes, RegionPropsNodes, GenericNodesOperator - Edge operators: DistanceEdges, GenericEdgesOperator - Solvers: ILPSolver, NearestNeighborsSolver - IO: numpy_array, ctc --- src/tracksdata/edges/_distance_edges.py | 3 ++- src/tracksdata/edges/_generic_edges.py | 3 ++- src/tracksdata/graph/_base_graph.py | 16 +++++++++------- src/tracksdata/graph/_rustworkx_graph.py | 12 ++++++++++-- src/tracksdata/graph/_sql_graph.py | 2 +- src/tracksdata/io/_ctc.py | 2 +- src/tracksdata/io/_numpy_array.py | 4 ++-- src/tracksdata/nodes/_generic_nodes.py | 9 ++++++++- src/tracksdata/nodes/_random.py | 3 ++- src/tracksdata/nodes/_regionprops.py | 12 ++++++++---- src/tracksdata/solvers/_ilp_solver.py | 4 ++-- .../solvers/_nearest_neighbors_solver.py | 5 +++-- src/tracksdata/utils/_dtypes.py | 4 ++++ 13 files changed, 54 insertions(+), 25 deletions(-) diff --git a/src/tracksdata/edges/_distance_edges.py b/src/tracksdata/edges/_distance_edges.py index 5f4d1d15..180fa89f 100644 --- a/src/tracksdata/edges/_distance_edges.py +++ b/src/tracksdata/edges/_distance_edges.py @@ -2,6 +2,7 @@ from typing import Any import numpy as np +import polars as pl from scipy.spatial import KDTree from tracksdata.attrs import NodeAttr @@ -110,7 +111,7 @@ def _init_edge_attrs(self, graph: BaseGraph) -> None: Initialize the edge attributes for the graph. """ if self.output_key not in graph.edge_attr_keys(): - graph.add_edge_attr_key(self.output_key, default_value=-99999.0) + graph.add_edge_attr_key(self.output_key, pl.Float64, default_value=-99999.0) def _add_edges_per_time( self, diff --git a/src/tracksdata/edges/_generic_edges.py b/src/tracksdata/edges/_generic_edges.py index d1ce1900..77cc3e96 100644 --- a/src/tracksdata/edges/_generic_edges.py +++ b/src/tracksdata/edges/_generic_edges.py @@ -2,6 +2,7 @@ from typing import Any import numpy as np +import polars as pl from tracksdata.attrs import NodeAttr from tracksdata.constants import DEFAULT_ATTR_KEYS @@ -54,7 +55,7 @@ def _init_edge_attrs(self, graph: BaseGraph) -> None: Initialize the edge attributes for the graph. """ if self.output_key not in graph.edge_attr_keys(): - graph.add_edge_attr_key(self.output_key, default_value=-99999.0) + graph.add_edge_attr_key(self.output_key, pl.Float64, default_value=-99999.0) def _edge_attrs_per_time( self, diff --git a/src/tracksdata/graph/_base_graph.py b/src/tracksdata/graph/_base_graph.py index 7237047a..1534d2c9 100644 --- a/src/tracksdata/graph/_base_graph.py +++ b/src/tracksdata/graph/_base_graph.py @@ -21,7 +21,6 @@ from tracksdata.utils._dtypes import ( AttrSchema, column_to_numpy, - infer_default_value, polars_dtype_to_numpy_dtype, ) from tracksdata.utils._logging import LOG @@ -1096,13 +1095,13 @@ def match( ) if matched_node_id_key not in self.node_attr_keys(): - self.add_node_attr_key(matched_node_id_key, -1) + self.add_node_attr_key(matched_node_id_key, pl.Int64, default_value=-1) if match_score_key not in self.node_attr_keys(): - self.add_node_attr_key(match_score_key, 0.0) + self.add_node_attr_key(match_score_key, pl.Float64, default_value=0.0) if matched_edge_mask_key not in self.edge_attr_keys(): - self.add_edge_attr_key(matched_edge_mask_key, False) + self.add_edge_attr_key(matched_edge_mask_key, pl.Boolean, default_value=False) node_ids = functools.reduce(operator.iadd, matching_data["mapped_comp"]) other_ids = functools.reduce(operator.iadd, matching_data["mapped_ref"]) @@ -1172,8 +1171,9 @@ def from_other(cls: type[T], other: "BaseGraph", **kwargs) -> T: for col in node_attrs.columns: if col != DEFAULT_ATTR_KEYS.T: - first_value = node_attrs[col].first() - graph.add_node_attr_key(col, infer_default_value(first_value)) + # Use the dtype from the source DataFrame + dtype = node_attrs[col].dtype + graph.add_node_attr_key(col, dtype) if graph.supports_custom_indices(): new_node_ids = graph.bulk_add_nodes( @@ -1197,7 +1197,9 @@ def from_other(cls: type[T], other: "BaseGraph", **kwargs) -> T: for col in edge_attrs.columns: if col not in [DEFAULT_ATTR_KEYS.EDGE_SOURCE, DEFAULT_ATTR_KEYS.EDGE_TARGET]: - graph.add_edge_attr_key(col, edge_attrs[col].first()) + # Use the dtype from the source DataFrame + dtype = edge_attrs[col].dtype + graph.add_edge_attr_key(col, dtype) edge_attrs = edge_attrs.with_columns( edge_attrs[col].map_elements(node_map.get, return_dtype=pl.Int64).alias(col) diff --git a/src/tracksdata/graph/_rustworkx_graph.py b/src/tracksdata/graph/_rustworkx_graph.py index 11c792e0..ce009c3f 100644 --- a/src/tracksdata/graph/_rustworkx_graph.py +++ b/src/tracksdata/graph/_rustworkx_graph.py @@ -371,7 +371,11 @@ def __init__(self, rx_graph: rx.PyDiGraph | None = None) -> None: for key, value in first_node_attrs.items(): if key == DEFAULT_ATTR_KEYS.NODE_ID: continue - dtype = pl.Series([value]).dtype + try: + dtype = pl.Series([value]).dtype + except (ValueError, TypeError): + # If polars can't infer dtype (e.g., for complex objects), use Object + dtype = pl.Object self._node_attr_schemas[key] = AttrSchema(key=key, dtype=dtype) # Process edges: set edge IDs and infer schemas @@ -389,7 +393,11 @@ def __init__(self, rx_graph: rx.PyDiGraph | None = None) -> None: # TODO: check if EDGE_SOURCE and EDGE_TARGET should be also ignored or in the schema if key == DEFAULT_ATTR_KEYS.EDGE_ID: continue - dtype = pl.Series([value]).dtype + try: + dtype = pl.Series([value]).dtype + except (ValueError, TypeError): + # If polars can't infer dtype (e.g., for complex objects), use Object + dtype = pl.Object self._edge_attr_schemas[key] = AttrSchema(key=key, dtype=dtype) @property diff --git a/src/tracksdata/graph/_sql_graph.py b/src/tracksdata/graph/_sql_graph.py index 859f5136..843badfa 100644 --- a/src/tracksdata/graph/_sql_graph.py +++ b/src/tracksdata/graph/_sql_graph.py @@ -623,7 +623,7 @@ def _polars_schema_override(self, table_class: type[DeclarativeBase]) -> SchemaD schemas = self._edge_attr_schemas # Return schema overrides for special types that need explicit casting - return {key: schema.dtype for key, schema in schemas.items() if isinstance(schema.dtype, pl.Boolean)} + return {key: schema.dtype for key, schema in schemas.items() if schema.dtype == pl.Boolean} def _cast_array_columns(self, table_class: type[DeclarativeBase], df: pl.DataFrame) -> pl.DataFrame: # Get the appropriate schema dict based on table class diff --git a/src/tracksdata/io/_ctc.py b/src/tracksdata/io/_ctc.py index 4ccdfc33..f4112144 100644 --- a/src/tracksdata/io/_ctc.py +++ b/src/tracksdata/io/_ctc.py @@ -220,7 +220,7 @@ def from_ctc( ) # is duplicating an attribute that bad? - graph.add_node_attr_key(DEFAULT_ATTR_KEYS.TRACKLET_ID, -1) + graph.add_node_attr_key(DEFAULT_ATTR_KEYS.TRACKLET_ID, pl.Int64, -1) graph.update_node_attrs( node_ids=nodes_df[DEFAULT_ATTR_KEYS.NODE_ID].to_list(), attrs={ diff --git a/src/tracksdata/io/_numpy_array.py b/src/tracksdata/io/_numpy_array.py index 8601a425..da5b1911 100644 --- a/src/tracksdata/io/_numpy_array.py +++ b/src/tracksdata/io/_numpy_array.py @@ -112,11 +112,11 @@ def from_array( "`tracklet_ids` must have the same length as `positions`. " f"Expected {positions.shape[0]}, got {len(tracklet_ids)}." ) - graph.add_node_attr_key(DEFAULT_ATTR_KEYS.TRACKLET_ID, -1) + graph.add_node_attr_key(DEFAULT_ATTR_KEYS.TRACKLET_ID, pl.Int64, -1) tracklet_ids = tracklet_ids.tolist() for col in spatial_cols: - graph.add_node_attr_key(col, -999_999) + graph.add_node_attr_key(col, pl.Float64, -999_999) node_attrs = [] diff --git a/src/tracksdata/nodes/_generic_nodes.py b/src/tracksdata/nodes/_generic_nodes.py index 343a0c7a..fa88b14e 100644 --- a/src/tracksdata/nodes/_generic_nodes.py +++ b/src/tracksdata/nodes/_generic_nodes.py @@ -2,6 +2,7 @@ from typing import Any, TypeVar import numpy as np +import polars as pl from numpy.typing import NDArray from tracksdata.attrs import NodeAttr @@ -109,7 +110,13 @@ def _init_node_attrs(self, graph: BaseGraph) -> None: Initialize the node attributes for the graph. """ if self.output_key not in graph.node_attr_keys(): - graph.add_node_attr_key(self.output_key, default_value=self.default_value) + # Infer dtype from default_value using polars + if self.default_value is None: + dtype = pl.Object + else: + # Use polars to infer the dtype from the value + dtype = pl.Series([self.default_value]).dtype + graph.add_node_attr_key(self.output_key, dtype, self.default_value) def add_node_attrs( self, diff --git a/src/tracksdata/nodes/_random.py b/src/tracksdata/nodes/_random.py index 48abd995..7f0d5564 100644 --- a/src/tracksdata/nodes/_random.py +++ b/src/tracksdata/nodes/_random.py @@ -2,6 +2,7 @@ from typing import Any, Literal import numpy as np +import polars as pl from tracksdata.constants import DEFAULT_ATTR_KEYS from tracksdata.graph._base_graph import BaseGraph @@ -124,7 +125,7 @@ def add_nodes( # Register each spatial column individually for col in self.spatial_cols: if col not in graph.node_attr_keys(): - graph.add_node_attr_key(col, -999999.0) + graph.add_node_attr_key(col, pl.Float64, -999999.0) if t is None: time_points = range(self.n_time_points) diff --git a/src/tracksdata/nodes/_regionprops.py b/src/tracksdata/nodes/_regionprops.py index 2ee2346d..d8b3cb0c 100644 --- a/src/tracksdata/nodes/_regionprops.py +++ b/src/tracksdata/nodes/_regionprops.py @@ -3,6 +3,7 @@ from typing import Any import numpy as np +import polars as pl from numpy.typing import NDArray from skimage.measure._regionprops import RegionProperties, regionprops from typing_extensions import override @@ -127,18 +128,21 @@ def _init_node_attrs(self, graph: BaseGraph, axis_names: list[str], ndims: int) Initialize the node attributes for the graph. """ if DEFAULT_ATTR_KEYS.MASK not in graph.node_attr_keys(): - graph.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, None) + graph.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, pl.Object, None) if DEFAULT_ATTR_KEYS.BBOX not in graph.node_attr_keys(): - graph.add_node_attr_key(DEFAULT_ATTR_KEYS.BBOX, np.zeros(2 * (ndims - 1), dtype=int)) + bbox_size = 2 * (ndims - 1) + graph.add_node_attr_key( + DEFAULT_ATTR_KEYS.BBOX, pl.Array(pl.Int64, bbox_size), np.zeros(bbox_size, dtype=int) + ) if "label" in self.attr_keys() and "label" not in graph.node_attr_keys(): - graph.add_node_attr_key("label", 0) + graph.add_node_attr_key("label", pl.Int64, 0) # initialize the remaining attribute keys for attr_key in axis_names + self.attr_keys(): if attr_key not in graph.node_attr_keys(): - graph.add_node_attr_key(attr_key, -1.0) + graph.add_node_attr_key(attr_key, pl.Float64, -1.0) def attr_keys(self) -> list[str]: """ diff --git a/src/tracksdata/solvers/_ilp_solver.py b/src/tracksdata/solvers/_ilp_solver.py index 7fe02232..98b1ea6f 100644 --- a/src/tracksdata/solvers/_ilp_solver.py +++ b/src/tracksdata/solvers/_ilp_solver.py @@ -401,7 +401,7 @@ def solve( selected_nodes = [node_id for node_id, var in self._node_vars.items() if solution[var.index] > 0.5] if self.output_key not in graph.node_attr_keys(): - graph.add_node_attr_key(self.output_key, False) + graph.add_node_attr_key(self.output_key, pl.Boolean, default_value=False) elif self.reset: graph.update_node_attrs(attrs={self.output_key: False}) @@ -413,7 +413,7 @@ def solve( selected_edges = [edge_id for edge_id, var in self._edge_vars.items() if solution[var.index] > 0.5] if self.output_key not in graph.edge_attr_keys(): - graph.add_edge_attr_key(self.output_key, False) + graph.add_edge_attr_key(self.output_key, pl.Boolean, default_value=False) elif self.reset: graph.update_edge_attrs(attrs={self.output_key: False}) diff --git a/src/tracksdata/solvers/_nearest_neighbors_solver.py b/src/tracksdata/solvers/_nearest_neighbors_solver.py index 1d74a947..637b9ac4 100644 --- a/src/tracksdata/solvers/_nearest_neighbors_solver.py +++ b/src/tracksdata/solvers/_nearest_neighbors_solver.py @@ -1,4 +1,5 @@ import numpy as np +import polars as pl from numba import njit, typed, types from tracksdata.attrs import Attr, EdgeAttr, ExprInput, NodeAttr @@ -274,7 +275,7 @@ def solve( solution_edges_df = edges_df.filter(solution) if self.output_key not in graph.edge_attr_keys(): - graph.add_edge_attr_key(self.output_key, False) + graph.add_edge_attr_key(self.output_key, pl.Boolean, default_value=False) elif self.reset: graph.update_edge_attrs(attrs={self.output_key: False}) @@ -293,7 +294,7 @@ def solve( ) if self.output_key not in graph.node_attr_keys(): - graph.add_node_attr_key(self.output_key, False) + graph.add_node_attr_key(self.output_key, pl.Boolean, default_value=False) graph.update_node_attrs( node_ids=node_ids, diff --git a/src/tracksdata/utils/_dtypes.py b/src/tracksdata/utils/_dtypes.py index 34588da3..dfba5941 100644 --- a/src/tracksdata/utils/_dtypes.py +++ b/src/tracksdata/utils/_dtypes.py @@ -439,6 +439,10 @@ def validate_default_value_dtype_compatibility(default_value: Any, dtype: pl.Dat >>> validate_default_value_dtype_compatibility("string", pl.Int64) ValueError: default_value 'string' (type: str) is incompatible with dtype Int64... """ + # Skip validation for Object and Binary types - they accept any value + if dtype in (pl.Object, pl.Binary): + return + try: # Try to create a polars series and cast s = pl.Series([default_value]) From 2c7a9e17d93d0fae056f0497b841657e7a7a6976 Mon Sep 17 00:00:00 2001 From: Jordao Bragantini Date: Tue, 20 Jan 2026 10:26:00 -0800 Subject: [PATCH 03/14] Update all tests to use explicit dtype parameter Update all test files to use the new required dtype parameter for add_node_attr_key and add_edge_attr_key methods. Key test fixes: - Add dtype inference logic in test_add_node_attr_key based on value type - Fix bbox attribute calls to use pl.Array(pl.Int64, size) instead of passing numpy arrays as dtype - Add missing polars imports where needed - Fix argument order from (key, False, dtype=pl.Boolean) to (key, pl.Boolean, default_value=False) --- .../array/_test/test_graph_array.py | 69 +++--- .../edges/_test/test_distance_edges.py | 43 ++-- .../edges/_test/test_generic_edges.py | 27 ++- src/tracksdata/edges/_test/test_iou_edges.py | 17 +- src/tracksdata/functional/_test/test_apply.py | 10 +- .../functional/_test/test_labeling.py | 4 +- .../functional/_test/test_napari.py | 7 +- .../graph/_test/test_graph_backends.py | 222 ++++++++++-------- src/tracksdata/graph/_test/test_subgraph.py | 22 +- .../filters/_test/test_spatial_filter.py | 34 +-- src/tracksdata/io/_test/test_ctc_io.py | 13 +- src/tracksdata/metrics/_test/test_matching.py | 68 +++--- .../nodes/_test/test_generic_nodes.py | 21 +- .../solvers/_test/test_ilp_solver.py | 103 ++++---- .../_test/test_nearest_neighbors_solver.py | 55 ++--- 15 files changed, 373 insertions(+), 342 deletions(-) diff --git a/src/tracksdata/array/_test/test_graph_array.py b/src/tracksdata/array/_test/test_graph_array.py index fc91014d..b37f23b2 100644 --- a/src/tracksdata/array/_test/test_graph_array.py +++ b/src/tracksdata/array/_test/test_graph_array.py @@ -1,6 +1,7 @@ from collections.abc import Sequence import numpy as np +import polars as pl import pytest from pytest import fixture @@ -31,9 +32,11 @@ def test_chain_indices() -> None: def test_graph_array_view_init(graph_backend: BaseGraph) -> None: """Test GraphArrayView initialization.""" # Add a attribute key - graph_backend.add_node_attr_key("label", 0) - graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, None) - graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.BBOX, np.asarray([0, 0, 0, 0, 0, 0])) + graph_backend.add_node_attr_key("label", dtype=pl.Int64, default_value=0) + graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, dtype=pl.Object, default_value=None) + graph_backend.add_node_attr_key( + DEFAULT_ATTR_KEYS.BBOX, dtype=pl.Object, default_value=np.asarray([0, 0, 0, 0, 0, 0]) + ) array_view = GraphArrayView(graph=graph_backend, shape=(10, 100, 100), attr_key="label", offset=0) @@ -57,9 +60,11 @@ def test_graph_array_view_init_invalid_attr_key(graph_backend: BaseGraph) -> Non def test_graph_array_view_getitem_empty_time(graph_backend: BaseGraph) -> None: """Test __getitem__ with empty time point (no nodes).""" - graph_backend.add_node_attr_key("label", 0) - graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, None) - graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.BBOX, np.asarray([0, 0, 0, 0, 0, 0])) + graph_backend.add_node_attr_key("label", dtype=pl.Int64, default_value=0) + graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, dtype=pl.Object, default_value=None) + graph_backend.add_node_attr_key( + DEFAULT_ATTR_KEYS.BBOX, dtype=pl.Object, default_value=np.asarray([0, 0, 0, 0, 0, 0]) + ) array_view = GraphArrayView(graph=graph_backend, shape=(10, 100, 100), attr_key="label") @@ -76,11 +81,11 @@ def test_graph_array_view_getitem_with_nodes(graph_backend: BaseGraph) -> None: """Test __getitem__ with nodes at time point.""" # Add attribute keys - graph_backend.add_node_attr_key("label", 0) - graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, None) - graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.BBOX, np.asarray([0, 0, 0, 0])) - graph_backend.add_node_attr_key("y", 0) - graph_backend.add_node_attr_key("x", 0) + graph_backend.add_node_attr_key("label", dtype=pl.Int64, default_value=0) + graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, dtype=pl.Object, default_value=None) + graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.BBOX, dtype=pl.Object, default_value=np.asarray([0, 0, 0, 0])) + graph_backend.add_node_attr_key("y", dtype=pl.Int64, default_value=0) + graph_backend.add_node_attr_key("x", dtype=pl.Int64, default_value=0) # Create a mask mask_data = np.array([[True, True], [True, False]], dtype=bool) @@ -126,11 +131,11 @@ def test_graph_array_view_getitem_multiple_nodes(graph_backend: BaseGraph) -> No """Test __getitem__ with multiple nodes at same time point.""" # Add attribute keys - graph_backend.add_node_attr_key("label", 0) - graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, None) - graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.BBOX, np.asarray([0, 0, 0, 0])) - graph_backend.add_node_attr_key("y", 0) - graph_backend.add_node_attr_key("x", 0) + graph_backend.add_node_attr_key("label", dtype=pl.Int64, default_value=0) + graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, dtype=pl.Object, default_value=None) + graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.BBOX, dtype=pl.Object, default_value=np.asarray([0, 0, 0, 0])) + graph_backend.add_node_attr_key("y", dtype=pl.Int64, default_value=0) + graph_backend.add_node_attr_key("x", dtype=pl.Int64, default_value=0) # Create two masks at different locations mask1_data = np.array([[True, True]], dtype=bool) @@ -181,11 +186,11 @@ def test_graph_array_view_getitem_boolean_dtype(graph_backend: BaseGraph) -> Non """Test __getitem__ with boolean attribute values.""" # Add attribute keys - graph_backend.add_node_attr_key("is_active", False) - graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, None) - graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.BBOX, np.asarray([0, 0, 0, 0])) - graph_backend.add_node_attr_key("y", 0) - graph_backend.add_node_attr_key("x", 0) + graph_backend.add_node_attr_key("is_active", dtype=pl.Boolean, default_value=False) + graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, dtype=pl.Object, default_value=None) + graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.BBOX, dtype=pl.Object, default_value=np.asarray([0, 0, 0, 0])) + graph_backend.add_node_attr_key("y", dtype=pl.Int64, default_value=0) + graph_backend.add_node_attr_key("x", dtype=pl.Int64, default_value=0) # Create a mask mask_data = np.array([[True]], dtype=bool) @@ -218,11 +223,11 @@ def test_graph_array_view_dtype_inference(graph_backend: BaseGraph) -> None: """Test that dtype is properly inferred from data.""" # Add attribute keys - graph_backend.add_node_attr_key("float_label", 0.0) - graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, None) - graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.BBOX, np.asarray([0, 0, 0, 0])) - graph_backend.add_node_attr_key("y", 0) - graph_backend.add_node_attr_key("x", 0) + graph_backend.add_node_attr_key("float_label", dtype=pl.Float64, default_value=0.0) + graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, dtype=pl.Object, default_value=None) + graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.BBOX, dtype=pl.Object, default_value=np.asarray([0, 0, 0, 0])) + graph_backend.add_node_attr_key("y", dtype=pl.Int64, default_value=0) + graph_backend.add_node_attr_key("x", dtype=pl.Int64, default_value=0) # Create a mask mask_data = np.array([[True]], dtype=bool) @@ -340,9 +345,9 @@ def test_graph_array_view_getitem_time_index_nested(multi_node_graph_from_image, def test_graph_array_set_options(graph_backend: BaseGraph) -> None: with Options(gav_chunk_shape=(512, 512), gav_default_dtype=np.int16): - graph_backend.add_node_attr_key("label", 0) - graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, None) - graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.BBOX, np.asarray([0, 0, 0, 0])) + graph_backend.add_node_attr_key("label", dtype=pl.Int64, default_value=0) + graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, dtype=pl.Object, default_value=None) + graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.BBOX, dtype=pl.Object, default_value=np.asarray([0, 0, 0, 0])) array_view = GraphArrayView(graph=graph_backend, shape=(10, 100, 100), attr_key="label") assert array_view.chunk_shape == (512, 512) assert array_view.dtype == np.int16 @@ -363,8 +368,10 @@ def test_graph_array_raise_error_on_non_scalar_attr_key(graph_backend: BaseGraph """Test that GraphArrayView raises error if attr_key values are non-scalar.""" # Add a attribute key - graph_backend.add_node_attr_key("label", np.array([0, 1])) # Non-scalar default value - graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, None) + graph_backend.add_node_attr_key( + "label", dtype=pl.Object, default_value=np.array([0, 1]) + ) # Non-scalar default value + graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, dtype=pl.Object, default_value=None) graph_backend.add_node( { DEFAULT_ATTR_KEYS.T: 0, diff --git a/src/tracksdata/edges/_test/test_distance_edges.py b/src/tracksdata/edges/_test/test_distance_edges.py index 41cd81cb..f91799de 100644 --- a/src/tracksdata/edges/_test/test_distance_edges.py +++ b/src/tracksdata/edges/_test/test_distance_edges.py @@ -1,3 +1,4 @@ +import polars as pl import pytest from tracksdata.constants import DEFAULT_ATTR_KEYS @@ -48,8 +49,8 @@ def test_distance_edges_add_edges_single_timepoint_no_previous() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", 0.0) - graph.add_node_attr_key("y", 0.0) + graph.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) + graph.add_node_attr_key("y", dtype=pl.Float64, default_value=0.0) # Add nodes only at t=1 (no t=0) graph.add_node({DEFAULT_ATTR_KEYS.T: 1, "x": 0.0, "y": 0.0}) @@ -67,8 +68,8 @@ def test_distance_edges_add_edges_single_timepoint_no_current() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", 0.0) - graph.add_node_attr_key("y", 0.0) + graph.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) + graph.add_node_attr_key("y", dtype=pl.Float64, default_value=0.0) # Add nodes only at t=0 (no t=1) graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 0.0, "y": 0.0}) @@ -86,8 +87,8 @@ def test_distance_edges_add_edges_2d_coordinates() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", 0.0) - graph.add_node_attr_key("y", 0.0) + graph.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) + graph.add_node_attr_key("y", dtype=pl.Float64, default_value=0.0) # Add nodes at t=0 _ = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 0.0, "y": 0.0}) @@ -114,9 +115,9 @@ def test_distance_edges_add_edges_3d_coordinates() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", 0.0) - graph.add_node_attr_key("y", 0.0) - graph.add_node_attr_key("z", 0.0) + graph.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) + graph.add_node_attr_key("y", dtype=pl.Float64, default_value=0.0) + graph.add_node_attr_key("z", dtype=pl.Float64, default_value=0.0) # Add nodes at t=0 _ = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 0.0, "y": 0.0, "z": 0.0}) @@ -139,8 +140,8 @@ def test_distance_edges_add_edges_custom_attr_keys() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("pos_x", 0.0) - graph.add_node_attr_key("pos_y", 0.0) + graph.add_node_attr_key("pos_x", dtype=pl.Float64, default_value=0.0) + graph.add_node_attr_key("pos_y", dtype=pl.Float64, default_value=0.0) # Add nodes at t=0 _ = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "pos_x": 0.0, "pos_y": 0.0}) @@ -163,8 +164,8 @@ def test_distance_edges_add_edges_distance_threshold() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", 0.0) - graph.add_node_attr_key("y", 0.0) + graph.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) + graph.add_node_attr_key("y", dtype=pl.Float64, default_value=0.0) # Add nodes at t=0 _ = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 0.0, "y": 0.0}) @@ -186,8 +187,8 @@ def test_distance_edges_add_edges_multiple_timepoints(n_workers: int) -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", 0.0) - graph.add_node_attr_key("y", 0.0) + graph.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) + graph.add_node_attr_key("y", dtype=pl.Float64, default_value=0.0) # Add nodes at multiple timepoints for t in range(3): @@ -209,8 +210,8 @@ def test_distance_edges_add_edges_custom_weight_key() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", 0.0) - graph.add_node_attr_key("y", 0.0) + graph.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) + graph.add_node_attr_key("y", dtype=pl.Float64, default_value=0.0) # Add nodes at t=0 _ = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 0.0, "y": 0.0}) @@ -237,8 +238,8 @@ def test_distance_edges_n_neighbors_limit() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", 0.0) - graph.add_node_attr_key("y", 0.0) + graph.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) + graph.add_node_attr_key("y", dtype=pl.Float64, default_value=0.0) # Add many nodes at t=0 for i in range(5): @@ -264,8 +265,8 @@ def test_distance_edges_add_edges_with_delta_t(n_workers: int) -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", 0.0) - graph.add_node_attr_key("y", 0.0) + graph.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) + graph.add_node_attr_key("y", dtype=pl.Float64, default_value=0.0) # Add nodes at t=0, t=1, t=2 node_0 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 0.0, "y": 0.0}) diff --git a/src/tracksdata/edges/_test/test_generic_edges.py b/src/tracksdata/edges/_test/test_generic_edges.py index e84b233b..e1c0ae4e 100644 --- a/src/tracksdata/edges/_test/test_generic_edges.py +++ b/src/tracksdata/edges/_test/test_generic_edges.py @@ -1,4 +1,5 @@ import numpy as np +import polars as pl from tracksdata.constants import DEFAULT_ATTR_KEYS from tracksdata.edges import GenericFuncEdgeAttrs @@ -41,8 +42,8 @@ def test_generic_edges_add_weights_single_attr_key() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", 0.0) - graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, 0.0) + graph.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) + graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, dtype=pl.Float64, default_value=0.0) # Add nodes at time 0 node0 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 1.0}) @@ -78,9 +79,9 @@ def test_generic_edges_add_weights_multiple_attr_keys() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", 0.0) - graph.add_node_attr_key("y", 0.0) - graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, 0.0) + graph.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) + graph.add_node_attr_key("y", dtype=pl.Float64, default_value=0.0) + graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, dtype=pl.Float64, default_value=0.0) # Add nodes at time 0 node0 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 0.0, "y": 0.0}) @@ -112,8 +113,8 @@ def test_generic_edges_add_weights_all_time_points() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", 0.0) - graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, 0.0) + graph.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) + graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, dtype=pl.Float64, default_value=0.0) # Add nodes at different time points node0_t0 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 1.0}) @@ -142,7 +143,7 @@ def test_generic_edges_no_edges_at_time_point() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", 0.0) + graph.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) # Add nodes but no edges at time 0 graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 1.0}) @@ -163,8 +164,8 @@ def test_generic_edges_creates_output_key() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", 0.0) - graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, 0.0) + graph.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) + graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, dtype=pl.Float64, default_value=0.0) # Add nodes and edge node0 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 1.0}) @@ -191,9 +192,9 @@ def test_generic_edges_dict_input_function() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("value", 0.0) - graph.add_node_attr_key("weight", 0.0) - graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, 0.0) + graph.add_node_attr_key("value", dtype=pl.Float64, default_value=0.0) + graph.add_node_attr_key("weight", dtype=pl.Float64, default_value=0.0) + graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, dtype=pl.Float64, default_value=0.0) # Add nodes node0 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "value": 10.0, "weight": 2.0}) diff --git a/src/tracksdata/edges/_test/test_iou_edges.py b/src/tracksdata/edges/_test/test_iou_edges.py index 8fddde60..cfdb6737 100644 --- a/src/tracksdata/edges/_test/test_iou_edges.py +++ b/src/tracksdata/edges/_test/test_iou_edges.py @@ -1,4 +1,5 @@ import numpy as np +import polars as pl import pytest from tracksdata.constants import DEFAULT_ATTR_KEYS @@ -32,8 +33,8 @@ def test_iou_edges_add_weights(n_workers: int) -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, None) - graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, 0.0) + graph.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, dtype=pl.Object, default_value=None) + graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, dtype=pl.Float64, default_value=0.0) # Create test masks mask1_data = np.array([[True, True], [True, False]], dtype=bool) @@ -82,8 +83,8 @@ def test_iou_edges_no_overlap() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, None) - graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, 0.0) + graph.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, dtype=pl.Object, default_value=None) + graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, dtype=pl.Float64, default_value=0.0) # Create non-overlapping masks mask1_data = np.array([[True, True], [False, False]], dtype=bool) @@ -121,8 +122,8 @@ def test_iou_edges_perfect_overlap() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, None) - graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, 0.0) + graph.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, dtype=pl.Object, default_value=None) + graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, dtype=pl.Float64, default_value=0.0) # Create identical masks mask_data = np.array([[True, True], [True, False]], dtype=bool) @@ -157,8 +158,8 @@ def test_iou_edges_custom_mask_key() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("custom_mask", None) - graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, 0.0) + graph.add_node_attr_key("custom_mask", dtype=pl.Object, default_value=None) + graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, dtype=pl.Float64, default_value=0.0) # Create test masks mask1_data = np.array([[True, True], [True, True]], dtype=bool) diff --git a/src/tracksdata/functional/_test/test_apply.py b/src/tracksdata/functional/_test/test_apply.py index 12750d17..fb5df478 100644 --- a/src/tracksdata/functional/_test/test_apply.py +++ b/src/tracksdata/functional/_test/test_apply.py @@ -10,9 +10,9 @@ def sample_graph() -> RustWorkXGraph: """Create a sample graph with spatial nodes for testing.""" graph = RustWorkXGraph() - graph.add_node_attr_key("z", 0) - graph.add_node_attr_key("y", 0) - graph.add_node_attr_key("x", 0) + graph.add_node_attr_key("z", dtype=pl.Int64, default_value=0) + graph.add_node_attr_key("y", dtype=pl.Int64, default_value=0) + graph.add_node_attr_key("x", dtype=pl.Int64, default_value=0) # Add nodes in a grid pattern nodes = [ @@ -112,8 +112,8 @@ def test_apply_tiled_default_attrs(sample_graph: RustWorkXGraph) -> None: def test_apply_tiled_2d_tiling() -> None: """Test apply_tiled with 2D spatial coordinates.""" graph = RustWorkXGraph() - graph.add_node_attr_key("y", 0) - graph.add_node_attr_key("x", 0) + graph.add_node_attr_key("y", dtype=pl.Int64, default_value=0) + graph.add_node_attr_key("x", dtype=pl.Int64, default_value=0) for y in [5, 11, 14]: for x in [10, 30]: diff --git a/src/tracksdata/functional/_test/test_labeling.py b/src/tracksdata/functional/_test/test_labeling.py index 814ecb0b..96225c49 100644 --- a/src/tracksdata/functional/_test/test_labeling.py +++ b/src/tracksdata/functional/_test/test_labeling.py @@ -1,3 +1,5 @@ +import polars as pl + import tracksdata as td @@ -43,7 +45,7 @@ def test_ancestral_connected_edges(): ref_graph.add_edge(7, 8, {}) # manual matching - input_graph.add_node_attr_key(td.DEFAULT_ATTR_KEYS.MATCHED_NODE_ID, -1) + input_graph.add_node_attr_key(td.DEFAULT_ATTR_KEYS.MATCHED_NODE_ID, dtype=pl.Int64, default_value=-1) input_graph.update_node_attrs( attrs={ td.DEFAULT_ATTR_KEYS.MATCHED_NODE_ID: [0, 1, 2, 3, 4, 5, 6, 8], diff --git a/src/tracksdata/functional/_test/test_napari.py b/src/tracksdata/functional/_test/test_napari.py index ec8faa2d..c1044ef9 100644 --- a/src/tracksdata/functional/_test/test_napari.py +++ b/src/tracksdata/functional/_test/test_napari.py @@ -1,4 +1,5 @@ import numpy as np +import polars as pl import pytest from tracksdata.constants import DEFAULT_ATTR_KEYS @@ -25,8 +26,8 @@ def test_napari_conversion(metadata_shape: bool) -> None: tracklet_ids=tracklet_ids, tracklet_id_graph=tracklet_id_graph, ) - graph.add_node_attr_key(DEFAULT_ATTR_KEYS.SOLUTION, True) - graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.SOLUTION, True) + graph.add_node_attr_key(DEFAULT_ATTR_KEYS.SOLUTION, dtype=pl.Boolean, default_value=True) + graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.SOLUTION, dtype=pl.Boolean, default_value=True) shape = (2, 10, 22, 32) if metadata_shape: @@ -43,7 +44,7 @@ def test_napari_conversion(metadata_shape: bool) -> None: mask_attrs.add_node_attrs(graph) # Maybe we should update the MaskDiskAttrs to handle bounding boxes - graph.add_node_attr_key(DEFAULT_ATTR_KEYS.BBOX, None) + graph.add_node_attr_key(DEFAULT_ATTR_KEYS.BBOX, dtype=pl.Object, default_value=None) masks = graph.node_attrs(attr_keys=[DEFAULT_ATTR_KEYS.MASK])[DEFAULT_ATTR_KEYS.MASK] graph.update_node_attrs( attrs={DEFAULT_ATTR_KEYS.BBOX: [mask.bbox for mask in masks]}, diff --git a/src/tracksdata/graph/_test/test_graph_backends.py b/src/tracksdata/graph/_test/test_graph_backends.py index 1420ee44..88d42885 100644 --- a/src/tracksdata/graph/_test/test_graph_backends.py +++ b/src/tracksdata/graph/_test/test_graph_backends.py @@ -19,10 +19,10 @@ def test_already_existing_keys(graph_backend: BaseGraph) -> None: """Test that adding already existing keys raises an error.""" - graph_backend.add_node_attr_key("x", None) + graph_backend.add_node_attr_key("x", dtype=pl.Object, default_value=None) with pytest.raises(ValueError): - graph_backend.add_node_attr_key("x", None) + graph_backend.add_node_attr_key("x", dtype=pl.Object, default_value=None) with pytest.raises(ValueError): # missing x @@ -57,7 +57,7 @@ def test_add_node(graph_backend: BaseGraph) -> None: """Test adding nodes with various attributes.""" for key in ["x", "y"]: - graph_backend.add_node_attr_key(key, 0.0) + graph_backend.add_node_attr_key(key, dtype=pl.Float64, default_value=0.0) node_id = graph_backend.add_node({"t": 0, "x": 1.0, "y": 2.0}) assert isinstance(node_id, int) @@ -77,7 +77,7 @@ def test_add_node(graph_backend: BaseGraph) -> None: def test_add_edge(graph_backend: BaseGraph) -> None: """Test adding edges with attributes.""" # Add node attribute key - graph_backend.add_node_attr_key("x", None) + graph_backend.add_node_attr_key("x", dtype=pl.Object, default_value=None) # Add two nodes first node1 = graph_backend.add_node({"t": 0, "x": 1.0}) @@ -85,7 +85,7 @@ def test_add_edge(graph_backend: BaseGraph) -> None: node3 = graph_backend.add_node({"t": 2, "x": 1.0}) # Add edge attribute key - graph_backend.add_edge_attr_key("weight", 0.0) + graph_backend.add_edge_attr_key("weight", dtype=pl.Float64, default_value=0.0) # Add edge edge_id = graph_backend.add_edge(node1, node2, attrs={"weight": 0.5}) @@ -98,7 +98,7 @@ def test_add_edge(graph_backend: BaseGraph) -> None: assert df["weight"].to_list() == [0.5] # testing adding new add attribute - graph_backend.add_edge_attr_key("new_attribute", 0.0) + graph_backend.add_edge_attr_key("new_attribute", dtype=pl.Float64, default_value=0.0) edge_id = graph_backend.add_edge(node2, node3, attrs={"new_attribute": 1.0, "weight": 0.1}) assert isinstance(edge_id, int) @@ -110,8 +110,8 @@ def test_add_edge(graph_backend: BaseGraph) -> None: def test_remove_edge_by_id(graph_backend: BaseGraph) -> None: """Test removing an edge by ID across backends using unified API.""" # Setup - graph_backend.add_node_attr_key("x", None) - graph_backend.add_edge_attr_key("weight", 0.0) + graph_backend.add_node_attr_key("x", dtype=pl.Object, default_value=None) + graph_backend.add_edge_attr_key("weight", dtype=pl.Float64, default_value=0.0) n1 = graph_backend.add_node({"t": 0, "x": 1.0}) n2 = graph_backend.add_node({"t": 1, "x": 2.0}) @@ -147,8 +147,8 @@ def test_remove_edge_by_id(graph_backend: BaseGraph) -> None: def test_remove_edge_by_nodes(graph_backend: BaseGraph) -> None: """Test removing an edge by its source/target IDs.""" - graph_backend.add_node_attr_key("x", None) - graph_backend.add_edge_attr_key("weight", 0.0) + graph_backend.add_node_attr_key("x", dtype=pl.Object, default_value=None) + graph_backend.add_edge_attr_key("weight", dtype=pl.Float64, default_value=0.0) a = graph_backend.add_node({"t": 0, "x": 0.0}) b = graph_backend.add_node({"t": 1, "x": 1.0}) @@ -184,7 +184,7 @@ def test_node_ids(graph_backend: BaseGraph) -> None: def test_filter_nodes_by_attribute(graph_backend: BaseGraph) -> None: """Test filtering nodes by attributes.""" - graph_backend.add_node_attr_key("label", None) + graph_backend.add_node_attr_key("label", dtype=pl.Object, default_value=None) node1 = graph_backend.add_node({"t": 0, "label": "A"}) node2 = graph_backend.add_node({"t": 0, "label": "B"}) @@ -235,8 +235,8 @@ def test_time_points(graph_backend: BaseGraph) -> None: def test_node_attrs(graph_backend: BaseGraph) -> None: """Test retrieving node attributes.""" - graph_backend.add_node_attr_key("x", 0.0) - graph_backend.add_node_attr_key("coordinates", np.array([0.0, 0.0])) + graph_backend.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) + graph_backend.add_node_attr_key("coordinates", dtype=pl.Object, default_value=np.array([0.0, 0.0])) node1 = graph_backend.add_node({"t": 0, "x": 1.0, "coordinates": np.array([10.0, 20.0])}) node2 = graph_backend.add_node({"t": 1, "x": 2.0, "coordinates": np.array([30.0, 40.0])}) @@ -257,8 +257,8 @@ def test_edge_attrs(graph_backend: BaseGraph) -> None: node1 = graph_backend.add_node({"t": 0}) node2 = graph_backend.add_node({"t": 1}) - graph_backend.add_edge_attr_key("weight", 0.0) - graph_backend.add_edge_attr_key("vector", np.array([0.0, 0.0])) + graph_backend.add_edge_attr_key("weight", dtype=pl.Float64, default_value=0.0) + graph_backend.add_edge_attr_key("vector", dtype=pl.Object, default_value=np.array([0.0, 0.0])) graph_backend.add_edge(node1, node2, attrs={"weight": 0.5, "vector": np.array([1.0, 2.0])}) @@ -276,7 +276,7 @@ def test_edge_attrs(graph_backend: BaseGraph) -> None: def test_edge_attrs_subgraph_edge_ids(graph_backend: BaseGraph) -> None: """Test that edge_attrs preserves original edge IDs when using node_ids parameter.""" # Add edge attribute key - graph_backend.add_edge_attr_key("weight", 0.0) + graph_backend.add_edge_attr_key("weight", dtype=pl.Float64, default_value=0.0) # Create nodes node1 = graph_backend.add_node({"t": 0}) @@ -335,10 +335,10 @@ def test_edge_attrs_subgraph_edge_ids(graph_backend: BaseGraph) -> None: def test_subgraph_with_node_and_edge_attr_filters(graph_backend: BaseGraph) -> None: """Test subgraph with node and edge attribute filters.""" - graph_backend.add_node_attr_key("x", 0.0) - graph_backend.add_node_attr_key("y", 0.0) - graph_backend.add_edge_attr_key("weight", 0.0) - graph_backend.add_edge_attr_key("length", 0.0) + graph_backend.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) + graph_backend.add_node_attr_key("y", dtype=pl.Float64, default_value=0.0) + graph_backend.add_edge_attr_key("weight", dtype=pl.Float64, default_value=0.0) + graph_backend.add_edge_attr_key("length", dtype=pl.Float64, default_value=0.0) node1 = graph_backend.add_node({"t": 0, "x": 1.0, "y": 0.0}) node2 = graph_backend.add_node({"t": 1, "x": 2.0, "y": 0.0}) @@ -374,8 +374,8 @@ def test_subgraph_with_node_and_edge_attr_filters(graph_backend: BaseGraph) -> N def test_subgraph_with_node_ids_and_filters(graph_backend: BaseGraph) -> None: """Test subgraph with node IDs and filters.""" - graph_backend.add_node_attr_key("x", None) - graph_backend.add_edge_attr_key("weight", 0.0) + graph_backend.add_node_attr_key("x", dtype=pl.Object, default_value=None) + graph_backend.add_edge_attr_key("weight", dtype=pl.Float64, default_value=0.0) node0 = graph_backend.add_node({"t": 0, "x": 1.0}) node1 = graph_backend.add_node({"t": 1, "x": 2.0}) @@ -423,7 +423,19 @@ def test_subgraph_with_node_ids_and_filters(graph_backend: BaseGraph) -> None: def test_add_node_attr_key(graph_backend: BaseGraph, value) -> None: """Test adding new node attribute keys.""" node = graph_backend.add_node({"t": 0}) - graph_backend.add_node_attr_key("new_attribute", value) + # Infer dtype from value + if isinstance(value, bool): + dtype = pl.Boolean + elif isinstance(value, int): + dtype = pl.Int64 + elif isinstance(value, float): + dtype = pl.Float64 + elif isinstance(value, str): + dtype = pl.String + else: + # For arrays, masks, and other objects + dtype = pl.Object + graph_backend.add_node_attr_key("new_attribute", dtype, default_value=value) df = graph_backend.filter(node_ids=[node]).node_attrs(attr_keys=["new_attribute"]) assert len(df) == 1 @@ -436,7 +448,7 @@ def test_add_node_attr_key(graph_backend: BaseGraph, value) -> None: def test_remove_node_attr_key(graph_backend: BaseGraph) -> None: """Test removing node attribute keys.""" - graph_backend.add_node_attr_key("label", "init") + graph_backend.add_node_attr_key("label", dtype=pl.String, default_value="init") node_a = graph_backend.add_node({"t": 0, "label": "a"}) node_b = graph_backend.add_node({"t": 1, "label": "b"}) @@ -460,7 +472,7 @@ def test_add_edge_attr_key(graph_backend: BaseGraph) -> None: node1 = graph_backend.add_node({"t": 0}) node2 = graph_backend.add_node({"t": 1}) - graph_backend.add_edge_attr_key("new_attribute", 42) + graph_backend.add_edge_attr_key("new_attribute", dtype=pl.Int64, default_value=42) graph_backend.add_edge(node1, node2, attrs={"new_attribute": 42}) df = graph_backend.edge_attrs(attr_keys=["new_attribute"]) @@ -472,7 +484,7 @@ def test_remove_edge_attr_key(graph_backend: BaseGraph) -> None: node1 = graph_backend.add_node({"t": 0}) node2 = graph_backend.add_node({"t": 1}) - graph_backend.add_edge_attr_key("weight", 0.5) + graph_backend.add_edge_attr_key("weight", dtype=pl.Float64, default_value=0.5) graph_backend.add_edge(node1, node2, attrs={"weight": 1.2}) assert "weight" in graph_backend.edge_attr_keys() @@ -489,7 +501,7 @@ def test_remove_edge_attr_key(graph_backend: BaseGraph) -> None: def test_update_node_attrs(graph_backend: BaseGraph) -> None: """Test updating node attributes.""" - graph_backend.add_node_attr_key("x", 0.0) + graph_backend.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) node_1 = graph_backend.add_node({"t": 0, "x": 1.0}) node_2 = graph_backend.add_node({"t": 0, "x": 2.0}) @@ -515,7 +527,7 @@ def test_update_edge_attrs(graph_backend: BaseGraph) -> None: node1 = graph_backend.add_node({"t": 0}) node2 = graph_backend.add_node({"t": 1}) - graph_backend.add_edge_attr_key("weight", 0.0) + graph_backend.add_edge_attr_key("weight", dtype=pl.Float64, default_value=0.0) edge_id = graph_backend.add_edge(node1, node2, attrs={"weight": 0.5}) graph_backend.update_edge_attrs(edge_ids=[edge_id], attrs={"weight": 1.0}) @@ -532,7 +544,7 @@ def test_num_edges(graph_backend: BaseGraph) -> None: node1 = graph_backend.add_node({"t": 0}) node2 = graph_backend.add_node({"t": 1}) - graph_backend.add_edge_attr_key("weight", 0.0) + graph_backend.add_edge_attr_key("weight", dtype=pl.Float64, default_value=0.0) graph_backend.add_edge(node1, node2, attrs={"weight": 0.5}) assert graph_backend.num_edges() == 1 @@ -549,7 +561,7 @@ def test_num_nodes(graph_backend: BaseGraph) -> None: def test_edge_attrs_include_targets(graph_backend: BaseGraph) -> None: """Test the inclusive flag behavior in edge_attrs method.""" # Add edge attribute key - graph_backend.add_edge_attr_key("weight", 0.0) + graph_backend.add_edge_attr_key("weight", dtype=pl.Float64, default_value=0.0) # Create a graph with 4 nodes # Graph structure: @@ -666,9 +678,9 @@ def test_from_ctc( def test_sucessors_and_degree(graph_backend: BaseGraph) -> None: """Test getting successors of nodes.""" # Add attribute keys - graph_backend.add_node_attr_key("x", 0.0) - graph_backend.add_node_attr_key("y", 0.0) - graph_backend.add_edge_attr_key("weight", 0.0) + graph_backend.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) + graph_backend.add_node_attr_key("y", dtype=pl.Float64, default_value=0.0) + graph_backend.add_edge_attr_key("weight", dtype=pl.Float64, default_value=0.0) # Create a simple graph structure: node0 -> node1 -> node2 # \-> node3 @@ -757,9 +769,9 @@ def test_sucessors_and_degree(graph_backend: BaseGraph) -> None: def test_predecessors_and_degree(graph_backend: BaseGraph) -> None: """Test getting predecessors of nodes.""" # Add attribute keys - graph_backend.add_node_attr_key("x", 0.0) - graph_backend.add_node_attr_key("y", 0.0) - graph_backend.add_edge_attr_key("weight", 0.0) + graph_backend.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) + graph_backend.add_node_attr_key("y", dtype=pl.Float64, default_value=0.0) + graph_backend.add_edge_attr_key("weight", dtype=pl.Float64, default_value=0.0) # Create a simple graph structure: node0 -> node1 -> node2 # \-> node3 @@ -846,10 +858,10 @@ def test_predecessors_and_degree(graph_backend: BaseGraph) -> None: def test_sucessors_with_attr_keys(graph_backend: BaseGraph) -> None: """Test getting successors with specific attribute keys.""" # Add attribute keys - graph_backend.add_node_attr_key("x", 0.0) - graph_backend.add_node_attr_key("y", 0.0) - graph_backend.add_node_attr_key("label", "X") - graph_backend.add_edge_attr_key("weight", 0.0) + graph_backend.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) + graph_backend.add_node_attr_key("y", dtype=pl.Float64, default_value=0.0) + graph_backend.add_node_attr_key("label", dtype=pl.String, default_value="X") + graph_backend.add_edge_attr_key("weight", dtype=pl.Float64, default_value=0.0) # Create nodes node0 = graph_backend.add_node({"t": 0, "x": 0.0, "y": 0.0, "label": "A"}) @@ -893,10 +905,10 @@ def test_sucessors_with_attr_keys(graph_backend: BaseGraph) -> None: def test_predecessors_with_attr_keys(graph_backend: BaseGraph) -> None: """Test getting predecessors with specific attribute keys.""" # Add attribute keys - graph_backend.add_node_attr_key("x", 0.0) - graph_backend.add_node_attr_key("y", 0.0) - graph_backend.add_node_attr_key("label", "X") - graph_backend.add_edge_attr_key("weight", 0.0) + graph_backend.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) + graph_backend.add_node_attr_key("y", dtype=pl.Float64, default_value=0.0) + graph_backend.add_node_attr_key("label", dtype=pl.String, default_value="X") + graph_backend.add_edge_attr_key("weight", dtype=pl.Float64, default_value=0.0) # Create nodes node0 = graph_backend.add_node({"t": 0, "x": 0.0, "y": 0.0, "label": "A"}) @@ -936,8 +948,8 @@ def test_predecessors_with_attr_keys(graph_backend: BaseGraph) -> None: def test_sucessors_predecessors_edge_cases(graph_backend: BaseGraph) -> None: """Test edge cases for successors and predecessors methods.""" # Add attribute keys - graph_backend.add_node_attr_key("x", 0.0) - graph_backend.add_edge_attr_key("weight", 0.0) + graph_backend.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) + graph_backend.add_edge_attr_key("weight", dtype=pl.Float64, default_value=0.0) # Create isolated nodes (no edges) node0 = graph_backend.add_node({"t": 0, "x": 0.0}) @@ -980,9 +992,9 @@ def test_sucessors_predecessors_edge_cases(graph_backend: BaseGraph) -> None: def test_match_method(graph_backend: BaseGraph) -> None: """Test the match method for matching nodes between two graphs.""" # Create first graph (self) with masks - graph_backend.add_node_attr_key("x", 0.0) - graph_backend.add_node_attr_key("y", 0.0) - graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, None) + graph_backend.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) + graph_backend.add_node_attr_key("y", dtype=pl.Float64, default_value=0.0) + graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, dtype=pl.Object, default_value=None) # Create masks for first graph mask1_data = np.array([[True, True], [True, True]], dtype=bool) @@ -999,7 +1011,7 @@ def test_match_method(graph_backend: BaseGraph) -> None: node2 = graph_backend.add_node({"t": 1, "x": 2.0, "y": 2.0, DEFAULT_ATTR_KEYS.MASK: mask2}) node3 = graph_backend.add_node({"t": 2, "x": 3.0, "y": 3.0, DEFAULT_ATTR_KEYS.MASK: mask3}) - graph_backend.add_edge_attr_key("weight", 0.0) + graph_backend.add_edge_attr_key("weight", dtype=pl.Float64, default_value=0.0) # this will not be matched graph_backend.add_edge(node1, node2, {"weight": 0.5}) graph_backend.add_edge(node2, node3, {"weight": 0.3}) @@ -1014,9 +1026,9 @@ def test_match_method(graph_backend: BaseGraph) -> None: kwargs = {} other_graph = graph_backend.__class__(**kwargs) - other_graph.add_node_attr_key("x", 0.0) - other_graph.add_node_attr_key("y", 0.0) - other_graph.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, None) + other_graph.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) + other_graph.add_node_attr_key("y", dtype=pl.Float64, default_value=0.0) + other_graph.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, dtype=pl.Object, default_value=None) # Create overlapping masks for second graph # This mask overlaps significantly with mask1 (IoU > 0.5) @@ -1043,7 +1055,7 @@ def test_match_method(graph_backend: BaseGraph) -> None: ref_node4 = other_graph.add_node({"t": 2, "x": 3.1, "y": 3.1, DEFAULT_ATTR_KEYS.MASK: ref_mask4}) # Add edges to reference graph - matching structure with first graph - other_graph.add_edge_attr_key("weight", 0.0) + other_graph.add_edge_attr_key("weight", dtype=pl.Float64, default_value=0.0) other_graph.add_edge(ref_node1, ref_node3, {"weight": 0.6}) # ref_node1 -> ref_node2 other_graph.add_edge(ref_node1, ref_node2, {"weight": 0.4}) # ref_node1 -> ref_node3 other_graph.add_edge(ref_node3, ref_node2, {"weight": 0.7}) # ref_node2 -> ref_node3 @@ -1119,15 +1131,15 @@ def test_match_method(graph_backend: BaseGraph) -> None: def test_attrs_with_duplicated_attr_keys(graph_backend: BaseGraph) -> None: """Test that node attributeswith duplicated attribute keys are handled correctly.""" # Add attribute keys - graph_backend.add_node_attr_key("x", 0.0) - graph_backend.add_node_attr_key("y", 0.0) + graph_backend.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) + graph_backend.add_node_attr_key("y", dtype=pl.Float64, default_value=0.0) # Add nodes node_1 = graph_backend.add_node({"t": 0, "x": 1.0, "y": 1.0}) node_2 = graph_backend.add_node({"t": 1, "x": 2.0, "y": 2.0}) # Add edges - graph_backend.add_edge_attr_key("weight", 0.0) + graph_backend.add_edge_attr_key("weight", dtype=pl.Float64, default_value=0.0) graph_backend.add_edge(node_1, node_2, {"weight": 0.5}) # Test with duplicated attribute keys @@ -1354,9 +1366,9 @@ def test_from_other_with_edges( # Create source graph with nodes, edges, and attributes graph_backend.update_metadata(special_key="special_value") - graph_backend.add_node_attr_key("x", 0.0) - graph_backend.add_edge_attr_key("weight", 0.0) - graph_backend.add_edge_attr_key("type", "forward") + graph_backend.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) + graph_backend.add_edge_attr_key("weight", dtype=pl.Float64, default_value=0.0) + graph_backend.add_edge_attr_key("type", dtype=pl.String, default_value="forward") node1 = graph_backend.add_node({"t": 0, "x": 1.0}) node2 = graph_backend.add_node({"t": 1, "x": 2.0}) @@ -1490,7 +1502,7 @@ def build_node_map(graph: BaseGraph) -> dict[tuple[int, tuple[int, ...]], dict[s def test_compute_overlaps_basic(graph_backend: BaseGraph) -> None: """Test basic compute_overlaps functionality.""" - graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, None) + graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, dtype=pl.Object, default_value=None) # Create overlapping masks at time 0 mask1_data = np.array([[True, True], [True, True]], dtype=bool) @@ -1512,7 +1524,7 @@ def test_compute_overlaps_basic(graph_backend: BaseGraph) -> None: def test_compute_overlaps_with_threshold(graph_backend: BaseGraph) -> None: """Test compute_overlaps with different IoU thresholds.""" - graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, None) + graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, dtype=pl.Object, default_value=None) # Create masks with different overlap levels mask1_data = np.array([[True, True], [True, True]], dtype=bool) @@ -1546,7 +1558,7 @@ def test_compute_overlaps_with_threshold(graph_backend: BaseGraph) -> None: def test_compute_overlaps_multiple_timepoints(graph_backend: BaseGraph) -> None: """Test compute_overlaps across multiple time points.""" - graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, None) + graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, dtype=pl.Object, default_value=None) # Time 0: overlapping masks mask1_t0 = Mask(np.array([[True, True], [True, True]], dtype=bool), bbox=np.array([0, 0, 2, 2])) @@ -1573,7 +1585,7 @@ def test_sql_graph_mask_update_survives_reload(tmp_path: Path) -> None: """Ensure SQLGraph keeps pickled column types after reloading from disk.""" db_path = tmp_path / "mask_graph.db" graph = SQLGraph("sqlite", str(db_path)) - graph.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, None) + graph.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, dtype=pl.Object, default_value=None) mask_data = np.array([[True, False], [False, True]], dtype=bool) mask = Mask(mask_data, bbox=np.array([0, 0, 2, 2])) @@ -1624,9 +1636,9 @@ def test_compute_overlaps_empty_graph(graph_backend: BaseGraph) -> None: def test_summary(graph_backend: BaseGraph) -> None: """Test summary method.""" - graph_backend.add_node_attr_key("x", 0.0) - graph_backend.add_edge_attr_key("weight", 0.0) - graph_backend.add_edge_attr_key("type", "good") + graph_backend.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) + graph_backend.add_edge_attr_key("weight", dtype=pl.Float64, default_value=0.0) + graph_backend.add_edge_attr_key("type", dtype=pl.String, default_value="good") node1 = graph_backend.add_node({"t": 0, "x": 1.0}) node2 = graph_backend.add_node({"t": 1, "x": 2.0}) @@ -1647,10 +1659,10 @@ def test_summary(graph_backend: BaseGraph) -> None: def test_spatial_filter_basic(graph_backend: BaseGraph) -> None: - graph_backend.add_node_attr_key("x", 0.0) - graph_backend.add_node_attr_key("y", 0.0) - graph_backend.add_node_attr_key("z", 0.0) - graph_backend.add_node_attr_key("bbox", np.zeros(6, dtype=int)) + graph_backend.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) + graph_backend.add_node_attr_key("y", dtype=pl.Float64, default_value=0.0) + graph_backend.add_node_attr_key("z", dtype=pl.Float64, default_value=0.0) + graph_backend.add_node_attr_key("bbox", pl.Array(pl.Int64, 6), default_value=np.zeros(6, dtype=int)) node1 = graph_backend.add_node({"t": 0, "x": 1.0, "y": 1.0, "z": 1.0, "bbox": np.array([6, 6, 6, 8, 8, 8])}) node2 = graph_backend.add_node({"t": 1, "x": 2.0, "y": 2.0, "z": 2.0, "bbox": np.array([0, 0, 0, 3, 3, 3])}) @@ -1811,7 +1823,7 @@ def test_assign_tracklet_ids_node_id_filter(graph_backend: BaseGraph, return_id_ # Ensure tracklet_id attribute exists after nodes were added if DEFAULT_ATTR_KEYS.TRACKLET_ID not in graph_backend.node_attr_keys(): - graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.TRACKLET_ID, -1) + graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.TRACKLET_ID, dtype=pl.Int64, default_value=-1) for seeds, expected in ( ([A1], [[A0, A1, A2, A3]]), @@ -1921,7 +1933,7 @@ def test_assign_tracklet_ids_node_id_filter(graph_backend: BaseGraph, return_id_ def test_tracklet_graph_basic(graph_backend: BaseGraph) -> None: """Test basic tracklet_graph functionality.""" # Add tracklet_id attribute and nodes with track IDs - graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.TRACKLET_ID, -1) + graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.TRACKLET_ID, dtype=pl.Int64, default_value=-1) # Create nodes with different track IDs node0 = graph_backend.add_node({"t": 0, DEFAULT_ATTR_KEYS.TRACKLET_ID: 1}) @@ -1934,7 +1946,7 @@ def test_tracklet_graph_basic(graph_backend: BaseGraph) -> None: node7 = graph_backend.add_node({"t": 2, DEFAULT_ATTR_KEYS.TRACKLET_ID: 4}) node8 = graph_backend.add_node({"t": 3, DEFAULT_ATTR_KEYS.TRACKLET_ID: 4}) - graph_backend.add_edge_attr_key("weight", 0.0) + graph_backend.add_edge_attr_key("weight", dtype=pl.Float64, default_value=0.0) # Add edges within tracks (will be filtered out) graph_backend.add_edge(node0, node1, {"weight": 0.5}) @@ -1965,8 +1977,8 @@ def test_tracklet_graph_basic(graph_backend: BaseGraph) -> None: def test_tracklet_graph_with_ignore_tracklet_id(graph_backend: BaseGraph) -> None: """Test tracklet_graph with ignore_tracklet_id parameter.""" # Add tracklet_id attribute and nodes with track IDs - graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.TRACKLET_ID, -1) - graph_backend.add_edge_attr_key("weight", 0.0) + graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.TRACKLET_ID, dtype=pl.Int64, default_value=-1) + graph_backend.add_edge_attr_key("weight", dtype=pl.Float64, default_value=0.0) # Simple test case: just check that the method accepts the parameter # and filters out nodes properly when there are no edges @@ -1993,7 +2005,7 @@ def test_tracklet_graph_missing_tracklet_id_key(graph_backend: BaseGraph) -> Non def test_nodes_interface(graph_backend: BaseGraph) -> None: - graph_backend.add_node_attr_key("x", 0) + graph_backend.add_node_attr_key("x", dtype=pl.Int64, default_value=0) # Simple test case: just check that the method accepts the parameter # and filters out nodes properly when there are no edges @@ -2005,7 +2017,7 @@ def test_nodes_interface(graph_backend: BaseGraph) -> None: assert graph_backend[node2]["x"] == 0 assert graph_backend[node3]["x"] == -1 - graph_backend.add_node_attr_key("y", -1) + graph_backend.add_node_attr_key("y", dtype=pl.Int64, default_value=-1) graph_backend[node2]["y"] = 5 @@ -2025,8 +2037,8 @@ def test_custom_indices(graph_backend: BaseGraph) -> None: pytest.skip("Graph does not support custom indices") # Add attribute keys for testing - graph_backend.add_node_attr_key("x", 0.0) - graph_backend.add_node_attr_key("y", 0.0) + graph_backend.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) + graph_backend.add_node_attr_key("y", dtype=pl.Float64, default_value=0.0) # Test add_node with custom index custom_node_id = graph_backend.add_node({"t": 0, "x": 10.0, "y": 20.0}, index=12345) @@ -2070,7 +2082,7 @@ def test_sqlgraph_node_attr_index_create_and_drop(graph_backend: BaseGraph) -> N if not isinstance(graph_backend, SQLGraph): pytest.skip("Only SQLGraph supports explicit SQL indexes") - graph_backend.add_node_attr_key("label", "") + graph_backend.add_node_attr_key("label", dtype=pl.String, default_value="") index_name = f"ix_{graph_backend.Node.__tablename__.lower()}_t_label" graph_backend.create_node_attr_index(["t", "label"], unique=False) @@ -2091,7 +2103,7 @@ def test_sqlgraph_edge_attr_index_create_and_drop(graph_backend: BaseGraph) -> N if not isinstance(graph_backend, SQLGraph): pytest.skip("Only SQLGraph supports explicit SQL indexes") - graph_backend.add_edge_attr_key("score", 0.0) + graph_backend.add_edge_attr_key("score", dtype=pl.Float64, default_value=0.0) index_name = f"ix_{graph_backend.Edge.__tablename__.lower()}_score" graph_backend.create_edge_attr_index("score", unique=True) @@ -2120,9 +2132,9 @@ def test_sqlgraph_index_missing_column(graph_backend: BaseGraph) -> None: def test_remove_node(graph_backend: BaseGraph) -> None: """Test removing nodes from the graph.""" # Add attribute keys - graph_backend.add_node_attr_key("x", 0.0) - graph_backend.add_node_attr_key("y", 0.0) - graph_backend.add_edge_attr_key("weight", 0.0) + graph_backend.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) + graph_backend.add_node_attr_key("y", dtype=pl.Float64, default_value=0.0) + graph_backend.add_edge_attr_key("weight", dtype=pl.Float64, default_value=0.0) # Add nodes node1 = graph_backend.add_node({"t": 0, "x": 1.0, "y": 1.0}) @@ -2182,8 +2194,8 @@ def test_remove_node(graph_backend: BaseGraph) -> None: def test_remove_node_and_add_new_nodes(graph_backend: BaseGraph) -> None: """Test removing nodes and then adding new nodes.""" # Add attribute keys - graph_backend.add_node_attr_key("x", 0.0) - graph_backend.add_edge_attr_key("weight", 0.0) + graph_backend.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) + graph_backend.add_edge_attr_key("weight", dtype=pl.Float64, default_value=0.0) # Add initial nodes node1 = graph_backend.add_node({"t": 0, "x": 1.0}) @@ -2279,15 +2291,17 @@ def test_remove_all_nodes_in_time_point(graph_backend: BaseGraph) -> None: def _fill_mock_geff_graph(graph_backend: BaseGraph) -> None: - graph_backend.add_node_attr_key("x", 0.0) - graph_backend.add_node_attr_key("y", 0.0) - graph_backend.add_node_attr_key("z", 0.0) - graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.BBOX, np.array([0, 0, 1, 1], dtype=int)) - graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, None) - graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.TRACKLET_ID, -1) - graph_backend.add_node_attr_key("ndfeature", np.asarray([[1.0], [2.0], [3.0]])) + graph_backend.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) + graph_backend.add_node_attr_key("y", dtype=pl.Float64, default_value=0.0) + graph_backend.add_node_attr_key("z", dtype=pl.Float64, default_value=0.0) + graph_backend.add_node_attr_key( + DEFAULT_ATTR_KEYS.BBOX, pl.Array(pl.Int64, 4), default_value=np.array([0, 0, 1, 1], dtype=int) + ) + graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, dtype=pl.Object, default_value=None) + graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.TRACKLET_ID, dtype=pl.Int64, default_value=-1) + graph_backend.add_node_attr_key("ndfeature", dtype=pl.Object, default_value=np.asarray([[1.0], [2.0], [3.0]])) - graph_backend.add_edge_attr_key("weight", 0.0) + graph_backend.add_edge_attr_key("weight", dtype=pl.Float64, default_value=0.0) graph_backend.update_metadata( shape=[1, 25, 25], @@ -2461,9 +2475,9 @@ def test_pickle_roundtrip(graph_backend: BaseGraph) -> None: if isinstance(graph_backend, SQLGraph): pytest.skip("SQLGraph does not support pickle roundtrip") - graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.BBOX, None) - graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, None) - graph_backend.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, 0.0) + graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.BBOX, dtype=pl.Object, default_value=None) + graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, dtype=pl.Object, default_value=None) + graph_backend.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, dtype=pl.Float64, default_value=0.0) bbox = np.array([0, 0, 2, 2]) mask = Mask(np.array([[True, True], [True, True]], dtype=bool), bbox=bbox) @@ -2505,7 +2519,7 @@ def test_sql_graph_huge_update() -> None: random_t = np.random.randint(0, 1000, n_nodes).tolist() random_x = np.random.rand(n_nodes).tolist() graph.bulk_add_nodes([{"t": t} for t in random_t]) - graph.add_node_attr_key("x", -1.0) + graph.add_node_attr_key("x", dtype=pl.Float64, default_value=-1.0) # testing with varying values graph.update_node_attrs( @@ -2527,10 +2541,10 @@ def test_to_traccuracy_graph(graph_backend: BaseGraph) -> None: from traccuracy.metrics import CTCMetrics # Create first graph (self) with masks - graph_backend.add_node_attr_key("x", 0.0) - graph_backend.add_node_attr_key("y", 0.0) - graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, None) - graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.BBOX, np.zeros(4, dtype=int)) + graph_backend.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) + graph_backend.add_node_attr_key("y", dtype=pl.Float64, default_value=0.0) + graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, dtype=pl.Object, default_value=None) + graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.BBOX, pl.Array(pl.Int64, 4), default_value=np.zeros(4, dtype=int)) graph_backend.update_metadata( shape=[3, 25, 25], ) @@ -2556,7 +2570,7 @@ def test_to_traccuracy_graph(graph_backend: BaseGraph) -> None: {"t": 2, "x": 3.0, "y": 3.0, DEFAULT_ATTR_KEYS.MASK: mask3, DEFAULT_ATTR_KEYS.BBOX: mask3.bbox} ) - graph_backend.add_edge_attr_key("weight", 0.0) + graph_backend.add_edge_attr_key("weight", dtype=pl.Float64, default_value=0.0) graph_backend.add_edge(node1, node2, {"weight": 0.5}) graph_backend.add_edge(node2, node3, {"weight": 0.3}) graph_backend.add_edge(node1, node3, {"weight": 0.3}) diff --git a/src/tracksdata/graph/_test/test_subgraph.py b/src/tracksdata/graph/_test/test_subgraph.py index 1be01097..36b295b8 100644 --- a/src/tracksdata/graph/_test/test_subgraph.py +++ b/src/tracksdata/graph/_test/test_subgraph.py @@ -37,11 +37,11 @@ def create_test_graph(graph_backend: BaseGraph, use_subgraph: bool = False) -> B Either the original graph or a subgraph with test data. """ # Add attribute keys - graph_backend.add_node_attr_key("x", -1.0) - graph_backend.add_node_attr_key("y", -1.0) - graph_backend.add_node_attr_key("label", "0") - graph_backend.add_edge_attr_key("weight", 0.0) - graph_backend.add_edge_attr_key("new_attribute", 0.0) + graph_backend.add_node_attr_key("x", dtype=pl.Float64, default_value=-1.0) + graph_backend.add_node_attr_key("y", dtype=pl.Float64, default_value=-1.0) + graph_backend.add_node_attr_key("label", dtype=pl.String, default_value="0") + graph_backend.add_edge_attr_key("weight", dtype=pl.Float64, default_value=0.0) + graph_backend.add_edge_attr_key("new_attribute", dtype=pl.Float64, default_value=0.0) # Add nodes with various attributes node0 = graph_backend.add_node({"t": 0, "x": 0.0, "y": 0.0, "label": "0"}) @@ -199,7 +199,7 @@ def test_add_node_attr_key_with_data(graph_backend: BaseGraph, use_subgraph: boo graph_with_data = create_test_graph(graph_backend, use_subgraph) # Add a new attribute key with default value - graph_with_data.add_node_attr_key("new_node_attribute", 42) + graph_with_data.add_node_attr_key("new_node_attribute", dtype=pl.Int64, default_value=42) # Check that all nodes have this attribute with the default value nodes = graph_with_data._test_nodes # type: ignore @@ -214,7 +214,7 @@ def test_add_edge_attr_key_with_data(graph_backend: BaseGraph, use_subgraph: boo graph_with_data = create_test_graph(graph_backend, use_subgraph) # Add a new edge attribute key with default value - graph_with_data.add_edge_attr_key("new_edge_attribute", 99) + graph_with_data.add_edge_attr_key("new_edge_attribute", dtype=pl.Int64, default_value=99) # Check that all edges have this attribute with the default value df = graph_with_data.edge_attrs(attr_keys=["new_edge_attribute"]) @@ -862,7 +862,7 @@ def test_bulk_add_nodes_returned_ids(graph_backend: BaseGraph, use_subgraph: boo graph_with_data = create_test_graph(graph_backend, use_subgraph) # Add attribute keys for the new nodes - graph_with_data.add_node_attr_key("z", 0.0) + graph_with_data.add_node_attr_key("z", dtype=pl.Float64, default_value=0.0) # Test bulk adding nodes nodes_to_add = [ @@ -913,7 +913,7 @@ def test_bulk_add_edges_returned_ids(graph_backend: BaseGraph, use_subgraph: boo graph_with_data = create_test_graph(graph_backend, use_subgraph) # Add attribute keys for the new edges - graph_with_data.add_edge_attr_key("strength", 0.0) + graph_with_data.add_edge_attr_key("strength", dtype=pl.Float64, default_value=0.0) # Get some existing nodes to create edges between existing_nodes = graph_with_data._test_nodes # type: ignore @@ -1101,8 +1101,8 @@ def test_graph_view_remove_edge(graph_backend: BaseGraph) -> None: Tests removal by endpoints and by edge_id with the view in sync mode. """ # Setup root graph with attributes - graph_backend.add_node_attr_key("x", None) - graph_backend.add_edge_attr_key("weight", 0.0) + graph_backend.add_node_attr_key("x", dtype=pl.Object, default_value=None) + graph_backend.add_edge_attr_key("weight", dtype=pl.Float64, default_value=0.0) # Nodes and edges n0 = graph_backend.add_node({"t": 0, "x": 0.0}) diff --git a/src/tracksdata/graph/filters/_test/test_spatial_filter.py b/src/tracksdata/graph/filters/_test/test_spatial_filter.py index ecd4803f..bb7badeb 100644 --- a/src/tracksdata/graph/filters/_test/test_spatial_filter.py +++ b/src/tracksdata/graph/filters/_test/test_spatial_filter.py @@ -13,9 +13,9 @@ def sample_graph() -> RustWorkXGraph: """Create a sample graph with nodes for testing.""" graph = RustWorkXGraph() - graph.add_node_attr_key("z", 0) - graph.add_node_attr_key("y", 0) - graph.add_node_attr_key("x", 0) + graph.add_node_attr_key("z", dtype=pl.Int64, default_value=0) + graph.add_node_attr_key("y", dtype=pl.Int64, default_value=0) + graph.add_node_attr_key("x", dtype=pl.Int64, default_value=0) # Add some nodes with spatial coordinates nodes = [ @@ -35,7 +35,7 @@ def sample_graph() -> RustWorkXGraph: def sample_bbox_graph() -> RustWorkXGraph: """Create a sample graph with nodes for bounding box testing.""" graph = RustWorkXGraph() - graph.add_node_attr_key("bbox", [0, 0, 0, 0, 0, 0]) + graph.add_node_attr_key("bbox", dtype=pl.List(pl.Int64), default_value=[0, 0, 0, 0, 0, 0]) # Add some nodes with bounding box coordinates nodes = [ @@ -141,9 +141,9 @@ def test_spatial_filter_querying(sample_graph: RustWorkXGraph) -> None: def test_spatial_filter_dimensions() -> None: """Test SpatialFilter with different coordinate dimensions.""" graph = RustWorkXGraph() - graph.add_node_attr_key("z", 0) - graph.add_node_attr_key("y", 0) - graph.add_node_attr_key("x", 0) + graph.add_node_attr_key("z", dtype=pl.Int64, default_value=0) + graph.add_node_attr_key("y", dtype=pl.Int64, default_value=0) + graph.add_node_attr_key("x", dtype=pl.Int64, default_value=0) graph.add_node({"t": 0, "z": 0, "y": 10, "x": 20}) # Test 2D coordinates @@ -171,9 +171,9 @@ def test_spatial_filter_error_handling(sample_graph: RustWorkXGraph) -> None: def test_spatial_filter_with_edges() -> None: """Test SpatialFilter preserves edges in subgraphs.""" graph = RustWorkXGraph() - graph.add_node_attr_key("y", 0) - graph.add_node_attr_key("x", 0) - graph.add_edge_attr_key("weight", 0.0) + graph.add_node_attr_key("y", dtype=pl.Int64, default_value=0) + graph.add_node_attr_key("x", dtype=pl.Int64, default_value=0) + graph.add_edge_attr_key("weight", dtype=pl.Float64, default_value=0.0) # Add nodes and edge node1_id = graph.add_node({"t": 0, "y": 10, "x": 20}) @@ -191,7 +191,7 @@ def test_spatial_filter_with_edges() -> None: def test_bbox_spatial_filter_overlaps() -> None: """Test BoundingBoxSpatialFilter overlaps with existing nodes.""" graph = RustWorkXGraph() - graph.add_node_attr_key("bbox", [0, 0, 0, 0]) + graph.add_node_attr_key("bbox", dtype=pl.List(pl.Int64), default_value=[0, 0, 0, 0]) # Add nodes with bounding boxes bboxes = [ [0, 20, 10, 30], # Node 1 @@ -214,8 +214,8 @@ def test_bbox_spatial_filter_overlaps() -> None: def test_bbox_spatial_filter_with_edges() -> None: """Test SpatialFilter preserves edges in subgraphs.""" graph = RustWorkXGraph() - graph.add_node_attr_key("bbox", [0, 0, 0, 0]) - graph.add_edge_attr_key("weight", 0.0) + graph.add_node_attr_key("bbox", dtype=pl.List(pl.Int64), default_value=[0, 0, 0, 0]) + graph.add_edge_attr_key("weight", dtype=pl.Float64, default_value=0.0) # Add nodes and edge node1_id = graph.add_node({"t": 0, "bbox": [10, 20, 15, 25]}) @@ -263,7 +263,7 @@ def test_bbox_spatial_filter_querying(sample_bbox_graph: RustWorkXGraph) -> None def test_bbox_spatial_filter_dimensions() -> None: """Test BoundingBoxSpatialFilter with different coordinate dimensions.""" graph = RustWorkXGraph() - graph.add_node_attr_key("bbox", [0, 0, 0, 0, 0, 0]) + graph.add_node_attr_key("bbox", dtype=pl.List(pl.Int64), default_value=[0, 0, 0, 0, 0, 0]) graph.add_node({"t": 0, "bbox": [0, 10, 20, 1, 15, 25]}) # Test 3D coordinates @@ -284,7 +284,7 @@ def test_bbox_spatial_filter_dimensions() -> None: def test_bbox_spatial_filter_error_handling() -> None: """Test error handling for mismatched min/max attribute lengths.""" graph = RustWorkXGraph() - graph.add_node_attr_key("bbox", [0, 0, 0, 0]) + graph.add_node_attr_key("bbox", dtype=pl.List(pl.Int64), default_value=[0, 0, 0, 0]) graph.add_node({"t": 0, "bbox": [10, 20, 15, 25, 14]}) # Test mismatched min/max attributes length with pytest.raises(ValueError, match="Bounding box coordinates must have even number of dimensions"): @@ -292,7 +292,7 @@ def test_bbox_spatial_filter_error_handling() -> None: def test_add_and_remove_node(graph_backend: BaseGraph) -> None: - graph_backend.add_node_attr_key("bbox", np.asarray([0, 0, 0, 0])) + graph_backend.add_node_attr_key("bbox", dtype=pl.Object, default_value=np.asarray([0, 0, 0, 0])) # testing if _node_tree is created in BBoxSpatialFilter when graph is empty _ = BBoxSpatialFilter(graph_backend, frame_attr_key="t", bbox_attr_key="bbox") @@ -342,7 +342,7 @@ def test_add_and_remove_node(graph_backend: BaseGraph) -> None: def test_bbox_spatial_filter_handles_list_dtype(graph_backend: BaseGraph) -> None: """Ensure bounding boxes stored as list dtype still work with the spatial filter.""" - graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.BBOX, None) + graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.BBOX, dtype=pl.Object, default_value=None) first = graph_backend.add_node({"t": 0, "bbox": [0, 0, 2, 2]}) second = graph_backend.add_node({"t": 1, "bbox": [5, 5, 8, 8]}) diff --git a/src/tracksdata/io/_test/test_ctc_io.py b/src/tracksdata/io/_test/test_ctc_io.py index a708b750..e0a79886 100644 --- a/src/tracksdata/io/_test/test_ctc_io.py +++ b/src/tracksdata/io/_test/test_ctc_io.py @@ -1,6 +1,7 @@ from pathlib import Path import numpy as np +import polars as pl import pytest from tracksdata.constants import DEFAULT_ATTR_KEYS @@ -14,11 +15,11 @@ def test_export_from_ctc_roundtrip(tmp_path: Path, metadata_shape: bool) -> None # Create original graph with nodes and edges in_graph = RustWorkXGraph() - in_graph.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, None) - in_graph.add_node_attr_key(DEFAULT_ATTR_KEYS.BBOX, None) - in_graph.add_node_attr_key(DEFAULT_ATTR_KEYS.TRACKLET_ID, -1) - in_graph.add_node_attr_key("x", -999_999) - in_graph.add_node_attr_key("y", -999_999) + in_graph.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, pl.Object, None) + in_graph.add_node_attr_key(DEFAULT_ATTR_KEYS.BBOX, pl.Object, None) + in_graph.add_node_attr_key(DEFAULT_ATTR_KEYS.TRACKLET_ID, pl.Int64, -1) + in_graph.add_node_attr_key("x", pl.Float64, -999_999) + in_graph.add_node_attr_key("y", pl.Float64, -999_999) node_1 = in_graph.add_node( attrs={ @@ -62,7 +63,7 @@ def test_export_from_ctc_roundtrip(tmp_path: Path, metadata_shape: bool) -> None }, ) - in_graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, 0.0) + in_graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, pl.Float64, 0.0) in_graph.add_edge(node_1, node_2, attrs={DEFAULT_ATTR_KEYS.EDGE_DIST: 1.0}) in_graph.add_edge(node_1, node_3, attrs={DEFAULT_ATTR_KEYS.EDGE_DIST: 1.0}) diff --git a/src/tracksdata/metrics/_test/test_matching.py b/src/tracksdata/metrics/_test/test_matching.py index 0b342f90..3daaa64f 100644 --- a/src/tracksdata/metrics/_test/test_matching.py +++ b/src/tracksdata/metrics/_test/test_matching.py @@ -50,8 +50,8 @@ def test_compute_weights_various_overlaps(self): graph2 = RustWorkXGraph() default_mask = Mask(np.array([[False]]), np.array([0, 0, 1, 1])) - graph1.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, default_mask) - graph2.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, default_mask) + graph1.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, dtype=pl.Object, default_value=default_mask) + graph2.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, dtype=pl.Object, default_value=default_mask) # Test 1: Perfect overlap (IoU = 1.0) mask_perfect = create_2d_mask_from_coords([(0, 0), (0, 1), (0, 2)]) @@ -72,8 +72,8 @@ def test_compute_weights_various_overlaps(self): # Test 2: Partial overlap (3/5 intersection, IoU = 3/7) graph1 = RustWorkXGraph() graph2 = RustWorkXGraph() - graph1.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, default_mask) - graph2.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, default_mask) + graph1.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, dtype=pl.Object, default_value=default_mask) + graph2.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, dtype=pl.Object, default_value=default_mask) mask1 = create_2d_mask_from_coords([(0, 0), (0, 1), (0, 2), (0, 3), (0, 4)]) mask2 = create_2d_mask_from_coords([(0, 2), (0, 3), (0, 4), (0, 5), (0, 6)]) @@ -94,8 +94,8 @@ def test_compute_weights_various_overlaps(self): # Test 3: Below threshold (1/5 = 0.2 < 0.5) - should not match graph1 = RustWorkXGraph() graph2 = RustWorkXGraph() - graph1.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, default_mask) - graph2.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, default_mask) + graph1.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, dtype=pl.Object, default_value=default_mask) + graph2.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, dtype=pl.Object, default_value=default_mask) mask1 = create_2d_mask_from_coords([(0, 0), (0, 1), (0, 2), (0, 3), (0, 4)]) mask2 = create_2d_mask_from_coords([(0, 4), (0, 5), (0, 6), (0, 7), (0, 8)]) @@ -134,10 +134,10 @@ def test_compute_weights_2d_and_3d(self): graph1 = RustWorkXGraph() graph2 = RustWorkXGraph() - graph1.add_node_attr_key("y", 0.0) - graph1.add_node_attr_key("x", 0.0) - graph2.add_node_attr_key("y", 0.0) - graph2.add_node_attr_key("x", 0.0) + graph1.add_node_attr_key("y", dtype=pl.Float64, default_value=0.0) + graph1.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) + graph2.add_node_attr_key("y", dtype=pl.Float64, default_value=0.0) + graph2.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) # Close nodes (distance ≈ 1.414) node1 = graph1.add_node({DEFAULT_ATTR_KEYS.T: 0, "y": 10.0, "x": 10.0}) @@ -157,10 +157,10 @@ def test_compute_weights_2d_and_3d(self): # Far nodes (distance ≈ 141.4) graph1 = RustWorkXGraph() graph2 = RustWorkXGraph() - graph1.add_node_attr_key("y", 0.0) - graph1.add_node_attr_key("x", 0.0) - graph2.add_node_attr_key("y", 0.0) - graph2.add_node_attr_key("x", 0.0) + graph1.add_node_attr_key("y", dtype=pl.Float64, default_value=0.0) + graph1.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) + graph2.add_node_attr_key("y", dtype=pl.Float64, default_value=0.0) + graph2.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) graph1.add_node({DEFAULT_ATTR_KEYS.T: 0, "y": 0.0, "x": 0.0}) graph2.add_node({DEFAULT_ATTR_KEYS.T: 0, "y": 100.0, "x": 100.0}) @@ -180,8 +180,8 @@ def test_compute_weights_2d_and_3d(self): graph2 = RustWorkXGraph() for key in ["z", "y", "x"]: - graph1.add_node_attr_key(key, 0.0) - graph2.add_node_attr_key(key, 0.0) + graph1.add_node_attr_key(key, dtype=pl.Float64, default_value=0.0) + graph2.add_node_attr_key(key, dtype=pl.Float64, default_value=0.0) node1 = graph1.add_node({DEFAULT_ATTR_KEYS.T: 0, "z": 5.0, "y": 10.0, "x": 15.0}) graph2.add_node({DEFAULT_ATTR_KEYS.T: 0, "z": 6.0, "y": 11.0, "x": 16.0}) @@ -202,10 +202,10 @@ def test_anisotropic_scaling(self): graph1 = RustWorkXGraph() graph2 = RustWorkXGraph() - graph1.add_node_attr_key("y", 0.0) - graph1.add_node_attr_key("x", 0.0) - graph2.add_node_attr_key("y", 0.0) - graph2.add_node_attr_key("x", 0.0) + graph1.add_node_attr_key("y", dtype=pl.Float64, default_value=0.0) + graph1.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) + graph2.add_node_attr_key("y", dtype=pl.Float64, default_value=0.0) + graph2.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) # Nodes far in y but close in x graph1.add_node({DEFAULT_ATTR_KEYS.T: 0, "y": 0.0, "x": 0.0}) @@ -234,10 +234,10 @@ def test_auto_detection_of_coordinates(self): graph1 = RustWorkXGraph() graph2 = RustWorkXGraph() - graph1.add_node_attr_key(DEFAULT_ATTR_KEYS.Y, 0.0) - graph1.add_node_attr_key(DEFAULT_ATTR_KEYS.X, 0.0) - graph2.add_node_attr_key(DEFAULT_ATTR_KEYS.Y, 0.0) - graph2.add_node_attr_key(DEFAULT_ATTR_KEYS.X, 0.0) + graph1.add_node_attr_key(DEFAULT_ATTR_KEYS.Y, dtype=pl.Float64, default_value=0.0) + graph1.add_node_attr_key(DEFAULT_ATTR_KEYS.X, dtype=pl.Float64, default_value=0.0) + graph2.add_node_attr_key(DEFAULT_ATTR_KEYS.Y, dtype=pl.Float64, default_value=0.0) + graph2.add_node_attr_key(DEFAULT_ATTR_KEYS.X, dtype=pl.Float64, default_value=0.0) node1 = graph1.add_node({DEFAULT_ATTR_KEYS.T: 0, DEFAULT_ATTR_KEYS.Y: 10.0, DEFAULT_ATTR_KEYS.X: 10.0}) graph2.add_node({DEFAULT_ATTR_KEYS.T: 0, DEFAULT_ATTR_KEYS.Y: 11.0, DEFAULT_ATTR_KEYS.X: 11.0}) @@ -258,10 +258,10 @@ def test_scale_validation(self): graph1 = RustWorkXGraph() graph2 = RustWorkXGraph() - graph1.add_node_attr_key("y", 0.0) - graph1.add_node_attr_key("x", 0.0) - graph2.add_node_attr_key("y", 0.0) - graph2.add_node_attr_key("x", 0.0) + graph1.add_node_attr_key("y", dtype=pl.Float64, default_value=0.0) + graph1.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) + graph2.add_node_attr_key("y", dtype=pl.Float64, default_value=0.0) + graph2.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) graph1.add_node({DEFAULT_ATTR_KEYS.T: 0, "y": 10.0, "x": 10.0}) graph2.add_node({DEFAULT_ATTR_KEYS.T: 0, "y": 11.0, "x": 11.0}) @@ -286,8 +286,8 @@ def test_graph_match_integration(self): graph2 = RustWorkXGraph() default_mask = Mask(np.array([[False]]), np.array([0, 0, 1, 1])) - graph1.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, default_mask) - graph2.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, default_mask) + graph1.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, dtype=pl.Object, default_value=default_mask) + graph2.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, dtype=pl.Object, default_value=default_mask) mask1 = create_2d_mask_from_coords([(0, 0), (0, 1), (0, 2), (0, 3), (0, 4)]) mask2 = create_2d_mask_from_coords([(0, 2), (0, 3), (0, 4), (0, 5), (0, 6)]) @@ -322,10 +322,10 @@ def test_graph_match_integration(self): graph1 = RustWorkXGraph() graph2 = RustWorkXGraph() - graph1.add_node_attr_key("y", 0.0) - graph1.add_node_attr_key("x", 0.0) - graph2.add_node_attr_key("y", 0.0) - graph2.add_node_attr_key("x", 0.0) + graph1.add_node_attr_key("y", dtype=pl.Float64, default_value=0.0) + graph1.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) + graph2.add_node_attr_key("y", dtype=pl.Float64, default_value=0.0) + graph2.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) node1 = graph1.add_node({DEFAULT_ATTR_KEYS.T: 0, "y": 10.0, "x": 10.0}) node2 = graph2.add_node({DEFAULT_ATTR_KEYS.T: 0, "y": 11.0, "x": 11.0}) diff --git a/src/tracksdata/nodes/_test/test_generic_nodes.py b/src/tracksdata/nodes/_test/test_generic_nodes.py index 1fdc9375..62cd7102 100644 --- a/src/tracksdata/nodes/_test/test_generic_nodes.py +++ b/src/tracksdata/nodes/_test/test_generic_nodes.py @@ -1,4 +1,5 @@ import numpy as np +import polars as pl import pytest from numpy.typing import NDArray @@ -60,7 +61,7 @@ def test_crop_func_attrs_simple_function_no_frames() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("value", 0.0) + graph.add_node_attr_key("value", dtype=pl.Float64, default_value=0.0) # Add nodes with values node1 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "value": 10.0}) @@ -93,7 +94,7 @@ def test_crop_func_attrs_function_with_frames() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, None) + graph.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, dtype=pl.Object, default_value=None) # Create test masks mask1_data = np.array([[True, True], [True, False]], dtype=bool) @@ -146,8 +147,8 @@ def test_crop_func_attrs_function_with_frames_and_attrs() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, None) - graph.add_node_attr_key("multiplier", 1.0) + graph.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, dtype=pl.Object, default_value=None) + graph.add_node_attr_key("multiplier", dtype=pl.Float64, default_value=1.0) # Create test masks mask1_data = np.array([[True, True], [True, False]], dtype=bool) @@ -199,7 +200,7 @@ def test_crop_func_attrs_function_returns_different_types() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, None) + graph.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, dtype=pl.Object, default_value=None) # Create test mask mask_data = np.array([[True, True], [True, False]], dtype=bool) @@ -265,7 +266,7 @@ def test_crop_func_attrs_error_handling_missing_attr_key() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, None) + graph.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, dtype=pl.Object, default_value=None) # Note: "value" is not registered # Create test mask @@ -296,7 +297,7 @@ def test_crop_func_attrs_function_with_frames_multiprocessing(n_workers: int) -> graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, None) + graph.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, dtype=pl.Object, default_value=None) # Create test masks for multiple time points mask1_data = np.array([[True, True], [True, False]], dtype=bool) @@ -348,7 +349,7 @@ def test_crop_func_attrs_empty_graph() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, None) + graph.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, dtype=pl.Object, default_value=None) def dummy_func(mask: Mask) -> float: return 1.0 @@ -377,7 +378,7 @@ def test_crop_func_attrs_batch_processing_without_frames() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("value", 0.0) + graph.add_node_attr_key("value", dtype=pl.Float64, default_value=0.0) # Add nodes with values node1 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "value": 10.0}) @@ -414,7 +415,7 @@ def test_crop_func_attrs_batch_processing_with_frames() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, None) + graph.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, dtype=pl.Object, default_value=None) # Create test masks mask1_data = np.array([[True, True], [True, False]], dtype=bool) diff --git a/src/tracksdata/solvers/_test/test_ilp_solver.py b/src/tracksdata/solvers/_test/test_ilp_solver.py index 0e9e6524..87bb9f63 100644 --- a/src/tracksdata/solvers/_test/test_ilp_solver.py +++ b/src/tracksdata/solvers/_test/test_ilp_solver.py @@ -1,5 +1,6 @@ import math +import polars as pl import pytest from tracksdata.attrs import Attr @@ -85,8 +86,8 @@ def test_ilp_solver_solve_no_edges(caplog: pytest.LogCaptureFixture) -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", 0.0) - graph.add_node_attr_key("y", 0.0) + graph.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) + graph.add_node_attr_key("y", dtype=pl.Float64, default_value=0.0) # Add some nodes graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 0.0, "y": 0.0}) @@ -107,9 +108,9 @@ def test_ilp_solver_solve_simple_case() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", 0.0) - graph.add_node_attr_key("y", 0.0) - graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, 0.0) + graph.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) + graph.add_node_attr_key("y", dtype=pl.Float64, default_value=0.0) + graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, dtype=pl.Float64, default_value=0.0) # Add nodes node0 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 0.0, "y": 0.0}) @@ -147,9 +148,9 @@ def test_ilp_solver_solve_with_appearance_weight() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", 0.0) - graph.add_node_attr_key("y", 0.0) - graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, 0.0) + graph.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) + graph.add_node_attr_key("y", dtype=pl.Float64, default_value=0.0) + graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, dtype=pl.Float64, default_value=0.0) # Add nodes node0 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 0.0, "y": 0.0}) @@ -187,9 +188,9 @@ def test_ilp_solver_solve_with_disappearance_weight() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", 0.0) - graph.add_node_attr_key("y", 0.0) - graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, 0.0) + graph.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) + graph.add_node_attr_key("y", dtype=pl.Float64, default_value=0.0) + graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, dtype=pl.Float64, default_value=0.0) # Add nodes node0 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 0.0, "y": 0.0}) @@ -227,9 +228,9 @@ def test_ilp_solver_solve_with_division_weight() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", 0.0) - graph.add_node_attr_key("y", 0.0) - graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, 0.0) + graph.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) + graph.add_node_attr_key("y", dtype=pl.Float64, default_value=0.0) + graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, dtype=pl.Float64, default_value=0.0) # Add nodes for division scenario node0 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 0.0, "y": 0.0}) @@ -279,10 +280,10 @@ def test_ilp_solver_solve_custom_edge_weight_expr() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", 0.0) - graph.add_node_attr_key("y", 0.0) - graph.add_edge_attr_key("custom_weight", 0.0) - graph.add_edge_attr_key("confidence", 0.0) + graph.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) + graph.add_node_attr_key("y", dtype=pl.Float64, default_value=0.0) + graph.add_edge_attr_key("custom_weight", dtype=pl.Float64, default_value=0.0) + graph.add_edge_attr_key("confidence", dtype=pl.Float64, default_value=0.0) # Add nodes node0 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 0.0, "y": 0.0}) @@ -309,9 +310,9 @@ def test_ilp_solver_solve_custom_node_weight_expr() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", 0.0) - graph.add_node_attr_key("quality", 0.0) - graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, 0.0) + graph.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) + graph.add_node_attr_key("quality", dtype=pl.Float64, default_value=0.0) + graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, dtype=pl.Float64, default_value=0.0) # Add nodes with quality attribute node0 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 0.0, "quality": 0.9}) @@ -337,9 +338,9 @@ def test_ilp_solver_solve_custom_output_key() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", 0.0) - graph.add_node_attr_key("y", 0.0) - graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, 0.0) + graph.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) + graph.add_node_attr_key("y", dtype=pl.Float64, default_value=0.0) + graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, dtype=pl.Float64, default_value=0.0) # Add nodes and edges node0 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 0.0, "y": 0.0}) @@ -365,9 +366,9 @@ def test_ilp_solver_solve_with_all_weights() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", 0.0) - graph.add_node_attr_key("y", 0.0) - graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, 0.0) + graph.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) + graph.add_node_attr_key("y", dtype=pl.Float64, default_value=0.0) + graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, dtype=pl.Float64, default_value=0.0) # Add nodes node0 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 0.0, "y": 0.0}) @@ -430,9 +431,9 @@ def test_ilp_solver_division_constraint() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", 0.0) - graph.add_node_attr_key("y", 0.0) - graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, 0.0) + graph.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) + graph.add_node_attr_key("y", dtype=pl.Float64, default_value=0.0) + graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, dtype=pl.Float64, default_value=0.0) # Create a scenario where division would be tempting but should be constrained # Time 0: 1 parent node @@ -503,9 +504,9 @@ def test_ilp_solver_solve_with_inf_expr() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", 0.0) - graph.add_node_attr_key("y", 0.0) - graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, 0.0) + graph.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) + graph.add_node_attr_key("y", dtype=pl.Float64, default_value=0.0) + graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, dtype=pl.Float64, default_value=0.0) # Add nodes node0 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 0.0, "y": 5.0}) @@ -536,8 +537,8 @@ def test_ilp_solver_solve_with_pos_inf_rejection() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", 0.0) - graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, 0.0) + graph.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) + graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, dtype=pl.Float64, default_value=0.0) # Add nodes node0 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 1.0}) @@ -565,8 +566,8 @@ def test_ilp_solver_solve_with_neg_inf_node_weight() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("priority", 0.0) - graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, 0.0) + graph.add_node_attr_key("priority", dtype=pl.Float64, default_value=0.0) + graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, dtype=pl.Float64, default_value=0.0) # Add nodes node0 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "priority": 1.0}) # High priority @@ -594,8 +595,8 @@ def test_ilp_solver_solve_with_inf_edge_weight() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", 0.0) - graph.add_edge_attr_key("confidence", 0.0) + graph.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) + graph.add_edge_attr_key("confidence", dtype=pl.Float64, default_value=0.0) # Add nodes node0 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 0.0}) @@ -626,9 +627,9 @@ def test_ilp_solver_solve_with_overlaps() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", 0.0) - graph.add_node_attr_key("y", 0.0) - graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, 0.0) + graph.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) + graph.add_node_attr_key("y", dtype=pl.Float64, default_value=0.0) + graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, dtype=pl.Float64, default_value=0.0) # Add nodes - overlapping pair at time t=1 node0 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 0.0, "y": 0.0}) @@ -681,9 +682,9 @@ def test_ilp_solver_solve_with_merge_weight() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", 0.0) - graph.add_node_attr_key("y", 0.0) - graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, 0.0) + graph.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) + graph.add_node_attr_key("y", dtype=pl.Float64, default_value=0.0) + graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, dtype=pl.Float64, default_value=0.0) # Simple merge scenario: 2 tracks -> 1 merge point track1_node = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 0.0, "y": 0.0}) @@ -713,8 +714,8 @@ def test_ilp_solver_solve_with_merge_weight() -> None: def test_ilp_solver_solve_with_positive_merge_weight() -> None: """Test solving with positive merge weight to penalize merges.""" graph = RustWorkXGraph() - graph.add_node_attr_key("x", 0.0) - graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, 0.0) + graph.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) + graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, dtype=pl.Float64, default_value=0.0) # Create merge scenario track1_node0 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 0.0}) @@ -742,8 +743,8 @@ def test_ilp_solver_solve_with_positive_merge_weight() -> None: def test_ilp_solver_solve_with_merge_expression() -> None: """Test solving with merge weight as an expression.""" graph = RustWorkXGraph() - graph.add_node_attr_key("merge_cost", 0.0) - graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, 0.0) + graph.add_node_attr_key("merge_cost", dtype=pl.Float64, default_value=0.0) + graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, dtype=pl.Float64, default_value=0.0) # Two source nodes source1 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "merge_cost": 0.0}) @@ -774,8 +775,8 @@ def test_ilp_solver_solve_merge_and_division_combined() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", 0.0) - graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, 0.0) + graph.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) + graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, dtype=pl.Float64, default_value=0.0) # Create complex scenario: merge followed by division # Time 0: Two separate tracks diff --git a/src/tracksdata/solvers/_test/test_nearest_neighbors_solver.py b/src/tracksdata/solvers/_test/test_nearest_neighbors_solver.py index 4481013f..98b73615 100644 --- a/src/tracksdata/solvers/_test/test_nearest_neighbors_solver.py +++ b/src/tracksdata/solvers/_test/test_nearest_neighbors_solver.py @@ -1,3 +1,4 @@ +import polars as pl import pytest from tracksdata.attrs import Attr @@ -47,8 +48,8 @@ def test_nearest_neighbors_solver_solve_no_edges() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", 0.0) - graph.add_node_attr_key("y", 0.0) + graph.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) + graph.add_node_attr_key("y", dtype=pl.Float64, default_value=0.0) # Add some nodes graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 0.0, "y": 0.0}) @@ -66,9 +67,9 @@ def test_nearest_neighbors_solver_solve_simple_case() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", 0.0) - graph.add_node_attr_key("y", 0.0) - graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, 0.0) + graph.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) + graph.add_node_attr_key("y", dtype=pl.Float64, default_value=0.0) + graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, dtype=pl.Float64, default_value=0.0) # Add nodes node0 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 0.0, "y": 0.0}) @@ -102,9 +103,9 @@ def test_nearest_neighbors_solver_solve_max_children_constraint() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", 0.0) - graph.add_node_attr_key("y", 0.0) - graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, 0.0) + graph.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) + graph.add_node_attr_key("y", dtype=pl.Float64, default_value=0.0) + graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, dtype=pl.Float64, default_value=0.0) # Add nodes node0 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 0.0, "y": 0.0}) # Parent @@ -142,9 +143,9 @@ def test_nearest_neighbors_solver_solve_one_parent_constraint() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", 0.0) - graph.add_node_attr_key("y", 0.0) - graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, 0.0) + graph.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) + graph.add_node_attr_key("y", dtype=pl.Float64, default_value=0.0) + graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, dtype=pl.Float64, default_value=0.0) # Add nodes node0 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 0.0, "y": 0.0}) # Parent 1 @@ -174,9 +175,9 @@ def test_nearest_neighbors_solver_solve_custom_weight_expr() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", 0.0) - graph.add_node_attr_key("y", 0.0) - graph.add_edge_attr_key("custom_weight", 0.0) + graph.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) + graph.add_node_attr_key("y", dtype=pl.Float64, default_value=0.0) + graph.add_edge_attr_key("custom_weight", dtype=pl.Float64, default_value=0.0) # Add nodes node0 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 0.0, "y": 0.0}) @@ -207,10 +208,10 @@ def test_nearest_neighbors_solver_solve_complex_expression() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", 0.0) - graph.add_node_attr_key("y", 0.0) - graph.add_edge_attr_key("distance", 0.0) - graph.add_edge_attr_key("confidence", 0.0) + graph.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) + graph.add_node_attr_key("y", dtype=pl.Float64, default_value=0.0) + graph.add_edge_attr_key("distance", dtype=pl.Float64, default_value=0.0) + graph.add_edge_attr_key("confidence", dtype=pl.Float64, default_value=0.0) # Add nodes node0 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 0.0, "y": 0.0}) @@ -242,9 +243,9 @@ def test_nearest_neighbors_solver_solve_custom_output_key() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", 0.0) - graph.add_node_attr_key("y", 0.0) - graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, 0.0) + graph.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) + graph.add_node_attr_key("y", dtype=pl.Float64, default_value=0.0) + graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, dtype=pl.Float64, default_value=0.0) # Add nodes and edges node0 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 0.0, "y": 0.0}) @@ -270,9 +271,9 @@ def test_nearest_neighbors_solver_solve_with_overlaps() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", 0.0) - graph.add_node_attr_key("y", 0.0) - graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, 0.0) + graph.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) + graph.add_node_attr_key("y", dtype=pl.Float64, default_value=0.0) + graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, dtype=pl.Float64, default_value=0.0) # Add nodes - overlapping pair at time t=1 node0 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 0.0, "y": 0.0}) @@ -325,9 +326,9 @@ def test_nearest_neighbors_solver_solve_large_graph() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", 0.0) - graph.add_node_attr_key("y", 0.0) - graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, 0.0) + graph.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) + graph.add_node_attr_key("y", dtype=pl.Float64, default_value=0.0) + graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, dtype=pl.Float64, default_value=0.0) # Create a more complex graph structure # Time 0: nodes 0, 1 From 5392cbfb6519a98209c357be278d6bad7635321d Mon Sep 17 00:00:00 2001 From: Jordao Bragantini Date: Tue, 20 Jan 2026 10:54:14 -0800 Subject: [PATCH 04/14] fixing correct usage of dtypes and using default when possible --- .../array/_test/test_graph_array.py | 28 +++---- src/tracksdata/edges/_test/test_iou_edges.py | 10 +-- .../functional/_test/test_napari.py | 2 +- .../graph/_test/test_graph_backends.py | 74 +++++++++---------- src/tracksdata/graph/_test/test_subgraph.py | 8 +- .../filters/_test/test_spatial_filter.py | 4 +- src/tracksdata/io/_test/test_ctc_io.py | 4 +- src/tracksdata/nodes/_regionprops.py | 6 +- .../nodes/_test/test_generic_nodes.py | 14 ++-- 9 files changed, 73 insertions(+), 77 deletions(-) diff --git a/src/tracksdata/array/_test/test_graph_array.py b/src/tracksdata/array/_test/test_graph_array.py index b37f23b2..e19140ad 100644 --- a/src/tracksdata/array/_test/test_graph_array.py +++ b/src/tracksdata/array/_test/test_graph_array.py @@ -33,7 +33,7 @@ def test_graph_array_view_init(graph_backend: BaseGraph) -> None: """Test GraphArrayView initialization.""" # Add a attribute key graph_backend.add_node_attr_key("label", dtype=pl.Int64, default_value=0) - graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, dtype=pl.Object, default_value=None) + graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, pl.Object) graph_backend.add_node_attr_key( DEFAULT_ATTR_KEYS.BBOX, dtype=pl.Object, default_value=np.asarray([0, 0, 0, 0, 0, 0]) ) @@ -61,7 +61,7 @@ def test_graph_array_view_getitem_empty_time(graph_backend: BaseGraph) -> None: """Test __getitem__ with empty time point (no nodes).""" graph_backend.add_node_attr_key("label", dtype=pl.Int64, default_value=0) - graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, dtype=pl.Object, default_value=None) + graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, pl.Object) graph_backend.add_node_attr_key( DEFAULT_ATTR_KEYS.BBOX, dtype=pl.Object, default_value=np.asarray([0, 0, 0, 0, 0, 0]) ) @@ -82,8 +82,8 @@ def test_graph_array_view_getitem_with_nodes(graph_backend: BaseGraph) -> None: # Add attribute keys graph_backend.add_node_attr_key("label", dtype=pl.Int64, default_value=0) - graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, dtype=pl.Object, default_value=None) - graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.BBOX, dtype=pl.Object, default_value=np.asarray([0, 0, 0, 0])) + graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, pl.Object) + graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.BBOX, pl.Array(pl.Int64, 4)) graph_backend.add_node_attr_key("y", dtype=pl.Int64, default_value=0) graph_backend.add_node_attr_key("x", dtype=pl.Int64, default_value=0) @@ -132,8 +132,8 @@ def test_graph_array_view_getitem_multiple_nodes(graph_backend: BaseGraph) -> No # Add attribute keys graph_backend.add_node_attr_key("label", dtype=pl.Int64, default_value=0) - graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, dtype=pl.Object, default_value=None) - graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.BBOX, dtype=pl.Object, default_value=np.asarray([0, 0, 0, 0])) + graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, pl.Object) + graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.BBOX, pl.Array(pl.Int64, 4)) graph_backend.add_node_attr_key("y", dtype=pl.Int64, default_value=0) graph_backend.add_node_attr_key("x", dtype=pl.Int64, default_value=0) @@ -186,9 +186,9 @@ def test_graph_array_view_getitem_boolean_dtype(graph_backend: BaseGraph) -> Non """Test __getitem__ with boolean attribute values.""" # Add attribute keys - graph_backend.add_node_attr_key("is_active", dtype=pl.Boolean, default_value=False) - graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, dtype=pl.Object, default_value=None) - graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.BBOX, dtype=pl.Object, default_value=np.asarray([0, 0, 0, 0])) + graph_backend.add_node_attr_key("is_active", pl.Boolean) + graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, pl.Object) + graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.BBOX, pl.Array(pl.Int64, 4)) graph_backend.add_node_attr_key("y", dtype=pl.Int64, default_value=0) graph_backend.add_node_attr_key("x", dtype=pl.Int64, default_value=0) @@ -224,8 +224,8 @@ def test_graph_array_view_dtype_inference(graph_backend: BaseGraph) -> None: # Add attribute keys graph_backend.add_node_attr_key("float_label", dtype=pl.Float64, default_value=0.0) - graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, dtype=pl.Object, default_value=None) - graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.BBOX, dtype=pl.Object, default_value=np.asarray([0, 0, 0, 0])) + graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, pl.Object) + graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.BBOX, pl.Array(pl.Int64, 4)) graph_backend.add_node_attr_key("y", dtype=pl.Int64, default_value=0) graph_backend.add_node_attr_key("x", dtype=pl.Int64, default_value=0) @@ -346,8 +346,8 @@ def test_graph_array_view_getitem_time_index_nested(multi_node_graph_from_image, def test_graph_array_set_options(graph_backend: BaseGraph) -> None: with Options(gav_chunk_shape=(512, 512), gav_default_dtype=np.int16): graph_backend.add_node_attr_key("label", dtype=pl.Int64, default_value=0) - graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, dtype=pl.Object, default_value=None) - graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.BBOX, dtype=pl.Object, default_value=np.asarray([0, 0, 0, 0])) + graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, pl.Object) + graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.BBOX, pl.Array(pl.Int64, 4)) array_view = GraphArrayView(graph=graph_backend, shape=(10, 100, 100), attr_key="label") assert array_view.chunk_shape == (512, 512) assert array_view.dtype == np.int16 @@ -371,7 +371,7 @@ def test_graph_array_raise_error_on_non_scalar_attr_key(graph_backend: BaseGraph graph_backend.add_node_attr_key( "label", dtype=pl.Object, default_value=np.array([0, 1]) ) # Non-scalar default value - graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, dtype=pl.Object, default_value=None) + graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, pl.Object) graph_backend.add_node( { DEFAULT_ATTR_KEYS.T: 0, diff --git a/src/tracksdata/edges/_test/test_iou_edges.py b/src/tracksdata/edges/_test/test_iou_edges.py index cfdb6737..834e5e72 100644 --- a/src/tracksdata/edges/_test/test_iou_edges.py +++ b/src/tracksdata/edges/_test/test_iou_edges.py @@ -33,7 +33,7 @@ def test_iou_edges_add_weights(n_workers: int) -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, dtype=pl.Object, default_value=None) + graph.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, pl.Object) graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, dtype=pl.Float64, default_value=0.0) # Create test masks @@ -83,7 +83,7 @@ def test_iou_edges_no_overlap() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, dtype=pl.Object, default_value=None) + graph.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, pl.Object) graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, dtype=pl.Float64, default_value=0.0) # Create non-overlapping masks @@ -122,7 +122,7 @@ def test_iou_edges_perfect_overlap() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, dtype=pl.Object, default_value=None) + graph.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, pl.Object) graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, dtype=pl.Float64, default_value=0.0) # Create identical masks @@ -158,8 +158,8 @@ def test_iou_edges_custom_mask_key() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("custom_mask", dtype=pl.Object, default_value=None) - graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, dtype=pl.Float64, default_value=0.0) + graph.add_node_attr_key("custom_mask", pl.Object) + graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, pl.Float64, default_value=0.0) # Create test masks mask1_data = np.array([[True, True], [True, True]], dtype=bool) diff --git a/src/tracksdata/functional/_test/test_napari.py b/src/tracksdata/functional/_test/test_napari.py index c1044ef9..9b4a81dc 100644 --- a/src/tracksdata/functional/_test/test_napari.py +++ b/src/tracksdata/functional/_test/test_napari.py @@ -44,7 +44,7 @@ def test_napari_conversion(metadata_shape: bool) -> None: mask_attrs.add_node_attrs(graph) # Maybe we should update the MaskDiskAttrs to handle bounding boxes - graph.add_node_attr_key(DEFAULT_ATTR_KEYS.BBOX, dtype=pl.Object, default_value=None) + graph.add_node_attr_key(DEFAULT_ATTR_KEYS.BBOX, dtype=pl.Array(pl.Int64, 6)) masks = graph.node_attrs(attr_keys=[DEFAULT_ATTR_KEYS.MASK])[DEFAULT_ATTR_KEYS.MASK] graph.update_node_attrs( attrs={DEFAULT_ATTR_KEYS.BBOX: [mask.bbox for mask in masks]}, diff --git a/src/tracksdata/graph/_test/test_graph_backends.py b/src/tracksdata/graph/_test/test_graph_backends.py index 88d42885..373a8a23 100644 --- a/src/tracksdata/graph/_test/test_graph_backends.py +++ b/src/tracksdata/graph/_test/test_graph_backends.py @@ -19,10 +19,10 @@ def test_already_existing_keys(graph_backend: BaseGraph) -> None: """Test that adding already existing keys raises an error.""" - graph_backend.add_node_attr_key("x", dtype=pl.Object, default_value=None) + graph_backend.add_node_attr_key("x", pl.Float64) with pytest.raises(ValueError): - graph_backend.add_node_attr_key("x", dtype=pl.Object, default_value=None) + graph_backend.add_node_attr_key("x", pl.Float64) with pytest.raises(ValueError): # missing x @@ -77,7 +77,7 @@ def test_add_node(graph_backend: BaseGraph) -> None: def test_add_edge(graph_backend: BaseGraph) -> None: """Test adding edges with attributes.""" # Add node attribute key - graph_backend.add_node_attr_key("x", dtype=pl.Object, default_value=None) + graph_backend.add_node_attr_key("x", pl.Float64) # Add two nodes first node1 = graph_backend.add_node({"t": 0, "x": 1.0}) @@ -110,7 +110,7 @@ def test_add_edge(graph_backend: BaseGraph) -> None: def test_remove_edge_by_id(graph_backend: BaseGraph) -> None: """Test removing an edge by ID across backends using unified API.""" # Setup - graph_backend.add_node_attr_key("x", dtype=pl.Object, default_value=None) + graph_backend.add_node_attr_key("x", pl.Float64) graph_backend.add_edge_attr_key("weight", dtype=pl.Float64, default_value=0.0) n1 = graph_backend.add_node({"t": 0, "x": 1.0}) @@ -147,7 +147,7 @@ def test_remove_edge_by_id(graph_backend: BaseGraph) -> None: def test_remove_edge_by_nodes(graph_backend: BaseGraph) -> None: """Test removing an edge by its source/target IDs.""" - graph_backend.add_node_attr_key("x", dtype=pl.Object, default_value=None) + graph_backend.add_node_attr_key("x", pl.Float64) graph_backend.add_edge_attr_key("weight", dtype=pl.Float64, default_value=0.0) a = graph_backend.add_node({"t": 0, "x": 0.0}) @@ -184,7 +184,7 @@ def test_node_ids(graph_backend: BaseGraph) -> None: def test_filter_nodes_by_attribute(graph_backend: BaseGraph) -> None: """Test filtering nodes by attributes.""" - graph_backend.add_node_attr_key("label", dtype=pl.Object, default_value=None) + graph_backend.add_node_attr_key("label", pl.String) node1 = graph_backend.add_node({"t": 0, "label": "A"}) node2 = graph_backend.add_node({"t": 0, "label": "B"}) @@ -235,8 +235,8 @@ def test_time_points(graph_backend: BaseGraph) -> None: def test_node_attrs(graph_backend: BaseGraph) -> None: """Test retrieving node attributes.""" - graph_backend.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) - graph_backend.add_node_attr_key("coordinates", dtype=pl.Object, default_value=np.array([0.0, 0.0])) + graph_backend.add_node_attr_key("x", pl.Float64) + graph_backend.add_node_attr_key("coordinates", pl.Array(pl.Float64, 2)) node1 = graph_backend.add_node({"t": 0, "x": 1.0, "coordinates": np.array([10.0, 20.0])}) node2 = graph_backend.add_node({"t": 1, "x": 2.0, "coordinates": np.array([30.0, 40.0])}) @@ -257,8 +257,8 @@ def test_edge_attrs(graph_backend: BaseGraph) -> None: node1 = graph_backend.add_node({"t": 0}) node2 = graph_backend.add_node({"t": 1}) - graph_backend.add_edge_attr_key("weight", dtype=pl.Float64, default_value=0.0) - graph_backend.add_edge_attr_key("vector", dtype=pl.Object, default_value=np.array([0.0, 0.0])) + graph_backend.add_edge_attr_key("weight", pl.Float64, default_value=0.0) + graph_backend.add_edge_attr_key("vector", pl.Array(pl.Float64, 2)) graph_backend.add_edge(node1, node2, attrs={"weight": 0.5, "vector": np.array([1.0, 2.0])}) @@ -374,7 +374,7 @@ def test_subgraph_with_node_and_edge_attr_filters(graph_backend: BaseGraph) -> N def test_subgraph_with_node_ids_and_filters(graph_backend: BaseGraph) -> None: """Test subgraph with node IDs and filters.""" - graph_backend.add_node_attr_key("x", dtype=pl.Object, default_value=None) + graph_backend.add_node_attr_key("x", pl.Float64) graph_backend.add_edge_attr_key("weight", dtype=pl.Float64, default_value=0.0) node0 = graph_backend.add_node({"t": 0, "x": 1.0}) @@ -994,7 +994,7 @@ def test_match_method(graph_backend: BaseGraph) -> None: # Create first graph (self) with masks graph_backend.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) graph_backend.add_node_attr_key("y", dtype=pl.Float64, default_value=0.0) - graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, dtype=pl.Object, default_value=None) + graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, pl.Object) # Create masks for first graph mask1_data = np.array([[True, True], [True, True]], dtype=bool) @@ -1028,7 +1028,7 @@ def test_match_method(graph_backend: BaseGraph) -> None: other_graph = graph_backend.__class__(**kwargs) other_graph.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) other_graph.add_node_attr_key("y", dtype=pl.Float64, default_value=0.0) - other_graph.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, dtype=pl.Object, default_value=None) + other_graph.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, pl.Object) # Create overlapping masks for second graph # This mask overlaps significantly with mask1 (IoU > 0.5) @@ -1502,7 +1502,7 @@ def build_node_map(graph: BaseGraph) -> dict[tuple[int, tuple[int, ...]], dict[s def test_compute_overlaps_basic(graph_backend: BaseGraph) -> None: """Test basic compute_overlaps functionality.""" - graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, dtype=pl.Object, default_value=None) + graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, pl.Object) # Create overlapping masks at time 0 mask1_data = np.array([[True, True], [True, True]], dtype=bool) @@ -1524,7 +1524,7 @@ def test_compute_overlaps_basic(graph_backend: BaseGraph) -> None: def test_compute_overlaps_with_threshold(graph_backend: BaseGraph) -> None: """Test compute_overlaps with different IoU thresholds.""" - graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, dtype=pl.Object, default_value=None) + graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, pl.Object) # Create masks with different overlap levels mask1_data = np.array([[True, True], [True, True]], dtype=bool) @@ -1558,7 +1558,7 @@ def test_compute_overlaps_with_threshold(graph_backend: BaseGraph) -> None: def test_compute_overlaps_multiple_timepoints(graph_backend: BaseGraph) -> None: """Test compute_overlaps across multiple time points.""" - graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, dtype=pl.Object, default_value=None) + graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, pl.Object) # Time 0: overlapping masks mask1_t0 = Mask(np.array([[True, True], [True, True]], dtype=bool), bbox=np.array([0, 0, 2, 2])) @@ -1585,7 +1585,7 @@ def test_sql_graph_mask_update_survives_reload(tmp_path: Path) -> None: """Ensure SQLGraph keeps pickled column types after reloading from disk.""" db_path = tmp_path / "mask_graph.db" graph = SQLGraph("sqlite", str(db_path)) - graph.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, dtype=pl.Object, default_value=None) + graph.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, pl.Object) mask_data = np.array([[True, False], [False, True]], dtype=bool) mask = Mask(mask_data, bbox=np.array([0, 0, 2, 2])) @@ -1659,10 +1659,10 @@ def test_summary(graph_backend: BaseGraph) -> None: def test_spatial_filter_basic(graph_backend: BaseGraph) -> None: - graph_backend.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) - graph_backend.add_node_attr_key("y", dtype=pl.Float64, default_value=0.0) - graph_backend.add_node_attr_key("z", dtype=pl.Float64, default_value=0.0) - graph_backend.add_node_attr_key("bbox", pl.Array(pl.Int64, 6), default_value=np.zeros(6, dtype=int)) + graph_backend.add_node_attr_key("x", pl.Float64) + graph_backend.add_node_attr_key("y", pl.Float64) + graph_backend.add_node_attr_key("z", pl.Float64) + graph_backend.add_node_attr_key("bbox", pl.Array(pl.Int64, 6)) node1 = graph_backend.add_node({"t": 0, "x": 1.0, "y": 1.0, "z": 1.0, "bbox": np.array([6, 6, 6, 8, 8, 8])}) node2 = graph_backend.add_node({"t": 1, "x": 2.0, "y": 2.0, "z": 2.0, "bbox": np.array([0, 0, 0, 3, 3, 3])}) @@ -1823,7 +1823,7 @@ def test_assign_tracklet_ids_node_id_filter(graph_backend: BaseGraph, return_id_ # Ensure tracklet_id attribute exists after nodes were added if DEFAULT_ATTR_KEYS.TRACKLET_ID not in graph_backend.node_attr_keys(): - graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.TRACKLET_ID, dtype=pl.Int64, default_value=-1) + graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.TRACKLET_ID, pl.Int64) for seeds, expected in ( ([A1], [[A0, A1, A2, A3]]), @@ -2017,7 +2017,7 @@ def test_nodes_interface(graph_backend: BaseGraph) -> None: assert graph_backend[node2]["x"] == 0 assert graph_backend[node3]["x"] == -1 - graph_backend.add_node_attr_key("y", dtype=pl.Int64, default_value=-1) + graph_backend.add_node_attr_key("y", pl.Int64) graph_backend[node2]["y"] = 5 @@ -2297,11 +2297,11 @@ def _fill_mock_geff_graph(graph_backend: BaseGraph) -> None: graph_backend.add_node_attr_key( DEFAULT_ATTR_KEYS.BBOX, pl.Array(pl.Int64, 4), default_value=np.array([0, 0, 1, 1], dtype=int) ) - graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, dtype=pl.Object, default_value=None) - graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.TRACKLET_ID, dtype=pl.Int64, default_value=-1) - graph_backend.add_node_attr_key("ndfeature", dtype=pl.Object, default_value=np.asarray([[1.0], [2.0], [3.0]])) + graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, pl.Object) + graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.TRACKLET_ID, pl.Int64) + graph_backend.add_node_attr_key("ndfeature", pl.Object) - graph_backend.add_edge_attr_key("weight", dtype=pl.Float64, default_value=0.0) + graph_backend.add_edge_attr_key("weight", pl.Float64) graph_backend.update_metadata( shape=[1, 25, 25], @@ -2475,9 +2475,9 @@ def test_pickle_roundtrip(graph_backend: BaseGraph) -> None: if isinstance(graph_backend, SQLGraph): pytest.skip("SQLGraph does not support pickle roundtrip") - graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.BBOX, dtype=pl.Object, default_value=None) - graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, dtype=pl.Object, default_value=None) - graph_backend.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, dtype=pl.Float64, default_value=0.0) + graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.BBOX, pl.Array(pl.Int64, 4)) + graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, pl.Object) + graph_backend.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, pl.Float64) bbox = np.array([0, 0, 2, 2]) mask = Mask(np.array([[True, True], [True, True]], dtype=bool), bbox=bbox) @@ -2519,7 +2519,7 @@ def test_sql_graph_huge_update() -> None: random_t = np.random.randint(0, 1000, n_nodes).tolist() random_x = np.random.rand(n_nodes).tolist() graph.bulk_add_nodes([{"t": t} for t in random_t]) - graph.add_node_attr_key("x", dtype=pl.Float64, default_value=-1.0) + graph.add_node_attr_key("x", pl.Float64) # testing with varying values graph.update_node_attrs( @@ -2541,13 +2541,11 @@ def test_to_traccuracy_graph(graph_backend: BaseGraph) -> None: from traccuracy.metrics import CTCMetrics # Create first graph (self) with masks - graph_backend.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) - graph_backend.add_node_attr_key("y", dtype=pl.Float64, default_value=0.0) - graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, dtype=pl.Object, default_value=None) - graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.BBOX, pl.Array(pl.Int64, 4), default_value=np.zeros(4, dtype=int)) - graph_backend.update_metadata( - shape=[3, 25, 25], - ) + graph_backend.add_node_attr_key("x", pl.Float64) + graph_backend.add_node_attr_key("y", pl.Float64) + graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, pl.Object) + graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.BBOX, pl.Array(pl.Int64, 4)) + graph_backend.update_metadata(shape=[3, 25, 25]) # Create masks for first graph mask1_data = np.array([[True, True], [True, True]], dtype=bool) diff --git a/src/tracksdata/graph/_test/test_subgraph.py b/src/tracksdata/graph/_test/test_subgraph.py index 36b295b8..56f158b2 100644 --- a/src/tracksdata/graph/_test/test_subgraph.py +++ b/src/tracksdata/graph/_test/test_subgraph.py @@ -37,8 +37,8 @@ def create_test_graph(graph_backend: BaseGraph, use_subgraph: bool = False) -> B Either the original graph or a subgraph with test data. """ # Add attribute keys - graph_backend.add_node_attr_key("x", dtype=pl.Float64, default_value=-1.0) - graph_backend.add_node_attr_key("y", dtype=pl.Float64, default_value=-1.0) + graph_backend.add_node_attr_key("x", pl.Float64) + graph_backend.add_node_attr_key("y", pl.Float64) graph_backend.add_node_attr_key("label", dtype=pl.String, default_value="0") graph_backend.add_edge_attr_key("weight", dtype=pl.Float64, default_value=0.0) graph_backend.add_edge_attr_key("new_attribute", dtype=pl.Float64, default_value=0.0) @@ -1101,8 +1101,8 @@ def test_graph_view_remove_edge(graph_backend: BaseGraph) -> None: Tests removal by endpoints and by edge_id with the view in sync mode. """ # Setup root graph with attributes - graph_backend.add_node_attr_key("x", dtype=pl.Object, default_value=None) - graph_backend.add_edge_attr_key("weight", dtype=pl.Float64, default_value=0.0) + graph_backend.add_node_attr_key("x", pl.Float64) + graph_backend.add_edge_attr_key("weight", pl.Float64, default_value=0.0) # Nodes and edges n0 = graph_backend.add_node({"t": 0, "x": 0.0}) diff --git a/src/tracksdata/graph/filters/_test/test_spatial_filter.py b/src/tracksdata/graph/filters/_test/test_spatial_filter.py index bb7badeb..0ee21027 100644 --- a/src/tracksdata/graph/filters/_test/test_spatial_filter.py +++ b/src/tracksdata/graph/filters/_test/test_spatial_filter.py @@ -292,7 +292,7 @@ def test_bbox_spatial_filter_error_handling() -> None: def test_add_and_remove_node(graph_backend: BaseGraph) -> None: - graph_backend.add_node_attr_key("bbox", dtype=pl.Object, default_value=np.asarray([0, 0, 0, 0])) + graph_backend.add_node_attr_key("bbox", pl.Array(pl.Int64, 4)) # testing if _node_tree is created in BBoxSpatialFilter when graph is empty _ = BBoxSpatialFilter(graph_backend, frame_attr_key="t", bbox_attr_key="bbox") @@ -342,7 +342,7 @@ def test_add_and_remove_node(graph_backend: BaseGraph) -> None: def test_bbox_spatial_filter_handles_list_dtype(graph_backend: BaseGraph) -> None: """Ensure bounding boxes stored as list dtype still work with the spatial filter.""" - graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.BBOX, dtype=pl.Object, default_value=None) + graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.BBOX, pl.Array(pl.Int64, 4)) first = graph_backend.add_node({"t": 0, "bbox": [0, 0, 2, 2]}) second = graph_backend.add_node({"t": 1, "bbox": [5, 5, 8, 8]}) diff --git a/src/tracksdata/io/_test/test_ctc_io.py b/src/tracksdata/io/_test/test_ctc_io.py index e0a79886..7c5fb925 100644 --- a/src/tracksdata/io/_test/test_ctc_io.py +++ b/src/tracksdata/io/_test/test_ctc_io.py @@ -15,8 +15,8 @@ def test_export_from_ctc_roundtrip(tmp_path: Path, metadata_shape: bool) -> None # Create original graph with nodes and edges in_graph = RustWorkXGraph() - in_graph.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, pl.Object, None) - in_graph.add_node_attr_key(DEFAULT_ATTR_KEYS.BBOX, pl.Object, None) + in_graph.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, pl.Object) + in_graph.add_node_attr_key(DEFAULT_ATTR_KEYS.BBOX, pl.Array(pl.Int64, 4)) in_graph.add_node_attr_key(DEFAULT_ATTR_KEYS.TRACKLET_ID, pl.Int64, -1) in_graph.add_node_attr_key("x", pl.Float64, -999_999) in_graph.add_node_attr_key("y", pl.Float64, -999_999) diff --git a/src/tracksdata/nodes/_regionprops.py b/src/tracksdata/nodes/_regionprops.py index d8b3cb0c..c78feb32 100644 --- a/src/tracksdata/nodes/_regionprops.py +++ b/src/tracksdata/nodes/_regionprops.py @@ -128,13 +128,11 @@ def _init_node_attrs(self, graph: BaseGraph, axis_names: list[str], ndims: int) Initialize the node attributes for the graph. """ if DEFAULT_ATTR_KEYS.MASK not in graph.node_attr_keys(): - graph.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, pl.Object, None) + graph.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, pl.Object) if DEFAULT_ATTR_KEYS.BBOX not in graph.node_attr_keys(): bbox_size = 2 * (ndims - 1) - graph.add_node_attr_key( - DEFAULT_ATTR_KEYS.BBOX, pl.Array(pl.Int64, bbox_size), np.zeros(bbox_size, dtype=int) - ) + graph.add_node_attr_key(DEFAULT_ATTR_KEYS.BBOX, pl.Array(pl.Int64, bbox_size)) if "label" in self.attr_keys() and "label" not in graph.node_attr_keys(): graph.add_node_attr_key("label", pl.Int64, 0) diff --git a/src/tracksdata/nodes/_test/test_generic_nodes.py b/src/tracksdata/nodes/_test/test_generic_nodes.py index 62cd7102..51bdb342 100644 --- a/src/tracksdata/nodes/_test/test_generic_nodes.py +++ b/src/tracksdata/nodes/_test/test_generic_nodes.py @@ -94,7 +94,7 @@ def test_crop_func_attrs_function_with_frames() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, dtype=pl.Object, default_value=None) + graph.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, pl.Object) # Create test masks mask1_data = np.array([[True, True], [True, False]], dtype=bool) @@ -147,7 +147,7 @@ def test_crop_func_attrs_function_with_frames_and_attrs() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, dtype=pl.Object, default_value=None) + graph.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, pl.Object) graph.add_node_attr_key("multiplier", dtype=pl.Float64, default_value=1.0) # Create test masks @@ -200,7 +200,7 @@ def test_crop_func_attrs_function_returns_different_types() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, dtype=pl.Object, default_value=None) + graph.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, pl.Object) # Create test mask mask_data = np.array([[True, True], [True, False]], dtype=bool) @@ -266,7 +266,7 @@ def test_crop_func_attrs_error_handling_missing_attr_key() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, dtype=pl.Object, default_value=None) + graph.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, pl.Object) # Note: "value" is not registered # Create test mask @@ -297,7 +297,7 @@ def test_crop_func_attrs_function_with_frames_multiprocessing(n_workers: int) -> graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, dtype=pl.Object, default_value=None) + graph.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, pl.Object) # Create test masks for multiple time points mask1_data = np.array([[True, True], [True, False]], dtype=bool) @@ -349,7 +349,7 @@ def test_crop_func_attrs_empty_graph() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, dtype=pl.Object, default_value=None) + graph.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, pl.Object) def dummy_func(mask: Mask) -> float: return 1.0 @@ -415,7 +415,7 @@ def test_crop_func_attrs_batch_processing_with_frames() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, dtype=pl.Object, default_value=None) + graph.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, pl.Object) # Create test masks mask1_data = np.array([[True, True], [True, False]], dtype=bool) From b2368e18c077acb7be553691aafff45d74556ab4 Mon Sep 17 00:00:00 2001 From: Jordao Bragantini Date: Tue, 20 Jan 2026 11:04:08 -0800 Subject: [PATCH 05/14] fixing more attr usage --- src/tracksdata/array/_test/test_graph_array.py | 8 ++------ src/tracksdata/functional/_test/test_labeling.py | 2 +- src/tracksdata/graph/_base_graph.py | 4 ++-- src/tracksdata/graph/_test/test_graph_backends.py | 6 +++--- src/tracksdata/solvers/_ilp_solver.py | 4 ++-- src/tracksdata/solvers/_nearest_neighbors_solver.py | 4 ++-- 6 files changed, 12 insertions(+), 16 deletions(-) diff --git a/src/tracksdata/array/_test/test_graph_array.py b/src/tracksdata/array/_test/test_graph_array.py index e19140ad..057e5c06 100644 --- a/src/tracksdata/array/_test/test_graph_array.py +++ b/src/tracksdata/array/_test/test_graph_array.py @@ -34,9 +34,7 @@ def test_graph_array_view_init(graph_backend: BaseGraph) -> None: # Add a attribute key graph_backend.add_node_attr_key("label", dtype=pl.Int64, default_value=0) graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, pl.Object) - graph_backend.add_node_attr_key( - DEFAULT_ATTR_KEYS.BBOX, dtype=pl.Object, default_value=np.asarray([0, 0, 0, 0, 0, 0]) - ) + graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.BBOX, pl.Array(pl.Int64, 6)) array_view = GraphArrayView(graph=graph_backend, shape=(10, 100, 100), attr_key="label", offset=0) @@ -62,9 +60,7 @@ def test_graph_array_view_getitem_empty_time(graph_backend: BaseGraph) -> None: graph_backend.add_node_attr_key("label", dtype=pl.Int64, default_value=0) graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, pl.Object) - graph_backend.add_node_attr_key( - DEFAULT_ATTR_KEYS.BBOX, dtype=pl.Object, default_value=np.asarray([0, 0, 0, 0, 0, 0]) - ) + graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.BBOX, pl.Array(pl.Int64, 6)) array_view = GraphArrayView(graph=graph_backend, shape=(10, 100, 100), attr_key="label") diff --git a/src/tracksdata/functional/_test/test_labeling.py b/src/tracksdata/functional/_test/test_labeling.py index 96225c49..26f0bbaf 100644 --- a/src/tracksdata/functional/_test/test_labeling.py +++ b/src/tracksdata/functional/_test/test_labeling.py @@ -45,7 +45,7 @@ def test_ancestral_connected_edges(): ref_graph.add_edge(7, 8, {}) # manual matching - input_graph.add_node_attr_key(td.DEFAULT_ATTR_KEYS.MATCHED_NODE_ID, dtype=pl.Int64, default_value=-1) + input_graph.add_node_attr_key(td.DEFAULT_ATTR_KEYS.MATCHED_NODE_ID, pl.Int64) input_graph.update_node_attrs( attrs={ td.DEFAULT_ATTR_KEYS.MATCHED_NODE_ID: [0, 1, 2, 3, 4, 5, 6, 8], diff --git a/src/tracksdata/graph/_base_graph.py b/src/tracksdata/graph/_base_graph.py index 1534d2c9..4d69466d 100644 --- a/src/tracksdata/graph/_base_graph.py +++ b/src/tracksdata/graph/_base_graph.py @@ -1095,13 +1095,13 @@ def match( ) if matched_node_id_key not in self.node_attr_keys(): - self.add_node_attr_key(matched_node_id_key, pl.Int64, default_value=-1) + self.add_node_attr_key(matched_node_id_key, pl.Int64) if match_score_key not in self.node_attr_keys(): self.add_node_attr_key(match_score_key, pl.Float64, default_value=0.0) if matched_edge_mask_key not in self.edge_attr_keys(): - self.add_edge_attr_key(matched_edge_mask_key, pl.Boolean, default_value=False) + self.add_edge_attr_key(matched_edge_mask_key, pl.Boolean) node_ids = functools.reduce(operator.iadd, matching_data["mapped_comp"]) other_ids = functools.reduce(operator.iadd, matching_data["mapped_ref"]) diff --git a/src/tracksdata/graph/_test/test_graph_backends.py b/src/tracksdata/graph/_test/test_graph_backends.py index 373a8a23..5b77d3d9 100644 --- a/src/tracksdata/graph/_test/test_graph_backends.py +++ b/src/tracksdata/graph/_test/test_graph_backends.py @@ -1933,7 +1933,7 @@ def test_assign_tracklet_ids_node_id_filter(graph_backend: BaseGraph, return_id_ def test_tracklet_graph_basic(graph_backend: BaseGraph) -> None: """Test basic tracklet_graph functionality.""" # Add tracklet_id attribute and nodes with track IDs - graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.TRACKLET_ID, dtype=pl.Int64, default_value=-1) + graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.TRACKLET_ID, pl.Int64) # Create nodes with different track IDs node0 = graph_backend.add_node({"t": 0, DEFAULT_ATTR_KEYS.TRACKLET_ID: 1}) @@ -1977,7 +1977,7 @@ def test_tracklet_graph_basic(graph_backend: BaseGraph) -> None: def test_tracklet_graph_with_ignore_tracklet_id(graph_backend: BaseGraph) -> None: """Test tracklet_graph with ignore_tracklet_id parameter.""" # Add tracklet_id attribute and nodes with track IDs - graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.TRACKLET_ID, dtype=pl.Int64, default_value=-1) + graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.TRACKLET_ID, pl.Int64) graph_backend.add_edge_attr_key("weight", dtype=pl.Float64, default_value=0.0) # Simple test case: just check that the method accepts the parameter @@ -2082,7 +2082,7 @@ def test_sqlgraph_node_attr_index_create_and_drop(graph_backend: BaseGraph) -> N if not isinstance(graph_backend, SQLGraph): pytest.skip("Only SQLGraph supports explicit SQL indexes") - graph_backend.add_node_attr_key("label", dtype=pl.String, default_value="") + graph_backend.add_node_attr_key("label", pl.String) index_name = f"ix_{graph_backend.Node.__tablename__.lower()}_t_label" graph_backend.create_node_attr_index(["t", "label"], unique=False) diff --git a/src/tracksdata/solvers/_ilp_solver.py b/src/tracksdata/solvers/_ilp_solver.py index 98b1ea6f..6485eaf7 100644 --- a/src/tracksdata/solvers/_ilp_solver.py +++ b/src/tracksdata/solvers/_ilp_solver.py @@ -401,7 +401,7 @@ def solve( selected_nodes = [node_id for node_id, var in self._node_vars.items() if solution[var.index] > 0.5] if self.output_key not in graph.node_attr_keys(): - graph.add_node_attr_key(self.output_key, pl.Boolean, default_value=False) + graph.add_node_attr_key(self.output_key, pl.Boolean) elif self.reset: graph.update_node_attrs(attrs={self.output_key: False}) @@ -413,7 +413,7 @@ def solve( selected_edges = [edge_id for edge_id, var in self._edge_vars.items() if solution[var.index] > 0.5] if self.output_key not in graph.edge_attr_keys(): - graph.add_edge_attr_key(self.output_key, pl.Boolean, default_value=False) + graph.add_edge_attr_key(self.output_key, pl.Boolean) elif self.reset: graph.update_edge_attrs(attrs={self.output_key: False}) diff --git a/src/tracksdata/solvers/_nearest_neighbors_solver.py b/src/tracksdata/solvers/_nearest_neighbors_solver.py index 637b9ac4..34011dee 100644 --- a/src/tracksdata/solvers/_nearest_neighbors_solver.py +++ b/src/tracksdata/solvers/_nearest_neighbors_solver.py @@ -275,7 +275,7 @@ def solve( solution_edges_df = edges_df.filter(solution) if self.output_key not in graph.edge_attr_keys(): - graph.add_edge_attr_key(self.output_key, pl.Boolean, default_value=False) + graph.add_edge_attr_key(self.output_key, pl.Boolean) elif self.reset: graph.update_edge_attrs(attrs={self.output_key: False}) @@ -294,7 +294,7 @@ def solve( ) if self.output_key not in graph.node_attr_keys(): - graph.add_node_attr_key(self.output_key, pl.Boolean, default_value=False) + graph.add_node_attr_key(self.output_key, pl.Boolean) graph.update_node_attrs( node_ids=node_ids, From e3ac6ead2b30bab43bbb49a83b5da42f46bd387f Mon Sep 17 00:00:00 2001 From: Jordao Bragantini Date: Tue, 20 Jan 2026 11:14:29 -0800 Subject: [PATCH 06/14] removing unused function --- src/tracksdata/graph/_sql_graph.py | 7 ++----- src/tracksdata/utils/_dtypes.py | 18 ------------------ src/tracksdata/utils/_test/test_dtypes.py | 20 -------------------- 3 files changed, 2 insertions(+), 43 deletions(-) delete mode 100644 src/tracksdata/utils/_test/test_dtypes.py diff --git a/src/tracksdata/graph/_sql_graph.py b/src/tracksdata/graph/_sql_graph.py index 843badfa..2b8f2181 100644 --- a/src/tracksdata/graph/_sql_graph.py +++ b/src/tracksdata/graph/_sql_graph.py @@ -20,7 +20,6 @@ from tracksdata.utils._dataframe import unpack_array_attrs, unpickle_bytes_columns from tracksdata.utils._dtypes import ( AttrSchema, - infer_default_value_from_dtype, polars_dtype_to_sqlalchemy_type, process_attr_key_args, ) @@ -562,11 +561,10 @@ def _init_schemas_from_tables(self) -> None: column = self.Node.__table__.columns[column_name] # Infer polars dtype from SQLAlchemy type pl_dtype = self._sqlalchemy_type_to_polars_dtype(column.type) - default_value = infer_default_value_from_dtype(pl_dtype) + # AttrSchema.__post_init__ will infer the default_value self._node_attr_schemas[column_name] = AttrSchema( key=column_name, dtype=pl_dtype, - default_value=default_value, ) # Initialize edge schemas from Edge table columns @@ -578,11 +576,10 @@ def _init_schemas_from_tables(self) -> None: column = self.Edge.__table__.columns[column_name] # Infer polars dtype from SQLAlchemy type pl_dtype = self._sqlalchemy_type_to_polars_dtype(column.type) - default_value = infer_default_value_from_dtype(pl_dtype) + # AttrSchema.__post_init__ will infer the default_value self._edge_attr_schemas[column_name] = AttrSchema( key=column_name, dtype=pl_dtype, - default_value=default_value, ) def _sqlalchemy_type_to_polars_dtype(self, sa_type: TypeEngine) -> pl.DataType: diff --git a/src/tracksdata/utils/_dtypes.py b/src/tracksdata/utils/_dtypes.py index dfba5941..9e215a98 100644 --- a/src/tracksdata/utils/_dtypes.py +++ b/src/tracksdata/utils/_dtypes.py @@ -130,24 +130,6 @@ def column_to_numpy(series: pl.Series) -> np.ndarray: return series.to_numpy() -def infer_default_value(sample_value: Any) -> Any: - """ - Infer a sensible default value based on a sample attribute value. - """ - if isinstance(sample_value, bool | np.bool_): - return False - dtype = getattr(sample_value, "dtype", None) - if dtype is not None and np.issubdtype(dtype, np.unsignedinteger): - return 0 - if isinstance(sample_value, np.unsignedinteger): - return 0 - if isinstance(sample_value, int | np.integer): - return -1 - if isinstance(sample_value, float | np.floating): - return -1.0 - return None - - @dataclass class AttrSchema: """ diff --git a/src/tracksdata/utils/_test/test_dtypes.py b/src/tracksdata/utils/_test/test_dtypes.py deleted file mode 100644 index bf8575f8..00000000 --- a/src/tracksdata/utils/_test/test_dtypes.py +++ /dev/null @@ -1,20 +0,0 @@ -import numpy as np -import pytest - -from tracksdata.utils._dtypes import infer_default_value - - -@pytest.mark.parametrize( - ("sample", "expected"), - [ - (True, False), - (42, -1), - (3.14, -1.0), - (np.uint8(5), 0), - (np.int32(7), -1), - (np.float32(3.14), -1.0), - ("foo", None), - ], -) -def test_infer_default_value(sample: object, expected: object) -> None: - assert infer_default_value(sample) == expected From b7b00f4287492986f5b7208a72f415f03302be27 Mon Sep 17 00:00:00 2001 From: Jordao Bragantini Date: Tue, 20 Jan 2026 11:29:52 -0800 Subject: [PATCH 07/14] improving attr key default usage --- .../array/_test/test_graph_array.py | 28 ++--- .../edges/_test/test_distance_edges.py | 42 ++++---- .../edges/_test/test_generic_edges.py | 26 ++--- src/tracksdata/edges/_test/test_iou_edges.py | 8 +- src/tracksdata/functional/_test/test_apply.py | 10 +- src/tracksdata/graph/_base_graph.py | 2 +- .../graph/_test/test_graph_backends.py | 100 ++++++++--------- src/tracksdata/graph/_test/test_subgraph.py | 10 +- .../filters/_test/test_spatial_filter.py | 20 ++-- src/tracksdata/metrics/_test/test_matching.py | 52 ++++----- .../nodes/_test/test_generic_nodes.py | 4 +- .../solvers/_test/test_ilp_solver.py | 102 +++++++++--------- .../_test/test_nearest_neighbors_solver.py | 54 +++++----- 13 files changed, 229 insertions(+), 229 deletions(-) diff --git a/src/tracksdata/array/_test/test_graph_array.py b/src/tracksdata/array/_test/test_graph_array.py index 057e5c06..f8dbc552 100644 --- a/src/tracksdata/array/_test/test_graph_array.py +++ b/src/tracksdata/array/_test/test_graph_array.py @@ -32,7 +32,7 @@ def test_chain_indices() -> None: def test_graph_array_view_init(graph_backend: BaseGraph) -> None: """Test GraphArrayView initialization.""" # Add a attribute key - graph_backend.add_node_attr_key("label", dtype=pl.Int64, default_value=0) + graph_backend.add_node_attr_key("label", dtype=pl.Int64) graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, pl.Object) graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.BBOX, pl.Array(pl.Int64, 6)) @@ -58,7 +58,7 @@ def test_graph_array_view_init_invalid_attr_key(graph_backend: BaseGraph) -> Non def test_graph_array_view_getitem_empty_time(graph_backend: BaseGraph) -> None: """Test __getitem__ with empty time point (no nodes).""" - graph_backend.add_node_attr_key("label", dtype=pl.Int64, default_value=0) + graph_backend.add_node_attr_key("label", dtype=pl.Int64) graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, pl.Object) graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.BBOX, pl.Array(pl.Int64, 6)) @@ -77,11 +77,11 @@ def test_graph_array_view_getitem_with_nodes(graph_backend: BaseGraph) -> None: """Test __getitem__ with nodes at time point.""" # Add attribute keys - graph_backend.add_node_attr_key("label", dtype=pl.Int64, default_value=0) + graph_backend.add_node_attr_key("label", dtype=pl.Int64) graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, pl.Object) graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.BBOX, pl.Array(pl.Int64, 4)) - graph_backend.add_node_attr_key("y", dtype=pl.Int64, default_value=0) - graph_backend.add_node_attr_key("x", dtype=pl.Int64, default_value=0) + graph_backend.add_node_attr_key("y", dtype=pl.Int64) + graph_backend.add_node_attr_key("x", dtype=pl.Int64) # Create a mask mask_data = np.array([[True, True], [True, False]], dtype=bool) @@ -127,11 +127,11 @@ def test_graph_array_view_getitem_multiple_nodes(graph_backend: BaseGraph) -> No """Test __getitem__ with multiple nodes at same time point.""" # Add attribute keys - graph_backend.add_node_attr_key("label", dtype=pl.Int64, default_value=0) + graph_backend.add_node_attr_key("label", dtype=pl.Int64) graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, pl.Object) graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.BBOX, pl.Array(pl.Int64, 4)) - graph_backend.add_node_attr_key("y", dtype=pl.Int64, default_value=0) - graph_backend.add_node_attr_key("x", dtype=pl.Int64, default_value=0) + graph_backend.add_node_attr_key("y", dtype=pl.Int64) + graph_backend.add_node_attr_key("x", dtype=pl.Int64) # Create two masks at different locations mask1_data = np.array([[True, True]], dtype=bool) @@ -185,8 +185,8 @@ def test_graph_array_view_getitem_boolean_dtype(graph_backend: BaseGraph) -> Non graph_backend.add_node_attr_key("is_active", pl.Boolean) graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, pl.Object) graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.BBOX, pl.Array(pl.Int64, 4)) - graph_backend.add_node_attr_key("y", dtype=pl.Int64, default_value=0) - graph_backend.add_node_attr_key("x", dtype=pl.Int64, default_value=0) + graph_backend.add_node_attr_key("y", dtype=pl.Int64) + graph_backend.add_node_attr_key("x", dtype=pl.Int64) # Create a mask mask_data = np.array([[True]], dtype=bool) @@ -219,11 +219,11 @@ def test_graph_array_view_dtype_inference(graph_backend: BaseGraph) -> None: """Test that dtype is properly inferred from data.""" # Add attribute keys - graph_backend.add_node_attr_key("float_label", dtype=pl.Float64, default_value=0.0) + graph_backend.add_node_attr_key("float_label", dtype=pl.Float64) graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, pl.Object) graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.BBOX, pl.Array(pl.Int64, 4)) - graph_backend.add_node_attr_key("y", dtype=pl.Int64, default_value=0) - graph_backend.add_node_attr_key("x", dtype=pl.Int64, default_value=0) + graph_backend.add_node_attr_key("y", dtype=pl.Int64) + graph_backend.add_node_attr_key("x", dtype=pl.Int64) # Create a mask mask_data = np.array([[True]], dtype=bool) @@ -341,7 +341,7 @@ def test_graph_array_view_getitem_time_index_nested(multi_node_graph_from_image, def test_graph_array_set_options(graph_backend: BaseGraph) -> None: with Options(gav_chunk_shape=(512, 512), gav_default_dtype=np.int16): - graph_backend.add_node_attr_key("label", dtype=pl.Int64, default_value=0) + graph_backend.add_node_attr_key("label", dtype=pl.Int64) graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, pl.Object) graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.BBOX, pl.Array(pl.Int64, 4)) array_view = GraphArrayView(graph=graph_backend, shape=(10, 100, 100), attr_key="label") diff --git a/src/tracksdata/edges/_test/test_distance_edges.py b/src/tracksdata/edges/_test/test_distance_edges.py index f91799de..e6630b26 100644 --- a/src/tracksdata/edges/_test/test_distance_edges.py +++ b/src/tracksdata/edges/_test/test_distance_edges.py @@ -49,8 +49,8 @@ def test_distance_edges_add_edges_single_timepoint_no_previous() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) - graph.add_node_attr_key("y", dtype=pl.Float64, default_value=0.0) + graph.add_node_attr_key("x", dtype=pl.Float64) + graph.add_node_attr_key("y", dtype=pl.Float64) # Add nodes only at t=1 (no t=0) graph.add_node({DEFAULT_ATTR_KEYS.T: 1, "x": 0.0, "y": 0.0}) @@ -68,8 +68,8 @@ def test_distance_edges_add_edges_single_timepoint_no_current() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) - graph.add_node_attr_key("y", dtype=pl.Float64, default_value=0.0) + graph.add_node_attr_key("x", dtype=pl.Float64) + graph.add_node_attr_key("y", dtype=pl.Float64) # Add nodes only at t=0 (no t=1) graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 0.0, "y": 0.0}) @@ -87,8 +87,8 @@ def test_distance_edges_add_edges_2d_coordinates() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) - graph.add_node_attr_key("y", dtype=pl.Float64, default_value=0.0) + graph.add_node_attr_key("x", dtype=pl.Float64) + graph.add_node_attr_key("y", dtype=pl.Float64) # Add nodes at t=0 _ = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 0.0, "y": 0.0}) @@ -115,9 +115,9 @@ def test_distance_edges_add_edges_3d_coordinates() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) - graph.add_node_attr_key("y", dtype=pl.Float64, default_value=0.0) - graph.add_node_attr_key("z", dtype=pl.Float64, default_value=0.0) + graph.add_node_attr_key("x", dtype=pl.Float64) + graph.add_node_attr_key("y", dtype=pl.Float64) + graph.add_node_attr_key("z", dtype=pl.Float64) # Add nodes at t=0 _ = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 0.0, "y": 0.0, "z": 0.0}) @@ -140,8 +140,8 @@ def test_distance_edges_add_edges_custom_attr_keys() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("pos_x", dtype=pl.Float64, default_value=0.0) - graph.add_node_attr_key("pos_y", dtype=pl.Float64, default_value=0.0) + graph.add_node_attr_key("pos_x", dtype=pl.Float64) + graph.add_node_attr_key("pos_y", dtype=pl.Float64) # Add nodes at t=0 _ = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "pos_x": 0.0, "pos_y": 0.0}) @@ -164,8 +164,8 @@ def test_distance_edges_add_edges_distance_threshold() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) - graph.add_node_attr_key("y", dtype=pl.Float64, default_value=0.0) + graph.add_node_attr_key("x", dtype=pl.Float64) + graph.add_node_attr_key("y", dtype=pl.Float64) # Add nodes at t=0 _ = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 0.0, "y": 0.0}) @@ -187,8 +187,8 @@ def test_distance_edges_add_edges_multiple_timepoints(n_workers: int) -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) - graph.add_node_attr_key("y", dtype=pl.Float64, default_value=0.0) + graph.add_node_attr_key("x", dtype=pl.Float64) + graph.add_node_attr_key("y", dtype=pl.Float64) # Add nodes at multiple timepoints for t in range(3): @@ -210,8 +210,8 @@ def test_distance_edges_add_edges_custom_weight_key() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) - graph.add_node_attr_key("y", dtype=pl.Float64, default_value=0.0) + graph.add_node_attr_key("x", dtype=pl.Float64) + graph.add_node_attr_key("y", dtype=pl.Float64) # Add nodes at t=0 _ = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 0.0, "y": 0.0}) @@ -238,8 +238,8 @@ def test_distance_edges_n_neighbors_limit() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) - graph.add_node_attr_key("y", dtype=pl.Float64, default_value=0.0) + graph.add_node_attr_key("x", dtype=pl.Float64) + graph.add_node_attr_key("y", dtype=pl.Float64) # Add many nodes at t=0 for i in range(5): @@ -265,8 +265,8 @@ def test_distance_edges_add_edges_with_delta_t(n_workers: int) -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) - graph.add_node_attr_key("y", dtype=pl.Float64, default_value=0.0) + graph.add_node_attr_key("x", dtype=pl.Float64) + graph.add_node_attr_key("y", dtype=pl.Float64) # Add nodes at t=0, t=1, t=2 node_0 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 0.0, "y": 0.0}) diff --git a/src/tracksdata/edges/_test/test_generic_edges.py b/src/tracksdata/edges/_test/test_generic_edges.py index e1c0ae4e..38685a03 100644 --- a/src/tracksdata/edges/_test/test_generic_edges.py +++ b/src/tracksdata/edges/_test/test_generic_edges.py @@ -42,8 +42,8 @@ def test_generic_edges_add_weights_single_attr_key() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) - graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, dtype=pl.Float64, default_value=0.0) + graph.add_node_attr_key("x", dtype=pl.Float64) + graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, dtype=pl.Float64) # Add nodes at time 0 node0 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 1.0}) @@ -79,9 +79,9 @@ def test_generic_edges_add_weights_multiple_attr_keys() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) - graph.add_node_attr_key("y", dtype=pl.Float64, default_value=0.0) - graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, dtype=pl.Float64, default_value=0.0) + graph.add_node_attr_key("x", dtype=pl.Float64) + graph.add_node_attr_key("y", dtype=pl.Float64) + graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, dtype=pl.Float64) # Add nodes at time 0 node0 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 0.0, "y": 0.0}) @@ -113,8 +113,8 @@ def test_generic_edges_add_weights_all_time_points() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) - graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, dtype=pl.Float64, default_value=0.0) + graph.add_node_attr_key("x", dtype=pl.Float64) + graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, dtype=pl.Float64) # Add nodes at different time points node0_t0 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 1.0}) @@ -143,7 +143,7 @@ def test_generic_edges_no_edges_at_time_point() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) + graph.add_node_attr_key("x", dtype=pl.Float64) # Add nodes but no edges at time 0 graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 1.0}) @@ -164,8 +164,8 @@ def test_generic_edges_creates_output_key() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) - graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, dtype=pl.Float64, default_value=0.0) + graph.add_node_attr_key("x", dtype=pl.Float64) + graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, dtype=pl.Float64) # Add nodes and edge node0 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 1.0}) @@ -192,9 +192,9 @@ def test_generic_edges_dict_input_function() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("value", dtype=pl.Float64, default_value=0.0) - graph.add_node_attr_key("weight", dtype=pl.Float64, default_value=0.0) - graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, dtype=pl.Float64, default_value=0.0) + graph.add_node_attr_key("value", dtype=pl.Float64) + graph.add_node_attr_key("weight", dtype=pl.Float64) + graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, dtype=pl.Float64) # Add nodes node0 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "value": 10.0, "weight": 2.0}) diff --git a/src/tracksdata/edges/_test/test_iou_edges.py b/src/tracksdata/edges/_test/test_iou_edges.py index 834e5e72..c9ecdf6d 100644 --- a/src/tracksdata/edges/_test/test_iou_edges.py +++ b/src/tracksdata/edges/_test/test_iou_edges.py @@ -34,7 +34,7 @@ def test_iou_edges_add_weights(n_workers: int) -> None: # Register attribute keys graph.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, pl.Object) - graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, dtype=pl.Float64, default_value=0.0) + graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, dtype=pl.Float64) # Create test masks mask1_data = np.array([[True, True], [True, False]], dtype=bool) @@ -84,7 +84,7 @@ def test_iou_edges_no_overlap() -> None: # Register attribute keys graph.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, pl.Object) - graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, dtype=pl.Float64, default_value=0.0) + graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, dtype=pl.Float64) # Create non-overlapping masks mask1_data = np.array([[True, True], [False, False]], dtype=bool) @@ -123,7 +123,7 @@ def test_iou_edges_perfect_overlap() -> None: # Register attribute keys graph.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, pl.Object) - graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, dtype=pl.Float64, default_value=0.0) + graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, dtype=pl.Float64) # Create identical masks mask_data = np.array([[True, True], [True, False]], dtype=bool) @@ -159,7 +159,7 @@ def test_iou_edges_custom_mask_key() -> None: # Register attribute keys graph.add_node_attr_key("custom_mask", pl.Object) - graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, pl.Float64, default_value=0.0) + graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, pl.Float64) # Create test masks mask1_data = np.array([[True, True], [True, True]], dtype=bool) diff --git a/src/tracksdata/functional/_test/test_apply.py b/src/tracksdata/functional/_test/test_apply.py index fb5df478..7791165c 100644 --- a/src/tracksdata/functional/_test/test_apply.py +++ b/src/tracksdata/functional/_test/test_apply.py @@ -10,9 +10,9 @@ def sample_graph() -> RustWorkXGraph: """Create a sample graph with spatial nodes for testing.""" graph = RustWorkXGraph() - graph.add_node_attr_key("z", dtype=pl.Int64, default_value=0) - graph.add_node_attr_key("y", dtype=pl.Int64, default_value=0) - graph.add_node_attr_key("x", dtype=pl.Int64, default_value=0) + graph.add_node_attr_key("z", dtype=pl.Int64) + graph.add_node_attr_key("y", dtype=pl.Int64) + graph.add_node_attr_key("x", dtype=pl.Int64) # Add nodes in a grid pattern nodes = [ @@ -112,8 +112,8 @@ def test_apply_tiled_default_attrs(sample_graph: RustWorkXGraph) -> None: def test_apply_tiled_2d_tiling() -> None: """Test apply_tiled with 2D spatial coordinates.""" graph = RustWorkXGraph() - graph.add_node_attr_key("y", dtype=pl.Int64, default_value=0) - graph.add_node_attr_key("x", dtype=pl.Int64, default_value=0) + graph.add_node_attr_key("y", dtype=pl.Int64) + graph.add_node_attr_key("x", dtype=pl.Int64) for y in [5, 11, 14]: for x in [10, 30]: diff --git a/src/tracksdata/graph/_base_graph.py b/src/tracksdata/graph/_base_graph.py index 4d69466d..b384bdbb 100644 --- a/src/tracksdata/graph/_base_graph.py +++ b/src/tracksdata/graph/_base_graph.py @@ -794,7 +794,7 @@ def add_edge_attr_key( Add edge attribute with custom default: ```python - graph.add_edge_attr_key("distance", pl.Float64, default_value=0.0) + graph.add_edge_attr_key("distance", pl.Float64) ``` Using AttrSchema: diff --git a/src/tracksdata/graph/_test/test_graph_backends.py b/src/tracksdata/graph/_test/test_graph_backends.py index 5b77d3d9..70ef1c2f 100644 --- a/src/tracksdata/graph/_test/test_graph_backends.py +++ b/src/tracksdata/graph/_test/test_graph_backends.py @@ -57,7 +57,7 @@ def test_add_node(graph_backend: BaseGraph) -> None: """Test adding nodes with various attributes.""" for key in ["x", "y"]: - graph_backend.add_node_attr_key(key, dtype=pl.Float64, default_value=0.0) + graph_backend.add_node_attr_key(key, dtype=pl.Float64) node_id = graph_backend.add_node({"t": 0, "x": 1.0, "y": 2.0}) assert isinstance(node_id, int) @@ -148,7 +148,7 @@ def test_remove_edge_by_id(graph_backend: BaseGraph) -> None: def test_remove_edge_by_nodes(graph_backend: BaseGraph) -> None: """Test removing an edge by its source/target IDs.""" graph_backend.add_node_attr_key("x", pl.Float64) - graph_backend.add_edge_attr_key("weight", dtype=pl.Float64, default_value=0.0) + graph_backend.add_edge_attr_key("weight", dtype=pl.Float64) a = graph_backend.add_node({"t": 0, "x": 0.0}) b = graph_backend.add_node({"t": 1, "x": 1.0}) @@ -257,7 +257,7 @@ def test_edge_attrs(graph_backend: BaseGraph) -> None: node1 = graph_backend.add_node({"t": 0}) node2 = graph_backend.add_node({"t": 1}) - graph_backend.add_edge_attr_key("weight", pl.Float64, default_value=0.0) + graph_backend.add_edge_attr_key("weight", pl.Float64) graph_backend.add_edge_attr_key("vector", pl.Array(pl.Float64, 2)) graph_backend.add_edge(node1, node2, attrs={"weight": 0.5, "vector": np.array([1.0, 2.0])}) @@ -276,7 +276,7 @@ def test_edge_attrs(graph_backend: BaseGraph) -> None: def test_edge_attrs_subgraph_edge_ids(graph_backend: BaseGraph) -> None: """Test that edge_attrs preserves original edge IDs when using node_ids parameter.""" # Add edge attribute key - graph_backend.add_edge_attr_key("weight", dtype=pl.Float64, default_value=0.0) + graph_backend.add_edge_attr_key("weight", dtype=pl.Float64) # Create nodes node1 = graph_backend.add_node({"t": 0}) @@ -335,10 +335,10 @@ def test_edge_attrs_subgraph_edge_ids(graph_backend: BaseGraph) -> None: def test_subgraph_with_node_and_edge_attr_filters(graph_backend: BaseGraph) -> None: """Test subgraph with node and edge attribute filters.""" - graph_backend.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) - graph_backend.add_node_attr_key("y", dtype=pl.Float64, default_value=0.0) - graph_backend.add_edge_attr_key("weight", dtype=pl.Float64, default_value=0.0) - graph_backend.add_edge_attr_key("length", dtype=pl.Float64, default_value=0.0) + graph_backend.add_node_attr_key("x", dtype=pl.Float64) + graph_backend.add_node_attr_key("y", dtype=pl.Float64) + graph_backend.add_edge_attr_key("weight", dtype=pl.Float64) + graph_backend.add_edge_attr_key("length", dtype=pl.Float64) node1 = graph_backend.add_node({"t": 0, "x": 1.0, "y": 0.0}) node2 = graph_backend.add_node({"t": 1, "x": 2.0, "y": 0.0}) @@ -375,7 +375,7 @@ def test_subgraph_with_node_and_edge_attr_filters(graph_backend: BaseGraph) -> N def test_subgraph_with_node_ids_and_filters(graph_backend: BaseGraph) -> None: """Test subgraph with node IDs and filters.""" graph_backend.add_node_attr_key("x", pl.Float64) - graph_backend.add_edge_attr_key("weight", dtype=pl.Float64, default_value=0.0) + graph_backend.add_edge_attr_key("weight", dtype=pl.Float64) node0 = graph_backend.add_node({"t": 0, "x": 1.0}) node1 = graph_backend.add_node({"t": 1, "x": 2.0}) @@ -472,7 +472,7 @@ def test_add_edge_attr_key(graph_backend: BaseGraph) -> None: node1 = graph_backend.add_node({"t": 0}) node2 = graph_backend.add_node({"t": 1}) - graph_backend.add_edge_attr_key("new_attribute", dtype=pl.Int64, default_value=42) + graph_backend.add_edge_attr_key("new_attribute", dtype=pl.Int64) graph_backend.add_edge(node1, node2, attrs={"new_attribute": 42}) df = graph_backend.edge_attrs(attr_keys=["new_attribute"]) @@ -501,7 +501,7 @@ def test_remove_edge_attr_key(graph_backend: BaseGraph) -> None: def test_update_node_attrs(graph_backend: BaseGraph) -> None: """Test updating node attributes.""" - graph_backend.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) + graph_backend.add_node_attr_key("x", dtype=pl.Float64) node_1 = graph_backend.add_node({"t": 0, "x": 1.0}) node_2 = graph_backend.add_node({"t": 0, "x": 2.0}) @@ -527,7 +527,7 @@ def test_update_edge_attrs(graph_backend: BaseGraph) -> None: node1 = graph_backend.add_node({"t": 0}) node2 = graph_backend.add_node({"t": 1}) - graph_backend.add_edge_attr_key("weight", dtype=pl.Float64, default_value=0.0) + graph_backend.add_edge_attr_key("weight", dtype=pl.Float64) edge_id = graph_backend.add_edge(node1, node2, attrs={"weight": 0.5}) graph_backend.update_edge_attrs(edge_ids=[edge_id], attrs={"weight": 1.0}) @@ -544,7 +544,7 @@ def test_num_edges(graph_backend: BaseGraph) -> None: node1 = graph_backend.add_node({"t": 0}) node2 = graph_backend.add_node({"t": 1}) - graph_backend.add_edge_attr_key("weight", dtype=pl.Float64, default_value=0.0) + graph_backend.add_edge_attr_key("weight", dtype=pl.Float64) graph_backend.add_edge(node1, node2, attrs={"weight": 0.5}) assert graph_backend.num_edges() == 1 @@ -561,7 +561,7 @@ def test_num_nodes(graph_backend: BaseGraph) -> None: def test_edge_attrs_include_targets(graph_backend: BaseGraph) -> None: """Test the inclusive flag behavior in edge_attrs method.""" # Add edge attribute key - graph_backend.add_edge_attr_key("weight", dtype=pl.Float64, default_value=0.0) + graph_backend.add_edge_attr_key("weight", dtype=pl.Float64) # Create a graph with 4 nodes # Graph structure: @@ -678,9 +678,9 @@ def test_from_ctc( def test_sucessors_and_degree(graph_backend: BaseGraph) -> None: """Test getting successors of nodes.""" # Add attribute keys - graph_backend.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) - graph_backend.add_node_attr_key("y", dtype=pl.Float64, default_value=0.0) - graph_backend.add_edge_attr_key("weight", dtype=pl.Float64, default_value=0.0) + graph_backend.add_node_attr_key("x", dtype=pl.Float64) + graph_backend.add_node_attr_key("y", dtype=pl.Float64) + graph_backend.add_edge_attr_key("weight", dtype=pl.Float64) # Create a simple graph structure: node0 -> node1 -> node2 # \-> node3 @@ -769,9 +769,9 @@ def test_sucessors_and_degree(graph_backend: BaseGraph) -> None: def test_predecessors_and_degree(graph_backend: BaseGraph) -> None: """Test getting predecessors of nodes.""" # Add attribute keys - graph_backend.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) - graph_backend.add_node_attr_key("y", dtype=pl.Float64, default_value=0.0) - graph_backend.add_edge_attr_key("weight", dtype=pl.Float64, default_value=0.0) + graph_backend.add_node_attr_key("x", dtype=pl.Float64) + graph_backend.add_node_attr_key("y", dtype=pl.Float64) + graph_backend.add_edge_attr_key("weight", dtype=pl.Float64) # Create a simple graph structure: node0 -> node1 -> node2 # \-> node3 @@ -858,10 +858,10 @@ def test_predecessors_and_degree(graph_backend: BaseGraph) -> None: def test_sucessors_with_attr_keys(graph_backend: BaseGraph) -> None: """Test getting successors with specific attribute keys.""" # Add attribute keys - graph_backend.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) - graph_backend.add_node_attr_key("y", dtype=pl.Float64, default_value=0.0) + graph_backend.add_node_attr_key("x", dtype=pl.Float64) + graph_backend.add_node_attr_key("y", dtype=pl.Float64) graph_backend.add_node_attr_key("label", dtype=pl.String, default_value="X") - graph_backend.add_edge_attr_key("weight", dtype=pl.Float64, default_value=0.0) + graph_backend.add_edge_attr_key("weight", dtype=pl.Float64) # Create nodes node0 = graph_backend.add_node({"t": 0, "x": 0.0, "y": 0.0, "label": "A"}) @@ -905,10 +905,10 @@ def test_sucessors_with_attr_keys(graph_backend: BaseGraph) -> None: def test_predecessors_with_attr_keys(graph_backend: BaseGraph) -> None: """Test getting predecessors with specific attribute keys.""" # Add attribute keys - graph_backend.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) - graph_backend.add_node_attr_key("y", dtype=pl.Float64, default_value=0.0) + graph_backend.add_node_attr_key("x", dtype=pl.Float64) + graph_backend.add_node_attr_key("y", dtype=pl.Float64) graph_backend.add_node_attr_key("label", dtype=pl.String, default_value="X") - graph_backend.add_edge_attr_key("weight", dtype=pl.Float64, default_value=0.0) + graph_backend.add_edge_attr_key("weight", dtype=pl.Float64) # Create nodes node0 = graph_backend.add_node({"t": 0, "x": 0.0, "y": 0.0, "label": "A"}) @@ -948,8 +948,8 @@ def test_predecessors_with_attr_keys(graph_backend: BaseGraph) -> None: def test_sucessors_predecessors_edge_cases(graph_backend: BaseGraph) -> None: """Test edge cases for successors and predecessors methods.""" # Add attribute keys - graph_backend.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) - graph_backend.add_edge_attr_key("weight", dtype=pl.Float64, default_value=0.0) + graph_backend.add_node_attr_key("x", dtype=pl.Float64) + graph_backend.add_edge_attr_key("weight", dtype=pl.Float64) # Create isolated nodes (no edges) node0 = graph_backend.add_node({"t": 0, "x": 0.0}) @@ -1131,15 +1131,15 @@ def test_match_method(graph_backend: BaseGraph) -> None: def test_attrs_with_duplicated_attr_keys(graph_backend: BaseGraph) -> None: """Test that node attributeswith duplicated attribute keys are handled correctly.""" # Add attribute keys - graph_backend.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) - graph_backend.add_node_attr_key("y", dtype=pl.Float64, default_value=0.0) + graph_backend.add_node_attr_key("x", dtype=pl.Float64) + graph_backend.add_node_attr_key("y", dtype=pl.Float64) # Add nodes node_1 = graph_backend.add_node({"t": 0, "x": 1.0, "y": 1.0}) node_2 = graph_backend.add_node({"t": 1, "x": 2.0, "y": 2.0}) # Add edges - graph_backend.add_edge_attr_key("weight", dtype=pl.Float64, default_value=0.0) + graph_backend.add_edge_attr_key("weight", dtype=pl.Float64) graph_backend.add_edge(node_1, node_2, {"weight": 0.5}) # Test with duplicated attribute keys @@ -1366,8 +1366,8 @@ def test_from_other_with_edges( # Create source graph with nodes, edges, and attributes graph_backend.update_metadata(special_key="special_value") - graph_backend.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) - graph_backend.add_edge_attr_key("weight", dtype=pl.Float64, default_value=0.0) + graph_backend.add_node_attr_key("x", dtype=pl.Float64) + graph_backend.add_edge_attr_key("weight", dtype=pl.Float64) graph_backend.add_edge_attr_key("type", dtype=pl.String, default_value="forward") node1 = graph_backend.add_node({"t": 0, "x": 1.0}) @@ -1636,8 +1636,8 @@ def test_compute_overlaps_empty_graph(graph_backend: BaseGraph) -> None: def test_summary(graph_backend: BaseGraph) -> None: """Test summary method.""" - graph_backend.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) - graph_backend.add_edge_attr_key("weight", dtype=pl.Float64, default_value=0.0) + graph_backend.add_node_attr_key("x", dtype=pl.Float64) + graph_backend.add_edge_attr_key("weight", dtype=pl.Float64) graph_backend.add_edge_attr_key("type", dtype=pl.String, default_value="good") node1 = graph_backend.add_node({"t": 0, "x": 1.0}) @@ -1946,7 +1946,7 @@ def test_tracklet_graph_basic(graph_backend: BaseGraph) -> None: node7 = graph_backend.add_node({"t": 2, DEFAULT_ATTR_KEYS.TRACKLET_ID: 4}) node8 = graph_backend.add_node({"t": 3, DEFAULT_ATTR_KEYS.TRACKLET_ID: 4}) - graph_backend.add_edge_attr_key("weight", dtype=pl.Float64, default_value=0.0) + graph_backend.add_edge_attr_key("weight", dtype=pl.Float64) # Add edges within tracks (will be filtered out) graph_backend.add_edge(node0, node1, {"weight": 0.5}) @@ -1978,7 +1978,7 @@ def test_tracklet_graph_with_ignore_tracklet_id(graph_backend: BaseGraph) -> Non """Test tracklet_graph with ignore_tracklet_id parameter.""" # Add tracklet_id attribute and nodes with track IDs graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.TRACKLET_ID, pl.Int64) - graph_backend.add_edge_attr_key("weight", dtype=pl.Float64, default_value=0.0) + graph_backend.add_edge_attr_key("weight", dtype=pl.Float64) # Simple test case: just check that the method accepts the parameter # and filters out nodes properly when there are no edges @@ -2005,7 +2005,7 @@ def test_tracklet_graph_missing_tracklet_id_key(graph_backend: BaseGraph) -> Non def test_nodes_interface(graph_backend: BaseGraph) -> None: - graph_backend.add_node_attr_key("x", dtype=pl.Int64, default_value=0) + graph_backend.add_node_attr_key("x", dtype=pl.Int64) # Simple test case: just check that the method accepts the parameter # and filters out nodes properly when there are no edges @@ -2037,8 +2037,8 @@ def test_custom_indices(graph_backend: BaseGraph) -> None: pytest.skip("Graph does not support custom indices") # Add attribute keys for testing - graph_backend.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) - graph_backend.add_node_attr_key("y", dtype=pl.Float64, default_value=0.0) + graph_backend.add_node_attr_key("x", dtype=pl.Float64) + graph_backend.add_node_attr_key("y", dtype=pl.Float64) # Test add_node with custom index custom_node_id = graph_backend.add_node({"t": 0, "x": 10.0, "y": 20.0}, index=12345) @@ -2103,7 +2103,7 @@ def test_sqlgraph_edge_attr_index_create_and_drop(graph_backend: BaseGraph) -> N if not isinstance(graph_backend, SQLGraph): pytest.skip("Only SQLGraph supports explicit SQL indexes") - graph_backend.add_edge_attr_key("score", dtype=pl.Float64, default_value=0.0) + graph_backend.add_edge_attr_key("score", dtype=pl.Float64) index_name = f"ix_{graph_backend.Edge.__tablename__.lower()}_score" graph_backend.create_edge_attr_index("score", unique=True) @@ -2132,9 +2132,9 @@ def test_sqlgraph_index_missing_column(graph_backend: BaseGraph) -> None: def test_remove_node(graph_backend: BaseGraph) -> None: """Test removing nodes from the graph.""" # Add attribute keys - graph_backend.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) - graph_backend.add_node_attr_key("y", dtype=pl.Float64, default_value=0.0) - graph_backend.add_edge_attr_key("weight", dtype=pl.Float64, default_value=0.0) + graph_backend.add_node_attr_key("x", dtype=pl.Float64) + graph_backend.add_node_attr_key("y", dtype=pl.Float64) + graph_backend.add_edge_attr_key("weight", dtype=pl.Float64) # Add nodes node1 = graph_backend.add_node({"t": 0, "x": 1.0, "y": 1.0}) @@ -2194,8 +2194,8 @@ def test_remove_node(graph_backend: BaseGraph) -> None: def test_remove_node_and_add_new_nodes(graph_backend: BaseGraph) -> None: """Test removing nodes and then adding new nodes.""" # Add attribute keys - graph_backend.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) - graph_backend.add_edge_attr_key("weight", dtype=pl.Float64, default_value=0.0) + graph_backend.add_node_attr_key("x", dtype=pl.Float64) + graph_backend.add_edge_attr_key("weight", dtype=pl.Float64) # Add initial nodes node1 = graph_backend.add_node({"t": 0, "x": 1.0}) @@ -2291,9 +2291,9 @@ def test_remove_all_nodes_in_time_point(graph_backend: BaseGraph) -> None: def _fill_mock_geff_graph(graph_backend: BaseGraph) -> None: - graph_backend.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) - graph_backend.add_node_attr_key("y", dtype=pl.Float64, default_value=0.0) - graph_backend.add_node_attr_key("z", dtype=pl.Float64, default_value=0.0) + graph_backend.add_node_attr_key("x", dtype=pl.Float64) + graph_backend.add_node_attr_key("y", dtype=pl.Float64) + graph_backend.add_node_attr_key("z", dtype=pl.Float64) graph_backend.add_node_attr_key( DEFAULT_ATTR_KEYS.BBOX, pl.Array(pl.Int64, 4), default_value=np.array([0, 0, 1, 1], dtype=int) ) @@ -2568,7 +2568,7 @@ def test_to_traccuracy_graph(graph_backend: BaseGraph) -> None: {"t": 2, "x": 3.0, "y": 3.0, DEFAULT_ATTR_KEYS.MASK: mask3, DEFAULT_ATTR_KEYS.BBOX: mask3.bbox} ) - graph_backend.add_edge_attr_key("weight", dtype=pl.Float64, default_value=0.0) + graph_backend.add_edge_attr_key("weight", dtype=pl.Float64) graph_backend.add_edge(node1, node2, {"weight": 0.5}) graph_backend.add_edge(node2, node3, {"weight": 0.3}) graph_backend.add_edge(node1, node3, {"weight": 0.3}) diff --git a/src/tracksdata/graph/_test/test_subgraph.py b/src/tracksdata/graph/_test/test_subgraph.py index 56f158b2..8664620a 100644 --- a/src/tracksdata/graph/_test/test_subgraph.py +++ b/src/tracksdata/graph/_test/test_subgraph.py @@ -40,8 +40,8 @@ def create_test_graph(graph_backend: BaseGraph, use_subgraph: bool = False) -> B graph_backend.add_node_attr_key("x", pl.Float64) graph_backend.add_node_attr_key("y", pl.Float64) graph_backend.add_node_attr_key("label", dtype=pl.String, default_value="0") - graph_backend.add_edge_attr_key("weight", dtype=pl.Float64, default_value=0.0) - graph_backend.add_edge_attr_key("new_attribute", dtype=pl.Float64, default_value=0.0) + graph_backend.add_edge_attr_key("weight", dtype=pl.Float64) + graph_backend.add_edge_attr_key("new_attribute", dtype=pl.Float64) # Add nodes with various attributes node0 = graph_backend.add_node({"t": 0, "x": 0.0, "y": 0.0, "label": "0"}) @@ -862,7 +862,7 @@ def test_bulk_add_nodes_returned_ids(graph_backend: BaseGraph, use_subgraph: boo graph_with_data = create_test_graph(graph_backend, use_subgraph) # Add attribute keys for the new nodes - graph_with_data.add_node_attr_key("z", dtype=pl.Float64, default_value=0.0) + graph_with_data.add_node_attr_key("z", dtype=pl.Float64) # Test bulk adding nodes nodes_to_add = [ @@ -913,7 +913,7 @@ def test_bulk_add_edges_returned_ids(graph_backend: BaseGraph, use_subgraph: boo graph_with_data = create_test_graph(graph_backend, use_subgraph) # Add attribute keys for the new edges - graph_with_data.add_edge_attr_key("strength", dtype=pl.Float64, default_value=0.0) + graph_with_data.add_edge_attr_key("strength", dtype=pl.Float64) # Get some existing nodes to create edges between existing_nodes = graph_with_data._test_nodes # type: ignore @@ -1102,7 +1102,7 @@ def test_graph_view_remove_edge(graph_backend: BaseGraph) -> None: """ # Setup root graph with attributes graph_backend.add_node_attr_key("x", pl.Float64) - graph_backend.add_edge_attr_key("weight", pl.Float64, default_value=0.0) + graph_backend.add_edge_attr_key("weight", pl.Float64) # Nodes and edges n0 = graph_backend.add_node({"t": 0, "x": 0.0}) diff --git a/src/tracksdata/graph/filters/_test/test_spatial_filter.py b/src/tracksdata/graph/filters/_test/test_spatial_filter.py index 0ee21027..52778539 100644 --- a/src/tracksdata/graph/filters/_test/test_spatial_filter.py +++ b/src/tracksdata/graph/filters/_test/test_spatial_filter.py @@ -13,9 +13,9 @@ def sample_graph() -> RustWorkXGraph: """Create a sample graph with nodes for testing.""" graph = RustWorkXGraph() - graph.add_node_attr_key("z", dtype=pl.Int64, default_value=0) - graph.add_node_attr_key("y", dtype=pl.Int64, default_value=0) - graph.add_node_attr_key("x", dtype=pl.Int64, default_value=0) + graph.add_node_attr_key("z", dtype=pl.Int64) + graph.add_node_attr_key("y", dtype=pl.Int64) + graph.add_node_attr_key("x", dtype=pl.Int64) # Add some nodes with spatial coordinates nodes = [ @@ -141,9 +141,9 @@ def test_spatial_filter_querying(sample_graph: RustWorkXGraph) -> None: def test_spatial_filter_dimensions() -> None: """Test SpatialFilter with different coordinate dimensions.""" graph = RustWorkXGraph() - graph.add_node_attr_key("z", dtype=pl.Int64, default_value=0) - graph.add_node_attr_key("y", dtype=pl.Int64, default_value=0) - graph.add_node_attr_key("x", dtype=pl.Int64, default_value=0) + graph.add_node_attr_key("z", dtype=pl.Int64) + graph.add_node_attr_key("y", dtype=pl.Int64) + graph.add_node_attr_key("x", dtype=pl.Int64) graph.add_node({"t": 0, "z": 0, "y": 10, "x": 20}) # Test 2D coordinates @@ -171,9 +171,9 @@ def test_spatial_filter_error_handling(sample_graph: RustWorkXGraph) -> None: def test_spatial_filter_with_edges() -> None: """Test SpatialFilter preserves edges in subgraphs.""" graph = RustWorkXGraph() - graph.add_node_attr_key("y", dtype=pl.Int64, default_value=0) - graph.add_node_attr_key("x", dtype=pl.Int64, default_value=0) - graph.add_edge_attr_key("weight", dtype=pl.Float64, default_value=0.0) + graph.add_node_attr_key("y", dtype=pl.Int64) + graph.add_node_attr_key("x", dtype=pl.Int64) + graph.add_edge_attr_key("weight", dtype=pl.Float64) # Add nodes and edge node1_id = graph.add_node({"t": 0, "y": 10, "x": 20}) @@ -215,7 +215,7 @@ def test_bbox_spatial_filter_with_edges() -> None: """Test SpatialFilter preserves edges in subgraphs.""" graph = RustWorkXGraph() graph.add_node_attr_key("bbox", dtype=pl.List(pl.Int64), default_value=[0, 0, 0, 0]) - graph.add_edge_attr_key("weight", dtype=pl.Float64, default_value=0.0) + graph.add_edge_attr_key("weight", dtype=pl.Float64) # Add nodes and edge node1_id = graph.add_node({"t": 0, "bbox": [10, 20, 15, 25]}) diff --git a/src/tracksdata/metrics/_test/test_matching.py b/src/tracksdata/metrics/_test/test_matching.py index 3daaa64f..2aa1ea46 100644 --- a/src/tracksdata/metrics/_test/test_matching.py +++ b/src/tracksdata/metrics/_test/test_matching.py @@ -134,10 +134,10 @@ def test_compute_weights_2d_and_3d(self): graph1 = RustWorkXGraph() graph2 = RustWorkXGraph() - graph1.add_node_attr_key("y", dtype=pl.Float64, default_value=0.0) - graph1.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) - graph2.add_node_attr_key("y", dtype=pl.Float64, default_value=0.0) - graph2.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) + graph1.add_node_attr_key("y", dtype=pl.Float64) + graph1.add_node_attr_key("x", dtype=pl.Float64) + graph2.add_node_attr_key("y", dtype=pl.Float64) + graph2.add_node_attr_key("x", dtype=pl.Float64) # Close nodes (distance ≈ 1.414) node1 = graph1.add_node({DEFAULT_ATTR_KEYS.T: 0, "y": 10.0, "x": 10.0}) @@ -157,10 +157,10 @@ def test_compute_weights_2d_and_3d(self): # Far nodes (distance ≈ 141.4) graph1 = RustWorkXGraph() graph2 = RustWorkXGraph() - graph1.add_node_attr_key("y", dtype=pl.Float64, default_value=0.0) - graph1.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) - graph2.add_node_attr_key("y", dtype=pl.Float64, default_value=0.0) - graph2.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) + graph1.add_node_attr_key("y", dtype=pl.Float64) + graph1.add_node_attr_key("x", dtype=pl.Float64) + graph2.add_node_attr_key("y", dtype=pl.Float64) + graph2.add_node_attr_key("x", dtype=pl.Float64) graph1.add_node({DEFAULT_ATTR_KEYS.T: 0, "y": 0.0, "x": 0.0}) graph2.add_node({DEFAULT_ATTR_KEYS.T: 0, "y": 100.0, "x": 100.0}) @@ -180,8 +180,8 @@ def test_compute_weights_2d_and_3d(self): graph2 = RustWorkXGraph() for key in ["z", "y", "x"]: - graph1.add_node_attr_key(key, dtype=pl.Float64, default_value=0.0) - graph2.add_node_attr_key(key, dtype=pl.Float64, default_value=0.0) + graph1.add_node_attr_key(key, dtype=pl.Float64) + graph2.add_node_attr_key(key, dtype=pl.Float64) node1 = graph1.add_node({DEFAULT_ATTR_KEYS.T: 0, "z": 5.0, "y": 10.0, "x": 15.0}) graph2.add_node({DEFAULT_ATTR_KEYS.T: 0, "z": 6.0, "y": 11.0, "x": 16.0}) @@ -202,10 +202,10 @@ def test_anisotropic_scaling(self): graph1 = RustWorkXGraph() graph2 = RustWorkXGraph() - graph1.add_node_attr_key("y", dtype=pl.Float64, default_value=0.0) - graph1.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) - graph2.add_node_attr_key("y", dtype=pl.Float64, default_value=0.0) - graph2.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) + graph1.add_node_attr_key("y", dtype=pl.Float64) + graph1.add_node_attr_key("x", dtype=pl.Float64) + graph2.add_node_attr_key("y", dtype=pl.Float64) + graph2.add_node_attr_key("x", dtype=pl.Float64) # Nodes far in y but close in x graph1.add_node({DEFAULT_ATTR_KEYS.T: 0, "y": 0.0, "x": 0.0}) @@ -234,10 +234,10 @@ def test_auto_detection_of_coordinates(self): graph1 = RustWorkXGraph() graph2 = RustWorkXGraph() - graph1.add_node_attr_key(DEFAULT_ATTR_KEYS.Y, dtype=pl.Float64, default_value=0.0) - graph1.add_node_attr_key(DEFAULT_ATTR_KEYS.X, dtype=pl.Float64, default_value=0.0) - graph2.add_node_attr_key(DEFAULT_ATTR_KEYS.Y, dtype=pl.Float64, default_value=0.0) - graph2.add_node_attr_key(DEFAULT_ATTR_KEYS.X, dtype=pl.Float64, default_value=0.0) + graph1.add_node_attr_key(DEFAULT_ATTR_KEYS.Y, dtype=pl.Float64) + graph1.add_node_attr_key(DEFAULT_ATTR_KEYS.X, dtype=pl.Float64) + graph2.add_node_attr_key(DEFAULT_ATTR_KEYS.Y, dtype=pl.Float64) + graph2.add_node_attr_key(DEFAULT_ATTR_KEYS.X, dtype=pl.Float64) node1 = graph1.add_node({DEFAULT_ATTR_KEYS.T: 0, DEFAULT_ATTR_KEYS.Y: 10.0, DEFAULT_ATTR_KEYS.X: 10.0}) graph2.add_node({DEFAULT_ATTR_KEYS.T: 0, DEFAULT_ATTR_KEYS.Y: 11.0, DEFAULT_ATTR_KEYS.X: 11.0}) @@ -258,10 +258,10 @@ def test_scale_validation(self): graph1 = RustWorkXGraph() graph2 = RustWorkXGraph() - graph1.add_node_attr_key("y", dtype=pl.Float64, default_value=0.0) - graph1.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) - graph2.add_node_attr_key("y", dtype=pl.Float64, default_value=0.0) - graph2.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) + graph1.add_node_attr_key("y", dtype=pl.Float64) + graph1.add_node_attr_key("x", dtype=pl.Float64) + graph2.add_node_attr_key("y", dtype=pl.Float64) + graph2.add_node_attr_key("x", dtype=pl.Float64) graph1.add_node({DEFAULT_ATTR_KEYS.T: 0, "y": 10.0, "x": 10.0}) graph2.add_node({DEFAULT_ATTR_KEYS.T: 0, "y": 11.0, "x": 11.0}) @@ -322,10 +322,10 @@ def test_graph_match_integration(self): graph1 = RustWorkXGraph() graph2 = RustWorkXGraph() - graph1.add_node_attr_key("y", dtype=pl.Float64, default_value=0.0) - graph1.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) - graph2.add_node_attr_key("y", dtype=pl.Float64, default_value=0.0) - graph2.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) + graph1.add_node_attr_key("y", dtype=pl.Float64) + graph1.add_node_attr_key("x", dtype=pl.Float64) + graph2.add_node_attr_key("y", dtype=pl.Float64) + graph2.add_node_attr_key("x", dtype=pl.Float64) node1 = graph1.add_node({DEFAULT_ATTR_KEYS.T: 0, "y": 10.0, "x": 10.0}) node2 = graph2.add_node({DEFAULT_ATTR_KEYS.T: 0, "y": 11.0, "x": 11.0}) diff --git a/src/tracksdata/nodes/_test/test_generic_nodes.py b/src/tracksdata/nodes/_test/test_generic_nodes.py index 51bdb342..e58f7665 100644 --- a/src/tracksdata/nodes/_test/test_generic_nodes.py +++ b/src/tracksdata/nodes/_test/test_generic_nodes.py @@ -61,7 +61,7 @@ def test_crop_func_attrs_simple_function_no_frames() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("value", dtype=pl.Float64, default_value=0.0) + graph.add_node_attr_key("value", dtype=pl.Float64) # Add nodes with values node1 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "value": 10.0}) @@ -378,7 +378,7 @@ def test_crop_func_attrs_batch_processing_without_frames() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("value", dtype=pl.Float64, default_value=0.0) + graph.add_node_attr_key("value", dtype=pl.Float64) # Add nodes with values node1 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "value": 10.0}) diff --git a/src/tracksdata/solvers/_test/test_ilp_solver.py b/src/tracksdata/solvers/_test/test_ilp_solver.py index 87bb9f63..86e31867 100644 --- a/src/tracksdata/solvers/_test/test_ilp_solver.py +++ b/src/tracksdata/solvers/_test/test_ilp_solver.py @@ -86,8 +86,8 @@ def test_ilp_solver_solve_no_edges(caplog: pytest.LogCaptureFixture) -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) - graph.add_node_attr_key("y", dtype=pl.Float64, default_value=0.0) + graph.add_node_attr_key("x", dtype=pl.Float64) + graph.add_node_attr_key("y", dtype=pl.Float64) # Add some nodes graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 0.0, "y": 0.0}) @@ -108,9 +108,9 @@ def test_ilp_solver_solve_simple_case() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) - graph.add_node_attr_key("y", dtype=pl.Float64, default_value=0.0) - graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, dtype=pl.Float64, default_value=0.0) + graph.add_node_attr_key("x", dtype=pl.Float64) + graph.add_node_attr_key("y", dtype=pl.Float64) + graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, dtype=pl.Float64) # Add nodes node0 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 0.0, "y": 0.0}) @@ -148,9 +148,9 @@ def test_ilp_solver_solve_with_appearance_weight() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) - graph.add_node_attr_key("y", dtype=pl.Float64, default_value=0.0) - graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, dtype=pl.Float64, default_value=0.0) + graph.add_node_attr_key("x", dtype=pl.Float64) + graph.add_node_attr_key("y", dtype=pl.Float64) + graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, dtype=pl.Float64) # Add nodes node0 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 0.0, "y": 0.0}) @@ -188,9 +188,9 @@ def test_ilp_solver_solve_with_disappearance_weight() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) - graph.add_node_attr_key("y", dtype=pl.Float64, default_value=0.0) - graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, dtype=pl.Float64, default_value=0.0) + graph.add_node_attr_key("x", dtype=pl.Float64) + graph.add_node_attr_key("y", dtype=pl.Float64) + graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, dtype=pl.Float64) # Add nodes node0 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 0.0, "y": 0.0}) @@ -228,9 +228,9 @@ def test_ilp_solver_solve_with_division_weight() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) - graph.add_node_attr_key("y", dtype=pl.Float64, default_value=0.0) - graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, dtype=pl.Float64, default_value=0.0) + graph.add_node_attr_key("x", dtype=pl.Float64) + graph.add_node_attr_key("y", dtype=pl.Float64) + graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, dtype=pl.Float64) # Add nodes for division scenario node0 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 0.0, "y": 0.0}) @@ -280,10 +280,10 @@ def test_ilp_solver_solve_custom_edge_weight_expr() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) - graph.add_node_attr_key("y", dtype=pl.Float64, default_value=0.0) - graph.add_edge_attr_key("custom_weight", dtype=pl.Float64, default_value=0.0) - graph.add_edge_attr_key("confidence", dtype=pl.Float64, default_value=0.0) + graph.add_node_attr_key("x", dtype=pl.Float64) + graph.add_node_attr_key("y", dtype=pl.Float64) + graph.add_edge_attr_key("custom_weight", dtype=pl.Float64) + graph.add_edge_attr_key("confidence", dtype=pl.Float64) # Add nodes node0 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 0.0, "y": 0.0}) @@ -310,9 +310,9 @@ def test_ilp_solver_solve_custom_node_weight_expr() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) - graph.add_node_attr_key("quality", dtype=pl.Float64, default_value=0.0) - graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, dtype=pl.Float64, default_value=0.0) + graph.add_node_attr_key("x", dtype=pl.Float64) + graph.add_node_attr_key("quality", dtype=pl.Float64) + graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, dtype=pl.Float64) # Add nodes with quality attribute node0 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 0.0, "quality": 0.9}) @@ -338,9 +338,9 @@ def test_ilp_solver_solve_custom_output_key() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) - graph.add_node_attr_key("y", dtype=pl.Float64, default_value=0.0) - graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, dtype=pl.Float64, default_value=0.0) + graph.add_node_attr_key("x", dtype=pl.Float64) + graph.add_node_attr_key("y", dtype=pl.Float64) + graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, dtype=pl.Float64) # Add nodes and edges node0 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 0.0, "y": 0.0}) @@ -366,9 +366,9 @@ def test_ilp_solver_solve_with_all_weights() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) - graph.add_node_attr_key("y", dtype=pl.Float64, default_value=0.0) - graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, dtype=pl.Float64, default_value=0.0) + graph.add_node_attr_key("x", dtype=pl.Float64) + graph.add_node_attr_key("y", dtype=pl.Float64) + graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, dtype=pl.Float64) # Add nodes node0 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 0.0, "y": 0.0}) @@ -431,9 +431,9 @@ def test_ilp_solver_division_constraint() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) - graph.add_node_attr_key("y", dtype=pl.Float64, default_value=0.0) - graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, dtype=pl.Float64, default_value=0.0) + graph.add_node_attr_key("x", dtype=pl.Float64) + graph.add_node_attr_key("y", dtype=pl.Float64) + graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, dtype=pl.Float64) # Create a scenario where division would be tempting but should be constrained # Time 0: 1 parent node @@ -504,9 +504,9 @@ def test_ilp_solver_solve_with_inf_expr() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) - graph.add_node_attr_key("y", dtype=pl.Float64, default_value=0.0) - graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, dtype=pl.Float64, default_value=0.0) + graph.add_node_attr_key("x", dtype=pl.Float64) + graph.add_node_attr_key("y", dtype=pl.Float64) + graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, dtype=pl.Float64) # Add nodes node0 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 0.0, "y": 5.0}) @@ -537,8 +537,8 @@ def test_ilp_solver_solve_with_pos_inf_rejection() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) - graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, dtype=pl.Float64, default_value=0.0) + graph.add_node_attr_key("x", dtype=pl.Float64) + graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, dtype=pl.Float64) # Add nodes node0 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 1.0}) @@ -566,8 +566,8 @@ def test_ilp_solver_solve_with_neg_inf_node_weight() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("priority", dtype=pl.Float64, default_value=0.0) - graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, dtype=pl.Float64, default_value=0.0) + graph.add_node_attr_key("priority", dtype=pl.Float64) + graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, dtype=pl.Float64) # Add nodes node0 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "priority": 1.0}) # High priority @@ -595,8 +595,8 @@ def test_ilp_solver_solve_with_inf_edge_weight() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) - graph.add_edge_attr_key("confidence", dtype=pl.Float64, default_value=0.0) + graph.add_node_attr_key("x", dtype=pl.Float64) + graph.add_edge_attr_key("confidence", dtype=pl.Float64) # Add nodes node0 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 0.0}) @@ -627,9 +627,9 @@ def test_ilp_solver_solve_with_overlaps() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) - graph.add_node_attr_key("y", dtype=pl.Float64, default_value=0.0) - graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, dtype=pl.Float64, default_value=0.0) + graph.add_node_attr_key("x", dtype=pl.Float64) + graph.add_node_attr_key("y", dtype=pl.Float64) + graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, dtype=pl.Float64) # Add nodes - overlapping pair at time t=1 node0 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 0.0, "y": 0.0}) @@ -682,9 +682,9 @@ def test_ilp_solver_solve_with_merge_weight() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) - graph.add_node_attr_key("y", dtype=pl.Float64, default_value=0.0) - graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, dtype=pl.Float64, default_value=0.0) + graph.add_node_attr_key("x", dtype=pl.Float64) + graph.add_node_attr_key("y", dtype=pl.Float64) + graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, dtype=pl.Float64) # Simple merge scenario: 2 tracks -> 1 merge point track1_node = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 0.0, "y": 0.0}) @@ -714,8 +714,8 @@ def test_ilp_solver_solve_with_merge_weight() -> None: def test_ilp_solver_solve_with_positive_merge_weight() -> None: """Test solving with positive merge weight to penalize merges.""" graph = RustWorkXGraph() - graph.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) - graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, dtype=pl.Float64, default_value=0.0) + graph.add_node_attr_key("x", dtype=pl.Float64) + graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, dtype=pl.Float64) # Create merge scenario track1_node0 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 0.0}) @@ -743,8 +743,8 @@ def test_ilp_solver_solve_with_positive_merge_weight() -> None: def test_ilp_solver_solve_with_merge_expression() -> None: """Test solving with merge weight as an expression.""" graph = RustWorkXGraph() - graph.add_node_attr_key("merge_cost", dtype=pl.Float64, default_value=0.0) - graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, dtype=pl.Float64, default_value=0.0) + graph.add_node_attr_key("merge_cost", dtype=pl.Float64) + graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, dtype=pl.Float64) # Two source nodes source1 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "merge_cost": 0.0}) @@ -775,8 +775,8 @@ def test_ilp_solver_solve_merge_and_division_combined() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) - graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, dtype=pl.Float64, default_value=0.0) + graph.add_node_attr_key("x", dtype=pl.Float64) + graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, dtype=pl.Float64) # Create complex scenario: merge followed by division # Time 0: Two separate tracks diff --git a/src/tracksdata/solvers/_test/test_nearest_neighbors_solver.py b/src/tracksdata/solvers/_test/test_nearest_neighbors_solver.py index 98b73615..10f91efd 100644 --- a/src/tracksdata/solvers/_test/test_nearest_neighbors_solver.py +++ b/src/tracksdata/solvers/_test/test_nearest_neighbors_solver.py @@ -48,8 +48,8 @@ def test_nearest_neighbors_solver_solve_no_edges() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) - graph.add_node_attr_key("y", dtype=pl.Float64, default_value=0.0) + graph.add_node_attr_key("x", dtype=pl.Float64) + graph.add_node_attr_key("y", dtype=pl.Float64) # Add some nodes graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 0.0, "y": 0.0}) @@ -67,9 +67,9 @@ def test_nearest_neighbors_solver_solve_simple_case() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) - graph.add_node_attr_key("y", dtype=pl.Float64, default_value=0.0) - graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, dtype=pl.Float64, default_value=0.0) + graph.add_node_attr_key("x", dtype=pl.Float64) + graph.add_node_attr_key("y", dtype=pl.Float64) + graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, dtype=pl.Float64) # Add nodes node0 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 0.0, "y": 0.0}) @@ -103,9 +103,9 @@ def test_nearest_neighbors_solver_solve_max_children_constraint() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) - graph.add_node_attr_key("y", dtype=pl.Float64, default_value=0.0) - graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, dtype=pl.Float64, default_value=0.0) + graph.add_node_attr_key("x", dtype=pl.Float64) + graph.add_node_attr_key("y", dtype=pl.Float64) + graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, dtype=pl.Float64) # Add nodes node0 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 0.0, "y": 0.0}) # Parent @@ -143,9 +143,9 @@ def test_nearest_neighbors_solver_solve_one_parent_constraint() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) - graph.add_node_attr_key("y", dtype=pl.Float64, default_value=0.0) - graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, dtype=pl.Float64, default_value=0.0) + graph.add_node_attr_key("x", dtype=pl.Float64) + graph.add_node_attr_key("y", dtype=pl.Float64) + graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, dtype=pl.Float64) # Add nodes node0 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 0.0, "y": 0.0}) # Parent 1 @@ -175,9 +175,9 @@ def test_nearest_neighbors_solver_solve_custom_weight_expr() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) - graph.add_node_attr_key("y", dtype=pl.Float64, default_value=0.0) - graph.add_edge_attr_key("custom_weight", dtype=pl.Float64, default_value=0.0) + graph.add_node_attr_key("x", dtype=pl.Float64) + graph.add_node_attr_key("y", dtype=pl.Float64) + graph.add_edge_attr_key("custom_weight", dtype=pl.Float64) # Add nodes node0 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 0.0, "y": 0.0}) @@ -208,10 +208,10 @@ def test_nearest_neighbors_solver_solve_complex_expression() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) - graph.add_node_attr_key("y", dtype=pl.Float64, default_value=0.0) - graph.add_edge_attr_key("distance", dtype=pl.Float64, default_value=0.0) - graph.add_edge_attr_key("confidence", dtype=pl.Float64, default_value=0.0) + graph.add_node_attr_key("x", dtype=pl.Float64) + graph.add_node_attr_key("y", dtype=pl.Float64) + graph.add_edge_attr_key("distance", dtype=pl.Float64) + graph.add_edge_attr_key("confidence", dtype=pl.Float64) # Add nodes node0 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 0.0, "y": 0.0}) @@ -243,9 +243,9 @@ def test_nearest_neighbors_solver_solve_custom_output_key() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) - graph.add_node_attr_key("y", dtype=pl.Float64, default_value=0.0) - graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, dtype=pl.Float64, default_value=0.0) + graph.add_node_attr_key("x", dtype=pl.Float64) + graph.add_node_attr_key("y", dtype=pl.Float64) + graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, dtype=pl.Float64) # Add nodes and edges node0 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 0.0, "y": 0.0}) @@ -271,9 +271,9 @@ def test_nearest_neighbors_solver_solve_with_overlaps() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) - graph.add_node_attr_key("y", dtype=pl.Float64, default_value=0.0) - graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, dtype=pl.Float64, default_value=0.0) + graph.add_node_attr_key("x", dtype=pl.Float64) + graph.add_node_attr_key("y", dtype=pl.Float64) + graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, dtype=pl.Float64) # Add nodes - overlapping pair at time t=1 node0 = graph.add_node({DEFAULT_ATTR_KEYS.T: 0, "x": 0.0, "y": 0.0}) @@ -326,9 +326,9 @@ def test_nearest_neighbors_solver_solve_large_graph() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", dtype=pl.Float64, default_value=0.0) - graph.add_node_attr_key("y", dtype=pl.Float64, default_value=0.0) - graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, dtype=pl.Float64, default_value=0.0) + graph.add_node_attr_key("x", dtype=pl.Float64) + graph.add_node_attr_key("y", dtype=pl.Float64) + graph.add_edge_attr_key(DEFAULT_ATTR_KEYS.EDGE_DIST, dtype=pl.Float64) # Create a more complex graph structure # Time 0: nodes 0, 1 From 2bcafcb4a1f2296c3958032e9cafdfeb599b04e0 Mon Sep 17 00:00:00 2001 From: Jordao Bragantini Date: Tue, 20 Jan 2026 11:39:23 -0800 Subject: [PATCH 08/14] improving testing --- .../graph/_test/test_graph_backends.py | 33 +++++++------------ src/tracksdata/utils/_dtypes.py | 3 +- 2 files changed, 12 insertions(+), 24 deletions(-) diff --git a/src/tracksdata/graph/_test/test_graph_backends.py b/src/tracksdata/graph/_test/test_graph_backends.py index 70ef1c2f..fc52fe59 100644 --- a/src/tracksdata/graph/_test/test_graph_backends.py +++ b/src/tracksdata/graph/_test/test_graph_backends.py @@ -408,33 +408,22 @@ def test_subgraph_with_node_ids_and_filters(graph_backend: BaseGraph) -> None: @pytest.mark.parametrize( - "value", + "dtype, value", [ - pytest.param(42, id="int-42"), - pytest.param(3.14, id="float-3.14"), - pytest.param("test_string", id="str-test_string"), - pytest.param(np.array([1, 2, 3]), id="ndarray-1d"), - pytest.param(np.array([[1.0, 2.0], [3.0, 4.0]]), id="ndarray-2d"), - pytest.param(Mask(mask=np.array([[True, False], [False, True]]), bbox=(0, 0, 2, 2)), id="mask"), - pytest.param(True, id="bool-True"), - pytest.param(False, id="bool-False"), + pytest.param(pl.Int64, 42, id="int-42"), + pytest.param(pl.Float64, 3.14, id="float-3.14"), + pytest.param(pl.String, "test_string", id="str-test_string"), + pytest.param(pl.Array(pl.Int64, 3), np.array([1, 2, 3]), id="ndarray-1d"), + pytest.param(pl.Array(pl.Float64, (2, 2)), np.array([[1.0, 2.0], [3.0, 4.0]]), id="ndarray-2d"), + pytest.param(pl.Object, Mask(mask=np.array([[True, False], [False, True]]), bbox=(0, 0, 2, 2)), id="mask"), + pytest.param(pl.Boolean, True, id="bool-True"), + pytest.param(pl.Boolean, False, id="bool-False"), ], ) -def test_add_node_attr_key(graph_backend: BaseGraph, value) -> None: +def test_add_node_attr_key(graph_backend: BaseGraph, dtype: pl.DataType, value: Any) -> None: """Test adding new node attribute keys.""" node = graph_backend.add_node({"t": 0}) - # Infer dtype from value - if isinstance(value, bool): - dtype = pl.Boolean - elif isinstance(value, int): - dtype = pl.Int64 - elif isinstance(value, float): - dtype = pl.Float64 - elif isinstance(value, str): - dtype = pl.String - else: - # For arrays, masks, and other objects - dtype = pl.Object + graph_backend.add_node_attr_key("new_attribute", dtype, default_value=value) df = graph_backend.filter(node_ids=[node]).node_attrs(attr_keys=["new_attribute"]) diff --git a/src/tracksdata/utils/_dtypes.py b/src/tracksdata/utils/_dtypes.py index 9e215a98..cce97bd3 100644 --- a/src/tracksdata/utils/_dtypes.py +++ b/src/tracksdata/utils/_dtypes.py @@ -427,8 +427,7 @@ def validate_default_value_dtype_compatibility(default_value: Any, dtype: pl.Dat try: # Try to create a polars series and cast - s = pl.Series([default_value]) - s.cast(dtype) + pl.Series([default_value], dtype=dtype) except Exception as e: raise ValueError( f"default_value {default_value!r} (type: {type(default_value).__name__}) " From 8288a6933f30b8a745387c33db8d4b239888fbdd Mon Sep 17 00:00:00 2001 From: Jordao Bragantini Date: Tue, 20 Jan 2026 13:03:00 -0800 Subject: [PATCH 09/14] fixing spatial filtering typing --- .../graph/filters/_test/test_spatial_filter.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/tracksdata/graph/filters/_test/test_spatial_filter.py b/src/tracksdata/graph/filters/_test/test_spatial_filter.py index 52778539..41b08a07 100644 --- a/src/tracksdata/graph/filters/_test/test_spatial_filter.py +++ b/src/tracksdata/graph/filters/_test/test_spatial_filter.py @@ -35,7 +35,7 @@ def sample_graph() -> RustWorkXGraph: def sample_bbox_graph() -> RustWorkXGraph: """Create a sample graph with nodes for bounding box testing.""" graph = RustWorkXGraph() - graph.add_node_attr_key("bbox", dtype=pl.List(pl.Int64), default_value=[0, 0, 0, 0, 0, 0]) + graph.add_node_attr_key("bbox", dtype=pl.Array(pl.Int64, 6)) # Add some nodes with bounding box coordinates nodes = [ @@ -191,7 +191,7 @@ def test_spatial_filter_with_edges() -> None: def test_bbox_spatial_filter_overlaps() -> None: """Test BoundingBoxSpatialFilter overlaps with existing nodes.""" graph = RustWorkXGraph() - graph.add_node_attr_key("bbox", dtype=pl.List(pl.Int64), default_value=[0, 0, 0, 0]) + graph.add_node_attr_key("bbox", dtype=pl.Array(pl.Int64, 4)) # Add nodes with bounding boxes bboxes = [ [0, 20, 10, 30], # Node 1 @@ -214,7 +214,7 @@ def test_bbox_spatial_filter_overlaps() -> None: def test_bbox_spatial_filter_with_edges() -> None: """Test SpatialFilter preserves edges in subgraphs.""" graph = RustWorkXGraph() - graph.add_node_attr_key("bbox", dtype=pl.List(pl.Int64), default_value=[0, 0, 0, 0]) + graph.add_node_attr_key("bbox", dtype=pl.Array(pl.Int64, 4)) graph.add_edge_attr_key("weight", dtype=pl.Float64) # Add nodes and edge @@ -263,7 +263,7 @@ def test_bbox_spatial_filter_querying(sample_bbox_graph: RustWorkXGraph) -> None def test_bbox_spatial_filter_dimensions() -> None: """Test BoundingBoxSpatialFilter with different coordinate dimensions.""" graph = RustWorkXGraph() - graph.add_node_attr_key("bbox", dtype=pl.List(pl.Int64), default_value=[0, 0, 0, 0, 0, 0]) + graph.add_node_attr_key("bbox", dtype=pl.Array(pl.Int64, 6)) graph.add_node({"t": 0, "bbox": [0, 10, 20, 1, 15, 25]}) # Test 3D coordinates @@ -284,7 +284,7 @@ def test_bbox_spatial_filter_dimensions() -> None: def test_bbox_spatial_filter_error_handling() -> None: """Test error handling for mismatched min/max attribute lengths.""" graph = RustWorkXGraph() - graph.add_node_attr_key("bbox", dtype=pl.List(pl.Int64), default_value=[0, 0, 0, 0]) + graph.add_node_attr_key("bbox", dtype=pl.Array(pl.Int64, 4)) graph.add_node({"t": 0, "bbox": [10, 20, 15, 25, 14]}) # Test mismatched min/max attributes length with pytest.raises(ValueError, match="Bounding box coordinates must have even number of dimensions"): From 58256e2c0f19a938e2c6f8e3aeb6fd0292eb89bc Mon Sep 17 00:00:00 2001 From: Jordao Bragantini Date: Tue, 20 Jan 2026 13:18:50 -0800 Subject: [PATCH 10/14] simplifying sql type detection case --- src/tracksdata/graph/_sql_graph.py | 41 +++++++++++++++--------------- 1 file changed, 21 insertions(+), 20 deletions(-) diff --git a/src/tracksdata/graph/_sql_graph.py b/src/tracksdata/graph/_sql_graph.py index 2b8f2181..dd96438a 100644 --- a/src/tracksdata/graph/_sql_graph.py +++ b/src/tracksdata/graph/_sql_graph.py @@ -1,7 +1,7 @@ import binascii from collections.abc import Callable, Sequence from enum import Enum -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, ClassVar import cloudpickle import numpy as np @@ -582,30 +582,31 @@ def _init_schemas_from_tables(self) -> None: dtype=pl_dtype, ) + # Mapping from SQLAlchemy types to polars dtypes for schema loading + _SQLALCHEMY_TO_POLARS_TYPE_MAP: ClassVar[dict[TypeEngine, pl.DataType]] = { + sa.Boolean: pl.Boolean, + sa.SmallInteger: pl.Int16, + sa.Integer: pl.Int32, + sa.BigInteger: pl.Int64, + sa.Float: pl.Float64, + sa.String: pl.String, + sa.Text: pl.String, + sa.LargeBinary: pl.Object, + sa.PickleType: pl.Object, + } + def _sqlalchemy_type_to_polars_dtype(self, sa_type: TypeEngine) -> pl.DataType: """ Convert a SQLAlchemy type to a polars dtype. This is a best-effort conversion for loading existing schemas. """ - if isinstance(sa_type, sa.Boolean): - return pl.Boolean - elif isinstance(sa_type, sa.SmallInteger): - return pl.Int16 - elif isinstance(sa_type, sa.Integer): - return pl.Int32 - elif isinstance(sa_type, sa.BigInteger): - return pl.Int64 - elif isinstance(sa_type, sa.Float): - return pl.Float64 - elif isinstance(sa_type, sa.String | sa.Text): - return pl.String - elif isinstance(sa_type, sa.LargeBinary | sa.PickleType): - # For pickled/binary types, default to Object - # Array types will need to be re-added explicitly - return pl.Object - else: - # Fallback to Object for unknown types - return pl.Object + # Check the type map for known types + for sa_type_class, pl_dtype in self._SQLALCHEMY_TO_POLARS_TYPE_MAP.items(): + if isinstance(sa_type, sa_type_class): + return pl_dtype + + # Fallback to Object for unknown types + return pl.Object def _restore_pickled_column_types(self, table: sa.Table) -> None: for column in table.columns: From a05788cb5e19b1799120db349d8650cb74dbdbce Mon Sep 17 00:00:00 2001 From: Jordao Bragantini Date: Tue, 20 Jan 2026 13:27:46 -0800 Subject: [PATCH 11/14] refactoring generic usage --- src/tracksdata/nodes/_generic_nodes.py | 28 +++++++----- src/tracksdata/nodes/_mask.py | 19 ++++++-- .../nodes/_test/test_generic_nodes.py | 45 ++++++++++++++----- 3 files changed, 67 insertions(+), 25 deletions(-) diff --git a/src/tracksdata/nodes/_generic_nodes.py b/src/tracksdata/nodes/_generic_nodes.py index fa88b14e..8831af20 100644 --- a/src/tracksdata/nodes/_generic_nodes.py +++ b/src/tracksdata/nodes/_generic_nodes.py @@ -2,7 +2,6 @@ from typing import Any, TypeVar import numpy as np -import polars as pl from numpy.typing import NDArray from tracksdata.attrs import NodeAttr @@ -59,6 +58,8 @@ def intensity_median_times_t(image: NDArray, mask: Mask, t: int) -> float: attr_keys=["mask", "t"], ) + graph.add_node_attr_key("intensity_median", pl.Float64) + crop_attrs.add_node_attrs(graph, frames=video) ``` @@ -85,6 +86,8 @@ def intensity_median_times_t(image: NDArray, masks: list[Mask], t: list[int]) -> attr_keys=["mask", "t"], ) + graph.add_node_attr_key("intensity_median", pl.Float64) + crop_attrs.add_node_attrs(graph, frames=video) ``` """ @@ -95,28 +98,31 @@ def __init__( self, func: Callable[[T], R] | Callable[[list[T]], list[R]], output_key: str, - default_value: Any = None, attr_keys: Sequence[str] = (), batch_size: int = 0, ) -> None: super().__init__(output_key) self.func = func + self.output_key = output_key self.attr_keys = attr_keys - self.default_value = default_value self.batch_size = batch_size def _init_node_attrs(self, graph: BaseGraph) -> None: """ - Initialize the node attributes for the graph. + Validate that the output key exists in the graph. + + The output key must be added to the graph before using this operator. + + Raises + ------ + ValueError + If the output key is not found in the graph. """ if self.output_key not in graph.node_attr_keys(): - # Infer dtype from default_value using polars - if self.default_value is None: - dtype = pl.Object - else: - # Use polars to infer the dtype from the value - dtype = pl.Series([self.default_value]).dtype - graph.add_node_attr_key(self.output_key, dtype, self.default_value) + raise ValueError( + f"Output key '{self.output_key}' not found in graph. " + f"You must add it with graph.add_node_attr_key() before using this operator." + ) def add_node_attrs( self, diff --git a/src/tracksdata/nodes/_mask.py b/src/tracksdata/nodes/_mask.py index 4320ba29..97b91d60 100644 --- a/src/tracksdata/nodes/_mask.py +++ b/src/tracksdata/nodes/_mask.py @@ -4,17 +4,20 @@ import blosc2 import numpy as np +import polars as pl import skimage.morphology as morph from numpy.typing import ArrayLike, NDArray from skimage.measure import regionprops -if TYPE_CHECKING: - from skimage.measure._regionprops import RegionProperties - from tracksdata.constants import DEFAULT_ATTR_KEYS from tracksdata.functional._iou import fast_intersection_with_bbox, fast_iou_with_bbox from tracksdata.nodes._generic_nodes import GenericFuncNodeAttrs +if TYPE_CHECKING: + from skimage.measure._regionprops import RegionProperties + + from tracksdata.graph._base_graph import BaseGraph + @lru_cache(maxsize=5) def _nd_sphere( @@ -497,6 +500,8 @@ def __init__( f"Expected image shape {image_shape} to have the same number of dimensions as attr_keys '{attr_keys}'." ) + self._image_shape = image_shape + super().__init__( func=lambda **kwargs: Mask.from_coordinates( center=np.asarray(list(kwargs.values())), @@ -505,6 +510,12 @@ def __init__( ), output_key=output_key, attr_keys=attr_keys, - default_value=None, batch_size=0, ) + + def _init_node_attrs(self, graph: "BaseGraph") -> None: + """ + Validate that the output key exists in the graph. + """ + if self.output_key not in graph.node_attr_keys(): + graph.add_node_attr_key(self.output_key, pl.Array(pl.Int64, 2 * len(self._image_shape))) diff --git a/src/tracksdata/nodes/_test/test_generic_nodes.py b/src/tracksdata/nodes/_test/test_generic_nodes.py index e58f7665..2e00c1c8 100644 --- a/src/tracksdata/nodes/_test/test_generic_nodes.py +++ b/src/tracksdata/nodes/_test/test_generic_nodes.py @@ -70,6 +70,9 @@ def test_crop_func_attrs_simple_function_no_frames() -> None: def double_value(value: float) -> float: return value * 2.0 + # Register output key before using operator + graph.add_node_attr_key("doubled_value", dtype=pl.Float64) + # Create operator and add attributes operator = GenericFuncNodeAttrs( func=double_value, @@ -118,6 +121,9 @@ def intensity_sum(frame: NDArray, mask: Mask) -> float: cropped = mask.crop(frame) return float(np.sum(cropped[mask.mask])) + # Register output key before using operator + graph.add_node_attr_key("intensity_sum", dtype=pl.Float64) + # Create operator and add attributes operator = GenericFuncNodeAttrs( func=intensity_sum, @@ -172,6 +178,9 @@ def intensity_sum_times_multiplier(frame: NDArray, mask: Mask, multiplier: float cropped = mask.crop(frame) return float(np.sum(cropped[mask.mask]) * multiplier) + # Register output key before using operator + graph.add_node_attr_key("weighted_intensity", dtype=pl.Float64) + # Create operator and add attributes operator = GenericFuncNodeAttrs( func=intensity_sum_times_multiplier, @@ -221,6 +230,12 @@ def return_dict(mask: Mask) -> dict[str, int]: def return_array(mask: Mask) -> NDArray: return np.asarray([1, 2, 3]) + # Register output keys before using operators + graph.add_node_attr_key("string_result", dtype=pl.String) + graph.add_node_attr_key("list_result", dtype=pl.List(pl.Int64)) + graph.add_node_attr_key("dict_result", dtype=pl.Struct({"count": pl.Int64})) + graph.add_node_attr_key("array_result", dtype=pl.Array(pl.Int64, 3)) + # Test string return type operator_str = GenericFuncNodeAttrs( func=return_string, @@ -262,32 +277,32 @@ def return_array(mask: Mask) -> NDArray: def test_crop_func_attrs_error_handling_missing_attr_key() -> None: - """Test error handling when required attr_key is missing.""" + """Test error handling when output key is not registered.""" graph = RustWorkXGraph() # Register attribute keys graph.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, pl.Object) - # Note: "value" is not registered + # Note: "result" output key is NOT registered # Create test mask mask_data = np.array([[True, True], [True, False]], dtype=bool) mask = Mask(mask_data, bbox=np.array([0, 0, 2, 2])) - # Add node without the required attribute + # Add node graph.add_node({DEFAULT_ATTR_KEYS.T: 0, DEFAULT_ATTR_KEYS.MASK: mask}) - def use_value(mask: Mask, value: float) -> float: - return value * 2.0 + def double_mask_sum(mask: Mask) -> float: + return float(np.sum(mask.mask)) * 2.0 - # Create operator that requires "value" attribute + # Create operator with output key that is not registered operator = GenericFuncNodeAttrs( - func=use_value, + func=double_mask_sum, output_key="result", - attr_keys=["value"], + attr_keys=[DEFAULT_ATTR_KEYS.MASK], ) - # Should raise an error when trying to access missing attribute - with pytest.raises(KeyError): # Specific exception type depends on graph backend + # Should raise ValueError when output key is not registered + with pytest.raises(ValueError, match="Output key 'result' not found in graph"): operator.add_node_attrs(graph) @@ -322,6 +337,9 @@ def intensity_sum(frame: NDArray, mask: Mask) -> float: cropped = mask.crop(frame) return float(np.sum(cropped[mask.mask])) + # Register output key before using operator + graph.add_node_attr_key("intensity_sum", dtype=pl.Float64) + # Create operator and add attributes operator = GenericFuncNodeAttrs( func=intensity_sum, @@ -350,6 +368,7 @@ def test_crop_func_attrs_empty_graph() -> None: # Register attribute keys graph.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, pl.Object) + graph.add_node_attr_key("result", dtype=pl.Float64) def dummy_func(mask: Mask) -> float: return 1.0 @@ -389,6 +408,9 @@ def batch_double_value(value: list[float]) -> list[float]: """Batch function that doubles each value.""" return [v * 2.0 for v in value] + # Register output key before using operator + graph.add_node_attr_key("doubled_value", dtype=pl.Float64) + # Create operator with batch_size = 2 operator = GenericFuncNodeAttrs( func=batch_double_value, @@ -447,6 +469,9 @@ def batch_intensity_sum(frame: NDArray, mask: list[Mask]) -> list[float]: results.append(float(np.sum(cropped[m.mask]))) return results + # Register output key before using operator + graph.add_node_attr_key("intensity_sum", dtype=pl.Float64) + # Create operator with batch_size = 2 operator = GenericFuncNodeAttrs( func=batch_intensity_sum, From b83b58745160eaa15f5c8ce032c150fd3f187525 Mon Sep 17 00:00:00 2001 From: Jordao Bragantini Date: Tue, 20 Jan 2026 13:36:51 -0800 Subject: [PATCH 12/14] moving tests and fixing get attr usage --- .../_test/test_attr_key_dtype.py | 27 ++++++++++++++----- 1 file changed, 20 insertions(+), 7 deletions(-) rename src/tracksdata/{graph => utils}/_test/test_attr_key_dtype.py (89%) diff --git a/src/tracksdata/graph/_test/test_attr_key_dtype.py b/src/tracksdata/utils/_test/test_attr_key_dtype.py similarity index 89% rename from src/tracksdata/graph/_test/test_attr_key_dtype.py rename to src/tracksdata/utils/_test/test_attr_key_dtype.py index 5c0bff70..b6bb37c6 100644 --- a/src/tracksdata/graph/_test/test_attr_key_dtype.py +++ b/src/tracksdata/utils/_test/test_attr_key_dtype.py @@ -57,6 +57,19 @@ def test_add_node_attr_key_with_array_dtype(self): assert default.dtype == np.float64 np.testing.assert_array_equal(default, np.zeros(4, dtype=np.float64)) + def test_add_node_attr_key_with_nd_array_dtype(self): + """Test adding node attribute with ndarray dtype.""" + graph = RustWorkXGraph() + graph.add_node_attr_key("something", pl.Array(pl.Float64, (5, 3, 2))) + + assert graph._node_attr_schemas["something"].dtype == pl.Array(pl.Float64, (5, 3, 2)) + default = graph._node_attr_schemas["something"].default_value + + assert isinstance(default, np.ndarray) + assert default.shape == (5, 3, 2) + assert default.dtype == np.float64 + np.testing.assert_array_equal(default, np.zeros((5, 3, 2), dtype=np.float64)) + def test_add_node_attr_key_with_schema_object(self): """Test adding node attribute using AttrSchema object.""" graph = RustWorkXGraph() @@ -96,15 +109,15 @@ def test_defaults_applied_to_existing_nodes(self): graph = RustWorkXGraph() # Add a node - node_id = graph.add_node({"t": 0}) + graph.add_node({"t": 0}) # Add new attribute graph.add_node_attr_key("score", pl.Float64) # Verify the default was applied - node_attrs = graph.rx_graph[node_id] - assert "score" in node_attrs - assert node_attrs["score"] == -1.0 + node_attrs = graph.node_attrs() + assert "score" in node_attrs.columns + assert node_attrs["score"].item() == -1.0 def test_add_edge_attr_key_with_dtype_only(self): """Test adding edge attribute with dtype only.""" @@ -139,9 +152,9 @@ def test_defaults_applied_to_existing_edges(self): graph.add_edge_attr_key("weight", pl.Float64, default_value=1.0) # Verify the default was applied - edge_attrs = graph.rx_graph.get_edge_data(n1, n2) - assert "weight" in edge_attrs - assert edge_attrs["weight"] == 1.0 + edge_attrs = graph.edge_attrs() + assert "weight" in edge_attrs.columns + assert edge_attrs["weight"].item() == 1.0 def test_node_attr_keys_returns_keys(self): """Test that node_attr_keys returns the correct keys.""" From 3ca91640abd2f919a5bb7a6577c76d1160d928b2 Mon Sep 17 00:00:00 2001 From: Jordao Bragantini Date: Tue, 20 Jan 2026 13:50:21 -0800 Subject: [PATCH 13/14] cleanup on sql type conversion --- src/tracksdata/graph/_sql_graph.py | 33 +++--------------- src/tracksdata/utils/_dtypes.py | 56 +++++++++++++++++++++++++++--- 2 files changed, 56 insertions(+), 33 deletions(-) diff --git a/src/tracksdata/graph/_sql_graph.py b/src/tracksdata/graph/_sql_graph.py index dd96438a..bcadf01e 100644 --- a/src/tracksdata/graph/_sql_graph.py +++ b/src/tracksdata/graph/_sql_graph.py @@ -1,7 +1,7 @@ import binascii from collections.abc import Callable, Sequence from enum import Enum -from typing import TYPE_CHECKING, Any, ClassVar +from typing import TYPE_CHECKING, Any import cloudpickle import numpy as np @@ -22,6 +22,7 @@ AttrSchema, polars_dtype_to_sqlalchemy_type, process_attr_key_args, + sqlalchemy_type_to_polars_dtype, ) from tracksdata.utils._logging import LOG from tracksdata.utils._signal import is_signal_on @@ -560,7 +561,7 @@ def _init_schemas_from_tables(self) -> None: if column_name not in self._node_attr_schemas: column = self.Node.__table__.columns[column_name] # Infer polars dtype from SQLAlchemy type - pl_dtype = self._sqlalchemy_type_to_polars_dtype(column.type) + pl_dtype = sqlalchemy_type_to_polars_dtype(column.type) # AttrSchema.__post_init__ will infer the default_value self._node_attr_schemas[column_name] = AttrSchema( key=column_name, @@ -575,39 +576,13 @@ def _init_schemas_from_tables(self) -> None: if column_name not in self._edge_attr_schemas: column = self.Edge.__table__.columns[column_name] # Infer polars dtype from SQLAlchemy type - pl_dtype = self._sqlalchemy_type_to_polars_dtype(column.type) + pl_dtype = sqlalchemy_type_to_polars_dtype(column.type) # AttrSchema.__post_init__ will infer the default_value self._edge_attr_schemas[column_name] = AttrSchema( key=column_name, dtype=pl_dtype, ) - # Mapping from SQLAlchemy types to polars dtypes for schema loading - _SQLALCHEMY_TO_POLARS_TYPE_MAP: ClassVar[dict[TypeEngine, pl.DataType]] = { - sa.Boolean: pl.Boolean, - sa.SmallInteger: pl.Int16, - sa.Integer: pl.Int32, - sa.BigInteger: pl.Int64, - sa.Float: pl.Float64, - sa.String: pl.String, - sa.Text: pl.String, - sa.LargeBinary: pl.Object, - sa.PickleType: pl.Object, - } - - def _sqlalchemy_type_to_polars_dtype(self, sa_type: TypeEngine) -> pl.DataType: - """ - Convert a SQLAlchemy type to a polars dtype. - This is a best-effort conversion for loading existing schemas. - """ - # Check the type map for known types - for sa_type_class, pl_dtype in self._SQLALCHEMY_TO_POLARS_TYPE_MAP.items(): - if isinstance(sa_type, sa_type_class): - return pl_dtype - - # Fallback to Object for unknown types - return pl.Object - def _restore_pickled_column_types(self, table: sa.Table) -> None: for column in table.columns: if isinstance(column.type, sa.LargeBinary): diff --git a/src/tracksdata/utils/_dtypes.py b/src/tracksdata/utils/_dtypes.py index cce97bd3..8eea2128 100644 --- a/src/tracksdata/utils/_dtypes.py +++ b/src/tracksdata/utils/_dtypes.py @@ -327,13 +327,12 @@ def infer_default_value_from_dtype(dtype: pl.DataType) -> Any: # Handle array types - create zeros with correct shape and dtype if isinstance(dtype, pl.Array): inner_dtype = dtype.inner - shape = dtype.size # Use size instead of width (deprecated) - numpy_dtype = polars_dtype_to_numpy_dtype(inner_dtype, allow_sequence=False) - return np.zeros(shape, dtype=numpy_dtype) + numpy_dtype = polars_dtype_to_numpy_dtype(inner_dtype, allow_sequence=True) + return np.zeros(dtype.shape, dtype=numpy_dtype) # Handle list types if isinstance(dtype, pl.List): - return None + return [] # Use dictionary lookup for standard types return DTYPE_DEFAULT_MAP.get(dtype, None) @@ -397,6 +396,55 @@ def polars_dtype_to_sqlalchemy_type(dtype: pl.DataType) -> TypeEngine: return sa.PickleType() +# SQLAlchemy to polars type mapping for schema loading +# Order matters: more specific types must come before more general types +# (e.g., BigInteger before Integer, since BigInteger is a subclass of Integer) +_SQLALCHEMY_TO_POLARS_TYPE_MAP = [ + (sa.Boolean, pl.Boolean), + (sa.BigInteger, pl.Int64), # Must come before Integer + (sa.SmallInteger, pl.Int16), # Must come before Integer + (sa.Integer, pl.Int32), + (sa.Float, pl.Float64), + (sa.Text, pl.String), # Must come before String + (sa.String, pl.String), + (sa.PickleType, pl.Object), # Must come before LargeBinary + (sa.LargeBinary, pl.Object), +] + + +def sqlalchemy_type_to_polars_dtype(sa_type: TypeEngine) -> pl.DataType: + """ + Convert a SQLAlchemy type to a polars dtype. + + This is a best-effort conversion for loading existing database schemas. + + Parameters + ---------- + sa_type : TypeEngine + The SQLAlchemy type. + + Returns + ------- + pl.DataType + The corresponding polars dtype. + + Examples + -------- + >>> sqlalchemy_type_to_polars_dtype(sa.BigInteger()) + Int64 + >>> sqlalchemy_type_to_polars_dtype(sa.Boolean()) + Boolean + """ + # Check the type map for known types + # Order matters: more specific types are checked first + for sa_type_class, pl_dtype in _SQLALCHEMY_TO_POLARS_TYPE_MAP: + if isinstance(sa_type, sa_type_class): + return pl_dtype + + # Fallback to Object for unknown types + return pl.Object + + def validate_default_value_dtype_compatibility(default_value: Any, dtype: pl.DataType) -> None: """ Validate that a default value is compatible with a polars dtype. From 7940fee5dad7efdff224a7b87e064ad839a6153e Mon Sep 17 00:00:00 2001 From: Jordao Bragantini Date: Tue, 20 Jan 2026 16:22:25 -0800 Subject: [PATCH 14/14] fixing docs --- src/tracksdata/nodes/_generic_nodes.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/tracksdata/nodes/_generic_nodes.py b/src/tracksdata/nodes/_generic_nodes.py index 8831af20..95fdcf0b 100644 --- a/src/tracksdata/nodes/_generic_nodes.py +++ b/src/tracksdata/nodes/_generic_nodes.py @@ -29,9 +29,6 @@ class GenericFuncNodeAttrs(BaseNodeAttrsOperator): Key of the new attribute to add. attr_keys : Sequence[str], optional Additional attributes to pass to the `func` as keyword arguments. - default_value : Any, optional - Default value to use for the new attribute. - TODO: this should be replaced by a more advanced typing that takes default values. batch_size : int, optional Batch size to use for the function. If 0, the function will be called for each node separately.