From b3467d0d0adf4e0f9886ac6051d8ec195281694a Mon Sep 17 00:00:00 2001 From: Yuta Sato Date: Wed, 7 Jan 2026 04:09:23 +0000 Subject: [PATCH] Fix edge_feature_cols in gdf_to_pyg --- city2graph/base.py | 2 +- city2graph/graph.py | 46 +++++++++++++++++++++++---------------------- tests/test_graph.py | 10 ++++++---- 3 files changed, 31 insertions(+), 27 deletions(-) diff --git a/city2graph/base.py b/city2graph/base.py index f404b67..e428b45 100644 --- a/city2graph/base.py +++ b/city2graph/base.py @@ -100,7 +100,7 @@ def __init__( self.node_mappings: dict[str, dict[str, dict[str | int, int] | str | list[str | int]]] = {} self.node_feature_cols: dict[str, list[str]] | list[str] | None = None self.node_label_cols: dict[str, list[str]] | list[str] | None = None - self.edge_feature_cols: dict[str, list[str]] | list[str] | None = None + self.edge_feature_cols: dict[tuple[str, str, str], list[str]] | list[str] | None = None self.edge_index_values: ( dict[tuple[str, str, str], list[list[str | int]]] | list[list[str | int]] | None ) = None diff --git a/city2graph/graph.py b/city2graph/graph.py index 1b9b8c6..c6b8831 100644 --- a/city2graph/graph.py +++ b/city2graph/graph.py @@ -114,7 +114,7 @@ def __init__( self, node_feature_cols: dict[str, list[str]] | list[str] | None = None, node_label_cols: dict[str, list[str]] | list[str] | None = None, - edge_feature_cols: dict[str, list[str]] | list[str] | None = None, + edge_feature_cols: dict[tuple[str, str, str], list[str]] | list[str] | None = None, device: str | torch.device | None = None, dtype: torch.dtype | None = None, keep_geom: bool = True, @@ -182,7 +182,7 @@ def pyg_to_gdf( _node_types: str | list[str] | None = None, _edge_types: str | list[tuple[str, str, str]] | None = None, additional_node_cols: dict[str, list[str]] | list[str] | None = None, - additional_edge_cols: dict[str | tuple[str, str, str], list[str]] | list[str] | None = None, + additional_edge_cols: dict[tuple[str, str, str], list[str]] | list[str] | None = None, ) -> ( tuple[dict[str, gpd.GeoDataFrame], dict[tuple[str, str, str], gpd.GeoDataFrame]] | tuple[gpd.GeoDataFrame | None, gpd.GeoDataFrame | None] @@ -552,7 +552,7 @@ def _process_hetero_edges( data: HeteroData, edges_dict: dict[tuple[str, str, str], gpd.GeoDataFrame], node_mappings: dict[str, dict[str, dict[str | int, int] | str | list[str | int]]], - edge_feature_cols: dict[str, list[str]] | None, + edge_feature_cols: dict[tuple[str, str, str], list[str]] | None, ) -> None: """ Process all edge types for heterogeneous graph. @@ -567,8 +567,8 @@ def _process_hetero_edges( Dictionary mapping edge types to GeoDataFrames. node_mappings : dict Dictionary containing node mapping information. - edge_feature_cols : dict[str, list[str]], optional - Dictionary mapping relation types to feature column names. + edge_feature_cols : dict[tuple[str, str, str], list[str]], optional + Dictionary mapping edge types to feature column names. """ device = _get_device(self.device) @@ -601,7 +601,9 @@ def _process_hetero_edges( ) data[edge_type].edge_index = edge_index - feature_cols = edge_feature_cols.get(rel_type) if edge_feature_cols else None + data[edge_type].edge_index = edge_index + + feature_cols = edge_feature_cols.get(edge_type) if edge_feature_cols else None data[edge_type].edge_attr = self._create_features(edge_gdf, feature_cols) else: data[edge_type].edge_index = torch.zeros((2, 0), dtype=torch.long, device=device) @@ -619,7 +621,7 @@ def _store_hetero_metadata( edges_dict: dict[tuple[str, str, str], gpd.GeoDataFrame], node_feature_cols: dict[str, list[str]] | None, node_label_cols: dict[str, list[str]] | None, - edge_feature_cols: dict[str, list[str]] | None, + edge_feature_cols: dict[tuple[str, str, str], list[str]] | None, ) -> None: """ Store metadata for heterogeneous graph. @@ -640,8 +642,8 @@ def _store_hetero_metadata( Dictionary mapping node types to feature column names. node_label_cols : dict[str, list[str]], optional Dictionary mapping node types to label column names. - edge_feature_cols : dict[str, list[str]], optional - Dictionary mapping relation types to feature column names. + edge_feature_cols : dict[tuple[str, str, str], list[str]], optional + Dictionary mapping edge types to feature column names. """ # Store mappings and column metadata metadata = GraphMetadata(is_hetero=True) @@ -1043,12 +1045,12 @@ def _extract_features( # Extract edge features (edge_attr) elif hasattr(obj_data, "edge_attr") and obj_data.edge_attr is not None: - feature_cols = metadata.edge_feature_cols + edge_feat_cols = metadata.edge_feature_cols cols_list = None - if is_hetero and isinstance(type_name, tuple) and isinstance(feature_cols, dict): - cols_list = feature_cols.get(type_name[1]) # relation type - elif not is_hetero and isinstance(feature_cols, list): - cols_list = feature_cols + if is_hetero and isinstance(type_name, tuple) and isinstance(edge_feat_cols, dict): + cols_list = edge_feat_cols.get(type_name) # full edge type tuple + elif not is_hetero and isinstance(edge_feat_cols, list): + cols_list = edge_feat_cols if cols_list is None: num_features = obj_data.edge_attr.shape[1] @@ -1645,7 +1647,7 @@ def gdf_to_pyg( edges: dict[tuple[str, str, str], gpd.GeoDataFrame] | gpd.GeoDataFrame | None = None, node_feature_cols: dict[str, list[str]] | list[str] | None = None, node_label_cols: dict[str, list[str]] | list[str] | None = None, - edge_feature_cols: dict[str, list[str]] | list[str] | None = None, + edge_feature_cols: dict[tuple[str, str, str], list[str]] | list[str] | None = None, device: str | torch.device | None = None, dtype: torch.dtype | None = None, keep_geom: bool = True, @@ -1682,9 +1684,9 @@ def gdf_to_pyg( Column names to use as node labels for supervised learning tasks. For heterogeneous graphs, provide a dictionary mapping node types to their label columns. - edge_feature_cols : dict[str, list[str]] or list[str], optional + edge_feature_cols : dict[tuple[str, str, str], list[str]] or list[str], optional Column names to use as edge features. For heterogeneous graphs, - provide a dictionary mapping relation types to their feature columns. + provide a dictionary mapping edge types to their feature columns. device : str or torch.device, optional Target device for tensor placement ('cpu', 'cuda', or torch.device). If None, automatically selects CUDA if available, otherwise CPU. @@ -1782,7 +1784,7 @@ def pyg_to_gdf( edge_types: str | list[tuple[str, str, str]] | None = None, keep_geom: bool = True, additional_node_cols: dict[str, list[str]] | list[str] | None = None, - additional_edge_cols: dict[str | tuple[str, str, str], list[str]] | list[str] | None = None, + additional_edge_cols: dict[tuple[str, str, str], list[str]] | list[str] | None = None, ) -> ( tuple[dict[str, gpd.GeoDataFrame], dict[tuple[str, str, str], gpd.GeoDataFrame]] | tuple[gpd.GeoDataFrame | None, gpd.GeoDataFrame | None] @@ -1814,10 +1816,10 @@ def pyg_to_gdf( Additional columns to extract from the PyG object attributes. For homogeneous graphs, a list of attribute names. For heterogeneous graphs, a dict mapping node types to lists of attribute names. - additional_edge_cols : dict[str | tuple[str, str, str], list[str]] or list[str], optional + additional_edge_cols : dict[tuple[str, str, str], list[str]] or list[str], optional Additional columns to extract from the PyG object attributes. For homogeneous graphs, a list of attribute names. - For heterogeneous graphs, a dict mapping edge type tuples or relation names + For heterogeneous graphs, a dict mapping edge type tuples to lists of attribute names. Returns @@ -1877,7 +1879,7 @@ def pyg_to_nx( data: Data | HeteroData, keep_geom: bool = True, additional_node_cols: dict[str, list[str]] | list[str] | None = None, - additional_edge_cols: dict[str | tuple[str, str, str], list[str]] | list[str] | None = None, + additional_edge_cols: dict[tuple[str, str, str], list[str]] | list[str] | None = None, ) -> nx.Graph: """ Convert a PyTorch Geometric object to a NetworkX graph. @@ -1896,7 +1898,7 @@ def pyg_to_nx( geometries exist, reconstructs geometries from node positions. additional_node_cols : dict[str, list[str]] or list[str], optional Additional columns to extract from the PyG object attributes. - additional_edge_cols : dict[str | tuple[str, str, str], list[str]] or list[str], optional + additional_edge_cols : dict[tuple[str, str, str], list[str]] or list[str], optional Additional columns to extract from the PyG object attributes. Returns diff --git a/tests/test_graph.py b/tests/test_graph.py index 8b43799..07d1e1b 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -174,7 +174,10 @@ def test_gdf_to_pyg_with_features( sample_hetero_edges_dict, node_feature_cols={"building": ["b_feat1"], "road": ["length"]}, node_label_cols={"building": ["b_label"]}, - edge_feature_cols={"connects_to": ["conn_feat1"], "links_to": ["link_feat1"]}, + edge_feature_cols={ + ("building", "connects_to", "road"): ["conn_feat1"], + ("road", "links_to", "road"): ["link_feat1"], + }, ) assert data["building"].x.shape[1] == 1 assert data["road"].x.shape[1] == 1 @@ -1050,11 +1053,10 @@ def test_hetero_edge_type_deduction(self) -> None: metadata.edge_types = [edge_type_1, edge_type_2] data.graph_metadata = metadata - # Test matching by full edge type tuple for e1 - # and matching by relation string for e2 + # Test matching by full edge type tuple for both _, edges_dict = pyg_to_gdf( data, - additional_edge_cols={edge_type_1: ["z"], "via": ["w"]}, + additional_edge_cols={edge_type_1: ["z"], edge_type_2: ["w"]}, ) assert isinstance(edges_dict, dict) assert "z" in edges_dict[edge_type_1].columns