diff --git a/src/tracksdata/graph/_base_graph.py b/src/tracksdata/graph/_base_graph.py index b384bdbb..e323f6a6 100644 --- a/src/tracksdata/graph/_base_graph.py +++ b/src/tracksdata/graph/_base_graph.py @@ -1911,24 +1911,35 @@ def copy(self, **kwargs) -> "BaseGraph": """ return self.__class__.from_other(self, **kwargs) - def __getitem__(self, node_id: int) -> "NodeInterface": + @property + def nodes(self) -> "NodesAccessor": """ - Helper method to interact with a single node. + Access node attributes with dictionary-style syntax. - Parameters - ---------- - node_id : int - The id of the node to interact with. + Use bracket notation to get or set attributes for a specific node, + or call to_dict() to retrieve all attributes as a dictionary. Returns ------- - NodeInterface - A node interface for the given node id. + NodesAccessor + An accessor for node attributes. """ + return NodesAccessor(self) - if not isinstance(node_id, int): - raise ValueError(f"graph index must be a integer, found '{node_id}' of type {type(node_id)}") - return NodeInterface(self, node_id) + @property + def edges(self) -> "EdgesAccessor": + """ + Access edge attributes with dictionary-style syntax. + + Use bracket notation to get or set attributes for a specific edge, + or call to_dict() to retrieve all attributes as a dictionary. + + Returns + ------- + EdgesAccessor + An accessor for edge attributes. + """ + return EdgesAccessor(self) class NodeInterface: @@ -1978,3 +1989,141 @@ def edge_list(self) -> list[list[int, int]]: """ Get the edge list of the graph. """ + + +class NodesAccessor: + """ + Accessor class for node attributes with dictionary-style syntax. + + Parameters + ---------- + graph : BaseGraph + The graph to access nodes from. + """ + + def __init__(self, graph: BaseGraph): + self._graph = graph + + def __getitem__(self, node_id: int) -> NodeInterface: + """ + Access a specific node's attributes. + + Parameters + ---------- + node_id : int + The id of the node to access. + + Returns + ------- + NodeInterface + Interface for accessing the node's attributes. + """ + if not isinstance(node_id, int): + raise ValueError(f"node_id must be an integer, found '{node_id}' of type {type(node_id)}") + return NodeInterface(self._graph, node_id) + + +class EdgesAccessor: + """ + Accessor class for edge attributes with dictionary-style syntax. + + Parameters + ---------- + graph : BaseGraph + The graph to access edges from. + """ + + def __init__(self, graph: BaseGraph): + self._graph = graph + + def __getitem__(self, edge_id: int) -> "EdgeInterface": + """ + Access a specific edge's attributes. + + Parameters + ---------- + edge_id : int + The id of the edge to access. + + Returns + ------- + EdgeInterface + Interface for accessing the edge's attributes. + """ + if not isinstance(edge_id, int): + raise ValueError(f"edge_id must be an integer, found '{edge_id}' of type {type(edge_id)}") + return EdgeInterface(self._graph, edge_id) + + +class EdgeInterface: + """ + Helper class to interact with a single edge. + + Parameters + ---------- + graph : BaseGraph + The graph to interact with. + edge_id : int + The id of the edge to interact with. + + See Also + -------- + [BaseGraph][tracksdata.graph.BaseGraph] The base graph class. + """ + + def __init__(self, graph: BaseGraph, edge_id: int): + self._graph = graph + self._edge_id = edge_id + + def __getitem__(self, key: str) -> Any: + """ + Get an edge attribute value. + + Parameters + ---------- + key : str + The attribute key to retrieve. + + Returns + ------- + Any + The attribute value. + """ + df = self._graph.edge_attrs(attr_keys=[key]) + filtered = df.filter(pl.col(DEFAULT_ATTR_KEYS.EDGE_ID) == self._edge_id) + return filtered[key].item() + + def __setitem__(self, key: str, value: Any) -> None: + """ + Set an edge attribute value. + + Parameters + ---------- + key : str + The attribute key to set. + value : Any + The value to set. + """ + return self._graph.update_edge_attrs(attrs={key: value}, edge_ids=[self._edge_id]) + + def __str__(self) -> str: + df = self._graph.edge_attrs() + edge_attr = df.filter(pl.col(DEFAULT_ATTR_KEYS.EDGE_ID) == self._edge_id) + return str(edge_attr) + + def __repr__(self) -> str: + return str(self) + + def to_dict(self) -> dict[str, Any]: + """ + Get all edge attributes as a dictionary. + + Returns + ------- + dict[str, Any] + Dictionary of attribute keys and values. + """ + df = self._graph.edge_attrs() + filtered = df.filter(pl.col(DEFAULT_ATTR_KEYS.EDGE_ID) == self._edge_id) + data = filtered.drop(DEFAULT_ATTR_KEYS.EDGE_ID).rows(named=True)[0] + return data diff --git a/src/tracksdata/graph/_test/test_graph_backends.py b/src/tracksdata/graph/_test/test_graph_backends.py index fc52fe59..acd34240 100644 --- a/src/tracksdata/graph/_test/test_graph_backends.py +++ b/src/tracksdata/graph/_test/test_graph_backends.py @@ -2002,21 +2002,56 @@ def test_nodes_interface(graph_backend: BaseGraph) -> None: node2 = graph_backend.add_node({"t": 1, "x": 0}) node3 = graph_backend.add_node({"t": 2, "x": -1}) - assert graph_backend[node1]["x"] == 1 - assert graph_backend[node2]["x"] == 0 - assert graph_backend[node3]["x"] == -1 + assert graph_backend.nodes[node1]["x"] == 1 + assert graph_backend.nodes[node2]["x"] == 0 + assert graph_backend.nodes[node3]["x"] == -1 graph_backend.add_node_attr_key("y", pl.Int64) - graph_backend[node2]["y"] = 5 + graph_backend.nodes[node2]["y"] = 5 - assert graph_backend[node1]["y"] == -1 - assert graph_backend[node2]["y"] == 5 - assert graph_backend[node3]["y"] == -1 + assert graph_backend.nodes[node1]["y"] == -1 + assert graph_backend.nodes[node2]["y"] == 5 + assert graph_backend.nodes[node3]["y"] == -1 - assert graph_backend[node1].to_dict() == {"t": 0, "x": 1, "y": -1} - assert graph_backend[node2].to_dict() == {"t": 1, "x": 0, "y": 5} - assert graph_backend[node3].to_dict() == {"t": 2, "x": -1, "y": -1} + assert graph_backend.nodes[node1].to_dict() == {"t": 0, "x": 1, "y": -1} + assert graph_backend.nodes[node2].to_dict() == {"t": 1, "x": 0, "y": 5} + assert graph_backend.nodes[node3].to_dict() == {"t": 2, "x": -1, "y": -1} + + +def test_edges_interface(graph_backend: BaseGraph) -> None: + """Test edge attribute access using graph.edges[edge_id]['attr'] syntax.""" + graph_backend.add_node_attr_key("x", dtype=pl.Int64, default_value=-1) + graph_backend.add_edge_attr_key("weight", dtype=pl.Float64, default_value=0.0) + graph_backend.add_edge_attr_key("score", dtype=pl.Float64, default_value=-1.0) + + # Create nodes and edges + node1 = graph_backend.add_node({"t": 0, "x": 1}) + node2 = graph_backend.add_node({"t": 1, "x": 2}) + node3 = graph_backend.add_node({"t": 2, "x": 3}) + + edge1 = graph_backend.add_edge(node1, node2, {"weight": 0.5, "score": -1.0}) + edge2 = graph_backend.add_edge(node2, node3, {"weight": 0.8, "score": -1.0}) + + # Test getting edge attributes + assert graph_backend.edges[edge1]["weight"] == 0.5 + assert graph_backend.edges[edge2]["weight"] == 0.8 + + # Test setting edge attributes + graph_backend.edges[edge1]["score"] = 0.95 + graph_backend.edges[edge2]["score"] = 0.75 + + assert graph_backend.edges[edge1]["score"] == 0.95 + assert graph_backend.edges[edge2]["score"] == 0.75 + + # Test to_dict method + edge1_dict = graph_backend.edges[edge1].to_dict() + assert edge1_dict["weight"] == 0.5 + assert edge1_dict["score"] == 0.95 + + edge2_dict = graph_backend.edges[edge2].to_dict() + assert edge2_dict["weight"] == 0.8 + assert edge2_dict["score"] == 0.75 def test_custom_indices(graph_backend: BaseGraph) -> None: @@ -2368,7 +2403,7 @@ def test_geff_roundtrip(graph_backend: BaseGraph) -> None: assert set(graph_backend.edge_attr_keys()) == set(geff_graph.edge_attr_keys()) for node_id in geff_graph.node_ids(): - assert geff_graph[node_id].to_dict() == graph_backend[node_id].to_dict() + assert geff_graph.nodes[node_id].to_dict() == graph_backend.nodes[node_id].to_dict() assert rx.is_isomorphic( rx_graph, diff --git a/src/tracksdata/graph/filters/_spatial_filter.py b/src/tracksdata/graph/filters/_spatial_filter.py index 3a86236d..52ce080e 100644 --- a/src/tracksdata/graph/filters/_spatial_filter.py +++ b/src/tracksdata/graph/filters/_spatial_filter.py @@ -388,7 +388,7 @@ def _add_node(self, node_id: int) -> None: else: raise ValueError("Spatial filter is not initialized") - attrs = self._graph[node_id].to_dict() + attrs = self._graph.nodes[node_id].to_dict() positions_min, positions_max = self._attrs_to_bb_window(attrs) self._node_rtree.insert_bb_items( @@ -409,7 +409,7 @@ def _remove_node(self, node_id: int) -> None: if self._node_rtree is None: raise ValueError("Spatial filter is not initialized") - attrs = self._graph[node_id].to_dict() + attrs = self._graph.nodes[node_id].to_dict() positions_min, positions_max = self._attrs_to_bb_window(attrs) self._node_rtree.delete_items(