diff --git a/src/tracksdata/array/_graph_array.py b/src/tracksdata/array/_graph_array.py index aebe8a2d..018e7ae6 100644 --- a/src/tracksdata/array/_graph_array.py +++ b/src/tracksdata/array/_graph_array.py @@ -153,7 +153,7 @@ def __init__( buffer_cache_size: int | None = None, dtype: np.dtype | None = None, ): - if attr_key not in graph.node_attr_keys(): + if attr_key not in graph.node_attr_keys(return_ids=True): raise ValueError(f"Attribute key '{attr_key}' not found in graph. Expected '{graph.node_attr_keys()}'") self.graph = graph diff --git a/src/tracksdata/edges/_test/test_distance_edges.py b/src/tracksdata/edges/_test/test_distance_edges.py index 48673cc3..118fabc9 100644 --- a/src/tracksdata/edges/_test/test_distance_edges.py +++ b/src/tracksdata/edges/_test/test_distance_edges.py @@ -320,8 +320,8 @@ def test_distance_edges_neighbors_per_frame_false() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", 0.0) - graph.add_node_attr_key("y", 0.0) + graph.add_node_attr_key("x", pl.Float64, 0.0) + graph.add_node_attr_key("y", pl.Float64, 0.0) # Add nodes at t=0, t=1, t=2 # At t=0: two nodes close to origin @@ -363,8 +363,8 @@ def test_distance_edges_neighbors_per_frame_true() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", 0.0) - graph.add_node_attr_key("y", 0.0) + graph.add_node_attr_key("x", pl.Float64, 0.0) + graph.add_node_attr_key("y", pl.Float64, 0.0) # Add nodes at t=0, t=1, t=2 # At t=0: two nodes close to origin @@ -408,8 +408,8 @@ def test_distance_edges_neighbors_per_frame_with_distance_threshold() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", 0.0) - graph.add_node_attr_key("y", 0.0) + graph.add_node_attr_key("x", pl.Float64, 0.0) + graph.add_node_attr_key("y", pl.Float64, 0.0) # Add nodes at t=0 (far away), t=1 (close), t=2 # At t=0: nodes very far from where t=2 node will be @@ -444,8 +444,8 @@ def test_distance_edges_neighbors_per_frame_single_delta_t() -> None: graph = RustWorkXGraph() # Register attribute keys - graph.add_node_attr_key("x", 0.0) - graph.add_node_attr_key("y", 0.0) + graph.add_node_attr_key("x", pl.Float64, 0.0) + graph.add_node_attr_key("y", pl.Float64, 0.0) # Add nodes at t=0 and t=1 for i in range(3): diff --git a/src/tracksdata/functional/_test/test_apply.py b/src/tracksdata/functional/_test/test_apply.py index 7791165c..fa775852 100644 --- a/src/tracksdata/functional/_test/test_apply.py +++ b/src/tracksdata/functional/_test/test_apply.py @@ -112,8 +112,8 @@ def test_apply_tiled_default_attrs(sample_graph: RustWorkXGraph) -> None: def test_apply_tiled_2d_tiling() -> None: """Test apply_tiled with 2D spatial coordinates.""" graph = RustWorkXGraph() - graph.add_node_attr_key("y", dtype=pl.Int64) - graph.add_node_attr_key("x", dtype=pl.Int64) + graph.add_node_attr_key("y", dtype=pl.Float64) + graph.add_node_attr_key("x", dtype=pl.Float64) for y in [5, 11, 14]: for x in [10, 30]: @@ -192,6 +192,9 @@ def test_apply_tile_scale_invariance() -> None: for scale in scales: graph = RustWorkXGraph() + # hack: updating schema + graph._node_attr_schemas()["t"].dtype = pl.Float64 + for p in pos: graph.add_node({"t": p * scale}) diff --git a/src/tracksdata/graph/_base_graph.py b/src/tracksdata/graph/_base_graph.py index b384bdbb..5ece3425 100644 --- a/src/tracksdata/graph/_base_graph.py +++ b/src/tracksdata/graph/_base_graph.py @@ -63,7 +63,7 @@ def supports_custom_indices(self) -> bool: def _validate_attributes( attrs: dict[str, Any], reference_keys: list[str], - mode: str, + mode: Literal["node", "edge"], ) -> None: """ Validate the attributes of a node. @@ -85,11 +85,18 @@ def _validate_attributes( f"`graph.add_{mode}_attr_key(key, default_value)`" ) - for ref_key in reference_keys: - if ref_key not in attrs.keys() and ref_key != DEFAULT_ATTR_KEYS.NODE_ID: - raise ValueError( - f"Attribute '{ref_key}' not found in attrs: '{attrs.keys()}'\nRequested keys: '{reference_keys}'" - ) + missing_keys = set(reference_keys) - set(attrs.keys()) + missing_keys = missing_keys - { + DEFAULT_ATTR_KEYS.NODE_ID, + DEFAULT_ATTR_KEYS.EDGE_ID, + DEFAULT_ATTR_KEYS.EDGE_SOURCE, + DEFAULT_ATTR_KEYS.EDGE_TARGET, + } + + if missing_keys: + raise ValueError( + f"{mode} attribute keys not found in attrs: '{missing_keys}'\nRequested keys: '{reference_keys}'" + ) @abc.abstractmethod def add_node( @@ -626,15 +633,27 @@ def edge_attrs( """ @abc.abstractmethod - def node_attr_keys(self) -> list[str]: + def node_attr_keys(self, return_ids: bool = False) -> list[str]: """ Get the keys of the attributes of the nodes. + + Parameters + ---------- + return_ids : bool, default False + Whether to include NODE_ID in the returned keys. Defaults to False. + If True, NODE_ID will be included in the list. """ @abc.abstractmethod - def edge_attr_keys(self) -> list[str]: + def edge_attr_keys(self, return_ids: bool = False) -> list[str]: """ Get the keys of the attributes of the edges. + + Parameters + ---------- + return_ids : bool, optional + Whether to include EDGE_ID, EDGE_SOURCE, and EDGE_TARGET in the returned keys. + Defaults to False. If True, these ID fields will be included in the list. """ @overload @@ -1169,11 +1188,10 @@ def from_other(cls: type[T], other: "BaseGraph", **kwargs) -> T: graph = cls(**kwargs) graph.update_metadata(**other.metadata()) - for col in node_attrs.columns: - if col != DEFAULT_ATTR_KEYS.T: - # Use the dtype from the source DataFrame - dtype = node_attrs[col].dtype - graph.add_node_attr_key(col, dtype) + current_node_attr_schemas = graph._node_attr_schemas() + for k, v in other._node_attr_schemas().items(): + if k not in current_node_attr_schemas: + graph.add_node_attr_key(k, v.dtype, v.default_value) if graph.supports_custom_indices(): new_node_ids = graph.bulk_add_nodes( @@ -1195,11 +1213,11 @@ def from_other(cls: type[T], other: "BaseGraph", **kwargs) -> T: edge_attrs = other.edge_attrs() edge_attrs = edge_attrs.drop(DEFAULT_ATTR_KEYS.EDGE_ID) - for col in edge_attrs.columns: - if col not in [DEFAULT_ATTR_KEYS.EDGE_SOURCE, DEFAULT_ATTR_KEYS.EDGE_TARGET]: - # Use the dtype from the source DataFrame - dtype = edge_attrs[col].dtype - graph.add_edge_attr_key(col, dtype) + current_edge_attr_schemas = graph._edge_attr_schemas() + for k, v in other._edge_attr_schemas().items(): + if k not in current_edge_attr_schemas: + print(f"Adding edge attribute key: {k} with dtype: {v.dtype} and default value: {v.default_value}") + graph.add_edge_attr_key(k, v.dtype, v.default_value) edge_attrs = edge_attrs.with_columns( edge_attrs[col].map_elements(node_map.get, return_dtype=pl.Int64).alias(col) @@ -1930,6 +1948,18 @@ def __getitem__(self, node_id: int) -> "NodeInterface": raise ValueError(f"graph index must be a integer, found '{node_id}' of type {type(node_id)}") return NodeInterface(self, node_id) + @abc.abstractmethod + def _node_attr_schemas(self) -> dict[str, AttrSchema]: + """ + Get the attribute schemas for the nodes. + """ + + @abc.abstractmethod + def _edge_attr_schemas(self) -> dict[str, AttrSchema]: + """ + Get the attribute schemas for the edges. + """ + class NodeInterface: """ diff --git a/src/tracksdata/graph/_graph_view.py b/src/tracksdata/graph/_graph_view.py index 3e35978f..b9f82ead 100644 --- a/src/tracksdata/graph/_graph_view.py +++ b/src/tracksdata/graph/_graph_view.py @@ -113,9 +113,29 @@ def __init__( self._node_attr_keys = node_attr_keys self._edge_attr_keys = edge_attr_keys + # add default keys to the node and edge attr keys + if self._node_attr_keys is not None: + self._node_attr_keys = [DEFAULT_ATTR_KEYS.NODE_ID, DEFAULT_ATTR_KEYS.T, *self._node_attr_keys] + self._node_attr_keys = list(dict.fromkeys(self._node_attr_keys)) + + if self._edge_attr_keys is not None: + self._edge_attr_keys = [ + DEFAULT_ATTR_KEYS.EDGE_ID, + DEFAULT_ATTR_KEYS.EDGE_SOURCE, + DEFAULT_ATTR_KEYS.EDGE_TARGET, + *self._edge_attr_keys, + ] + self._edge_attr_keys = list(dict.fromkeys(self._edge_attr_keys)) + # use parent graph overlaps self._overlaps = None + def _node_attr_schemas(self) -> dict[str, AttrSchema]: + return self._root._node_attr_schemas() + + def _edge_attr_schemas(self) -> dict[str, AttrSchema]: + return self._root._edge_attr_schemas() + def supports_custom_indices(self) -> bool: return self._root.supports_custom_indices() @@ -230,11 +250,48 @@ def filter( include_sources=include_sources, ) - def node_attr_keys(self) -> list[str]: - return self._root.node_attr_keys() if self._node_attr_keys is None else self._node_attr_keys + def node_attr_keys(self, return_ids: bool = False) -> list[str]: + """ + Get the keys of the attributes of the nodes. - def edge_attr_keys(self) -> list[str]: - return self._root.edge_attr_keys() if self._edge_attr_keys is None else self._edge_attr_keys + Parameters + ---------- + return_ids : bool, optional + Whether to include NODE_ID in the returned keys. Defaults to False. + If True, NODE_ID will be included in the list. + """ + if self._node_attr_keys is None: + return self._root.node_attr_keys(return_ids=return_ids) + else: + keys = self._node_attr_keys.copy() + if not return_ids: + try: + keys.remove(DEFAULT_ATTR_KEYS.NODE_ID) + except ValueError: + pass + return keys + + def edge_attr_keys(self, return_ids: bool = False) -> list[str]: + """ + Get the keys of the attributes of the edges. + + Parameters + ---------- + return_ids : bool, optional + Whether to include EDGE_ID, EDGE_SOURCE, and EDGE_TARGET in the returned keys. + Defaults to False. If True, these ID fields will be included in the list. + """ + if self._edge_attr_keys is None: + return self._root.edge_attr_keys(return_ids=return_ids) + else: + keys = self._edge_attr_keys.copy() + if not return_ids: + for k in [DEFAULT_ATTR_KEYS.EDGE_ID, DEFAULT_ATTR_KEYS.EDGE_SOURCE, DEFAULT_ATTR_KEYS.EDGE_TARGET]: + try: + keys.remove(k) + except ValueError: + pass + return keys def add_node_attr_key( self, @@ -258,7 +315,7 @@ def add_node_attr_key( if not self._is_root_rx_graph: if self.sync: # Get the schema from root to get the actual default value used - schema = self._root._node_attr_schemas[key] + schema = self._root._node_attr_schemas()[key] # Apply to local rx_graph rx_graph = self.rx_graph for node_id in rx_graph.node_indices(): @@ -300,7 +357,7 @@ def add_edge_attr_key( if not self._is_root_rx_graph: if self.sync: # Get the schema from root to get the actual default value used - schema = self._root._edge_attr_schemas[key] + schema = self._root._edge_attr_schemas()[key] # Apply to local rx_graph for _, _, edge_attr in self.rx_graph.weighted_edge_list(): edge_attr[key] = schema.default_value diff --git a/src/tracksdata/graph/_rustworkx_graph.py b/src/tracksdata/graph/_rustworkx_graph.py index ce009c3f..0463d8ef 100644 --- a/src/tracksdata/graph/_rustworkx_graph.py +++ b/src/tracksdata/graph/_rustworkx_graph.py @@ -52,14 +52,28 @@ def _pop_time_eq( return out_attrs, time +def _maybe_fill_null(s: pl.Series, schema: AttrSchema) -> pl.Series: + if s.has_nulls() and schema.default_value is not None: + if isinstance(schema.default_value, np.ndarray): + value = schema.default_value.tolist() + elif schema.dtype == pl.Object: + value = pl.lit(schema.default_value, allow_object=True) + else: + value = schema.default_value + s = s.fill_null(value) + return s + + def _create_filter_func( attr_comps: Sequence[AttrComparison], + schema: dict[str, AttrSchema], ) -> Callable[[dict[str, Any]], bool]: LOG.info(f"Creating filter function for {attr_comps}") def _filter(attrs: dict[str, Any]) -> bool: for attr_op in attr_comps: - if not attr_op.op(attrs[str(attr_op.column)], attr_op.other): + value = attrs.get(attr_op.column, schema[attr_op.column].default_value) + if not attr_op.op(value, attr_op.other): return False return True @@ -138,7 +152,7 @@ def node_ids(self) -> list[int]: def _edge_attrs(self) -> pl.DataFrame: node_ids = self._current_node_ids() - _filter_func = _create_filter_func(self._edge_attr_comps) + _filter_func = _create_filter_func(self._edge_attr_comps, self._graph._edge_attr_schemas()) neigh_funcs = [self._graph.rx_graph.out_edges] if self._include_sources: @@ -167,12 +181,22 @@ def _edge_attrs(self) -> pl.DataFrame: sources.append(src) targets.append(tgt) for k in data.keys(): - data[k].append(attr[k]) + data[k].append(attr.get(k, None)) + + for k in data.keys(): + schema = self._graph._edge_attr_schemas()[k] + s = pl.Series(name=k, values=data[k], dtype=schema.dtype) + s = _maybe_fill_null(s, schema) + data[k] = s - df = pl.DataFrame(data).with_columns( - pl.Series(sources, dtype=pl.Int64).alias(DEFAULT_ATTR_KEYS.EDGE_SOURCE), - pl.Series(targets, dtype=pl.Int64).alias(DEFAULT_ATTR_KEYS.EDGE_TARGET), + data[DEFAULT_ATTR_KEYS.EDGE_SOURCE] = pl.Series( + name=DEFAULT_ATTR_KEYS.EDGE_SOURCE, values=sources, dtype=pl.Int64 ) + data[DEFAULT_ATTR_KEYS.EDGE_TARGET] = pl.Series( + name=DEFAULT_ATTR_KEYS.EDGE_TARGET, values=targets, dtype=pl.Int64 + ) + + df = pl.DataFrame(data) return df @cache_method @@ -229,16 +253,11 @@ def subgraph( rx_graph, node_map = self._graph._rx_subgraph_with_nodemap(node_ids) if self._edge_attr_comps: - _filter_func = _create_filter_func(self._edge_attr_comps) + _filter_func = _create_filter_func(self._edge_attr_comps, self._graph._edge_attr_schemas()) for src, tgt, attr in rx_graph.weighted_edge_list(): if not _filter_func(attr): rx_graph.remove_edge(src, tgt) - # Ensure the time key is in the node attributes - if node_attr_keys is not None: - node_attr_keys = [DEFAULT_ATTR_KEYS.T, *node_attr_keys] - node_attr_keys = list(dict.fromkeys(node_attr_keys)) - graph_view = GraphView( rx_graph, node_map_to_root=dict(node_map.items()), @@ -318,20 +337,26 @@ def __init__(self, rx_graph: rx.PyDiGraph | None = None) -> None: super().__init__() self._time_to_nodes: dict[int, list[int]] = {} - self._node_attr_schemas: dict[str, AttrSchema] = {} - self._edge_attr_schemas: dict[str, AttrSchema] = {} + self.__node_attr_schemas: dict[str, AttrSchema] = {} + self.__edge_attr_schemas: dict[str, AttrSchema] = {} self._overlaps: list[list[int, 2]] = [] # Add default node attributes with inferred schemas - self._node_attr_schemas[DEFAULT_ATTR_KEYS.NODE_ID] = AttrSchema( - key=DEFAULT_ATTR_KEYS.NODE_ID, - dtype=pl.Int64, - ) - self._node_attr_schemas[DEFAULT_ATTR_KEYS.T] = AttrSchema( + self.__node_attr_schemas[DEFAULT_ATTR_KEYS.T] = AttrSchema( key=DEFAULT_ATTR_KEYS.T, + dtype=pl.Int32, + ) + self.__node_attr_schemas[DEFAULT_ATTR_KEYS.NODE_ID] = AttrSchema( + key=DEFAULT_ATTR_KEYS.NODE_ID, dtype=pl.Int64, ) + for key in [DEFAULT_ATTR_KEYS.EDGE_ID, DEFAULT_ATTR_KEYS.EDGE_SOURCE, DEFAULT_ATTR_KEYS.EDGE_TARGET]: + self.__edge_attr_schemas[key] = AttrSchema( + key=key, + dtype=pl.Int32 if key == DEFAULT_ATTR_KEYS.EDGE_ID else pl.Int64, + ) + if rx_graph is None: self._graph = rx.PyDiGraph(attrs={}) else: @@ -376,7 +401,7 @@ def __init__(self, rx_graph: rx.PyDiGraph | None = None) -> None: except (ValueError, TypeError): # If polars can't infer dtype (e.g., for complex objects), use Object dtype = pl.Object - self._node_attr_schemas[key] = AttrSchema(key=key, dtype=dtype) + self.__node_attr_schemas[key] = AttrSchema(key=key, dtype=dtype) # Process edges: set edge IDs and infer schemas edge_idx_map = self._graph.edge_index_map() @@ -398,7 +423,13 @@ def __init__(self, rx_graph: rx.PyDiGraph | None = None) -> None: except (ValueError, TypeError): # If polars can't infer dtype (e.g., for complex objects), use Object dtype = pl.Object - self._edge_attr_schemas[key] = AttrSchema(key=key, dtype=dtype) + self.__edge_attr_schemas[key] = AttrSchema(key=key, dtype=dtype) + + def _node_attr_schemas(self) -> dict[str, AttrSchema]: + return self.__node_attr_schemas + + def _edge_attr_schemas(self) -> dict[str, AttrSchema]: + return self.__edge_attr_schemas @property def rx_graph(self) -> rx.PyDiGraph: @@ -870,7 +901,7 @@ def _filter_nodes_by_attrs( # subgraph of selected nodes rx_graph, node_map = rx_graph.subgraph_with_nodemap(selected_nodes) - _filter_func = _create_filter_func(attrs) + _filter_func = _create_filter_func(attrs, self._node_attr_schemas()) if node_map is None: return list(rx_graph.filter_nodes(_filter_func)) @@ -895,17 +926,36 @@ def time_points(self) -> list[int]: """ return list(self._time_to_nodes.keys()) - def node_attr_keys(self) -> list[str]: + def node_attr_keys(self, return_ids: bool = False) -> list[str]: """ Get the keys of the attributes of the nodes. + + Parameters + ---------- + return_ids : bool, optional + Whether to include NODE_ID in the returned keys. Defaults to False. + If True, NODE_ID will be included in the list. """ - return list(self._node_attr_schemas.keys()) + keys = list(self._node_attr_schemas().keys()) + if not return_ids and DEFAULT_ATTR_KEYS.NODE_ID in keys: + keys.remove(DEFAULT_ATTR_KEYS.NODE_ID) + return keys - def edge_attr_keys(self) -> list[str]: + def edge_attr_keys(self, return_ids: bool = False) -> list[str]: """ Get the keys of the attributes of the edges. + + Parameters + ---------- + return_ids : bool, optional + Whether to include EDGE_ID, EDGE_SOURCE, and EDGE_TARGET in the returned keys. + Defaults to False. If True, these ID fields will be included in the list. """ - return list(self._edge_attr_schemas.keys()) + keys = list(self.__edge_attr_schemas.keys()) + if not return_ids: + for id_key in [DEFAULT_ATTR_KEYS.EDGE_ID, DEFAULT_ATTR_KEYS.EDGE_SOURCE, DEFAULT_ATTR_KEYS.EDGE_TARGET]: + keys.remove(id_key) + return keys def add_node_attr_key( self, @@ -928,15 +978,10 @@ def add_node_attr_key( If None, will be inferred from dtype. """ # Process arguments and create validated schema - schema = process_attr_key_args(key_or_schema, dtype, default_value, self._node_attr_schemas) + schema = process_attr_key_args(key_or_schema, dtype, default_value, self.__node_attr_schemas) # Store schema - self._node_attr_schemas[schema.key] = schema - - # Add to all existing nodes - rx_graph = self.rx_graph - for node_id in rx_graph.node_indices(): - rx_graph[node_id][schema.key] = schema.default_value + self.__node_attr_schemas[schema.key] = schema def remove_node_attr_key(self, key: str) -> None: """ @@ -948,7 +993,7 @@ def remove_node_attr_key(self, key: str) -> None: if key in (DEFAULT_ATTR_KEYS.NODE_ID, DEFAULT_ATTR_KEYS.T): raise ValueError(f"Cannot remove required node attribute key {key}") - del self._node_attr_schemas[key] + del self.__node_attr_schemas[key] for node_attr in self.rx_graph.nodes(): node_attr.pop(key, None) @@ -973,14 +1018,10 @@ def add_edge_attr_key( If None, will be inferred from dtype. """ # Process arguments and create validated schema - schema = process_attr_key_args(key_or_schema, dtype, default_value, self._edge_attr_schemas) + schema = process_attr_key_args(key_or_schema, dtype, default_value, self.__edge_attr_schemas) # Store schema - self._edge_attr_schemas[schema.key] = schema - - # Add to all existing edges - for edge_attr in self.rx_graph.edges(): - edge_attr[schema.key] = schema.default_value + self.__edge_attr_schemas[schema.key] = schema def remove_edge_attr_key(self, key: str) -> None: """ @@ -989,7 +1030,7 @@ def remove_edge_attr_key(self, key: str) -> None: if key not in self.edge_attr_keys(): raise ValueError(f"Edge attribute key {key} does not exist") - del self._edge_attr_schemas[key] + del self.__edge_attr_schemas[key] for edge_attr in self.rx_graph.edges(): edge_attr.pop(key, None) @@ -1012,7 +1053,7 @@ def _node_attrs_from_node_ids( The attribute keys to get. If None, all the attributes of the first node are used. unpack : bool - Whether to unpack array attributesinto multiple scalar attributes. + Whether to unpack array attributes into multiple scalar attributes. Returns ------- @@ -1025,13 +1066,16 @@ def _node_attrs_from_node_ids( node_ids = list(rx_graph.node_indices()) if attr_keys is None: - attr_keys = self.node_attr_keys() + attr_keys = self.node_attr_keys(return_ids=True) if isinstance(attr_keys, str): attr_keys = [attr_keys] + node_attr_schemas = self._node_attr_schemas() + pl_schema = {k: node_attr_schemas[k].dtype for k in attr_keys} + if len(node_ids) == 0: - return pl.DataFrame({key: [] for key in attr_keys}) + return pl.DataFrame({key: [] for key in attr_keys}, schema=pl_schema) # making them unique attr_keys = list(dict.fromkeys(attr_keys)) @@ -1040,17 +1084,24 @@ def _node_attrs_from_node_ids( columns = {key: [] for key in attr_keys} if DEFAULT_ATTR_KEYS.NODE_ID in attr_keys: - columns[DEFAULT_ATTR_KEYS.NODE_ID] = np.asarray(node_ids, dtype=int) + columns[DEFAULT_ATTR_KEYS.NODE_ID] = pl.Series( + name=DEFAULT_ATTR_KEYS.NODE_ID, + values=node_ids, + dtype=pl.Int64, + ) attr_keys.remove(DEFAULT_ATTR_KEYS.NODE_ID) # Build columns in a vectorized way for node_id in node_ids: node_data = rx_graph[node_id] for key in attr_keys: - columns[key].append(node_data[key]) + columns[key].append(node_data.get(key)) for key in attr_keys: - columns[key] = np.asarray(columns[key]) + schema = node_attr_schemas[key] + s = pl.Series(name=key, values=columns[key], dtype=schema.dtype) + s = _maybe_fill_null(s, schema) + columns[key] = s # Create DataFrame and set node_id as index in one shot df = pl.DataFrame(columns) @@ -1115,12 +1166,16 @@ def edge_attrs( for row in data: for key in attr_keys: - columns[key].append(row[key]) + columns[key].append(row.get(key)) columns[DEFAULT_ATTR_KEYS.EDGE_SOURCE] = source columns[DEFAULT_ATTR_KEYS.EDGE_TARGET] = target - columns = {k: np.asarray(v) for k, v in columns.items()} + for key in attr_keys: + schema = self._edge_attr_schemas()[key] + s = pl.Series(name=key, values=columns[key], dtype=schema.dtype) + s = _maybe_fill_null(s, schema) + columns[key] = s df = pl.DataFrame(columns) if unpack: @@ -1235,7 +1290,7 @@ def assign_tracklet_ids( else: if output_key not in self.node_attr_keys(): previous_id_df = None - self.add_node_attr_key(output_key, -1) + self.add_node_attr_key(output_key, dtype=pl.Int64, default_value=-1) if tracklet_id_offset is None: tracklet_id_offset = 1 elif reset: diff --git a/src/tracksdata/graph/_sql_graph.py b/src/tracksdata/graph/_sql_graph.py index bcadf01e..985cbdc9 100644 --- a/src/tracksdata/graph/_sql_graph.py +++ b/src/tracksdata/graph/_sql_graph.py @@ -297,14 +297,14 @@ def subgraph( # Give the node_attr_keys as a list, since otherwise the SQL results return the # Ensure the time key is in the node attributes if node_attr_keys is not None: - node_attr_keys = [DEFAULT_ATTR_KEYS.T, *node_attr_keys] + node_attr_keys = [DEFAULT_ATTR_KEYS.T, DEFAULT_ATTR_KEYS.NODE_ID, *node_attr_keys] else: - node_attr_keys = [DEFAULT_ATTR_KEYS.T, *self._graph.node_attr_keys()] + node_attr_keys = [DEFAULT_ATTR_KEYS.T, *self._graph.node_attr_keys(return_ids=True)] node_attr_keys = list(dict.fromkeys(node_attr_keys)) if edge_attr_keys is None: - edge_attr_keys = self._graph.edge_attr_keys().copy() + edge_attr_keys = self._graph.edge_attr_keys(return_ids=True).copy() node_query = self._query_from_attr_keys( query=self._node_query, @@ -469,8 +469,8 @@ def __init__( # Create unique classes for this instance self._define_schema(overwrite=overwrite) - self._node_attr_schemas: dict[str, AttrSchema] = {} - self._edge_attr_schemas: dict[str, AttrSchema] = {} + self.__node_attr_schemas: dict[str, AttrSchema] = {} + self.__edge_attr_schemas: dict[str, AttrSchema] = {} if overwrite: self.Base.metadata.drop_all(self._engine) @@ -529,15 +529,15 @@ class Node(Base): class Edge(Base): __tablename__ = "Edge" - edge_id = sa.Column(sa.Integer, primary_key=True, unique=True, autoincrement=True) - source_id = sa.Column(sa.BigInteger, sa.ForeignKey(f"{node_tb_name}.node_id"), index=True) - target_id = sa.Column(sa.BigInteger, sa.ForeignKey(f"{node_tb_name}.node_id"), index=True) + edge_id = sa.Column(sa.Integer, sa.Identity(always=True), primary_key=True, unique=True) + source_id = sa.Column(sa.BigInteger, sa.ForeignKey(f"{node_tb_name}.node_id"), index=True, nullable=False) + target_id = sa.Column(sa.BigInteger, sa.ForeignKey(f"{node_tb_name}.node_id"), index=True, nullable=False) class Overlap(Base): __tablename__ = "Overlap" - overlap_id = sa.Column(sa.Integer, primary_key=True, unique=True, autoincrement=True) - source_id = sa.Column(sa.BigInteger, sa.ForeignKey(f"{node_tb_name}.node_id"), index=True) - target_id = sa.Column(sa.BigInteger, sa.ForeignKey(f"{node_tb_name}.node_id"), index=True) + overlap_id = sa.Column(sa.Integer, sa.Identity(always=True), primary_key=True, unique=True) + source_id = sa.Column(sa.BigInteger, sa.ForeignKey(f"{node_tb_name}.node_id"), index=True, nullable=False) + target_id = sa.Column(sa.BigInteger, sa.ForeignKey(f"{node_tb_name}.node_id"), index=True, nullable=False) class Metadata(Base): __tablename__ = "Metadata" @@ -558,12 +558,12 @@ def _init_schemas_from_tables(self) -> None: """ # Initialize node schemas from Node table columns for column_name in self.Node.__table__.columns.keys(): - if column_name not in self._node_attr_schemas: + if column_name not in self.__node_attr_schemas: column = self.Node.__table__.columns[column_name] # Infer polars dtype from SQLAlchemy type pl_dtype = sqlalchemy_type_to_polars_dtype(column.type) # AttrSchema.__post_init__ will infer the default_value - self._node_attr_schemas[column_name] = AttrSchema( + self.__node_attr_schemas[column_name] = AttrSchema( key=column_name, dtype=pl_dtype, ) @@ -571,14 +571,12 @@ def _init_schemas_from_tables(self) -> None: # Initialize edge schemas from Edge table columns for column_name in self.Edge.__table__.columns.keys(): # Skip internal edge columns - if column_name in [DEFAULT_ATTR_KEYS.EDGE_ID, DEFAULT_ATTR_KEYS.EDGE_SOURCE, DEFAULT_ATTR_KEYS.EDGE_TARGET]: - continue - if column_name not in self._edge_attr_schemas: + if column_name not in self.__edge_attr_schemas: column = self.Edge.__table__.columns[column_name] # Infer polars dtype from SQLAlchemy type pl_dtype = sqlalchemy_type_to_polars_dtype(column.type) # AttrSchema.__post_init__ will infer the default_value - self._edge_attr_schemas[column_name] = AttrSchema( + self.__edge_attr_schemas[column_name] = AttrSchema( key=column_name, dtype=pl_dtype, ) @@ -591,19 +589,23 @@ def _restore_pickled_column_types(self, table: sa.Table) -> None: def _polars_schema_override(self, table_class: type[DeclarativeBase]) -> SchemaDict: # Get the appropriate schema dict based on table class if table_class.__tablename__ == self.Node.__tablename__: - schemas = self._node_attr_schemas + schemas = self._node_attr_schemas() else: - schemas = self._edge_attr_schemas + schemas = self._edge_attr_schemas() # Return schema overrides for special types that need explicit casting - return {key: schema.dtype for key, schema in schemas.items() if schema.dtype == pl.Boolean} + return { + key: schema.dtype + for key, schema in schemas.items() + if not (schema.dtype == pl.Object or isinstance(schema.dtype, pl.Array | pl.List)) + } def _cast_array_columns(self, table_class: type[DeclarativeBase], df: pl.DataFrame) -> pl.DataFrame: # Get the appropriate schema dict based on table class if table_class.__tablename__ == self.Node.__tablename__: - schemas = self._node_attr_schemas + schemas = self._node_attr_schemas() else: - schemas = self._edge_attr_schemas + schemas = self._edge_attr_schemas() # Cast array columns (stored as blobs in database) df = df.with_columns( @@ -1332,14 +1334,42 @@ def edge_attrs( return edges_df - def node_attr_keys(self) -> list[str]: + def _node_attr_schemas(self) -> dict[str, AttrSchema]: + return self.__node_attr_schemas + + def _edge_attr_schemas(self) -> dict[str, AttrSchema]: + return self.__edge_attr_schemas + + def node_attr_keys(self, return_ids: bool = False) -> list[str]: + """ + Get the keys of the attributes of the nodes. + + Parameters + ---------- + return_ids : bool, optional + Whether to include NODE_ID in the returned keys. Defaults to False. + If True, NODE_ID will be included in the list. + """ keys = list(self.Node.__table__.columns.keys()) + if not return_ids and DEFAULT_ATTR_KEYS.NODE_ID in keys: + keys.remove(DEFAULT_ATTR_KEYS.NODE_ID) return keys - def edge_attr_keys(self) -> list[str]: + def edge_attr_keys(self, return_ids: bool = False) -> list[str]: + """ + Get the keys of the attributes of the edges. + + Parameters + ---------- + return_ids : bool, optional + Whether to include EDGE_ID, EDGE_SOURCE, and EDGE_TARGET in the returned keys. + Defaults to False. If True, these ID fields will be included in the list. + """ keys = list(self.Edge.__table__.columns.keys()) - for k in [DEFAULT_ATTR_KEYS.EDGE_ID, DEFAULT_ATTR_KEYS.EDGE_SOURCE, DEFAULT_ATTR_KEYS.EDGE_TARGET]: - keys.remove(k) + if not return_ids: + for id_key in [DEFAULT_ATTR_KEYS.EDGE_ID, DEFAULT_ATTR_KEYS.EDGE_SOURCE, DEFAULT_ATTR_KEYS.EDGE_TARGET]: + if id_key in keys: + keys.remove(id_key) return keys def _resolve_attr_keys( @@ -1588,10 +1618,10 @@ def add_node_attr_key( default_value: Any = None, ) -> None: # Process arguments and create validated schema - schema = process_attr_key_args(key_or_schema, dtype, default_value, self._node_attr_schemas) + schema = process_attr_key_args(key_or_schema, dtype, default_value, self.__node_attr_schemas) # Store schema - self._node_attr_schemas[schema.key] = schema + self.__node_attr_schemas[schema.key] = schema # Add column to database self._add_new_column(self.Node, schema) @@ -1604,7 +1634,7 @@ def remove_node_attr_key(self, key: str) -> None: raise ValueError(f"Cannot remove required node attribute key {key}") self._drop_column(self.Node, key) - self._node_attr_schemas.pop(key, None) + self.__node_attr_schemas.pop(key, None) def add_edge_attr_key( self, @@ -1613,10 +1643,10 @@ def add_edge_attr_key( default_value: Any = None, ) -> None: # Process arguments and create validated schema - schema = process_attr_key_args(key_or_schema, dtype, default_value, self._edge_attr_schemas) + schema = process_attr_key_args(key_or_schema, dtype, default_value, self.__edge_attr_schemas) # Store schema - self._edge_attr_schemas[schema.key] = schema + self.__edge_attr_schemas[schema.key] = schema # Add column to database self._add_new_column(self.Edge, schema) @@ -1626,7 +1656,7 @@ def remove_edge_attr_key(self, key: str) -> None: raise ValueError(f"Edge attribute key {key} does not exist") self._drop_column(self.Edge, key) - self._edge_attr_schemas.pop(key, None) + self.__edge_attr_schemas.pop(key, None) def num_edges(self) -> int: with Session(self._engine) as session: diff --git a/src/tracksdata/graph/_test/test_graph_backends.py b/src/tracksdata/graph/_test/test_graph_backends.py index fc52fe59..4c3714e4 100644 --- a/src/tracksdata/graph/_test/test_graph_backends.py +++ b/src/tracksdata/graph/_test/test_graph_backends.py @@ -374,8 +374,8 @@ def test_subgraph_with_node_and_edge_attr_filters(graph_backend: BaseGraph) -> N def test_subgraph_with_node_ids_and_filters(graph_backend: BaseGraph) -> None: """Test subgraph with node IDs and filters.""" - graph_backend.add_node_attr_key("x", pl.Float64) - graph_backend.add_edge_attr_key("weight", dtype=pl.Float64) + graph_backend.add_node_attr_key("x", pl.Float32) + graph_backend.add_edge_attr_key("weight", dtype=pl.Float32) node0 = graph_backend.add_node({"t": 0, "x": 1.0}) node1 = graph_backend.add_node({"t": 1, "x": 2.0}) @@ -406,6 +406,12 @@ def test_subgraph_with_node_ids_and_filters(graph_backend: BaseGraph) -> None: subgraph_edge_ids = subgraph.edge_ids() assert len(subgraph_edge_ids) == 0 + assert subgraph._node_attr_schemas() == graph_backend._node_attr_schemas() + assert subgraph._edge_attr_schemas() == graph_backend._edge_attr_schemas() + + assert dict(subgraph.node_attrs().schema) == dict(graph_backend.node_attrs().schema) + assert dict(subgraph.edge_attrs().schema) == dict(graph_backend.edge_attrs().schema) + @pytest.mark.parametrize( "dtype, value", @@ -1356,7 +1362,7 @@ def test_from_other_with_edges( graph_backend.update_metadata(special_key="special_value") graph_backend.add_node_attr_key("x", dtype=pl.Float64) - graph_backend.add_edge_attr_key("weight", dtype=pl.Float64) + graph_backend.add_edge_attr_key("weight", dtype=pl.Float64, default_value=-1) graph_backend.add_edge_attr_key("type", dtype=pl.String, default_value="forward") node1 = graph_backend.add_node({"t": 0, "x": 1.0}) @@ -1382,6 +1388,12 @@ def test_from_other_with_edges( assert new_graph.metadata() == graph_backend.metadata() + assert new_graph._node_attr_schemas() == graph_backend._node_attr_schemas() + assert new_graph._edge_attr_schemas() == graph_backend._edge_attr_schemas() + + assert dict(new_graph.node_attrs().schema) == dict(graph_backend.node_attrs().schema) + assert dict(new_graph.edge_attrs().schema) == dict(graph_backend.edge_attrs().schema) + # Verify edge attributes are copied correctly source_edges = graph_backend.edge_attrs(attr_keys=["weight", "type"]) new_edges = new_graph.edge_attrs(attr_keys=["weight", "type"]) @@ -2288,7 +2300,7 @@ def _fill_mock_geff_graph(graph_backend: BaseGraph) -> None: ) graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, pl.Object) graph_backend.add_node_attr_key(DEFAULT_ATTR_KEYS.TRACKLET_ID, pl.Int64) - graph_backend.add_node_attr_key("ndfeature", pl.Object) + graph_backend.add_node_attr_key("ndfeature", pl.Array(pl.Float64, (3, 1))) graph_backend.add_edge_attr_key("weight", pl.Float64) diff --git a/src/tracksdata/graph/_test/test_index_graph.py b/src/tracksdata/graph/_test/test_index_graph.py index 4ac30c6e..4161799a 100644 --- a/src/tracksdata/graph/_test/test_index_graph.py +++ b/src/tracksdata/graph/_test/test_index_graph.py @@ -23,7 +23,8 @@ def test_index_rx_graph_with_mapping() -> None: ) assert graph.node_ids() == [1, 5_000, 3] - assert set(graph.node_attr_keys()) == {DEFAULT_ATTR_KEYS.NODE_ID, DEFAULT_ATTR_KEYS.T, "a"} + assert set(graph.node_attr_keys()) == {DEFAULT_ATTR_KEYS.T, "a"} + assert set(graph.node_attr_keys(return_ids=True)) == {DEFAULT_ATTR_KEYS.T, "a", DEFAULT_ATTR_KEYS.NODE_ID} def test_duplicate_index_map() -> None: diff --git a/src/tracksdata/graph/filters/_indexed_filter.py b/src/tracksdata/graph/filters/_indexed_filter.py index 9fab8e38..7c43cd5a 100644 --- a/src/tracksdata/graph/filters/_indexed_filter.py +++ b/src/tracksdata/graph/filters/_indexed_filter.py @@ -59,7 +59,7 @@ def subgraph( rx_graph, node_map = self._graph._rx_subgraph_with_nodemap(node_ids) if self._edge_attr_comps: - _filter_func = _create_filter_func(self._edge_attr_comps) + _filter_func = _create_filter_func(self._edge_attr_comps, self._graph._edge_attr_schemas()) for src, tgt, attr in rx_graph.weighted_edge_list(): if not _filter_func(attr): rx_graph.remove_edge(src, tgt) @@ -68,11 +68,6 @@ def subgraph( if hasattr(self._graph, "_root"): root = self._graph._root - # Ensure the time key is in the node attributes - if node_attr_keys is not None: - node_attr_keys = [DEFAULT_ATTR_KEYS.T, *node_attr_keys] - node_attr_keys = list(dict.fromkeys(node_attr_keys)) - graph_view = GraphView( rx_graph, node_map_to_root=dict(node_map.items()), diff --git a/src/tracksdata/graph/filters/_test/test_spatial_filter.py b/src/tracksdata/graph/filters/_test/test_spatial_filter.py index 41b08a07..d17ba94b 100644 --- a/src/tracksdata/graph/filters/_test/test_spatial_filter.py +++ b/src/tracksdata/graph/filters/_test/test_spatial_filter.py @@ -284,7 +284,7 @@ def test_bbox_spatial_filter_dimensions() -> None: def test_bbox_spatial_filter_error_handling() -> None: """Test error handling for mismatched min/max attribute lengths.""" graph = RustWorkXGraph() - graph.add_node_attr_key("bbox", dtype=pl.Array(pl.Int64, 4)) + graph.add_node_attr_key("bbox", dtype=pl.Array(pl.Int64, 5)) graph.add_node({"t": 0, "bbox": [10, 20, 15, 25, 14]}) # Test mismatched min/max attributes length with pytest.raises(ValueError, match="Bounding box coordinates must have even number of dimensions"): @@ -309,6 +309,8 @@ def test_add_and_remove_node(graph_backend: BaseGraph) -> None: empty_region = spatial_filter[0:3, 6:9, 6:9].node_attrs() assert empty_region.is_empty() + print(graph) + new_node_id = graph.add_node({"t": 2, "bbox": np.asarray([7, 7, 8, 8])}) assert len(spatial_filter._node_rtree) == 3 diff --git a/src/tracksdata/nodes/_mask.py b/src/tracksdata/nodes/_mask.py index d6bd0fcc..e13d7e66 100644 --- a/src/tracksdata/nodes/_mask.py +++ b/src/tracksdata/nodes/_mask.py @@ -524,4 +524,4 @@ def _init_node_attrs(self, graph: "BaseGraph") -> None: Validate that the output key exists in the graph. """ if self.output_key not in graph.node_attr_keys(): - graph.add_node_attr_key(self.output_key, pl.Array(pl.Int64, 2 * len(self._image_shape))) + graph.add_node_attr_key(self.output_key, pl.Object) diff --git a/src/tracksdata/utils/_test/test_attr_key_dtype.py b/src/tracksdata/utils/_test/test_attr_key_dtype.py index b6bb37c6..836b5812 100644 --- a/src/tracksdata/utils/_test/test_attr_key_dtype.py +++ b/src/tracksdata/utils/_test/test_attr_key_dtype.py @@ -22,18 +22,19 @@ def test_add_node_attr_key_with_dtype_only(self): graph.add_node_attr_key("name", pl.String) # Verify schemas are stored - assert "count" in graph._node_attr_schemas - assert graph._node_attr_schemas["count"].dtype == pl.UInt32 - assert graph._node_attr_schemas["count"].default_value == 0 + schemas = graph._node_attr_schemas() + assert "count" in schemas + assert schemas["count"].dtype == pl.UInt32 + assert schemas["count"].default_value == 0 - assert graph._node_attr_schemas["score"].dtype == pl.Float64 - assert graph._node_attr_schemas["score"].default_value == -1.0 + assert schemas["score"].dtype == pl.Float64 + assert schemas["score"].default_value == -1.0 - assert graph._node_attr_schemas["flag"].dtype == pl.Boolean - assert graph._node_attr_schemas["flag"].default_value is False + assert schemas["flag"].dtype == pl.Boolean + assert schemas["flag"].default_value is False - assert graph._node_attr_schemas["name"].dtype == pl.String - assert graph._node_attr_schemas["name"].default_value == "" + assert schemas["name"].dtype == pl.String + assert schemas["name"].default_value == "" def test_add_node_attr_key_with_dtype_and_default(self): """Test adding node attribute with both dtype and default value.""" @@ -41,8 +42,9 @@ def test_add_node_attr_key_with_dtype_and_default(self): graph.add_node_attr_key("score", pl.Float64, default_value=0.0) - assert graph._node_attr_schemas["score"].dtype == pl.Float64 - assert graph._node_attr_schemas["score"].default_value == 0.0 + schemas = graph._node_attr_schemas() + assert schemas["score"].dtype == pl.Float64 + assert schemas["score"].default_value == 0.0 def test_add_node_attr_key_with_array_dtype(self): """Test adding node attribute with array dtype (zeros default).""" @@ -50,8 +52,9 @@ def test_add_node_attr_key_with_array_dtype(self): graph.add_node_attr_key("bbox", pl.Array(pl.Float64, 4)) - assert graph._node_attr_schemas["bbox"].dtype == pl.Array(pl.Float64, 4) - default = graph._node_attr_schemas["bbox"].default_value + schemas = graph._node_attr_schemas() + assert schemas["bbox"].dtype == pl.Array(pl.Float64, 4) + default = schemas["bbox"].default_value assert isinstance(default, np.ndarray) assert default.shape == (4,) assert default.dtype == np.float64 @@ -62,8 +65,9 @@ def test_add_node_attr_key_with_nd_array_dtype(self): graph = RustWorkXGraph() graph.add_node_attr_key("something", pl.Array(pl.Float64, (5, 3, 2))) - assert graph._node_attr_schemas["something"].dtype == pl.Array(pl.Float64, (5, 3, 2)) - default = graph._node_attr_schemas["something"].default_value + schemas = graph._node_attr_schemas() + assert schemas["something"].dtype == pl.Array(pl.Float64, (5, 3, 2)) + default = schemas["something"].default_value assert isinstance(default, np.ndarray) assert default.shape == (5, 3, 2) @@ -77,9 +81,10 @@ def test_add_node_attr_key_with_schema_object(self): schema = AttrSchema(key="intensity", dtype=pl.Float64) graph.add_node_attr_key(schema) - assert "intensity" in graph._node_attr_schemas - assert graph._node_attr_schemas["intensity"].dtype == pl.Float64 - assert graph._node_attr_schemas["intensity"].default_value == -1.0 + schemas = graph._node_attr_schemas() + assert "intensity" in schemas + assert schemas["intensity"].dtype == pl.Float64 + assert schemas["intensity"].default_value == -1.0 def test_add_node_attr_key_missing_dtype_raises(self): """Test that missing dtype raises TypeError.""" @@ -125,9 +130,10 @@ def test_add_edge_attr_key_with_dtype_only(self): graph.add_edge_attr_key("weight", pl.Float64) - assert "weight" in graph._edge_attr_schemas - assert graph._edge_attr_schemas["weight"].dtype == pl.Float64 - assert graph._edge_attr_schemas["weight"].default_value == -1.0 + schemas = graph._edge_attr_schemas() + assert "weight" in schemas + assert schemas["weight"].dtype == pl.Float64 + assert schemas["weight"].default_value == -1.0 def test_add_edge_attr_key_with_schema(self): """Test adding edge attribute using AttrSchema.""" @@ -136,8 +142,9 @@ def test_add_edge_attr_key_with_schema(self): schema = AttrSchema(key="distance", dtype=pl.Float64, default_value=0.0) graph.add_edge_attr_key(schema) - assert "distance" in graph._edge_attr_schemas - assert graph._edge_attr_schemas["distance"].default_value == 0.0 + schemas = graph._edge_attr_schemas() + assert "distance" in schemas + assert schemas["distance"].default_value == 0.0 def test_defaults_applied_to_existing_edges(self): """Test that defaults are applied to existing edges.""" @@ -163,7 +170,7 @@ def test_node_attr_keys_returns_keys(self): graph.add_node_attr_key("score", pl.Float64) graph.add_node_attr_key("count", pl.UInt32) - keys = graph.node_attr_keys() + keys = graph.node_attr_keys(return_ids=True) assert "node_id" in keys assert "t" in keys assert "score" in keys @@ -187,8 +194,9 @@ def test_signed_vs_unsigned_int_defaults(self): graph.add_node_attr_key("unsigned", pl.UInt32) graph.add_node_attr_key("signed", pl.Int32) - assert graph._node_attr_schemas["unsigned"].default_value == 0 - assert graph._node_attr_schemas["signed"].default_value == -1 + schemas = graph._node_attr_schemas() + assert schemas["unsigned"].default_value == 0 + assert schemas["signed"].default_value == -1 def test_schema_defensive_copy(self): """Test that passing AttrSchema creates a defensive copy to prevent mutation.""" @@ -201,7 +209,8 @@ def test_schema_defensive_copy(self): graph.add_node_attr_key(original_schema) # Verify the schema was stored - stored_schema = graph._node_attr_schemas["score"] + schemas = graph._node_attr_schemas() + stored_schema = schemas["score"] assert stored_schema.key == "score" assert stored_schema.dtype == pl.Float64 assert stored_schema.default_value == 1.0 @@ -210,7 +219,8 @@ def test_schema_defensive_copy(self): original_schema.default_value = 999.0 # Verify the stored schema wasn't affected (defensive copy worked) - assert graph._node_attr_schemas["score"].default_value == 1.0 + schemas = graph._node_attr_schemas() + assert schemas["score"].default_value == 1.0 assert stored_schema.default_value == 1.0 # Verify the original was changed