Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion city2graph/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def __init__(
self.node_mappings: dict[str, dict[str, dict[str | int, int] | str | list[str | int]]] = {}
self.node_feature_cols: dict[str, list[str]] | list[str] | None = None
self.node_label_cols: dict[str, list[str]] | list[str] | None = None
self.edge_feature_cols: dict[str, list[str]] | list[str] | None = None
self.edge_feature_cols: dict[tuple[str, str, str], list[str]] | list[str] | None = None
self.edge_index_values: (
dict[tuple[str, str, str], list[list[str | int]]] | list[list[str | int]] | None
) = None
Expand Down
46 changes: 24 additions & 22 deletions city2graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def __init__(
self,
node_feature_cols: dict[str, list[str]] | list[str] | None = None,
node_label_cols: dict[str, list[str]] | list[str] | None = None,
edge_feature_cols: dict[str, list[str]] | list[str] | None = None,
edge_feature_cols: dict[tuple[str, str, str], list[str]] | list[str] | None = None,
device: str | torch.device | None = None,
dtype: torch.dtype | None = None,
keep_geom: bool = True,
Expand Down Expand Up @@ -182,7 +182,7 @@ def pyg_to_gdf(
_node_types: str | list[str] | None = None,
_edge_types: str | list[tuple[str, str, str]] | None = None,
additional_node_cols: dict[str, list[str]] | list[str] | None = None,
additional_edge_cols: dict[str | tuple[str, str, str], list[str]] | list[str] | None = None,
additional_edge_cols: dict[tuple[str, str, str], list[str]] | list[str] | None = None,
) -> (
tuple[dict[str, gpd.GeoDataFrame], dict[tuple[str, str, str], gpd.GeoDataFrame]]
| tuple[gpd.GeoDataFrame | None, gpd.GeoDataFrame | None]
Expand Down Expand Up @@ -552,7 +552,7 @@ def _process_hetero_edges(
data: HeteroData,
edges_dict: dict[tuple[str, str, str], gpd.GeoDataFrame],
node_mappings: dict[str, dict[str, dict[str | int, int] | str | list[str | int]]],
edge_feature_cols: dict[str, list[str]] | None,
edge_feature_cols: dict[tuple[str, str, str], list[str]] | None,
) -> None:
"""
Process all edge types for heterogeneous graph.
Expand All @@ -567,8 +567,8 @@ def _process_hetero_edges(
Dictionary mapping edge types to GeoDataFrames.
node_mappings : dict
Dictionary containing node mapping information.
edge_feature_cols : dict[str, list[str]], optional
Dictionary mapping relation types to feature column names.
edge_feature_cols : dict[tuple[str, str, str], list[str]], optional
Dictionary mapping edge types to feature column names.
"""
device = _get_device(self.device)

Expand Down Expand Up @@ -601,7 +601,9 @@ def _process_hetero_edges(
)
data[edge_type].edge_index = edge_index

feature_cols = edge_feature_cols.get(rel_type) if edge_feature_cols else None
data[edge_type].edge_index = edge_index

feature_cols = edge_feature_cols.get(edge_type) if edge_feature_cols else None
data[edge_type].edge_attr = self._create_features(edge_gdf, feature_cols)
else:
data[edge_type].edge_index = torch.zeros((2, 0), dtype=torch.long, device=device)
Expand All @@ -619,7 +621,7 @@ def _store_hetero_metadata(
edges_dict: dict[tuple[str, str, str], gpd.GeoDataFrame],
node_feature_cols: dict[str, list[str]] | None,
node_label_cols: dict[str, list[str]] | None,
edge_feature_cols: dict[str, list[str]] | None,
edge_feature_cols: dict[tuple[str, str, str], list[str]] | None,
) -> None:
"""
Store metadata for heterogeneous graph.
Expand All @@ -640,8 +642,8 @@ def _store_hetero_metadata(
Dictionary mapping node types to feature column names.
node_label_cols : dict[str, list[str]], optional
Dictionary mapping node types to label column names.
edge_feature_cols : dict[str, list[str]], optional
Dictionary mapping relation types to feature column names.
edge_feature_cols : dict[tuple[str, str, str], list[str]], optional
Dictionary mapping edge types to feature column names.
"""
# Store mappings and column metadata
metadata = GraphMetadata(is_hetero=True)
Expand Down Expand Up @@ -1043,12 +1045,12 @@ def _extract_features(

# Extract edge features (edge_attr)
elif hasattr(obj_data, "edge_attr") and obj_data.edge_attr is not None:
feature_cols = metadata.edge_feature_cols
edge_feat_cols = metadata.edge_feature_cols
cols_list = None
if is_hetero and isinstance(type_name, tuple) and isinstance(feature_cols, dict):
cols_list = feature_cols.get(type_name[1]) # relation type
elif not is_hetero and isinstance(feature_cols, list):
cols_list = feature_cols
if is_hetero and isinstance(type_name, tuple) and isinstance(edge_feat_cols, dict):
cols_list = edge_feat_cols.get(type_name) # full edge type tuple
elif not is_hetero and isinstance(edge_feat_cols, list):
cols_list = edge_feat_cols

if cols_list is None:
num_features = obj_data.edge_attr.shape[1]
Expand Down Expand Up @@ -1645,7 +1647,7 @@ def gdf_to_pyg(
edges: dict[tuple[str, str, str], gpd.GeoDataFrame] | gpd.GeoDataFrame | None = None,
node_feature_cols: dict[str, list[str]] | list[str] | None = None,
node_label_cols: dict[str, list[str]] | list[str] | None = None,
edge_feature_cols: dict[str, list[str]] | list[str] | None = None,
edge_feature_cols: dict[tuple[str, str, str], list[str]] | list[str] | None = None,
device: str | torch.device | None = None,
dtype: torch.dtype | None = None,
keep_geom: bool = True,
Expand Down Expand Up @@ -1682,9 +1684,9 @@ def gdf_to_pyg(
Column names to use as node labels for supervised learning tasks.
For heterogeneous graphs, provide a dictionary mapping node types
to their label columns.
edge_feature_cols : dict[str, list[str]] or list[str], optional
edge_feature_cols : dict[tuple[str, str, str], list[str]] or list[str], optional
Column names to use as edge features. For heterogeneous graphs,
provide a dictionary mapping relation types to their feature columns.
provide a dictionary mapping edge types to their feature columns.
device : str or torch.device, optional
Target device for tensor placement ('cpu', 'cuda', or torch.device).
If None, automatically selects CUDA if available, otherwise CPU.
Expand Down Expand Up @@ -1782,7 +1784,7 @@ def pyg_to_gdf(
edge_types: str | list[tuple[str, str, str]] | None = None,
keep_geom: bool = True,
additional_node_cols: dict[str, list[str]] | list[str] | None = None,
additional_edge_cols: dict[str | tuple[str, str, str], list[str]] | list[str] | None = None,
additional_edge_cols: dict[tuple[str, str, str], list[str]] | list[str] | None = None,
) -> (
tuple[dict[str, gpd.GeoDataFrame], dict[tuple[str, str, str], gpd.GeoDataFrame]]
| tuple[gpd.GeoDataFrame | None, gpd.GeoDataFrame | None]
Expand Down Expand Up @@ -1814,10 +1816,10 @@ def pyg_to_gdf(
Additional columns to extract from the PyG object attributes.
For homogeneous graphs, a list of attribute names.
For heterogeneous graphs, a dict mapping node types to lists of attribute names.
additional_edge_cols : dict[str | tuple[str, str, str], list[str]] or list[str], optional
additional_edge_cols : dict[tuple[str, str, str], list[str]] or list[str], optional
Additional columns to extract from the PyG object attributes.
For homogeneous graphs, a list of attribute names.
For heterogeneous graphs, a dict mapping edge type tuples or relation names
For heterogeneous graphs, a dict mapping edge type tuples
to lists of attribute names.

Returns
Expand Down Expand Up @@ -1877,7 +1879,7 @@ def pyg_to_nx(
data: Data | HeteroData,
keep_geom: bool = True,
additional_node_cols: dict[str, list[str]] | list[str] | None = None,
additional_edge_cols: dict[str | tuple[str, str, str], list[str]] | list[str] | None = None,
additional_edge_cols: dict[tuple[str, str, str], list[str]] | list[str] | None = None,
) -> nx.Graph:
"""
Convert a PyTorch Geometric object to a NetworkX graph.
Expand All @@ -1896,7 +1898,7 @@ def pyg_to_nx(
geometries exist, reconstructs geometries from node positions.
additional_node_cols : dict[str, list[str]] or list[str], optional
Additional columns to extract from the PyG object attributes.
additional_edge_cols : dict[str | tuple[str, str, str], list[str]] or list[str], optional
additional_edge_cols : dict[tuple[str, str, str], list[str]] or list[str], optional
Additional columns to extract from the PyG object attributes.

Returns
Expand Down
10 changes: 6 additions & 4 deletions tests/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,10 @@ def test_gdf_to_pyg_with_features(
sample_hetero_edges_dict,
node_feature_cols={"building": ["b_feat1"], "road": ["length"]},
node_label_cols={"building": ["b_label"]},
edge_feature_cols={"connects_to": ["conn_feat1"], "links_to": ["link_feat1"]},
edge_feature_cols={
("building", "connects_to", "road"): ["conn_feat1"],
("road", "links_to", "road"): ["link_feat1"],
},
)
assert data["building"].x.shape[1] == 1
assert data["road"].x.shape[1] == 1
Expand Down Expand Up @@ -1050,11 +1053,10 @@ def test_hetero_edge_type_deduction(self) -> None:
metadata.edge_types = [edge_type_1, edge_type_2]
data.graph_metadata = metadata

# Test matching by full edge type tuple for e1
# and matching by relation string for e2
# Test matching by full edge type tuple for both
_, edges_dict = pyg_to_gdf(
data,
additional_edge_cols={edge_type_1: ["z"], "via": ["w"]},
additional_edge_cols={edge_type_1: ["z"], edge_type_2: ["w"]},
)
assert isinstance(edges_dict, dict)
assert "z" in edges_dict[edge_type_1].columns
Expand Down