From 0863fc4921bc8ebb35e97f5f6a5c8689f65be9ce Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lu=C3=ADs=20F=2E=20Pereira?= Date: Mon, 25 Nov 2024 12:05:54 -0800 Subject: [PATCH 01/28] Add initial redesign of liftings maps --- .../test_SimplicialCliqueLifting.py | 19 +- topobenchmarkx/complex.py | 85 +++++ topobenchmarkx/transforms/converters.py | 313 ++++++++++++++++++ .../transforms/feature_liftings/base.py | 13 + .../transforms/feature_liftings/identity.py | 37 +-- .../feature_liftings/projection_sum.py | 67 +--- topobenchmarkx/transforms/liftings/base.py | 95 +++++- .../liftings/graph2simplicial/clique.py | 30 +- 8 files changed, 550 insertions(+), 109 deletions(-) create mode 100644 topobenchmarkx/complex.py create mode 100644 topobenchmarkx/transforms/converters.py create mode 100644 topobenchmarkx/transforms/feature_liftings/base.py diff --git a/test/transforms/liftings/simplicial/test_SimplicialCliqueLifting.py b/test/transforms/liftings/simplicial/test_SimplicialCliqueLifting.py index 41b8ac45..fc2f89f8 100644 --- a/test/transforms/liftings/simplicial/test_SimplicialCliqueLifting.py +++ b/test/transforms/liftings/simplicial/test_SimplicialCliqueLifting.py @@ -3,21 +3,24 @@ import torch from topobenchmarkx.transforms.liftings.graph2simplicial import ( - SimplicialCliqueLifting, + SimplicialCliqueLifting ) - +from topobenchmarkx.transforms.converters import Data2NxGraph, Complex2Dict +from topobenchmarkx.transforms.liftings.base import LiftingTransform class TestSimplicialCliqueLifting: """Test the SimplicialCliqueLifting class.""" def setup_method(self): # Initialise the SimplicialCliqueLifting class - self.lifting_signed = SimplicialCliqueLifting( - complex_dim=3, signed=True - ) - self.lifting_unsigned = SimplicialCliqueLifting( - complex_dim=3, signed=False - ) + data2graph = Data2NxGraph() + simplicial2dict_signed = Complex2Dict(signed=True) + simplicial2dict_unsigned = Complex2Dict(signed=False) + + lifting_map = SimplicialCliqueLifting(complex_dim=3) + + self.lifting_signed = LiftingTransform(data2graph, simplicial2dict_signed, lifting_map) + self.lifting_unsigned = LiftingTransform(data2graph, simplicial2dict_unsigned, lifting_map) def test_lift_topology(self, simple_graph_1): """Test the lift_topology method.""" diff --git a/topobenchmarkx/complex.py b/topobenchmarkx/complex.py new file mode 100644 index 00000000..8a2949f2 --- /dev/null +++ b/topobenchmarkx/complex.py @@ -0,0 +1,85 @@ +import torch + + +class PlainComplex: + def __init__( + self, + incidence, + down_laplacian, + up_laplacian, + adjacency, + coadjacency, + hodge_laplacian, + features=None, + ): + # TODO: allow None with nice error message if callable? + + # TODO: make this private? do not allow for changes in these values? + self.incidence = incidence + self.down_laplacian = down_laplacian + self.up_laplacian = up_laplacian + self.adjacency = adjacency + self.coadjacency = coadjacency + self.hodge_laplacian = hodge_laplacian + + if features is None: + features = [None for _ in range(len(self.incidence))] + else: + for rank, dim in enumerate(self.shape): + # TODO: make error message more informative + if ( + features[rank] is not None + and features[rank].shape[0] != dim + ): + raise ValueError("Features have wrong shape.") + + self.features = features + + @property + def shape(self): + """Shape of the complex. + + Returns + ------- + list[int] + """ + return [incidence.shape[-1] for incidence in self.incidence] + + @property + def max_rank(self): + """Maximum rank of the complex. + + Returns + ------- + int + """ + return len(self.incidence) + + def update_features(self, rank, values): + """Update features. + + Parameters + ---------- + rank : int + Rank of simplices the features belong to. + values : array-like + New features for the rank-simplices. + """ + self.features[rank] = values + + def reset_features(self): + """Reset features.""" + self.features = [None for _ in self.features] + + def propagate_values(self, rank, values): + """Propagate features from a rank to an upper one. + + Parameters + ---------- + rank : int + Rank of the simplices the values belong to. + values : array-like + Features for the rank-simplices. + """ + # TODO: can be made much better + return torch.matmul(torch.abs(self.incidence[rank + 1].t()), values) diff --git a/topobenchmarkx/transforms/converters.py b/topobenchmarkx/transforms/converters.py new file mode 100644 index 00000000..96920b74 --- /dev/null +++ b/topobenchmarkx/transforms/converters.py @@ -0,0 +1,313 @@ +import abc + +import networkx as nx +import numpy as np +import torch +import torch_geometric +from topomodelx.utils.sparse import from_sparse +from torch_geometric.utils.undirected import is_undirected, to_undirected + +from topobenchmarkx.complex import PlainComplex +from topobenchmarkx.data.utils.utils import ( + generate_zero_sparse_connectivity, + select_neighborhoods_of_interest, +) + + +class Converter(abc.ABC): + """Convert between data structures representing the same domain.""" + + def __call__(self, domain): + """Convert domain's data structure.""" + return self.convert(domain) + + @abc.abstractmethod + def convert(self, domain): + """Convert domain's data structure.""" + + +class IdentityConverter(Converter): + """Identity conversion. + + Retrieves same data structure for domain. + """ + + def convert(self, domain): + """Convert domain.""" + return domain + + +class Data2NxGraph(Converter): + """Data to nx.Graph conversion. + + Parameters + ---------- + preserve_edge_attr : bool + Whether to preserve edge attributes. + """ + + def __init__(self, preserve_edge_attr=False): + self.preserve_edge_attr = preserve_edge_attr + + def _data_has_edge_attr(self, data: torch_geometric.data.Data) -> bool: + r"""Check if the input data object has edge attributes. + + Parameters + ---------- + data : torch_geometric.data.Data + The input data. + + Returns + ------- + bool + Whether the data object has edge attributes. + """ + return hasattr(data, "edge_attr") and data.edge_attr is not None + + def convert(self, domain: torch_geometric.data.Data) -> nx.Graph: + r"""Generate a NetworkX graph from the input data object. + + Parameters + ---------- + domain : torch_geometric.data.Data + The input data. + + Returns + ------- + nx.Graph + The generated NetworkX graph. + """ + # Check if data object have edge_attr, return list of tuples as [(node_id, {'features':data}, 'dim':1)] or ?? + nodes = [ + (n, dict(features=domain.x[n], dim=0)) + for n in range(domain.x.shape[0]) + ] + + if self.preserve_edge_attr and self._data_has_edge_attr(domain): + # In case edge features are given, assign features to every edge + edge_index, edge_attr = ( + domain.edge_index, + ( + domain.edge_attr + if is_undirected(domain.edge_index, domain.edge_attr) + else to_undirected(domain.edge_index, domain.edge_attr) + ), + ) + edges = [ + (i.item(), j.item(), dict(features=edge_attr[edge_idx], dim=1)) + for edge_idx, (i, j) in enumerate( + zip(edge_index[0], edge_index[1], strict=False) + ) + ] + + else: + # If edge_attr is not present, return list list of edges + edges = [ + (i.item(), j.item(), {}) + for i, j in zip( + domain.edge_index[0], domain.edge_index[1], strict=False + ) + ] + graph = nx.Graph() + graph.add_nodes_from(nodes) + graph.add_edges_from(edges) + return graph + + +class Complex2PlainComplex(Converter): + """toponetx.Complex to PlainComplex conversion. + + NB: order of features plays a crucial role, as ``PlainComplex`` + simply stores them as lists (i.e. the reference to the indices + of the simplex are lost). + + Parameters + ---------- + max_rank : int + Maximum rank of the complex. + neighborhoods : list, optional + List of neighborhoods of interest. + signed : bool, optional + If True, returns signed connectivity matrices. + transfer_features : bool, optional + Whether to transfer features. + """ + + def __init__( + self, + max_rank=None, + neighborhoods=None, + signed=False, + transfer_features=True, + ): + super().__init__() + self.max_rank = max_rank + self.neighborhoods = neighborhoods + self.signed = signed + self.transfer_features = transfer_features + + def convert(self, domain): + """Convert toponetx.Complex to PlainComplex. + + Parameters + ---------- + domain : toponetx.Complex + + Returns + ------- + PlainComplex + """ + # NB: just a slightly rewriting of get_complex_connectivity + + max_rank = self.max_rank or domain.dim + signed = self.signed + neighborhoods = self.neighborhoods + + connectivity_infos = [ + "incidence", + "down_laplacian", + "up_laplacian", + "adjacency", + "coadjacency", + "hodge_laplacian", + ] + + practical_shape = list( + np.pad(list(domain.shape), (0, max_rank + 1 - len(domain.shape))) + ) + data = { + connectivity_info: [] for connectivity_info in connectivity_infos + } + for rank_idx in range(max_rank + 1): + for connectivity_info in connectivity_infos: + try: + data[connectivity_info].append( + from_sparse( + getattr(domain, f"{connectivity_info}_matrix")( + rank=rank_idx, signed=signed + ) + ) + ) + except ValueError: + if connectivity_info == "incidence": + data[connectivity_info].append( + generate_zero_sparse_connectivity( + m=practical_shape[rank_idx - 1], + n=practical_shape[rank_idx], + ) + ) + else: + data[connectivity_info].append( + generate_zero_sparse_connectivity( + m=practical_shape[rank_idx], + n=practical_shape[rank_idx], + ) + ) + + # TODO: handle this + if neighborhoods is not None: + data = select_neighborhoods_of_interest(data, neighborhoods) + + # TODO: simplex specific? + # TODO: how to do this for other? + if self.transfer_features and hasattr( + domain, "get_simplex_attributes" + ): + # TODO: confirm features are in the right order; update this + data["features"] = [] + for rank in range(max_rank + 1): + rank_features_dict = domain.get_simplex_attributes( + "features", rank + ) + if rank_features_dict: + rank_features = torch.stack( + list(rank_features_dict.values()) + ) + else: + rank_features = None + data["features"].append(rank_features) + + return PlainComplex(**data) + + +class PlainComplex2Dict(Converter): + """PlainComplex to dict conversion.""" + + def convert(self, domain): + """Convert PlainComplex to dict. + + Parameters + ---------- + domain : toponetx.Complex + + Returns + ------- + dict + """ + data = {} + connectivity_infos = [ + "incidence", + "down_laplacian", + "up_laplacian", + "adjacency", + "coadjacency", + "hodge_laplacian", + ] + for connectivity_info in connectivity_infos: + info = getattr(domain, connectivity_info) + for rank, rank_info in enumerate(info): + data[f"{connectivity_info}_{rank}"] = rank_info + + # TODO: handle neighborhoods + data["shape"] = domain.shape + + for index, values in enumerate(domain.features): + if values is not None: + data[f"x_{index}"] = values + + return data + + +class ConverterComposition(Converter): + def __init__(self, converters): + super().__init__() + self.converters = converters + + def convert(self, domain): + """Convert domain""" + for converter in self.converters: + domain = converter(domain) + + return domain + + +class Complex2Dict(ConverterComposition): + """Complex to dict conversion. + + Parameters + ---------- + max_rank : int + Maximum rank of the complex. + neighborhoods : list, optional + List of neighborhoods of interest. + signed : bool, optional + If True, returns signed connectivity matrices. + transfer_features : bool, optional + Whether to transfer features. + """ + + def __init__( + self, + max_rank=None, + neighborhoods=None, + signed=False, + transfer_features=True, + ): + complex2plain = Complex2PlainComplex( + max_rank=max_rank, + neighborhoods=neighborhoods, + signed=signed, + transfer_features=transfer_features, + ) + plain2dict = PlainComplex2Dict() + super().__init__(converters=(complex2plain, plain2dict)) diff --git a/topobenchmarkx/transforms/feature_liftings/base.py b/topobenchmarkx/transforms/feature_liftings/base.py new file mode 100644 index 00000000..c5969398 --- /dev/null +++ b/topobenchmarkx/transforms/feature_liftings/base.py @@ -0,0 +1,13 @@ +import abc + + +class FeatureLiftingMap(abc.ABC): + """Feature lifting map.""" + + def __call__(self, domain): + """Lift features of a domain.""" + return self.lift_features(domain) + + @abc.abstractmethod + def lift_features(self, domain): + """Lift features of a domain.""" diff --git a/topobenchmarkx/transforms/feature_liftings/identity.py b/topobenchmarkx/transforms/feature_liftings/identity.py index 93806f1d..9abf4e5d 100644 --- a/topobenchmarkx/transforms/feature_liftings/identity.py +++ b/topobenchmarkx/transforms/feature_liftings/identity.py @@ -1,36 +1,13 @@ """Identity transform that does nothing to the input data.""" -import torch_geometric +from .base import FeatureLiftingMap -class Identity(torch_geometric.transforms.BaseTransform): - r"""An identity transform that does nothing to the input data. +class Identity(FeatureLiftingMap): + """Identity feature lifting map.""" - Parameters - ---------- - **kwargs : optional - Parameters for the base transform. - """ + # TODO: rename to IdentityFeatureLifting - def __init__(self, **kwargs): - super().__init__() - self.type = "domain2domain" - self.parameters = kwargs - - def __repr__(self) -> str: - return f"{self.__class__.__name__}(type={self.type!r}, parameters={self.parameters!r})" - - def forward(self, data: torch_geometric.data.Data): - r"""Apply the transform to the input data. - - Parameters - ---------- - data : torch_geometric.data.Data - The input data. - - Returns - ------- - torch_geometric.data.Data - The same data. - """ - return data + def lift_features(self, domain): + """Lift features of a domain using identity map.""" + return domain diff --git a/topobenchmarkx/transforms/feature_liftings/projection_sum.py b/topobenchmarkx/transforms/feature_liftings/projection_sum.py index 3cce03eb..4d0c04b5 100644 --- a/topobenchmarkx/transforms/feature_liftings/projection_sum.py +++ b/topobenchmarkx/transforms/feature_liftings/projection_sum.py @@ -1,69 +1,30 @@ """ProjectionSum class.""" -import torch -import torch_geometric +from .base import FeatureLiftingMap -class ProjectionSum(torch_geometric.transforms.BaseTransform): - r"""Lift r-cell features to r+1-cells by projection. +class ProjectionSum(FeatureLiftingMap): + r"""Lift r-cell features to r+1-cells by projection.""" - Parameters - ---------- - **kwargs : optional - Additional arguments for the class. - """ - - def __init__(self, **kwargs): - super().__init__() - - def __repr__(self) -> str: - return f"{self.__class__.__name__}()" - - def lift_features( - self, data: torch_geometric.data.Data | dict - ) -> torch_geometric.data.Data | dict: + def lift_features(self, domain): r"""Project r-cell features of a graph to r+1-cell structures. Parameters ---------- - data : torch_geometric.data.Data | dict + data : PlainComplex The input data to be lifted. Returns ------- - torch_geometric.data.Data | dict - The data with the lifted features. + PlainComplex + Domain with the lifted features. """ - keys = sorted( - [ - key.split("_")[1] - for key in data - if ("incidence" in key and "-" not in key) - ] - ) - for elem in keys: - if f"x_{elem}" not in data: - idx_to_project = 0 if elem == "hyperedges" else int(elem) - 1 - data["x_" + elem] = torch.matmul( - abs(data["incidence_" + elem].t()), - data[f"x_{idx_to_project}"], - ) - return data + for rank in range(domain.max_rank - 1): + if domain.features[rank + 1] is not None: + continue - def forward( - self, data: torch_geometric.data.Data | dict - ) -> torch_geometric.data.Data | dict: - r"""Apply the lifting to the input data. - - Parameters - ---------- - data : torch_geometric.data.Data | dict - The input data to be lifted. + domain.features[rank + 1] = domain.propagate_values( + rank, domain.features[rank] + ) - Returns - ------- - torch_geometric.data.Data | dict - The lifted data. - """ - data = self.lift_features(data) - return data + return domain diff --git a/topobenchmarkx/transforms/liftings/base.py b/topobenchmarkx/transforms/liftings/base.py index c08a54e5..fa00e40e 100644 --- a/topobenchmarkx/transforms/liftings/base.py +++ b/topobenchmarkx/transforms/liftings/base.py @@ -1,10 +1,99 @@ """Abstract class for topological liftings.""" -from abc import abstractmethod +import abc import torch_geometric +from topobenchmarkx.transforms.converters import IdentityConverter from topobenchmarkx.transforms.feature_liftings import FEATURE_LIFTINGS +from topobenchmarkx.transforms.feature_liftings.identity import ( + Identity, +) + + +class LiftingTransform(torch_geometric.transforms.BaseTransform): + """Lifting transform. + + Parameters + ---------- + data2domain : Converter + Conversion between ``torch_geometric.Data`` into + domain for consumption by lifting. + domain2dict : Converter + Conversion between output domain of feature lifting + and ``torch_geometric.Data``. + lifting : LiftingMap + Lifting map. + domain2domain : Converter + Conversion between output domain of lifting + and input domain for feature lifting. + feature_lifting : FeatureLiftingMap + Feature lifting map. + """ + + # NB: emulates previous AbstractLifting + def __init__( + self, + data2domain, + domain2dict, + lifting, + domain2domain=None, + feature_lifting=None, + ): + if feature_lifting is None: + feature_lifting = Identity() + + if domain2domain is None: + domain2domain = IdentityConverter() + + self.data2domain = data2domain + self.domain2domain = domain2domain + self.domain2dict = domain2dict + self.lifting = lifting + self.feature_lifting = feature_lifting + + def forward( + self, data: torch_geometric.data.Data + ) -> torch_geometric.data.Data: + r"""Apply the full lifting (topology + features) to the input data. + + Parameters + ---------- + data : torch_geometric.data.Data + The input data to be lifted. + + Returns + ------- + torch_geometric.data.Data + The lifted data. + """ + initial_data = data.to_dict() + + domain = self.data2domain(data) + lifted_topology = self.lifting(domain) + lifted_topology = self.domain2domain(lifted_topology) + lifted_topology = self.feature_lifting(lifted_topology) + lifted_topology_dict = self.domain2dict(lifted_topology) + + # TODO: make this line more clear + return torch_geometric.data.Data( + **initial_data, **lifted_topology_dict + ) + + +class LiftingMap(abc.ABC): + """Lifting map. + + Lifts a domain into another. + """ + + def __call__(self, domain): + """Lift domain.""" + return self.lift(domain) + + @abc.abstractmethod + def lift(self, domain): + """Lift domain.""" class AbstractLifting(torch_geometric.transforms.BaseTransform): @@ -18,12 +107,14 @@ class AbstractLifting(torch_geometric.transforms.BaseTransform): Additional arguments for the class. """ + # TODO: delete + def __init__(self, feature_lifting=None, **kwargs): super().__init__() self.feature_lifting = FEATURE_LIFTINGS[feature_lifting]() self.neighborhoods = kwargs.get("neighborhoods") - @abstractmethod + @abc.abstractmethod def lift_topology(self, data: torch_geometric.data.Data) -> dict: r"""Lift the topology of a graph to higher-order topological domains. diff --git a/topobenchmarkx/transforms/liftings/graph2simplicial/clique.py b/topobenchmarkx/transforms/liftings/graph2simplicial/clique.py index af7d5cdf..990d2e6e 100755 --- a/topobenchmarkx/transforms/liftings/graph2simplicial/clique.py +++ b/topobenchmarkx/transforms/liftings/graph2simplicial/clique.py @@ -1,32 +1,30 @@ """This module implements the CliqueLifting class, which lifts graphs to simplicial complexes.""" from itertools import combinations -from typing import Any import networkx as nx -import torch_geometric from toponetx.classes import SimplicialComplex -from topobenchmarkx.transforms.liftings.graph2simplicial import ( - Graph2SimplicialLifting, -) +from topobenchmarkx.transforms.liftings.base import LiftingMap -class SimplicialCliqueLifting(Graph2SimplicialLifting): +class SimplicialCliqueLifting(LiftingMap): r"""Lift graphs to simplicial complex domain. The algorithm creates simplices by identifying the cliques and considering them as simplices of the same dimension. Parameters ---------- - **kwargs : optional - Additional arguments for the class. + complex_dim : int + Maximum rank of the complex. """ - def __init__(self, **kwargs): - super().__init__(**kwargs) + def __init__(self, complex_dim=2): + super().__init__() + # TODO: better naming + self.complex_dim = complex_dim - def lift_topology(self, data: torch_geometric.data.Data) -> dict: + def lift(self, domain): r"""Lift the topology of a graph to a simplicial complex. Parameters @@ -39,12 +37,11 @@ def lift_topology(self, data: torch_geometric.data.Data) -> dict: dict The lifted topology. """ - graph = self._generate_graph_from_data(data) + graph = domain + simplicial_complex = SimplicialComplex(graph) cliques = nx.find_cliques(graph) - simplices: list[set[tuple[Any, ...]]] = [ - set() for _ in range(2, self.complex_dim + 1) - ] + simplices = [set() for _ in range(2, self.complex_dim + 1)] for clique in cliques: for i in range(2, self.complex_dim + 1): for c in combinations(clique, i + 1): @@ -53,4 +50,5 @@ def lift_topology(self, data: torch_geometric.data.Data) -> dict: for set_k_simplices in simplices: simplicial_complex.add_simplices_from(list(set_k_simplices)) - return self._get_lifted_topology(simplicial_complex, graph) + # TODO: need to check for edge preservation + return simplicial_complex From e5ee0f338734f6e29b4059dff433a423853a4b66 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lu=C3=ADs=20F=2E=20Pereira?= Date: Thu, 5 Dec 2024 18:56:36 -0800 Subject: [PATCH 02/28] Rename Complex and move propagate_values to projection sum feature lifting --- topobenchmarkx/complex.py | 18 +----------------- .../feature_liftings/projection_sum.py | 10 ++++++++-- 2 files changed, 9 insertions(+), 19 deletions(-) diff --git a/topobenchmarkx/complex.py b/topobenchmarkx/complex.py index 8a2949f2..531592dc 100644 --- a/topobenchmarkx/complex.py +++ b/topobenchmarkx/complex.py @@ -1,7 +1,4 @@ -import torch - - -class PlainComplex: +class Complex: def __init__( self, incidence, @@ -70,16 +67,3 @@ def update_features(self, rank, values): def reset_features(self): """Reset features.""" self.features = [None for _ in self.features] - - def propagate_values(self, rank, values): - """Propagate features from a rank to an upper one. - - Parameters - ---------- - rank : int - Rank of the simplices the values belong to. - values : array-like - Features for the rank-simplices. - """ - # TODO: can be made much better - return torch.matmul(torch.abs(self.incidence[rank + 1].t()), values) diff --git a/topobenchmarkx/transforms/feature_liftings/projection_sum.py b/topobenchmarkx/transforms/feature_liftings/projection_sum.py index 4d0c04b5..a02a1db5 100644 --- a/topobenchmarkx/transforms/feature_liftings/projection_sum.py +++ b/topobenchmarkx/transforms/feature_liftings/projection_sum.py @@ -1,5 +1,7 @@ """ProjectionSum class.""" +import torch + from .base import FeatureLiftingMap @@ -23,8 +25,12 @@ def lift_features(self, domain): if domain.features[rank + 1] is not None: continue - domain.features[rank + 1] = domain.propagate_values( - rank, domain.features[rank] + domain.update_features( + rank + 1, + torch.matmul( + torch.abs(domain.incidence[rank + 1].t()), + domain.features[rank], + ), ) return domain From 55b41207c19791a3216ce183ca2783983b3f6594 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lu=C3=ADs=20F=2E=20Pereira?= Date: Thu, 5 Dec 2024 18:57:15 -0800 Subject: [PATCH 03/28] Rename adapters --- topobenchmarkx/transforms/converters.py | 76 +++++++++++----------- topobenchmarkx/transforms/liftings/base.py | 4 +- 2 files changed, 40 insertions(+), 40 deletions(-) diff --git a/topobenchmarkx/transforms/converters.py b/topobenchmarkx/transforms/converters.py index 96920b74..bc31ef1f 100644 --- a/topobenchmarkx/transforms/converters.py +++ b/topobenchmarkx/transforms/converters.py @@ -7,38 +7,38 @@ from topomodelx.utils.sparse import from_sparse from torch_geometric.utils.undirected import is_undirected, to_undirected -from topobenchmarkx.complex import PlainComplex -from topobenchmarkx.data.utils.utils import ( +from topobenchmarkx.complex import Complex +from topobenchmarkx.data.utils import ( generate_zero_sparse_connectivity, select_neighborhoods_of_interest, ) -class Converter(abc.ABC): - """Convert between data structures representing the same domain.""" +class Adapter(abc.ABC): + """Adapt between data structures representing the same domain.""" def __call__(self, domain): - """Convert domain's data structure.""" - return self.convert(domain) + """Adapt domain's data structure.""" + return self.adapt(domain) @abc.abstractmethod - def convert(self, domain): - """Convert domain's data structure.""" + def adapt(self, domain): + """Adapt domain's data structure.""" -class IdentityConverter(Converter): - """Identity conversion. +class IdentityAdapter(Adapter): + """Identity adaptation. Retrieves same data structure for domain. """ - def convert(self, domain): - """Convert domain.""" + def adapt(self, domain): + """Adapt domain.""" return domain -class Data2NxGraph(Converter): - """Data to nx.Graph conversion. +class Data2NxGraph(Adapter): + """Data to nx.Graph adaptation. Parameters ---------- @@ -64,7 +64,7 @@ def _data_has_edge_attr(self, data: torch_geometric.data.Data) -> bool: """ return hasattr(data, "edge_attr") and data.edge_attr is not None - def convert(self, domain: torch_geometric.data.Data) -> nx.Graph: + def adapt(self, domain: torch_geometric.data.Data) -> nx.Graph: r"""Generate a NetworkX graph from the input data object. Parameters @@ -114,10 +114,10 @@ def convert(self, domain: torch_geometric.data.Data) -> nx.Graph: return graph -class Complex2PlainComplex(Converter): - """toponetx.Complex to PlainComplex conversion. +class TnxComplex2Complex(Adapter): + """toponetx.Complex to Complex adaptation. - NB: order of features plays a crucial role, as ``PlainComplex`` + NB: order of features plays a crucial role, as ``Complex`` simply stores them as lists (i.e. the reference to the indices of the simplex are lost). @@ -146,8 +146,8 @@ def __init__( self.signed = signed self.transfer_features = transfer_features - def convert(self, domain): - """Convert toponetx.Complex to PlainComplex. + def adapt(self, domain): + """Adapt toponetx.Complex to Complex. Parameters ---------- @@ -155,7 +155,7 @@ def convert(self, domain): Returns ------- - PlainComplex + Complex """ # NB: just a slightly rewriting of get_complex_connectivity @@ -227,14 +227,14 @@ def convert(self, domain): rank_features = None data["features"].append(rank_features) - return PlainComplex(**data) + return Complex(**data) -class PlainComplex2Dict(Converter): - """PlainComplex to dict conversion.""" +class Complex2Dict(Adapter): + """Complex to dict adaptation.""" - def convert(self, domain): - """Convert PlainComplex to dict. + def adapt(self, domain): + """Adapt Complex to dict. Parameters ---------- @@ -268,21 +268,21 @@ def convert(self, domain): return data -class ConverterComposition(Converter): - def __init__(self, converters): +class AdapterComposition(Adapter): + def __init__(self, adapters): super().__init__() - self.converters = converters + self.adapters = adapters - def convert(self, domain): - """Convert domain""" - for converter in self.converters: - domain = converter(domain) + def adapt(self, domain): + """Adapt domain""" + for adapter in self.adapters: + domain = adapter(domain) return domain -class Complex2Dict(ConverterComposition): - """Complex to dict conversion. +class TnxComplex2Dict(AdapterComposition): + """toponetx.Complex to dict adaptation. Parameters ---------- @@ -303,11 +303,11 @@ def __init__( signed=False, transfer_features=True, ): - complex2plain = Complex2PlainComplex( + complex2plain = TnxComplex2Complex( max_rank=max_rank, neighborhoods=neighborhoods, signed=signed, transfer_features=transfer_features, ) - plain2dict = PlainComplex2Dict() - super().__init__(converters=(complex2plain, plain2dict)) + plain2dict = Complex2Dict() + super().__init__(adapters=(complex2plain, plain2dict)) diff --git a/topobenchmarkx/transforms/liftings/base.py b/topobenchmarkx/transforms/liftings/base.py index fa00e40e..8dddde67 100644 --- a/topobenchmarkx/transforms/liftings/base.py +++ b/topobenchmarkx/transforms/liftings/base.py @@ -4,7 +4,7 @@ import torch_geometric -from topobenchmarkx.transforms.converters import IdentityConverter +from topobenchmarkx.transforms.converters import IdentityAdapter from topobenchmarkx.transforms.feature_liftings import FEATURE_LIFTINGS from topobenchmarkx.transforms.feature_liftings.identity import ( Identity, @@ -44,7 +44,7 @@ def __init__( feature_lifting = Identity() if domain2domain is None: - domain2domain = IdentityConverter() + domain2domain = IdentityAdapter() self.data2domain = data2domain self.domain2domain = domain2domain From 8c146f27a6d61ea139e549493a09429a8076ef4a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lu=C3=ADs=20F=2E=20Pereira?= Date: Thu, 5 Dec 2024 18:58:34 -0800 Subject: [PATCH 04/28] Move Complex and adapters to data utils --- .../{transforms/converters.py => data/utils/adapters.py} | 0 topobenchmarkx/{complex.py => data/utils/domain.py} | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename topobenchmarkx/{transforms/converters.py => data/utils/adapters.py} (100%) rename topobenchmarkx/{complex.py => data/utils/domain.py} (100%) diff --git a/topobenchmarkx/transforms/converters.py b/topobenchmarkx/data/utils/adapters.py similarity index 100% rename from topobenchmarkx/transforms/converters.py rename to topobenchmarkx/data/utils/adapters.py diff --git a/topobenchmarkx/complex.py b/topobenchmarkx/data/utils/domain.py similarity index 100% rename from topobenchmarkx/complex.py rename to topobenchmarkx/data/utils/domain.py From 59051c6c1f10e488e068727d32f500b4e97765d3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lu=C3=ADs=20F=2E=20Pereira?= Date: Tue, 24 Dec 2024 16:05:42 -0800 Subject: [PATCH 05/28] Update imports --- topobenchmarkx/data/utils/__init__.py | 2 ++ topobenchmarkx/data/utils/adapters.py | 4 ++-- topobenchmarkx/transforms/liftings/base.py | 2 +- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/topobenchmarkx/data/utils/__init__.py b/topobenchmarkx/data/utils/__init__.py index 74f57c96..01e77220 100644 --- a/topobenchmarkx/data/utils/__init__.py +++ b/topobenchmarkx/data/utils/__init__.py @@ -1,5 +1,7 @@ """Init file for data/utils module.""" +from .adapters import * +from .domain import Complex from .utils import ( ensure_serializable, # noqa: F401 generate_zero_sparse_connectivity, # noqa: F401 diff --git a/topobenchmarkx/data/utils/adapters.py b/topobenchmarkx/data/utils/adapters.py index bc31ef1f..6aa33fac 100644 --- a/topobenchmarkx/data/utils/adapters.py +++ b/topobenchmarkx/data/utils/adapters.py @@ -7,8 +7,8 @@ from topomodelx.utils.sparse import from_sparse from torch_geometric.utils.undirected import is_undirected, to_undirected -from topobenchmarkx.complex import Complex -from topobenchmarkx.data.utils import ( +from topobenchmarkx.data.utils.domain import Complex +from topobenchmarkx.data.utils.utils import ( generate_zero_sparse_connectivity, select_neighborhoods_of_interest, ) diff --git a/topobenchmarkx/transforms/liftings/base.py b/topobenchmarkx/transforms/liftings/base.py index 8dddde67..0a436585 100644 --- a/topobenchmarkx/transforms/liftings/base.py +++ b/topobenchmarkx/transforms/liftings/base.py @@ -4,7 +4,7 @@ import torch_geometric -from topobenchmarkx.transforms.converters import IdentityAdapter +from topobenchmarkx.data.utils import IdentityAdapter from topobenchmarkx.transforms.feature_liftings import FEATURE_LIFTINGS from topobenchmarkx.transforms.feature_liftings.identity import ( Identity, From e4165d81f83612dcee230f382ac8c9b80a806564 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lu=C3=ADs=20F=2E=20Pereira?= Date: Tue, 14 Jan 2025 16:05:29 -0800 Subject: [PATCH 06/28] Update SimplicialKHopLifting to work with new design --- .../liftings/graph2simplicial/khop.py | 38 +++++++++++-------- 1 file changed, 22 insertions(+), 16 deletions(-) diff --git a/topobenchmark/transforms/liftings/graph2simplicial/khop.py b/topobenchmark/transforms/liftings/graph2simplicial/khop.py index 50239f18..dc9e13e2 100755 --- a/topobenchmark/transforms/liftings/graph2simplicial/khop.py +++ b/topobenchmark/transforms/liftings/graph2simplicial/khop.py @@ -4,15 +4,14 @@ from itertools import combinations from typing import Any +import torch import torch_geometric from toponetx.classes import SimplicialComplex -from topobenchmark.transforms.liftings.graph2simplicial.base import ( - Graph2SimplicialLifting, -) +from topobenchmark.transforms.liftings.base import LiftingMap -class SimplicialKHopLifting(Graph2SimplicialLifting): +class SimplicialKHopLifting(LiftingMap): r"""Lift graphs to simplicial complex domain. The function lifts a graph to a simplicial complex by considering k-hop @@ -23,38 +22,43 @@ class SimplicialKHopLifting(Graph2SimplicialLifting): Parameters ---------- + complex_dim : int + Dimension of the desired complex. max_k_simplices : int, optional The maximum number of k-simplices to consider. Default is 5000. - **kwargs : optional - Additional arguments for the class. """ - def __init__(self, max_k_simplices=5000, **kwargs): - super().__init__(**kwargs) + def __init__(self, complex_dim=3, max_k_simplices=5000): + super().__init__() + self.complex_dim = complex_dim self.max_k_simplices = max_k_simplices def __repr__(self) -> str: return f"{self.__class__.__name__}(max_k_simplices={self.max_k_simplices!r})" - def lift_topology(self, data: torch_geometric.data.Data) -> dict: + def lift(self, domain): r"""Lift the topology to simplicial complex domain. Parameters ---------- - data : torch_geometric.data.Data - The input data to be lifted. + domain : nx.Graph + Graph to be lifted. Returns ------- - dict - The lifted topology. + toponetx.Complex + Lifted simplicial complex. """ - graph = self._generate_graph_from_data(data) + graph = domain + simplicial_complex = SimplicialComplex(graph) - edge_index = torch_geometric.utils.to_undirected(data.edge_index) + edge_index = torch_geometric.utils.to_undirected( + torch.tensor(list(zip(*graph.edges, strict=False))) + ) simplices: list[set[tuple[Any, ...]]] = [ set() for _ in range(2, self.complex_dim + 1) ] + for n in range(graph.number_of_nodes()): # Find 1-hop node n neighbors neighbors, _, _, _ = torch_geometric.utils.k_hop_subgraph( @@ -67,10 +71,12 @@ def lift_topology(self, data: torch_geometric.data.Data) -> dict: for i in range(1, self.complex_dim): for c in combinations(neighbors, i + 1): simplices[i - 1].add(tuple(c)) + for set_k_simplices in simplices: list_k_simplices = list(set_k_simplices) if len(set_k_simplices) > self.max_k_simplices: random.shuffle(list_k_simplices) list_k_simplices = list_k_simplices[: self.max_k_simplices] simplicial_complex.add_simplices_from(list_k_simplices) - return self._get_lifted_topology(simplicial_complex, graph) + + return simplicial_complex From 2777f9b607354f540d555af65088376bab7e1bfb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lu=C3=ADs=20F=2E=20Pereira?= Date: Tue, 14 Jan 2025 16:18:11 -0800 Subject: [PATCH 07/28] Add IdentityAdapter as default for all the adaptations in the LiftingTransform pipeline --- topobenchmark/transforms/liftings/base.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/topobenchmark/transforms/liftings/base.py b/topobenchmark/transforms/liftings/base.py index 6f5f35a7..d5c78dc9 100644 --- a/topobenchmark/transforms/liftings/base.py +++ b/topobenchmark/transforms/liftings/base.py @@ -32,15 +32,21 @@ class LiftingTransform(torch_geometric.transforms.BaseTransform): # NB: emulates previous AbstractLifting def __init__( self, - data2domain, - domain2dict, lifting, + data2domain=None, + domain2dict=None, domain2domain=None, feature_lifting=None, ): if feature_lifting is None: feature_lifting = Identity() + if data2domain is None: + data2domain = IdentityAdapter() + + if domain2dict is None: + domain2dict = IdentityAdapter() + if domain2domain is None: domain2domain = IdentityAdapter() From e5d48d3918354a0174c6748dc4d9697c579e0564 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lu=C3=ADs=20F=2E=20Pereira?= Date: Tue, 14 Jan 2025 16:19:46 -0800 Subject: [PATCH 08/28] Improve TnxComplex2Complex api and signatures; improve variable naming --- topobenchmark/data/utils/adapters.py | 32 +++++++++++++++------------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/topobenchmark/data/utils/adapters.py b/topobenchmark/data/utils/adapters.py index 9aa8355a..9db40c08 100644 --- a/topobenchmark/data/utils/adapters.py +++ b/topobenchmark/data/utils/adapters.py @@ -123,8 +123,9 @@ class TnxComplex2Complex(Adapter): Parameters ---------- - max_rank : int - Maximum rank of the complex. + complex_dim : int + Dimension of the desired subcomplex. + If ``None``, adapts the (full) complex. neighborhoods : list, optional List of neighborhoods of interest. signed : bool, optional @@ -135,13 +136,13 @@ class TnxComplex2Complex(Adapter): def __init__( self, - max_rank=None, + complex_dim=None, neighborhoods=None, signed=False, transfer_features=True, ): super().__init__() - self.max_rank = max_rank + self.complex_dim = complex_dim self.neighborhoods = neighborhoods self.signed = signed self.transfer_features = transfer_features @@ -159,7 +160,7 @@ def adapt(self, domain): """ # NB: just a slightly rewriting of get_complex_connectivity - max_rank = self.max_rank or domain.dim + dim = self.complex_dim or domain.dim signed = self.signed neighborhoods = self.neighborhoods @@ -173,12 +174,12 @@ def adapt(self, domain): ] practical_shape = list( - np.pad(list(domain.shape), (0, max_rank + 1 - len(domain.shape))) + np.pad(list(domain.shape), (0, dim + 1 - len(domain.shape))) ) data = { connectivity_info: [] for connectivity_info in connectivity_infos } - for rank_idx in range(max_rank + 1): + for rank_idx in range(dim + 1): for connectivity_info in connectivity_infos: try: data[connectivity_info].append( @@ -215,7 +216,7 @@ def adapt(self, domain): ): # TODO: confirm features are in the right order; update this data["features"] = [] - for rank in range(max_rank + 1): + for rank in range(dim + 1): rank_features_dict = domain.get_simplex_attributes( "features", rank ) @@ -286,8 +287,9 @@ class TnxComplex2Dict(AdapterComposition): Parameters ---------- - max_rank : int - Maximum rank of the complex. + complex_dim : int + Dimension of the desired subcomplex. + If ``None``, adapts the (full) complex. neighborhoods : list, optional List of neighborhoods of interest. signed : bool, optional @@ -298,16 +300,16 @@ class TnxComplex2Dict(AdapterComposition): def __init__( self, - max_rank=None, + complex_dim=None, neighborhoods=None, signed=False, transfer_features=True, ): - complex2plain = TnxComplex2Complex( - max_rank=max_rank, + tnxcomplex2complex = TnxComplex2Complex( + complex_dim=complex_dim, neighborhoods=neighborhoods, signed=signed, transfer_features=transfer_features, ) - plain2dict = Complex2Dict() - super().__init__(adapters=(complex2plain, plain2dict)) + complex2dict = Complex2Dict() + super().__init__(adapters=(tnxcomplex2complex, complex2dict)) From 601a0e1ad6e7288ce2020a8369fd832fb8e300c6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lu=C3=ADs=20F=2E=20Pereira?= Date: Tue, 14 Jan 2025 17:38:13 -0800 Subject: [PATCH 09/28] Update graph2hypergraph liftings to work with new design --- .../liftings/graph2hypergraph/khop.py | 20 ++++++++----------- .../liftings/graph2hypergraph/knn.py | 20 +++++++------------ 2 files changed, 15 insertions(+), 25 deletions(-) diff --git a/topobenchmark/transforms/liftings/graph2hypergraph/khop.py b/topobenchmark/transforms/liftings/graph2hypergraph/khop.py index 298fa135..f8997e31 100755 --- a/topobenchmark/transforms/liftings/graph2hypergraph/khop.py +++ b/topobenchmark/transforms/liftings/graph2hypergraph/khop.py @@ -3,12 +3,10 @@ import torch import torch_geometric -from topobenchmark.transforms.liftings.graph2hypergraph import ( - Graph2HypergraphLifting, -) +from topobenchmark.transforms.liftings.base import LiftingMap -class HypergraphKHopLifting(Graph2HypergraphLifting): +class HypergraphKHopLifting(LiftingMap): r"""Lift graph to hypergraphs by considering k-hop neighborhoods. The class transforms graphs to hypergraph domain by considering k-hop neighborhoods of @@ -19,18 +17,16 @@ class HypergraphKHopLifting(Graph2HypergraphLifting): ---------- k_value : int, optional The number of hops to consider. Default is 1. - **kwargs : optional - Additional arguments for the class. """ - def __init__(self, k_value=1, **kwargs): - super().__init__(**kwargs) - self.k = k_value + def __init__(self, k_value=1): + super().__init__() + self.n_hops = k_value def __repr__(self) -> str: - return f"{self.__class__.__name__}(k={self.k!r})" + return f"{self.__class__.__name__}(k={self.n_hops!r})" - def lift_topology(self, data: torch_geometric.data.Data) -> dict: + def lift(self, data: torch_geometric.data.Data) -> dict: r"""Lift a graphs to hypergraphs by considering k-hop neighborhoods. Parameters @@ -70,7 +66,7 @@ def lift_topology(self, data: torch_geometric.data.Data) -> dict: for n in range(num_nodes): neighbors, _, _, _ = torch_geometric.utils.k_hop_subgraph( - n, self.k, edge_index + n, self.n_hops, edge_index ) incidence_1[n, neighbors] = 1 diff --git a/topobenchmark/transforms/liftings/graph2hypergraph/knn.py b/topobenchmark/transforms/liftings/graph2hypergraph/knn.py index 03d0a13a..5b0de672 100755 --- a/topobenchmark/transforms/liftings/graph2hypergraph/knn.py +++ b/topobenchmark/transforms/liftings/graph2hypergraph/knn.py @@ -3,12 +3,10 @@ import torch import torch_geometric -from topobenchmark.transforms.liftings.graph2hypergraph import ( - Graph2HypergraphLifting, -) +from topobenchmark.transforms.liftings.base import LiftingMap -class HypergraphKNNLifting(Graph2HypergraphLifting): +class HypergraphKNNLifting(LiftingMap): r"""Lift graphs to hypergraph domain by considering k-nearest neighbors. Parameters @@ -17,8 +15,6 @@ class HypergraphKNNLifting(Graph2HypergraphLifting): The number of nearest neighbors to consider. Must be positive. Default is 1. loop : bool, optional If True the hyperedges will contain the node they were created from. - **kwargs : optional - Additional arguments for the class. Raises ------ @@ -28,8 +24,8 @@ class HypergraphKNNLifting(Graph2HypergraphLifting): If k_value is not an integer or if loop is not a boolean. """ - def __init__(self, k_value=1, loop=True, **kwargs): - super().__init__(**kwargs) + def __init__(self, k_value=1, loop=True): + super().__init__() # Validate k_value if not isinstance(k_value, int): @@ -41,11 +37,9 @@ def __init__(self, k_value=1, loop=True, **kwargs): if not isinstance(loop, bool): raise TypeError("loop must be a boolean") - self.k = k_value - self.loop = loop - self.transform = torch_geometric.transforms.KNNGraph(self.k, self.loop) + self.transform = torch_geometric.transforms.KNNGraph(k_value, loop) - def lift_topology(self, data: torch_geometric.data.Data) -> dict: + def lift(self, data: torch_geometric.data.Data) -> dict: r"""Lift a graph to hypergraph by considering k-nearest neighbors. Parameters @@ -64,7 +58,7 @@ def lift_topology(self, data: torch_geometric.data.Data) -> dict: incidence_1 = torch.zeros(num_nodes, num_nodes) data_lifted = self.transform(data) # check for loops, since KNNGraph is inconsistent with nodes with equal features - if self.loop: + if self.transform.loop: for i in range(num_nodes): if not torch.any( torch.all( From 1f9b12d76095fc428617ab405852ece84aa7efe6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lu=C3=ADs=20F=2E=20Pereira?= Date: Tue, 14 Jan 2025 17:38:39 -0800 Subject: [PATCH 10/28] Update graph2cell liftings to work with new design --- .../transforms/liftings/graph2cell/cycle.py | 40 +++++++++---------- 1 file changed, 19 insertions(+), 21 deletions(-) diff --git a/topobenchmark/transforms/liftings/graph2cell/cycle.py b/topobenchmark/transforms/liftings/graph2cell/cycle.py index 31e94d8b..63160701 100755 --- a/topobenchmark/transforms/liftings/graph2cell/cycle.py +++ b/topobenchmark/transforms/liftings/graph2cell/cycle.py @@ -1,15 +1,12 @@ """This module implements the cycle lifting for graphs to cell complexes.""" import networkx as nx -import torch_geometric from toponetx.classes import CellComplex -from topobenchmark.transforms.liftings.graph2cell.base import ( - Graph2CellLifting, -) +from topobenchmark.transforms.liftings.base import LiftingMap -class CellCycleLifting(Graph2CellLifting): +class CellCycleLifting(LiftingMap): r"""Lift graphs to cell complexes. The algorithm creates 2-cells by identifying the cycles and considering them as 2-cells. @@ -18,39 +15,40 @@ class CellCycleLifting(Graph2CellLifting): ---------- max_cell_length : int, optional The maximum length of the cycles to be lifted. Default is None. - **kwargs : optional - Additional arguments for the class. """ - def __init__(self, max_cell_length=None, **kwargs): - super().__init__(**kwargs) - self.complex_dim = 2 + def __init__(self, max_cell_length=None): + super().__init__() + self._complex_dim = 2 self.max_cell_length = max_cell_length - def lift_topology(self, data: torch_geometric.data.Data) -> dict: + def lift(self, domain): r"""Find the cycles of a graph and lifts them to 2-cells. Parameters ---------- - data : torch_geometric.data.Data - The input data to be lifted. + domain : nx.Graph + Graph to be lifted. Returns ------- - dict - The lifted topology. + CellComplex + The cell complex. """ - G = self._generate_graph_from_data(data) - cycles = nx.cycle_basis(G) - cell_complex = CellComplex(G) + graph = domain + + cycles = nx.cycle_basis(graph) + cell_complex = CellComplex(graph) # Eliminate self-loop cycles cycles = [cycle for cycle in cycles if len(cycle) != 1] - # Eliminate cycles that are greater than the max_cell_lenght + + # Eliminate cycles that are greater than the max_cell_length if self.max_cell_length is not None: cycles = [ cycle for cycle in cycles if len(cycle) <= self.max_cell_length ] if len(cycles) != 0: - cell_complex.add_cells_from(cycles, rank=self.complex_dim) - return self._get_lifted_topology(cell_complex, G) + cell_complex.add_cells_from(cycles, rank=self._complex_dim) + + return cell_complex From f5604753268b667f4d6fe22d60442e461feb9d3c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lu=C3=ADs=20F=2E=20Pereira?= Date: Tue, 14 Jan 2025 17:40:06 -0800 Subject: [PATCH 11/28] Improve SimplicialCliqueLifting docstrings --- .../liftings/graph2simplicial/clique.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/topobenchmark/transforms/liftings/graph2simplicial/clique.py b/topobenchmark/transforms/liftings/graph2simplicial/clique.py index 2bb8c405..37a5cc15 100755 --- a/topobenchmark/transforms/liftings/graph2simplicial/clique.py +++ b/topobenchmark/transforms/liftings/graph2simplicial/clique.py @@ -11,17 +11,17 @@ class SimplicialCliqueLifting(LiftingMap): r"""Lift graphs to simplicial complex domain. - The algorithm creates simplices by identifying the cliques and considering them as simplices of the same dimension. + The algorithm creates simplices by identifying the cliques + and considering them as simplices of the same dimension. Parameters ---------- complex_dim : int - Maximum rank of the complex. + Dimension of the subcomplex. """ def __init__(self, complex_dim=2): super().__init__() - # TODO: better naming self.complex_dim = complex_dim def lift(self, domain): @@ -29,13 +29,13 @@ def lift(self, domain): Parameters ---------- - data : torch_geometric.data.Data - The input data to be lifted. + domain : nx.Graph + Graph to be lifted. Returns ------- - dict - The lifted topology. + toponetx.Complex + Lifted simplicial complex. """ graph = domain @@ -50,5 +50,4 @@ def lift(self, domain): for set_k_simplices in simplices: simplicial_complex.add_simplices_from(list(set_k_simplices)) - # TODO: need to check for edge preservation return simplicial_complex From cfd7f458587e908b5bfccb97179e2f18aa9991b8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lu=C3=ADs=20F=2E=20Pereira?= Date: Tue, 14 Jan 2025 17:41:35 -0800 Subject: [PATCH 12/28] Fix lifting tests (NB: same behavior, only adapted setup - with few exceptions) --- test/conftest.py | 48 +++--- .../liftings/cell/test_CellCyclesLifting.py | 10 +- .../hypergraph/test_HypergraphKHopLifting.py | 41 ++++-- ...test_HypergraphKNearestNeighborsLifting.py | 138 ++++++++++-------- .../test_SimplicialCliqueLifting.py | 38 ++++- .../test_SimplicialNeighborhoodLifting.py | 35 ++++- test/transforms/liftings/test_GraphLifting.py | 89 +++++------ 7 files changed, 243 insertions(+), 156 deletions(-) diff --git a/test/conftest.py b/test/conftest.py index c84a1b72..753d63b2 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -1,25 +1,25 @@ """Configuration file for pytest.""" + import networkx as nx import pytest import torch import torch_geometric -from topobenchmark.transforms.liftings.graph2simplicial import ( - SimplicialCliqueLifting -) -from topobenchmark.transforms.liftings.graph2cell import ( - CellCycleLifting + +from topobenchmark.transforms.liftings.graph2cell.cycle import CellCycleLifting +from topobenchmark.transforms.liftings.graph2simplicial.clique import ( + SimplicialCliqueLifting, ) @pytest.fixture def mocker_fixture(mocker): """Return pytest mocker, used when one want to use mocker in setup_method. - + Parameters ---------- mocker : pytest_mock.plugin.MockerFixture A pytest mocker. - + Returns ------- pytest_mock.plugin.MockerFixture @@ -31,7 +31,7 @@ def mocker_fixture(mocker): @pytest.fixture def simple_graph_0(): """Create a manual graph for testing purposes. - + Returns ------- torch_geometric.data.Data @@ -74,10 +74,11 @@ def simple_graph_0(): ) return data + @pytest.fixture def simple_graph_1(): """Create a manual graph for testing purposes. - + Returns ------- torch_geometric.data.Data @@ -133,37 +134,35 @@ def simple_graph_1(): return data - @pytest.fixture def sg1_clique_lifted(simple_graph_1): """Return a simple graph with a clique lifting. - + Parameters ---------- simple_graph_1 : torch_geometric.data.Data A simple graph data object. - + Returns ------- torch_geometric.data.Data A simple graph data object with a clique lifting. """ - lifting_signed = SimplicialCliqueLifting( - complex_dim=3, signed=True - ) + lifting_signed = SimplicialCliqueLifting(complex_dim=3, signed=True) data = lifting_signed(simple_graph_1) data.batch_0 = "null" return data + @pytest.fixture def sg1_cell_lifted(simple_graph_1): """Return a simple graph with a cell lifting. - + Parameters ---------- simple_graph_1 : torch_geometric.data.Data A simple graph data object. - + Returns ------- torch_geometric.data.Data @@ -178,7 +177,7 @@ def sg1_cell_lifted(simple_graph_1): @pytest.fixture def simple_graph_2(): """Create a manual graph for testing purposes. - + Returns ------- torch_geometric.data.Data @@ -244,7 +243,7 @@ def simple_graph_2(): @pytest.fixture def random_graph_input(): """Create a random graph for testing purposes. - + Returns ------- torch.Tensor @@ -261,13 +260,12 @@ def random_graph_input(): num_nodes = 8 d_feat = 12 x = torch.randn(num_nodes, 12) - edges_1 = torch.randint(0, num_nodes, (2, num_nodes*2)) - edges_2 = torch.randint(0, num_nodes, (2, num_nodes*2)) - + edges_1 = torch.randint(0, num_nodes, (2, num_nodes * 2)) + edges_2 = torch.randint(0, num_nodes, (2, num_nodes * 2)) + d_feat_1, d_feat_2 = 5, 17 - x_1 = torch.randn(num_nodes*2, d_feat_1) - x_2 = torch.randn(num_nodes*2, d_feat_2) + x_1 = torch.randn(num_nodes * 2, d_feat_1) + x_2 = torch.randn(num_nodes * 2, d_feat_2) return x, x_1, x_2, edges_1, edges_2 - diff --git a/test/transforms/liftings/cell/test_CellCyclesLifting.py b/test/transforms/liftings/cell/test_CellCyclesLifting.py index 54fd276f..c574992e 100644 --- a/test/transforms/liftings/cell/test_CellCyclesLifting.py +++ b/test/transforms/liftings/cell/test_CellCyclesLifting.py @@ -2,7 +2,9 @@ import torch -from topobenchmark.transforms.liftings.graph2cell import CellCycleLifting +from topobenchmark.data.utils import Data2NxGraph, TnxComplex2Dict +from topobenchmark.transforms.liftings.base import LiftingTransform +from topobenchmark.transforms.liftings.graph2cell.cycle import CellCycleLifting class TestCellCycleLifting: @@ -10,7 +12,11 @@ class TestCellCycleLifting: def setup_method(self): # Initialise the CellCycleLifting class - self.lifting = CellCycleLifting() + self.lifting = LiftingTransform( + CellCycleLifting(), + data2domain=Data2NxGraph(), + domain2dict=TnxComplex2Dict(), + ) def test_lift_topology(self, simple_graph_1): # Test the lift_topology method diff --git a/test/transforms/liftings/hypergraph/test_HypergraphKHopLifting.py b/test/transforms/liftings/hypergraph/test_HypergraphKHopLifting.py index 13285fc1..3fcc7ebb 100644 --- a/test/transforms/liftings/hypergraph/test_HypergraphKHopLifting.py +++ b/test/transforms/liftings/hypergraph/test_HypergraphKHopLifting.py @@ -2,7 +2,8 @@ import torch -from topobenchmark.transforms.liftings.graph2hypergraph import ( +from topobenchmark.transforms.liftings.base import LiftingTransform +from topobenchmark.transforms.liftings.graph2hypergraph.khop import ( HypergraphKHopLifting, ) @@ -11,15 +12,23 @@ class TestHypergraphKHopLifting: """Test the HypergraphKHopLifting class.""" def setup_method(self): - """ Setup the test.""" + """Setup the test.""" # Initialise the HypergraphKHopLifting class - self.lifting_k1 = HypergraphKHopLifting(k_value=1) - self.lifting_k2 = HypergraphKHopLifting(k_value=2) - self.lifting_edge_attr = HypergraphKHopLifting(k_value=1, preserve_edge_attr=True) + self.lifting_k1 = LiftingTransform(HypergraphKHopLifting(k_value=1)) + self.lifting_k2 = LiftingTransform(HypergraphKHopLifting(k_value=2)) + + # TODO: delete? + # NB: `preserve_edge_attr` is never used? therefore they're equivalent + # self.lifting_edge_attr = HypergraphKHopLifting( + # k_value=1, preserve_edge_attr=True + # ) + self.lifting_edge_attr = LiftingTransform( + HypergraphKHopLifting(k_value=1) + ) def test_lift_topology(self, simple_graph_2): - """ Test the lift_topology method. - + """Test the lift_topology method. + Parameters ---------- simple_graph_2 : Data @@ -78,10 +87,18 @@ def test_lift_topology(self, simple_graph_2): assert ( expected_n_hyperedges == lifted_data_k2.num_hyperedges ), "Something is wrong with the number of hyperedges (k=2)." - + self.data_edge_attr = simple_graph_2 - edge_attributes = torch.rand((self.data_edge_attr.edge_index.shape[1], 2)) + edge_attributes = torch.rand( + (self.data_edge_attr.edge_index.shape[1], 2) + ) self.data_edge_attr.edge_attr = edge_attributes - lifted_data_edge_attr = self.lifting_edge_attr.forward(self.data_edge_attr.clone()) - assert lifted_data_edge_attr.edge_attr is not None, "Edge attributes are not preserved." - assert torch.all(edge_attributes == lifted_data_edge_attr.edge_attr), "Edge attributes are not preserved correctly." + lifted_data_edge_attr = self.lifting_edge_attr.forward( + self.data_edge_attr.clone() + ) + assert ( + lifted_data_edge_attr.edge_attr is not None + ), "Edge attributes are not preserved." + assert torch.all( + edge_attributes == lifted_data_edge_attr.edge_attr + ), "Edge attributes are not preserved correctly." diff --git a/test/transforms/liftings/hypergraph/test_HypergraphKNearestNeighborsLifting.py b/test/transforms/liftings/hypergraph/test_HypergraphKNearestNeighborsLifting.py index 7e9d1216..069d7a3c 100644 --- a/test/transforms/liftings/hypergraph/test_HypergraphKNearestNeighborsLifting.py +++ b/test/transforms/liftings/hypergraph/test_HypergraphKNearestNeighborsLifting.py @@ -3,7 +3,8 @@ import pytest import torch from torch_geometric.data import Data -from topobenchmark.transforms.liftings.graph2hypergraph import ( + +from topobenchmark.transforms.liftings.graph2hypergraph.knn import ( HypergraphKNNLifting, ) @@ -13,7 +14,7 @@ class TestHypergraphKNNLifting: def setup_method(self): """Set up test fixtures before each test method. - + Creates instances of HypergraphKNNLifting with different k values and loop settings. """ @@ -23,88 +24,94 @@ def setup_method(self): def test_initialization(self): """Test initialization with different parameters.""" + # TODO: overkill, delete? + # Test default parameters lifting_default = HypergraphKNNLifting() - assert lifting_default.k == 1 - assert lifting_default.loop is True + assert lifting_default.transform.k == 1 + assert lifting_default.transform.loop is True # Test custom parameters lifting_custom = HypergraphKNNLifting(k_value=5, loop=False) - assert lifting_custom.k == 5 - assert lifting_custom.loop is False + assert lifting_custom.transform.k == 5 + assert lifting_custom.transform.loop is False def test_lift_topology_k2(self, simple_graph_2): """Test the lift_topology method with k=2. - + Parameters ---------- simple_graph_2 : torch_geometric.data.Data A simple graph fixture with 9 nodes arranged in a line pattern. """ - lifted_data_k2 = self.lifting_k2.lift_topology(simple_graph_2.clone()) + lifted_data_k2 = self.lifting_k2.lift(simple_graph_2.clone()) expected_n_hyperedges = 9 - expected_incidence_1 = torch.tensor([ - [1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0], - ]) + expected_incidence_1 = torch.tensor( + [ + [1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0], + ] + ) assert torch.equal( lifted_data_k2["incidence_hyperedges"].to_dense(), - expected_incidence_1 + expected_incidence_1, ), "Incorrect incidence_hyperedges for k=2" - + assert lifted_data_k2["num_hyperedges"] == expected_n_hyperedges assert torch.equal(lifted_data_k2["x_0"], simple_graph_2.x) def test_lift_topology_k3(self, simple_graph_2): """Test the lift_topology method with k=3. - + Parameters ---------- simple_graph_2 : torch_geometric.data.Data A simple graph fixture with 9 nodes arranged in a line pattern. """ - lifted_data_k3 = self.lifting_k3.lift_topology(simple_graph_2.clone()) + lifted_data_k3 = self.lifting_k3.lift(simple_graph_2.clone()) expected_n_hyperedges = 9 - expected_incidence_1 = torch.tensor([ - [1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0], - ]) + expected_incidence_1 = torch.tensor( + [ + [1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0], + ] + ) assert torch.equal( lifted_data_k3["incidence_hyperedges"].to_dense(), - expected_incidence_1 + expected_incidence_1, ), "Incorrect incidence_hyperedges for k=3" - + assert lifted_data_k3["num_hyperedges"] == expected_n_hyperedges assert torch.equal(lifted_data_k3["x_0"], simple_graph_2.x) def test_lift_topology_no_loop(self, simple_graph_2): """Test the lift_topology method with loop=False. - + Parameters ---------- simple_graph_2 : torch_geometric.data.Data A simple graph fixture with 9 nodes arranged in a line pattern. """ - lifted_data = self.lifting_no_loop.lift_topology(simple_graph_2.clone()) - + lifted_data = self.lifting_no_loop.lift(simple_graph_2.clone()) + # Verify no self-loops in the incidence matrix incidence_matrix = lifted_data["incidence_hyperedges"].to_dense() diagonal = torch.diag(incidence_matrix) @@ -115,11 +122,11 @@ def test_lift_topology_with_equal_features(self): # Create a graph where some nodes have identical features data = Data( x=torch.tensor([[1.0], [1.0], [2.0], [2.0]]), - edge_index=torch.tensor([[0, 1, 2, 3], [1, 0, 3, 2]]) + edge_index=torch.tensor([[0, 1, 2, 3], [1, 0, 3, 2]]), ) - - lifted_data = self.lifting_k2.lift_topology(data) - + + lifted_data = self.lifting_k2.lift(data) + # Verify the shape of the output assert lifted_data["incidence_hyperedges"].size() == (4, 4) assert lifted_data["num_hyperedges"] == 4 @@ -128,7 +135,7 @@ def test_lift_topology_with_equal_features(self): @pytest.mark.parametrize("k_value", [1, 2, 3, 4]) def test_different_k_values(self, k_value, simple_graph_2): """Test lift_topology with different k values. - + Parameters ---------- k_value : int @@ -137,29 +144,30 @@ def test_different_k_values(self, k_value, simple_graph_2): A simple graph fixture with 9 nodes arranged in a line pattern. """ lifting = HypergraphKNNLifting(k_value=k_value, loop=True) - lifted_data = lifting.lift_topology(simple_graph_2.clone()) - + lifted_data = lifting.lift(simple_graph_2.clone()) + # Verify basic properties assert lifted_data["num_hyperedges"] == simple_graph_2.x.size(0) incidence_matrix = lifted_data["incidence_hyperedges"].to_dense() - + # Check that each node is connected to at most k nodes - assert torch.all(incidence_matrix.sum(dim=1) <= k_value), \ - f"Some nodes are connected to more than {k_value} neighbors" + assert torch.all( + incidence_matrix.sum(dim=1) <= k_value + ), f"Some nodes are connected to more than {k_value} neighbors" def test_invalid_inputs(self): """Test handling of invalid inputs and edge cases.""" # Test with no x attribute (this should raise AttributeError) data_no_x = Data(edge_index=torch.tensor([[0, 1], [1, 0]])) with pytest.raises(AttributeError): - self.lifting_k2.lift_topology(data_no_x) + self.lifting_k2.lift(data_no_x) # Test single node case (edge case that should work) single_node_data = Data( x=torch.tensor([[1.0]], dtype=torch.float), - edge_index=torch.tensor([[0], [0]]) + edge_index=torch.tensor([[0], [0]]), ) - lifted_single = self.lifting_k2.lift_topology(single_node_data) + lifted_single = self.lifting_k2.lift(single_node_data) assert lifted_single["num_hyperedges"] == 1 assert lifted_single["incidence_hyperedges"].size() == (1, 1) assert torch.equal(lifted_single["x_0"], single_node_data.x) @@ -167,32 +175,30 @@ def test_invalid_inputs(self): # Test with identical nodes (edge case that should work) identical_nodes_data = Data( x=torch.tensor([[1.0], [1.0]], dtype=torch.float), - edge_index=torch.tensor([[0, 1], [1, 0]]) + edge_index=torch.tensor([[0, 1], [1, 0]]), ) - lifted_identical = self.lifting_k2.lift_topology(identical_nodes_data) + lifted_identical = self.lifting_k2.lift(identical_nodes_data) assert lifted_identical["num_hyperedges"] == 2 assert lifted_identical["incidence_hyperedges"].size() == (2, 2) assert torch.equal(lifted_identical["x_0"], identical_nodes_data.x) # Test with missing edge_index (this should work as KNNGraph will create edges) - data_no_edges = Data( - x=torch.tensor([[1.0], [2.0]], dtype=torch.float) - ) - lifted_no_edges = self.lifting_k2.lift_topology(data_no_edges) + data_no_edges = Data(x=torch.tensor([[1.0], [2.0]], dtype=torch.float)) + lifted_no_edges = self.lifting_k2.lift(data_no_edges) assert lifted_no_edges["num_hyperedges"] == 2 assert lifted_no_edges["incidence_hyperedges"].size() == (2, 2) assert torch.equal(lifted_no_edges["x_0"], data_no_edges.x) # Test with no data (should raise AttributeError) with pytest.raises(AttributeError): - self.lifting_k2.lift_topology(None) + self.lifting_k2.lift(None) # Test with empty tensor for x (should work but result in empty outputs) empty_data = Data( x=torch.tensor([], dtype=torch.float).reshape(0, 1), - edge_index=torch.tensor([], dtype=torch.long).reshape(2, 0) + edge_index=torch.tensor([], dtype=torch.long).reshape(2, 0), ) - lifted_empty = self.lifting_k2.lift_topology(empty_data) + lifted_empty = self.lifting_k2.lift(empty_data) assert lifted_empty["num_hyperedges"] == 0 assert lifted_empty["incidence_hyperedges"].size(0) == 0 @@ -203,13 +209,17 @@ def test_invalid_initialization(self): HypergraphKNNLifting(k_value=1.5) # Test with zero k_value - with pytest.raises(ValueError, match="k_value must be greater than or equal to 1"): + with pytest.raises( + ValueError, match="k_value must be greater than or equal to 1" + ): HypergraphKNNLifting(k_value=0) # Test with negative k_value - with pytest.raises(ValueError, match="k_value must be greater than or equal to 1"): + with pytest.raises( + ValueError, match="k_value must be greater than or equal to 1" + ): HypergraphKNNLifting(k_value=-1) # Test with non-boolean loop with pytest.raises(TypeError, match="loop must be a boolean"): - HypergraphKNNLifting(k_value=1, loop="True") \ No newline at end of file + HypergraphKNNLifting(k_value=1, loop="True") diff --git a/test/transforms/liftings/simplicial/test_SimplicialCliqueLifting.py b/test/transforms/liftings/simplicial/test_SimplicialCliqueLifting.py index 7d85b19e..fa36e072 100644 --- a/test/transforms/liftings/simplicial/test_SimplicialCliqueLifting.py +++ b/test/transforms/liftings/simplicial/test_SimplicialCliqueLifting.py @@ -2,11 +2,19 @@ import torch -from topobenchmark.transforms.liftings.graph2simplicial import ( - SimplicialCliqueLifting +from topobenchmark.data.utils import ( + Complex2Dict, + Data2NxGraph, + TnxComplex2Complex, +) +from topobenchmark.transforms.feature_liftings.projection_sum import ( + ProjectionSum, ) -from topobenchmark.transforms.converters import Data2NxGraph, Complex2Dict from topobenchmark.transforms.liftings.base import LiftingTransform +from topobenchmark.transforms.liftings.graph2simplicial.clique import ( + SimplicialCliqueLifting, +) + class TestSimplicialCliqueLifting: """Test the SimplicialCliqueLifting class.""" @@ -14,13 +22,25 @@ class TestSimplicialCliqueLifting: def setup_method(self): # Initialise the SimplicialCliqueLifting class data2graph = Data2NxGraph() - simplicial2dict_signed = Complex2Dict(signed=True) - simplicial2dict_unsigned = Complex2Dict(signed=False) lifting_map = SimplicialCliqueLifting(complex_dim=3) + feature_lifting = ProjectionSum() + domain2dict = Complex2Dict() - self.lifting_signed = LiftingTransform(data2graph, simplicial2dict_signed, lifting_map) - self.lifting_unsigned = LiftingTransform(data2graph, simplicial2dict_unsigned, lifting_map) + self.lifting_signed = LiftingTransform( + lifting=lifting_map, + feature_lifting=feature_lifting, + data2domain=data2graph, + domain2domain=TnxComplex2Complex(signed=True), + domain2dict=domain2dict, + ) + self.lifting_unsigned = LiftingTransform( + lifting=lifting_map, + feature_lifting=feature_lifting, + data2domain=data2graph, + domain2domain=TnxComplex2Complex(signed=False), + domain2dict=domain2dict, + ) def test_lift_topology(self, simple_graph_1): """Test the lift_topology method.""" @@ -207,6 +227,8 @@ def test_lift_topology(self, simple_graph_1): def test_lifted_features_signed(self, simple_graph_1): """Test the lift_features method in signed incidence cases.""" + # TODO: can be removed/moved; part of projection sum + self.data = simple_graph_1 # Test the lift_features method for signed case lifted_data = self.lifting_signed.forward(self.data) @@ -249,6 +271,8 @@ def test_lifted_features_signed(self, simple_graph_1): def test_lifted_features_unsigned(self, simple_graph_1): """Test the lift_features method in unsigned incidence cases.""" + # TODO: redundant. can be moved/removed + self.data = simple_graph_1 # Test the lift_features method for unsigned case lifted_data = self.lifting_unsigned.forward(self.data) diff --git a/test/transforms/liftings/simplicial/test_SimplicialNeighborhoodLifting.py b/test/transforms/liftings/simplicial/test_SimplicialNeighborhoodLifting.py index 5a03f67e..e21b8f99 100644 --- a/test/transforms/liftings/simplicial/test_SimplicialNeighborhoodLifting.py +++ b/test/transforms/liftings/simplicial/test_SimplicialNeighborhoodLifting.py @@ -2,19 +2,46 @@ import torch -from topobenchmark.transforms.liftings.graph2simplicial import ( +from topobenchmark.data.utils import ( + Complex2Dict, + Data2NxGraph, + TnxComplex2Complex, +) +from topobenchmark.transforms.feature_liftings.projection_sum import ( + ProjectionSum, +) +from topobenchmark.transforms.liftings.base import LiftingTransform +from topobenchmark.transforms.liftings.graph2simplicial.khop import ( SimplicialKHopLifting, ) +# TODO: rename for consistency? + class TestSimplicialKHopLifting: """Test the SimplicialKHopLifting class.""" def setup_method(self): # Initialise the SimplicialKHopLifting class - self.lifting_signed = SimplicialKHopLifting(complex_dim=3, signed=True) - self.lifting_unsigned = SimplicialKHopLifting( - complex_dim=3, signed=False + data2graph = Data2NxGraph() + feature_lifting = ProjectionSum() + domain2dict = Complex2Dict() + + lifting_map = SimplicialKHopLifting(complex_dim=3) + + self.lifting_signed = LiftingTransform( + lifting=lifting_map, + feature_lifting=feature_lifting, + data2domain=data2graph, + domain2domain=TnxComplex2Complex(signed=True), + domain2dict=domain2dict, + ) + self.lifting_unsigned = LiftingTransform( + lifting=lifting_map, + feature_lifting=feature_lifting, + data2domain=data2graph, + domain2domain=TnxComplex2Complex(signed=False), + domain2dict=domain2dict, ) def test_lift_topology(self, simple_graph_1): diff --git a/test/transforms/liftings/test_GraphLifting.py b/test/transforms/liftings/test_GraphLifting.py index c7acf454..546956c9 100644 --- a/test/transforms/liftings/test_GraphLifting.py +++ b/test/transforms/liftings/test_GraphLifting.py @@ -1,21 +1,42 @@ """Test the GraphLifting class.""" -import pytest + import torch +import torch_geometric from torch_geometric.data import Data -from topobenchmark.transforms.liftings import GraphLifting +from topobenchmark.transforms.feature_liftings.projection_sum import ( + ProjectionSum, +) +from topobenchmark.transforms.liftings.base import LiftingMap, LiftingTransform + + +def _data_has_edge_attr(data: torch_geometric.data.Data) -> bool: + r"""Check if the input data object has edge attributes. + + Parameters + ---------- + data : torch_geometric.data.Data + The input data. + + Returns + ------- + bool + Whether the data object has edge attributes. + """ + return hasattr(data, "edge_attr") and data.edge_attr is not None -class ConcreteGraphLifting(GraphLifting): + +class ConcreteGraphLifting(LiftingMap): """Concrete implementation of GraphLifting for testing.""" - - def lift_topology(self, data): + + def lift(self, data): """Implement the abstract lift_topology method. - + Parameters ---------- data : torch_geometric.data.Data The input data to be lifted. - + Returns ------- dict @@ -26,86 +47,70 @@ def lift_topology(self, data): class TestGraphLifting: """Test the GraphLifting class.""" - + def setup_method(self): """Set up test fixtures before each test method. - + Creates an instance of ConcreteGraphLifting with default parameters. """ - self.lifting = ConcreteGraphLifting( - feature_lifting="ProjectionSum", - preserve_edge_attr=False + self.lifting = LiftingTransform( + ConcreteGraphLifting(), feature_lifting=ProjectionSum() ) def test_data_has_edge_attr(self): """Test _data_has_edge_attr method with different data configurations.""" - + # Test case 1: Data with edge attributes data_with_edge_attr = Data( x=torch.tensor([[1.0], [2.0]]), edge_index=torch.tensor([[0, 1], [1, 0]]), - edge_attr=torch.tensor([[1.0], [1.0]]) + edge_attr=torch.tensor([[1.0], [1.0]]), ) - assert self.lifting._data_has_edge_attr(data_with_edge_attr) is True + assert _data_has_edge_attr(data_with_edge_attr) is True # Test case 2: Data without edge attributes data_without_edge_attr = Data( x=torch.tensor([[1.0], [2.0]]), - edge_index=torch.tensor([[0, 1], [1, 0]]) + edge_index=torch.tensor([[0, 1], [1, 0]]), ) - assert self.lifting._data_has_edge_attr(data_without_edge_attr) is False + assert _data_has_edge_attr(data_without_edge_attr) is False # Test case 3: Data with edge_attr set to None data_with_none_edge_attr = Data( x=torch.tensor([[1.0], [2.0]]), edge_index=torch.tensor([[0, 1], [1, 0]]), - edge_attr=None + edge_attr=None, ) - assert self.lifting._data_has_edge_attr(data_with_none_edge_attr) is False + assert _data_has_edge_attr(data_with_none_edge_attr) is False def test_data_has_edge_attr_empty_data(self): """Test _data_has_edge_attr method with empty data object.""" empty_data = Data() - assert self.lifting._data_has_edge_attr(empty_data) is False + assert _data_has_edge_attr(empty_data) is False def test_data_has_edge_attr_different_edge_formats(self): """Test _data_has_edge_attr method with different edge attribute formats.""" - + # Test with float edge attributes data_float_attr = Data( x=torch.tensor([[1.0], [2.0]]), edge_index=torch.tensor([[0, 1], [1, 0]]), - edge_attr=torch.tensor([[0.5], [0.5]]) + edge_attr=torch.tensor([[0.5], [0.5]]), ) - assert self.lifting._data_has_edge_attr(data_float_attr) is True + assert _data_has_edge_attr(data_float_attr) is True # Test with integer edge attributes data_int_attr = Data( x=torch.tensor([[1.0], [2.0]]), edge_index=torch.tensor([[0, 1], [1, 0]]), - edge_attr=torch.tensor([[1], [1]], dtype=torch.long) + edge_attr=torch.tensor([[1], [1]], dtype=torch.long), ) - assert self.lifting._data_has_edge_attr(data_int_attr) is True + assert _data_has_edge_attr(data_int_attr) is True # Test with multi-dimensional edge attributes data_multidim_attr = Data( x=torch.tensor([[1.0], [2.0]]), edge_index=torch.tensor([[0, 1], [1, 0]]), - edge_attr=torch.tensor([[1.0, 2.0], [2.0, 1.0]]) - ) - assert self.lifting._data_has_edge_attr(data_multidim_attr) is True - - @pytest.mark.parametrize("preserve_edge_attr", [True, False]) - def test_init_preserve_edge_attr(self, preserve_edge_attr): - """Test initialization with different preserve_edge_attr values. - - Parameters - ---------- - preserve_edge_attr : bool - Boolean value to test initialization with True and False values. - """ - lifting = ConcreteGraphLifting( - feature_lifting="ProjectionSum", - preserve_edge_attr=preserve_edge_attr + edge_attr=torch.tensor([[1.0, 2.0], [2.0, 1.0]]), ) - assert lifting.preserve_edge_attr == preserve_edge_attr \ No newline at end of file + assert _data_has_edge_attr(data_multidim_attr) is True From 5ac166f34f47a176931ff20d6c55830086bd77e8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lu=C3=ADs=20F=2E=20Pereira?= Date: Tue, 14 Jan 2025 17:47:16 -0800 Subject: [PATCH 13/28] Remove dead code --- .../liftings/test_AbstractLifting.py | 53 ------ topobenchmark/transforms/liftings/__init__.py | 20 -- topobenchmark/transforms/liftings/base.py | 58 ------ .../transforms/liftings/graph2cell/base.py | 57 ------ .../liftings/graph2hypergraph/base.py | 17 -- .../liftings/graph2simplicial/base.py | 69 ------- topobenchmark/transforms/liftings/liftings.py | 172 ------------------ 7 files changed, 446 deletions(-) delete mode 100644 test/transforms/liftings/test_AbstractLifting.py delete mode 100755 topobenchmark/transforms/liftings/graph2cell/base.py delete mode 100755 topobenchmark/transforms/liftings/graph2hypergraph/base.py delete mode 100755 topobenchmark/transforms/liftings/graph2simplicial/base.py delete mode 100644 topobenchmark/transforms/liftings/liftings.py diff --git a/test/transforms/liftings/test_AbstractLifting.py b/test/transforms/liftings/test_AbstractLifting.py deleted file mode 100644 index 49167cb1..00000000 --- a/test/transforms/liftings/test_AbstractLifting.py +++ /dev/null @@ -1,53 +0,0 @@ -"""Test AbstractLifting module.""" - -import pytest -import torch -from torch_geometric.data import Data -from topobenchmark.transforms.liftings import AbstractLifting - -class TestAbstractLifting: - """Test the AbstractLifting class.""" - - def setup_method(self): - """Set up test fixtures for each test method. - - Creates a concrete subclass of AbstractLifting for testing purposes. - """ - class ConcreteLifting(AbstractLifting): - """Concrete implementation of AbstractLifting for testing.""" - - def lift_topology(self, data): - """Implementation of abstract method that calls parent's method. - - Parameters - ---------- - data : torch_geometric.data.Data - The input data to be lifted. - - Returns - ------- - dict - Empty dictionary as this is just for testing. - - Raises - ------ - NotImplementedError - Always raises this error as it calls the parent's abstract method. - """ - return super().lift_topology(data) - - self.lifting = ConcreteLifting(feature_lifting=None) - - def test_lift_topology_raises_not_implemented(self): - """Test that the abstract lift_topology method raises NotImplementedError. - - Verifies that calling lift_topology on an abstract class implementation - raises NotImplementedError as expected. - """ - dummy_data = Data( - x=torch.tensor([[1.0], [2.0]]), - edge_index=torch.tensor([[0, 1], [1, 0]]) - ) - - with pytest.raises(NotImplementedError): - self.lifting.lift_topology(dummy_data) \ No newline at end of file diff --git a/topobenchmark/transforms/liftings/__init__.py b/topobenchmark/transforms/liftings/__init__.py index 4692ceaf..0776fee4 100755 --- a/topobenchmark/transforms/liftings/__init__.py +++ b/topobenchmark/transforms/liftings/__init__.py @@ -1,21 +1 @@ """This module implements the liftings for the topological transforms.""" - -from .base import AbstractLifting -from .liftings import ( - CellComplexLifting, - CombinatorialLifting, - GraphLifting, - HypergraphLifting, - PointCloudLifting, - SimplicialLifting, -) - -__all__ = [ - "AbstractLifting", - "CellComplexLifting", - "CombinatorialLifting", - "GraphLifting", - "HypergraphLifting", - "PointCloudLifting", - "SimplicialLifting", -] diff --git a/topobenchmark/transforms/liftings/base.py b/topobenchmark/transforms/liftings/base.py index d5c78dc9..ce3d1ab4 100644 --- a/topobenchmark/transforms/liftings/base.py +++ b/topobenchmark/transforms/liftings/base.py @@ -5,7 +5,6 @@ import torch_geometric from topobenchmark.data.utils import IdentityAdapter -from topobenchmark.transforms.feature_liftings import FEATURE_LIFTINGS from topobenchmark.transforms.feature_liftings.identity import Identity @@ -29,7 +28,6 @@ class LiftingTransform(torch_geometric.transforms.BaseTransform): Feature lifting map. """ - # NB: emulates previous AbstractLifting def __init__( self, lifting, @@ -79,7 +77,6 @@ def forward( lifted_topology = self.feature_lifting(lifted_topology) lifted_topology_dict = self.domain2dict(lifted_topology) - # TODO: make this line more clear return torch_geometric.data.Data( **initial_data, **lifted_topology_dict ) @@ -98,58 +95,3 @@ def __call__(self, domain): @abc.abstractmethod def lift(self, domain): """Lift domain.""" - - -class AbstractLifting(torch_geometric.transforms.BaseTransform): - r"""Abstract class for topological liftings. - - Parameters - ---------- - feature_lifting : str, optional - The feature lifting method to be used. Default is 'ProjectionSum'. - **kwargs : optional - Additional arguments for the class. - """ - - # TODO: delete - - def __init__(self, feature_lifting=None, **kwargs): - super().__init__() - self.feature_lifting = FEATURE_LIFTINGS[feature_lifting]() - self.neighborhoods = kwargs.get("neighborhoods") - - @abc.abstractmethod - def lift_topology(self, data: torch_geometric.data.Data) -> dict: - r"""Lift the topology of a graph to higher-order topological domains. - - Parameters - ---------- - data : torch_geometric.data.Data - The input data to be lifted. - - Returns - ------- - dict - The lifted topology. - """ - raise NotImplementedError - - def forward( - self, data: torch_geometric.data.Data - ) -> torch_geometric.data.Data: - r"""Apply the full lifting (topology + features) to the input data. - - Parameters - ---------- - data : torch_geometric.data.Data - The input data to be lifted. - - Returns - ------- - torch_geometric.data.Data - The lifted data. - """ - initial_data = data.to_dict() - lifted_topology = self.lift_topology(data) - lifted_topology = self.feature_lifting(lifted_topology) - return torch_geometric.data.Data(**initial_data, **lifted_topology) diff --git a/topobenchmark/transforms/liftings/graph2cell/base.py b/topobenchmark/transforms/liftings/graph2cell/base.py deleted file mode 100755 index aeff3646..00000000 --- a/topobenchmark/transforms/liftings/graph2cell/base.py +++ /dev/null @@ -1,57 +0,0 @@ -"""Abstract class for lifting graphs to cell complexes.""" - -import networkx as nx -import torch -from toponetx.classes import CellComplex - -from topobenchmark.data.utils.utils import get_complex_connectivity -from topobenchmark.transforms.liftings import GraphLifting - - -class Graph2CellLifting(GraphLifting): - r"""Abstract class for lifting graphs to cell complexes. - - Parameters - ---------- - complex_dim : int, optional - The dimension of the cell complex to be generated. Default is 2. - **kwargs : optional - Additional arguments for the class. - """ - - def __init__(self, complex_dim=2, **kwargs): - super().__init__(**kwargs) - self.complex_dim = complex_dim - self.type = "graph2cell" - - def _get_lifted_topology( - self, cell_complex: CellComplex, graph: nx.Graph - ) -> dict: - r"""Return the lifted topology. - - Parameters - ---------- - cell_complex : CellComplex - The cell complex. - graph : nx.Graph - The input graph. - - Returns - ------- - dict - The lifted topology. - """ - lifted_topology = get_complex_connectivity( - cell_complex, self.complex_dim, neighborhoods=self.neighborhoods - ) - lifted_topology["x_0"] = torch.stack( - list(cell_complex.get_cell_attributes("features", 0).values()) - ) - # If new edges have been added during the lifting process, we discard the edge attributes - if self.contains_edge_attr and cell_complex.shape[1] == ( - graph.number_of_edges() - ): - lifted_topology["x_1"] = torch.stack( - list(cell_complex.get_cell_attributes("features", 1).values()) - ) - return lifted_topology diff --git a/topobenchmark/transforms/liftings/graph2hypergraph/base.py b/topobenchmark/transforms/liftings/graph2hypergraph/base.py deleted file mode 100755 index e060e30e..00000000 --- a/topobenchmark/transforms/liftings/graph2hypergraph/base.py +++ /dev/null @@ -1,17 +0,0 @@ -"""Abstract class for lifting graphs to hypergraphs.""" - -from topobenchmark.transforms.liftings import GraphLifting - - -class Graph2HypergraphLifting(GraphLifting): - r"""Abstract class for lifting graphs to hypergraphs. - - Parameters - ---------- - **kwargs : optional - Additional arguments for the class. - """ - - def __init__(self, **kwargs): - super().__init__(**kwargs) - self.type = "graph2hypergraph" diff --git a/topobenchmark/transforms/liftings/graph2simplicial/base.py b/topobenchmark/transforms/liftings/graph2simplicial/base.py deleted file mode 100755 index e52449dc..00000000 --- a/topobenchmark/transforms/liftings/graph2simplicial/base.py +++ /dev/null @@ -1,69 +0,0 @@ -"""Abstract class for lifting graphs to simplicial complexes.""" - -import networkx as nx -import torch -from toponetx.classes import SimplicialComplex - -from topobenchmark.data.utils.utils import get_complex_connectivity -from topobenchmark.transforms.liftings import GraphLifting - - -class Graph2SimplicialLifting(GraphLifting): - r"""Abstract class for lifting graphs to simplicial complexes. - - Parameters - ---------- - complex_dim : int, optional - The maximum dimension of the simplicial complex to be generated. Default is 2. - **kwargs : optional - Additional arguments for the class. - """ - - def __init__(self, complex_dim=2, **kwargs): - super().__init__(**kwargs) - self.complex_dim = complex_dim - self.type = "graph2simplicial" - self.signed = kwargs.get("signed", False) - - def _get_lifted_topology( - self, simplicial_complex: SimplicialComplex, graph: nx.Graph - ) -> dict: - r"""Return the lifted topology. - - Parameters - ---------- - simplicial_complex : SimplicialComplex - The simplicial complex. - graph : nx.Graph - The input graph. - - Returns - ------- - dict - The lifted topology. - """ - lifted_topology = get_complex_connectivity( - simplicial_complex, - self.complex_dim, - neighborhoods=self.neighborhoods, - signed=self.signed, - ) - lifted_topology["x_0"] = torch.stack( - list( - simplicial_complex.get_simplex_attributes( - "features", 0 - ).values() - ) - ) - # If new edges have been added during the lifting process, we discard the edge attributes - if self.contains_edge_attr and simplicial_complex.shape[1] == ( - graph.number_of_edges() - ): - lifted_topology["x_1"] = torch.stack( - list( - simplicial_complex.get_simplex_attributes( - "features", 1 - ).values() - ) - ) - return lifted_topology diff --git a/topobenchmark/transforms/liftings/liftings.py b/topobenchmark/transforms/liftings/liftings.py deleted file mode 100644 index 9453eaa3..00000000 --- a/topobenchmark/transforms/liftings/liftings.py +++ /dev/null @@ -1,172 +0,0 @@ -"""This module implements the abstract classes for lifting graphs.""" - -import networkx as nx -import torch_geometric -from torch_geometric.utils.undirected import is_undirected, to_undirected - -from topobenchmark.transforms.liftings import AbstractLifting - - -class GraphLifting(AbstractLifting): - r"""Abstract class for lifting graph topologies to other domains. - - Parameters - ---------- - feature_lifting : str, optional - The feature lifting method to be used. Default is 'ProjectionSum'. - preserve_edge_attr : bool, optional - Whether to preserve edge attributes. Default is False. - **kwargs : optional - Additional arguments for the class. - """ - - def __init__( - self, - feature_lifting="ProjectionSum", - preserve_edge_attr=False, - **kwargs, - ): - super().__init__(feature_lifting=feature_lifting, **kwargs) - self.preserve_edge_attr = preserve_edge_attr - - def _data_has_edge_attr(self, data: torch_geometric.data.Data) -> bool: - r"""Check if the input data object has edge attributes. - - Parameters - ---------- - data : torch_geometric.data.Data - The input data. - - Returns - ------- - bool - Whether the data object has edge attributes. - """ - return hasattr(data, "edge_attr") and data.edge_attr is not None - - def _generate_graph_from_data( - self, data: torch_geometric.data.Data - ) -> nx.Graph: - r"""Generate a NetworkX graph from the input data object. - - Parameters - ---------- - data : torch_geometric.data.Data - The input data. - - Returns - ------- - nx.Graph - The generated NetworkX graph. - """ - # Check if data object have edge_attr, return list of tuples as [(node_id, {'features':data}, 'dim':1)] or ?? - nodes = [ - (n, dict(features=data.x[n], dim=0)) - for n in range(data.x.shape[0]) - ] - - if self.preserve_edge_attr and self._data_has_edge_attr(data): - # In case edge features are given, assign features to every edge - edge_index, edge_attr = ( - data.edge_index, - ( - data.edge_attr - if is_undirected(data.edge_index, data.edge_attr) - else to_undirected(data.edge_index, data.edge_attr) - ), - ) - edges = [ - (i.item(), j.item(), dict(features=edge_attr[edge_idx], dim=1)) - for edge_idx, (i, j) in enumerate( - zip(edge_index[0], edge_index[1], strict=False) - ) - ] - self.contains_edge_attr = True - else: - # If edge_attr is not present, return list list of edges - edges = [ - (i.item(), j.item(), {}) - for i, j in zip( - data.edge_index[0], data.edge_index[1], strict=False - ) - ] - self.contains_edge_attr = False - graph = nx.Graph() - graph.add_nodes_from(nodes) - graph.add_edges_from(edges) - return graph - - -class PointCloudLifting(AbstractLifting): - r"""Abstract class for lifting point clouds to other topological domains. - - Parameters - ---------- - feature_lifting : str, optional - The feature lifting method to be used. Default is 'ProjectionSum'. - **kwargs : optional - Additional arguments for the class. - """ - - def __init__(self, feature_lifting="ProjectionSum", **kwargs): - super().__init__(feature_lifting=feature_lifting, **kwargs) - - -class CellComplexLifting(AbstractLifting): - r"""Abstract class for lifting cell complexes to other domains. - - Parameters - ---------- - feature_lifting : str, optional - The feature lifting method to be used. Default is 'ProjectionSum'. - **kwargs : optional - Additional arguments for the class. - """ - - def __init__(self, feature_lifting="ProjectionSum", **kwargs): - super().__init__(feature_lifting=feature_lifting, **kwargs) - - -class SimplicialLifting(AbstractLifting): - r"""Abstract class for lifting simplicial complexes to other domains. - - Parameters - ---------- - feature_lifting : str, optional - The feature lifting method to be used. Default is 'ProjectionSum'. - **kwargs : optional - Additional arguments for the class. - """ - - def __init__(self, feature_lifting="ProjectionSum", **kwargs): - super().__init__(feature_lifting=feature_lifting, **kwargs) - - -class HypergraphLifting(AbstractLifting): - r"""Abstract class for lifting hypergraphs to other domains. - - Parameters - ---------- - feature_lifting : str, optional - The feature lifting method to be used. Default is 'ProjectionSum'. - **kwargs : optional - Additional arguments for the class. - """ - - def __init__(self, feature_lifting="ProjectionSum", **kwargs): - super().__init__(feature_lifting=feature_lifting, **kwargs) - - -class CombinatorialLifting(AbstractLifting): - r"""Abstract class for lifting combinatorial complexes to other domains. - - Parameters - ---------- - feature_lifting : str, optional - The feature lifting method to be used. Default is 'ProjectionSum'. - **kwargs : optional - Additional arguments for the class. - """ - - def __init__(self, feature_lifting="ProjectionSum", **kwargs): - super().__init__(feature_lifting=feature_lifting, **kwargs) From 305f4861670240717d677f87454c1344250e4286 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lu=C3=ADs=20F=2E=20Pereira?= Date: Wed, 15 Jan 2025 19:10:25 -0800 Subject: [PATCH 14/28] Update TRANSFORMS automatically dict creation/imports --- topobenchmark/transforms/__init__.py | 29 +---- topobenchmark/transforms/_utils.py | 53 +++++++++ .../transforms/data_manipulations/__init__.py | 83 +------------- .../transforms/feature_liftings/__init__.py | 104 +----------------- .../transforms/feature_liftings/identity.py | 2 +- .../feature_liftings/projection_sum.py | 6 +- topobenchmark/transforms/liftings/__init__.py | 14 +++ .../liftings/graph2cell/__init__.py | 99 ++--------------- .../liftings/graph2hypergraph/__init__.py | 99 ++--------------- .../liftings/graph2simplicial/__init__.py | 99 ++--------------- 10 files changed, 104 insertions(+), 484 deletions(-) create mode 100644 topobenchmark/transforms/_utils.py diff --git a/topobenchmark/transforms/__init__.py b/topobenchmark/transforms/__init__.py index 3f568814..62f8d85e 100755 --- a/topobenchmark/transforms/__init__.py +++ b/topobenchmark/transforms/__init__.py @@ -1,32 +1,11 @@ """This module contains the transforms for the topobenchmark package.""" -from typing import Any +from .data_manipulations import DATA_MANIPULATIONS +from .feature_liftings import FEATURE_LIFTINGS +from .liftings import LIFTINGS -from topobenchmark.transforms.data_manipulations import DATA_MANIPULATIONS -from topobenchmark.transforms.feature_liftings import FEATURE_LIFTINGS -from topobenchmark.transforms.liftings.graph2cell import GRAPH2CELL_LIFTINGS -from topobenchmark.transforms.liftings.graph2hypergraph import ( - GRAPH2HYPERGRAPH_LIFTINGS, -) -from topobenchmark.transforms.liftings.graph2simplicial import ( - GRAPH2SIMPLICIAL_LIFTINGS, -) - -LIFTINGS = { - **GRAPH2CELL_LIFTINGS, - **GRAPH2HYPERGRAPH_LIFTINGS, - **GRAPH2SIMPLICIAL_LIFTINGS, -} - -TRANSFORMS: dict[Any, Any] = { +TRANSFORMS = { **LIFTINGS, **FEATURE_LIFTINGS, **DATA_MANIPULATIONS, } - -__all__ = [ - "DATA_MANIPULATIONS", - "FEATURE_LIFTINGS", - "LIFTINGS", - "TRANSFORMS", -] diff --git a/topobenchmark/transforms/_utils.py b/topobenchmark/transforms/_utils.py new file mode 100644 index 00000000..f14d156e --- /dev/null +++ b/topobenchmark/transforms/_utils.py @@ -0,0 +1,53 @@ +import inspect +from importlib import util +from pathlib import Path + + +def discover_objs(package_path, condition=None): + """Dynamically discover all manipulation classes in the package. + + Parameters + ---------- + package_path : str + Path to the package's __init__.py file. + condition : callable + `(name, obj) -> bool` + + Returns + ------- + dict[str, type] + Dictionary mapping class names to their corresponding class objects. + """ + if condition is None: + condition = lambda name, obj: True + + objs = {} + + # Get the directory containing the manipulation modules + package_dir = Path(package_path).parent + + # Iterate through all .py files in the directory + for file_path in package_dir.glob("*.py"): + if file_path.stem == "__init__": + continue + + # Import the module + module_name = f"{Path(package_path).stem}.{file_path.stem}" + spec = util.spec_from_file_location(module_name, file_path) + if spec and spec.loader: + module = util.module_from_spec(spec) + spec.loader.exec_module(module) + + # Find all manipulation classes in the module + for name, obj in inspect.getmembers(module): + if ( + not inspect.isclass(obj) + or name.startswith("_") + or obj.__module__ != module.__name__ + ): + continue + + if condition(name, obj): + objs[name] = obj + + return objs diff --git a/topobenchmark/transforms/data_manipulations/__init__.py b/topobenchmark/transforms/data_manipulations/__init__.py index a17e506d..314d5fa6 100644 --- a/topobenchmark/transforms/data_manipulations/__init__.py +++ b/topobenchmark/transforms/data_manipulations/__init__.py @@ -1,86 +1,7 @@ """Data manipulations module with automated exports.""" -import inspect -from importlib import util -from pathlib import Path -from typing import Any +from topobenchmark.transforms._utils import discover_objs +DATA_MANIPULATIONS = discover_objs(__file__) -class ModuleExportsManager: - """Manages automatic discovery and registration of data manipulation classes.""" - - @staticmethod - def is_manipulation_class(obj: Any) -> bool: - """Check if an object is a valid manipulation class. - - Parameters - ---------- - obj : Any - The object to check if it's a valid manipulation class. - - Returns - ------- - bool - True if the object is a valid manipulation class (non-private class - defined in __main__), False otherwise. - """ - return ( - inspect.isclass(obj) - and obj.__module__ == "__main__" - and not obj.__name__.startswith("_") - ) - - @classmethod - def discover_manipulations(cls, package_path: str) -> dict[str, type]: - """Dynamically discover all manipulation classes in the package. - - Parameters - ---------- - package_path : str - Path to the package's __init__.py file. - - Returns - ------- - dict[str, type] - Dictionary mapping class names to their corresponding class objects. - """ - manipulations = {} - - # Get the directory containing the manipulation modules - package_dir = Path(package_path).parent - - # Iterate through all .py files in the directory - for file_path in package_dir.glob("*.py"): - if file_path.stem == "__init__": - continue - - # Import the module - module_name = f"{Path(package_path).stem}.{file_path.stem}" - spec = util.spec_from_file_location(module_name, file_path) - if spec and spec.loader: - module = util.module_from_spec(spec) - spec.loader.exec_module(module) - - # Find all manipulation classes in the module - for name, obj in inspect.getmembers(module): - if ( - inspect.isclass(obj) - and obj.__module__ == module.__name__ - and not name.startswith("_") - ): - manipulations[name] = obj # noqa: PERF403 - - return manipulations - - -# Create the exports manager -manager = ModuleExportsManager() - -# Automatically discover and populate DATA_MANIPULATIONS -DATA_MANIPULATIONS = manager.discover_manipulations(__file__) - -# Automatically generate __all__ -__all__ = [*DATA_MANIPULATIONS.keys(), "DATA_MANIPULATIONS"] - -# For backwards compatibility, also create individual imports locals().update(DATA_MANIPULATIONS) diff --git a/topobenchmark/transforms/feature_liftings/__init__.py b/topobenchmark/transforms/feature_liftings/__init__.py index ec4f763c..6e047683 100644 --- a/topobenchmark/transforms/feature_liftings/__init__.py +++ b/topobenchmark/transforms/feature_liftings/__init__.py @@ -1,104 +1,12 @@ """Feature lifting transforms with automated exports.""" -import inspect -from importlib import util -from pathlib import Path -from typing import Any +from topobenchmark.transforms._utils import discover_objs -from .identity import Identity # Import Identity for special case +from .base import FeatureLiftingMap - -class ModuleExportsManager: - """Manages automatic discovery and registration of feature lifting classes.""" - - @staticmethod - def is_lifting_class(obj: Any) -> bool: - """Check if an object is a valid lifting class. - - Parameters - ---------- - obj : Any - The object to check if it's a valid lifting class. - - Returns - ------- - bool - True if the object is a valid lifting class (non-private class - defined in __main__), False otherwise. - """ - return ( - inspect.isclass(obj) - and obj.__module__ == "__main__" - and not obj.__name__.startswith("_") - ) - - @classmethod - def discover_liftings( - cls, package_path: str, special_cases: dict[Any, type] | None = None - ) -> dict[str, type]: - """Dynamically discover all lifting classes in the package. - - Parameters - ---------- - package_path : str - Path to the package's __init__.py file. - special_cases : Optional[dict[Any, type]] - Dictionary of special case mappings (e.g., {None: Identity}), - by default None. - - Returns - ------- - dict[str, type] - Dictionary mapping class names to their corresponding class objects, - including any special cases if provided. - """ - liftings = {} - - # Get the directory containing the lifting modules - package_dir = Path(package_path).parent - - # Iterate through all .py files in the directory - for file_path in package_dir.glob("*.py"): - if file_path.stem == "__init__": - continue - - # Import the module - module_name = f"{Path(package_path).stem}.{file_path.stem}" - spec = util.spec_from_file_location(module_name, file_path) - if spec and spec.loader: - module = util.module_from_spec(spec) - spec.loader.exec_module(module) - - # Find all lifting classes in the module - for name, obj in inspect.getmembers(module): - if ( - inspect.isclass(obj) - and obj.__module__ == module.__name__ - and not name.startswith("_") - ): - liftings[name] = obj # noqa: PERF403 - - # Add special cases if provided - if special_cases: - liftings.update(special_cases) - - return liftings - - -# Create the exports manager -manager = ModuleExportsManager() - -# Automatically discover and populate FEATURE_LIFTINGS with special case for None -FEATURE_LIFTINGS = manager.discover_liftings( - __file__, special_cases={None: Identity} +FEATURE_LIFTINGS = discover_objs( + __file__, + condition=lambda name, obj: issubclass(obj, FeatureLiftingMap), ) -# Automatically generate __all__ (excluding None key) -__all__ = [name for name in FEATURE_LIFTINGS if isinstance(name, str)] + [ - "FEATURE_LIFTINGS" -] - -# For backwards compatibility, create individual imports (excluding None key) -locals().update( - {k: v for k, v in FEATURE_LIFTINGS.items() if isinstance(k, str)} -) +locals().update(FEATURE_LIFTINGS) diff --git a/topobenchmark/transforms/feature_liftings/identity.py b/topobenchmark/transforms/feature_liftings/identity.py index 9abf4e5d..e640bd06 100644 --- a/topobenchmark/transforms/feature_liftings/identity.py +++ b/topobenchmark/transforms/feature_liftings/identity.py @@ -1,6 +1,6 @@ """Identity transform that does nothing to the input data.""" -from .base import FeatureLiftingMap +from topobenchmark.transforms.feature_liftings.base import FeatureLiftingMap class Identity(FeatureLiftingMap): diff --git a/topobenchmark/transforms/feature_liftings/projection_sum.py b/topobenchmark/transforms/feature_liftings/projection_sum.py index a02a1db5..a756fd0e 100644 --- a/topobenchmark/transforms/feature_liftings/projection_sum.py +++ b/topobenchmark/transforms/feature_liftings/projection_sum.py @@ -2,7 +2,7 @@ import torch -from .base import FeatureLiftingMap +from topobenchmark.transforms.feature_liftings.base import FeatureLiftingMap class ProjectionSum(FeatureLiftingMap): @@ -13,12 +13,12 @@ def lift_features(self, domain): Parameters ---------- - data : PlainComplex + data : Complex The input data to be lifted. Returns ------- - PlainComplex + Complex Domain with the lifted features. """ for rank in range(domain.max_rank - 1): diff --git a/topobenchmark/transforms/liftings/__init__.py b/topobenchmark/transforms/liftings/__init__.py index 0776fee4..513f5035 100755 --- a/topobenchmark/transforms/liftings/__init__.py +++ b/topobenchmark/transforms/liftings/__init__.py @@ -1 +1,15 @@ """This module implements the liftings for the topological transforms.""" + +from .base import LiftingTransform # noqa: F401 +from .graph2cell import GRAPH2CELL_LIFTINGS +from .graph2hypergraph import GRAPH2HYPERGRAPH_LIFTINGS +from .graph2simplicial import GRAPH2SIMPLICIAL_LIFTINGS + +LIFTINGS = { + **GRAPH2CELL_LIFTINGS, + **GRAPH2HYPERGRAPH_LIFTINGS, + **GRAPH2SIMPLICIAL_LIFTINGS, +} + + +locals().update(LIFTINGS) diff --git a/topobenchmark/transforms/liftings/graph2cell/__init__.py b/topobenchmark/transforms/liftings/graph2cell/__init__.py index d0faae96..480ada64 100755 --- a/topobenchmark/transforms/liftings/graph2cell/__init__.py +++ b/topobenchmark/transforms/liftings/graph2cell/__init__.py @@ -1,96 +1,11 @@ """Graph2Cell liftings with automated exports.""" -import inspect -from importlib import util -from pathlib import Path -from typing import Any +from topobenchmark.transforms._utils import discover_objs +from topobenchmark.transforms.liftings.base import LiftingMap -from .base import Graph2CellLifting +GRAPH2CELL_LIFTINGS = discover_objs( + __file__, + condition=lambda name, obj: issubclass(obj, LiftingMap), +) - -class ModuleExportsManager: - """Manages automatic discovery and registration of Graph2Cell lifting classes.""" - - @staticmethod - def is_lifting_class(obj: Any) -> bool: - """Check if an object is a valid Graph2Cell lifting class. - - Parameters - ---------- - obj : Any - The object to check if it's a valid lifting class. - - Returns - ------- - bool - True if the object is a valid Graph2Cell lifting class (non-private class - inheriting from Graph2CellLifting), False otherwise. - """ - return ( - inspect.isclass(obj) - and obj.__module__ == "__main__" - and not obj.__name__.startswith("_") - and issubclass(obj, Graph2CellLifting) - and obj != Graph2CellLifting - ) - - @classmethod - def discover_liftings(cls, package_path: str) -> dict[str, type]: - """Dynamically discover all Graph2Cell lifting classes in the package. - - Parameters - ---------- - package_path : str - Path to the package's __init__.py file. - - Returns - ------- - dict[str, type] - Dictionary mapping class names to their corresponding class objects. - """ - liftings = {} - - # Get the directory containing the lifting modules - package_dir = Path(package_path).parent - - # Iterate through all .py files in the directory - for file_path in package_dir.glob("*.py"): - if file_path.stem == "__init__": - continue - - # Import the module - module_name = f"{Path(package_path).stem}.{file_path.stem}" - spec = util.spec_from_file_location(module_name, file_path) - if spec and spec.loader: - module = util.module_from_spec(spec) - spec.loader.exec_module(module) - - # Find all lifting classes in the module - for name, obj in inspect.getmembers(module): - if ( - inspect.isclass(obj) - and obj.__module__ == module.__name__ - and not name.startswith("_") - and issubclass(obj, Graph2CellLifting) - and obj != Graph2CellLifting - ): - liftings[name] = obj # noqa: PERF403 - - return liftings - - -# Create the exports manager -manager = ModuleExportsManager() - -# Automatically discover and populate GRAPH2CELL_LIFTINGS -GRAPH2CELL_LIFTINGS = manager.discover_liftings(__file__) - -# Automatically generate __all__ -__all__ = [ - *GRAPH2CELL_LIFTINGS.keys(), - "Graph2CellLifting", - "GRAPH2CELL_LIFTINGS", -] - -# For backwards compatibility, create individual imports -locals().update(**GRAPH2CELL_LIFTINGS) +locals().update(GRAPH2CELL_LIFTINGS) diff --git a/topobenchmark/transforms/liftings/graph2hypergraph/__init__.py b/topobenchmark/transforms/liftings/graph2hypergraph/__init__.py index acb89e0c..e7a5a815 100755 --- a/topobenchmark/transforms/liftings/graph2hypergraph/__init__.py +++ b/topobenchmark/transforms/liftings/graph2hypergraph/__init__.py @@ -1,96 +1,11 @@ """Graph2HypergraphLifting module with automated exports.""" -import inspect -from importlib import util -from pathlib import Path -from typing import Any +from topobenchmark.transforms._utils import discover_objs +from topobenchmark.transforms.liftings.base import LiftingMap -from .base import Graph2HypergraphLifting +GRAPH2HYPERGRAPH_LIFTINGS = discover_objs( + __file__, + condition=lambda name, obj: issubclass(obj, LiftingMap), +) - -class ModuleExportsManager: - """Manages automatic discovery and registration of Graph2Hypergraph lifting classes.""" - - @staticmethod - def is_lifting_class(obj: Any) -> bool: - """Check if an object is a valid Graph2Hypergraph lifting class. - - Parameters - ---------- - obj : Any - The object to check if it's a valid lifting class. - - Returns - ------- - bool - True if the object is a valid Graph2Hypergraph lifting class (non-private class - inheriting from Graph2HypergraphLifting), False otherwise. - """ - return ( - inspect.isclass(obj) - and obj.__module__ == "__main__" - and not obj.__name__.startswith("_") - and issubclass(obj, Graph2HypergraphLifting) - and obj != Graph2HypergraphLifting - ) - - @classmethod - def discover_liftings(cls, package_path: str) -> dict[str, type]: - """Dynamically discover all Graph2Hypergraph lifting classes in the package. - - Parameters - ---------- - package_path : str - Path to the package's __init__.py file. - - Returns - ------- - dict[str, type] - Dictionary mapping class names to their corresponding class objects. - """ - liftings = {} - - # Get the directory containing the lifting modules - package_dir = Path(package_path).parent - - # Iterate through all .py files in the directory - for file_path in package_dir.glob("*.py"): - if file_path.stem == "__init__": - continue - - # Import the module - module_name = f"{Path(package_path).stem}.{file_path.stem}" - spec = util.spec_from_file_location(module_name, file_path) - if spec and spec.loader: - module = util.module_from_spec(spec) - spec.loader.exec_module(module) - - # Find all lifting classes in the module - for name, obj in inspect.getmembers(module): - if ( - inspect.isclass(obj) - and obj.__module__ == module.__name__ - and not name.startswith("_") - and issubclass(obj, Graph2HypergraphLifting) - and obj != Graph2HypergraphLifting - ): - liftings[name] = obj # noqa: PERF403 - - return liftings - - -# Create the exports manager -manager = ModuleExportsManager() - -# Automatically discover and populate GRAPH2HYPERGRAPH_LIFTINGS -GRAPH2HYPERGRAPH_LIFTINGS = manager.discover_liftings(__file__) - -# Automatically generate __all__ -__all__ = [ - *GRAPH2HYPERGRAPH_LIFTINGS.keys(), - "Graph2HypergraphLifting", - "GRAPH2HYPERGRAPH_LIFTINGS", -] - -# For backwards compatibility, create individual imports -locals().update(**GRAPH2HYPERGRAPH_LIFTINGS) +locals().update(GRAPH2HYPERGRAPH_LIFTINGS) diff --git a/topobenchmark/transforms/liftings/graph2simplicial/__init__.py b/topobenchmark/transforms/liftings/graph2simplicial/__init__.py index 238691cd..9e77797b 100755 --- a/topobenchmark/transforms/liftings/graph2simplicial/__init__.py +++ b/topobenchmark/transforms/liftings/graph2simplicial/__init__.py @@ -1,96 +1,11 @@ """Graph2SimplicialLifting module with automated exports.""" -import inspect -from importlib import util -from pathlib import Path -from typing import Any +from topobenchmark.transforms._utils import discover_objs +from topobenchmark.transforms.liftings.base import LiftingMap -from .base import Graph2SimplicialLifting +GRAPH2SIMPLICIAL_LIFTINGS = discover_objs( + __file__, + condition=lambda name, obj: issubclass(obj, LiftingMap), +) - -class ModuleExportsManager: - """Manages automatic discovery and registration of Graph2Simplicial lifting classes.""" - - @staticmethod - def is_lifting_class(obj: Any) -> bool: - """Check if an object is a valid Graph2Simplicial lifting class. - - Parameters - ---------- - obj : Any - The object to check if it's a valid lifting class. - - Returns - ------- - bool - True if the object is a valid Graph2Simplicial lifting class (non-private class - inheriting from Graph2SimplicialLifting), False otherwise. - """ - return ( - inspect.isclass(obj) - and obj.__module__ == "__main__" - and not obj.__name__.startswith("_") - and issubclass(obj, Graph2SimplicialLifting) - and obj != Graph2SimplicialLifting - ) - - @classmethod - def discover_liftings(cls, package_path: str) -> dict[str, type]: - """Dynamically discover all Graph2Simplicial lifting classes in the package. - - Parameters - ---------- - package_path : str - Path to the package's __init__.py file. - - Returns - ------- - dict[str, type] - Dictionary mapping class names to their corresponding class objects. - """ - liftings = {} - - # Get the directory containing the lifting modules - package_dir = Path(package_path).parent - - # Iterate through all .py files in the directory - for file_path in package_dir.glob("*.py"): - if file_path.stem == "__init__": - continue - - # Import the module - module_name = f"{Path(package_path).stem}.{file_path.stem}" - spec = util.spec_from_file_location(module_name, file_path) - if spec and spec.loader: - module = util.module_from_spec(spec) - spec.loader.exec_module(module) - - # Find all lifting classes in the module - for name, obj in inspect.getmembers(module): - if ( - inspect.isclass(obj) - and obj.__module__ == module.__name__ - and not name.startswith("_") - and issubclass(obj, Graph2SimplicialLifting) - and obj != Graph2SimplicialLifting - ): - liftings[name] = obj # noqa: PERF403 - - return liftings - - -# Create the exports manager -manager = ModuleExportsManager() - -# Automatically discover and populate GRAPH2SIMPLICIAL_LIFTINGS -GRAPH2SIMPLICIAL_LIFTINGS = manager.discover_liftings(__file__) - -# Automatically generate __all__ -__all__ = [ - *GRAPH2SIMPLICIAL_LIFTINGS.keys(), - "Graph2SimplicialLifting", - "GRAPH2SIMPLICIAL_LIFTINGS", -] - -# For backwards compatibility, create individual imports -locals().update(**GRAPH2SIMPLICIAL_LIFTINGS) +locals().update(GRAPH2SIMPLICIAL_LIFTINGS) From f77ad640d9687bb253d75becae6ea35550ea6a36 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lu=C3=ADs=20F=2E=20Pereira?= Date: Wed, 15 Jan 2025 19:11:52 -0800 Subject: [PATCH 15/28] Fix handling of empty matrices due to inexisting dimension --- topobenchmark/data/utils/adapters.py | 37 +++++++++++-------- topobenchmark/data/utils/domain.py | 3 ++ .../liftings/graph2simplicial/clique.py | 3 ++ 3 files changed, 27 insertions(+), 16 deletions(-) diff --git a/topobenchmark/data/utils/adapters.py b/topobenchmark/data/utils/adapters.py index 9db40c08..342a2622 100644 --- a/topobenchmark/data/utils/adapters.py +++ b/topobenchmark/data/utils/adapters.py @@ -124,8 +124,9 @@ class TnxComplex2Complex(Adapter): Parameters ---------- complex_dim : int - Dimension of the desired subcomplex. + Dimension of the (sub)complex. If ``None``, adapts the (full) complex. + If greater than dimension of complex, pads with empty matrices. neighborhoods : list, optional List of neighborhoods of interest. signed : bool, optional @@ -136,13 +137,11 @@ class TnxComplex2Complex(Adapter): def __init__( self, - complex_dim=None, neighborhoods=None, signed=False, transfer_features=True, ): super().__init__() - self.complex_dim = complex_dim self.neighborhoods = neighborhoods self.signed = signed self.transfer_features = transfer_features @@ -160,7 +159,13 @@ def adapt(self, domain): """ # NB: just a slightly rewriting of get_complex_connectivity - dim = self.complex_dim or domain.dim + practical_dim = ( + domain.practical_dim + if hasattr(domain, "practical_dim") + else domain.dim + ) + dim = domain.dim + signed = self.signed neighborhoods = self.neighborhoods @@ -174,18 +179,20 @@ def adapt(self, domain): ] practical_shape = list( - np.pad(list(domain.shape), (0, dim + 1 - len(domain.shape))) + np.pad( + list(domain.shape), (0, practical_dim + 1 - len(domain.shape)) + ) ) data = { connectivity_info: [] for connectivity_info in connectivity_infos } - for rank_idx in range(dim + 1): + for rank in range(practical_dim + 1): for connectivity_info in connectivity_infos: try: data[connectivity_info].append( from_sparse( getattr(domain, f"{connectivity_info}_matrix")( - rank=rank_idx, signed=signed + rank=rank, signed=signed ) ) ) @@ -193,15 +200,15 @@ def adapt(self, domain): if connectivity_info == "incidence": data[connectivity_info].append( generate_zero_sparse_connectivity( - m=practical_shape[rank_idx - 1], - n=practical_shape[rank_idx], + m=practical_shape[rank - 1], + n=practical_shape[rank], ) ) else: data[connectivity_info].append( generate_zero_sparse_connectivity( - m=practical_shape[rank_idx], - n=practical_shape[rank_idx], + m=practical_shape[rank], + n=practical_shape[rank], ) ) @@ -228,6 +235,9 @@ def adapt(self, domain): rank_features = None data["features"].append(rank_features) + for _ in range(dim + 1, practical_dim + 1): + data["features"].append(None) + return Complex(**data) @@ -287,9 +297,6 @@ class TnxComplex2Dict(AdapterComposition): Parameters ---------- - complex_dim : int - Dimension of the desired subcomplex. - If ``None``, adapts the (full) complex. neighborhoods : list, optional List of neighborhoods of interest. signed : bool, optional @@ -300,13 +307,11 @@ class TnxComplex2Dict(AdapterComposition): def __init__( self, - complex_dim=None, neighborhoods=None, signed=False, transfer_features=True, ): tnxcomplex2complex = TnxComplex2Complex( - complex_dim=complex_dim, neighborhoods=neighborhoods, signed=signed, transfer_features=transfer_features, diff --git a/topobenchmark/data/utils/domain.py b/topobenchmark/data/utils/domain.py index 531592dc..8bf4f1d7 100644 --- a/topobenchmark/data/utils/domain.py +++ b/topobenchmark/data/utils/domain.py @@ -46,6 +46,9 @@ def shape(self): def max_rank(self): """Maximum rank of the complex. + NB: may differ from mathematical definition due to empty + matrices. + Returns ------- int diff --git a/topobenchmark/transforms/liftings/graph2simplicial/clique.py b/topobenchmark/transforms/liftings/graph2simplicial/clique.py index 37a5cc15..04baa1ef 100755 --- a/topobenchmark/transforms/liftings/graph2simplicial/clique.py +++ b/topobenchmark/transforms/liftings/graph2simplicial/clique.py @@ -50,4 +50,7 @@ def lift(self, domain): for set_k_simplices in simplices: simplicial_complex.add_simplices_from(list(set_k_simplices)) + # because Complex pads unexisting dimensions with empty matrices + simplicial_complex.practical_dim = self.complex_dim + return simplicial_complex From 0668f97821c607f4ccff4acdcec9f1c4c0e928ad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lu=C3=ADs=20F=2E=20Pereira?= Date: Wed, 15 Jan 2025 19:12:30 -0800 Subject: [PATCH 16/28] Update feature liftings due to new design --- .../feature_liftings/test_Concatenation.py | 24 ++-- .../feature_liftings/test_ProjectionSum.py | 59 +++++----- .../feature_liftings/test_SetLifting.py | 18 ++- .../feature_liftings/concatenation.py | 86 +++++---------- .../transforms/feature_liftings/set.py | 103 +++++++----------- 5 files changed, 125 insertions(+), 165 deletions(-) diff --git a/test/transforms/feature_liftings/test_Concatenation.py b/test/transforms/feature_liftings/test_Concatenation.py index a8f83d78..aff3a2c1 100644 --- a/test/transforms/feature_liftings/test_Concatenation.py +++ b/test/transforms/feature_liftings/test_Concatenation.py @@ -2,24 +2,34 @@ import torch -from topobenchmark.transforms.liftings.graph2simplicial import ( +from topobenchmark.data.utils import ( + Complex2Dict, + Data2NxGraph, + TnxComplex2Complex, +) +from topobenchmark.transforms.liftings import ( + LiftingTransform, SimplicialCliqueLifting, ) -class TestConcatention: +class TestConcatenation: """Test the Concatention feature lifting class.""" def setup_method(self): """Set up the test.""" # Initialize a lifting class - self.lifting = SimplicialCliqueLifting( - feature_lifting="Concatenation", complex_dim=3 + self.lifting = LiftingTransform( + SimplicialCliqueLifting(complex_dim=3), + feature_lifting="Concatenation", + data2domain=Data2NxGraph(), + domain2domain=TnxComplex2Complex(signed=False), + domain2dict=Complex2Dict(), ) def test_lift_features(self, simple_graph_0, simple_graph_1): """Test the lift_features method. - + Parameters ---------- simple_graph_0 : torch_geometric.data.Data @@ -27,12 +37,12 @@ def test_lift_features(self, simple_graph_0, simple_graph_1): simple_graph_1 : torch_geometric.data.Data A simple graph data object. """ - + data = simple_graph_0 # Test the lift_features method lifted_data = self.lifting.forward(data.clone()) assert lifted_data.x_2.shape == torch.Size([0, 6]) - + data = simple_graph_1 # Test the lift_features method lifted_data = self.lifting.forward(data.clone()) diff --git a/test/transforms/feature_liftings/test_ProjectionSum.py b/test/transforms/feature_liftings/test_ProjectionSum.py index 935a5148..b14ea5e8 100644 --- a/test/transforms/feature_liftings/test_ProjectionSum.py +++ b/test/transforms/feature_liftings/test_ProjectionSum.py @@ -2,7 +2,13 @@ import torch -from topobenchmark.transforms.liftings.graph2simplicial import ( +from topobenchmark.data.utils import ( + Complex2Dict, + Data2NxGraph, + TnxComplex2Complex, +) +from topobenchmark.transforms.liftings import ( + LiftingTransform, SimplicialCliqueLifting, ) @@ -13,13 +19,17 @@ class TestProjectionSum: def setup_method(self): """Set up the test.""" # Initialize a lifting class - self.lifting = SimplicialCliqueLifting( - feature_lifting="ProjectionSum", complex_dim=3 + self.lifting = LiftingTransform( + lifting=SimplicialCliqueLifting(complex_dim=3), + feature_lifting="ProjectionSum", + data2domain=Data2NxGraph(), + domain2domain=TnxComplex2Complex(), + domain2dict=Complex2Dict(), ) def test_lift_features(self, simple_graph_1): """Test the lift_features method. - + Parameters ---------- simple_graph_1 : torch_geometric.data.Data @@ -31,38 +41,27 @@ def test_lift_features(self, simple_graph_1): expected_x1 = torch.tensor( [ - [ 6.], - [ 11.], - [ 101.], - [5001.], - [ 15.], - [ 105.], - [ 60.], - [ 110.], - [ 510.], - [5010.], - [1050.], - [1500.], - [5500.] + [6.0], + [11.0], + [101.0], + [5001.0], + [15.0], + [105.0], + [60.0], + [110.0], + [510.0], + [5010.0], + [1050.0], + [1500.0], + [5500.0], ] ) expected_x2 = torch.tensor( - [ - [ 32.], - [ 212.], - [ 222.], - [10022.], - [ 230.], - [11020.] - ] + [[32.0], [212.0], [222.0], [10022.0], [230.0], [11020.0]] ) - expected_x3 = torch.tensor( - [ - [696.] - ] - ) + expected_x3 = torch.tensor([[696.0]]) assert ( expected_x1 == lifted_data.x_1 diff --git a/test/transforms/feature_liftings/test_SetLifting.py b/test/transforms/feature_liftings/test_SetLifting.py index 9b71816f..bf0c621f 100644 --- a/test/transforms/feature_liftings/test_SetLifting.py +++ b/test/transforms/feature_liftings/test_SetLifting.py @@ -2,7 +2,13 @@ import torch -from topobenchmark.transforms.liftings.graph2simplicial import ( +from topobenchmark.data.utils import ( + Complex2Dict, + Data2NxGraph, + TnxComplex2Complex, +) +from topobenchmark.transforms.liftings import ( + LiftingTransform, SimplicialCliqueLifting, ) @@ -13,13 +19,17 @@ class TestSetLifting: def setup_method(self): """Set up the test.""" # Initialize a lifting class - self.lifting = SimplicialCliqueLifting( - feature_lifting="Set", complex_dim=3 + self.lifting = LiftingTransform( + SimplicialCliqueLifting(complex_dim=3), + feature_lifting="Set", + data2domain=Data2NxGraph(), + domain2domain=TnxComplex2Complex(signed=False), + domain2dict=Complex2Dict(), ) def test_lift_features(self, simple_graph_1): """Test the lift_features method. - + Parameters ---------- simple_graph_1 : torch_geometric.data.Data diff --git a/topobenchmark/transforms/feature_liftings/concatenation.py b/topobenchmark/transforms/feature_liftings/concatenation.py index 5a69f46d..b26509d9 100644 --- a/topobenchmark/transforms/feature_liftings/concatenation.py +++ b/topobenchmark/transforms/feature_liftings/concatenation.py @@ -1,83 +1,53 @@ """Concatenation feature lifting.""" import torch -import torch_geometric +from topobenchmark.transforms.feature_liftings.base import FeatureLiftingMap -class Concatenation(torch_geometric.transforms.BaseTransform): - r"""Lift r-cell features to r+1-cells by concatenation. - Parameters - ---------- - **kwargs : optional - Additional arguments for the class. - """ - - def __init__(self, **kwargs): - super().__init__() +class Concatenation(FeatureLiftingMap): + """Lift r-cell features to r+1-cells by concatenation.""" def __repr__(self) -> str: return f"{self.__class__.__name__}()" - def lift_features( - self, data: torch_geometric.data.Data | dict - ) -> torch_geometric.data.Data | dict: + def lift_features(self, domain): r"""Concatenate r-cell features to obtain r+1-cell features. Parameters ---------- - data : torch_geometric.data.Data | dict + data : Complex The input data to be lifted. Returns ------- - torch_geometric.data.Data | dict - The lifted data. + Complex + Domain with the lifted features. """ - keys = sorted( - [ - key.split("_")[1] - for key in data - if "incidence" in key and "-" not in key - ] - ) - for elem in keys: - if f"x_{elem}" not in data: - idx_to_project = 0 if elem == "hyperedges" else int(elem) - 1 - incidence = data["incidence_" + elem] - _, n = incidence.shape + for rank in range(domain.max_rank - 1): + if domain.features[rank + 1] is not None: + continue - if n != 0: - idxs_list = [] - for n_feature in range(n): - idxs_for_feature = incidence.indices()[ - 0, incidence.indices()[1, :] == n_feature - ] - idxs_list.append(torch.sort(idxs_for_feature)[0]) + # TODO: different if hyperedges? + idx_to_project = rank - idxs = torch.stack(idxs_list, dim=0) - values = data[f"x_{idx_to_project}"][idxs].view(n, -1) - else: - m = data[f"x_{int(elem)-1}"].shape[1] * (int(elem) + 1) - values = torch.zeros([0, m]) + incidence = domain.incidence[rank + 1] + _, n = incidence.shape - data["x_" + elem] = values - return data + if n != 0: + idxs_list = [] + for n_feature in range(n): + idxs_for_feature = incidence.indices()[ + 0, incidence.indices()[1, :] == n_feature + ] + idxs_list.append(torch.sort(idxs_for_feature)[0]) - def forward( - self, data: torch_geometric.data.Data | dict - ) -> torch_geometric.data.Data | dict: - r"""Apply the lifting to the input data. + idxs = torch.stack(idxs_list, dim=0) + values = domain.features[idx_to_project][idxs].view(n, -1) + else: + m = domain.features[rank].shape[1] * (rank + 2) + values = torch.zeros([0, m]) - Parameters - ---------- - data : torch_geometric.data.Data | dict - The input data to be lifted. + domain.update_features(rank + 1, values) - Returns - ------- - torch_geometric.data.Data | dict - The lifted data. - """ - data = self.lift_features(data) - return data + return domain diff --git a/topobenchmark/transforms/feature_liftings/set.py b/topobenchmark/transforms/feature_liftings/set.py index 28ccd0cc..1886e25b 100644 --- a/topobenchmark/transforms/feature_liftings/set.py +++ b/topobenchmark/transforms/feature_liftings/set.py @@ -1,89 +1,60 @@ """Set lifting for r-cell features to r+1-cell features.""" import torch -import torch_geometric +from topobenchmark.transforms.feature_liftings.base import FeatureLiftingMap -class Set(torch_geometric.transforms.BaseTransform): - r"""Lift r-cell features to r+1-cells by set operations. - Parameters - ---------- - **kwargs : optional - Additional arguments for the class. - """ - - def __init__(self, **kwargs): - super().__init__() +class Set(FeatureLiftingMap): + """Lift r-cell features to r+1-cells by set operations.""" def __repr__(self) -> str: return f"{self.__class__.__name__}()" - def lift_features( - self, data: torch_geometric.data.Data | dict - ) -> torch_geometric.data.Data | dict: + def lift_features(self, domain): r"""Concatenate r-cell features to r+1-cell structures. Parameters ---------- - data : torch_geometric.data.Data | dict + data : Complex The input data to be lifted. Returns ------- - torch_geometric.data.Data | dict - The lifted data. + Complex + Domain with the lifted features. """ - keys = sorted( - [key.split("_")[1] for key in data if "incidence" in key] - ) - for elem in keys: - if f"x_{elem}" not in data: - # idx_to_project = 0 if elem == "hyperedges" else int(elem) - 1 - incidence = data["incidence_" + elem] - _, n = incidence.shape - - if n != 0: - idxs_list = [] - for n_feature in range(n): - idxs_for_feature = incidence.indices()[ - 0, incidence.indices()[1, :] == n_feature - ] - idxs_list.append(torch.sort(idxs_for_feature)[0]) - - idxs = torch.stack(idxs_list, dim=0) - if elem == "1": - values = idxs - else: - values = torch.sort( - torch.unique( - data["x_" + str(int(elem) - 1)][idxs].view( - idxs.shape[0], -1 - ), - dim=1, + for rank in range(domain.max_rank - 1): + if domain.features[rank + 1] is not None: + continue + + incidence = domain.incidence[rank + 1] + _, n = incidence.shape + + if n != 0: + idxs_list = [] + for n_feature in range(n): + idxs_for_feature = incidence.indices()[ + 0, incidence.indices()[1, :] == n_feature + ] + idxs_list.append(torch.sort(idxs_for_feature)[0]) + + idxs = torch.stack(idxs_list, dim=0) + if rank == 0: + values = idxs + else: + values = torch.sort( + torch.unique( + domain.features[rank][idxs].view( + idxs.shape[0], -1 ), dim=1, - )[0] - else: - values = torch.tensor([]) - - data["x_" + elem] = values - return data + ), + dim=1, + )[0] + else: + values = torch.tensor([]) - def forward( - self, data: torch_geometric.data.Data | dict - ) -> torch_geometric.data.Data | dict: - r"""Apply the lifting to the input data. + domain.update_features(rank + 1, values) - Parameters - ---------- - data : torch_geometric.data.Data | dict - The input data to be lifted. - - Returns - ------- - torch_geometric.data.Data | dict - The lifted data. - """ - data = self.lift_features(data) - return data + return domain From 190e2b89b7de4e47504211b5a6867a09025571e9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lu=C3=ADs=20F=2E=20Pereira?= Date: Wed, 15 Jan 2025 19:13:00 -0800 Subject: [PATCH 17/28] Add str-based instantiation to LiftingMap for backwards compatibility --- topobenchmark/transforms/liftings/base.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/topobenchmark/transforms/liftings/base.py b/topobenchmark/transforms/liftings/base.py index ce3d1ab4..d9fb0593 100644 --- a/topobenchmark/transforms/liftings/base.py +++ b/topobenchmark/transforms/liftings/base.py @@ -48,6 +48,16 @@ def __init__( if domain2domain is None: domain2domain = IdentityAdapter() + if isinstance(lifting, str): + from topobenchmark.transforms import TRANSFORMS + + lifting = TRANSFORMS[lifting]() + + if isinstance(feature_lifting, str): + from topobenchmark.transforms import TRANSFORMS + + feature_lifting = TRANSFORMS[feature_lifting]() + self.data2domain = data2domain self.domain2domain = domain2domain self.domain2dict = domain2dict From f774c97b192ee6e60ebe11938069b33789a57545 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lu=C3=ADs=20F=2E=20Pereira?= Date: Wed, 15 Jan 2025 19:13:46 -0800 Subject: [PATCH 18/28] Fix Data2NxGraph adapter --- topobenchmark/data/utils/adapters.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/topobenchmark/data/utils/adapters.py b/topobenchmark/data/utils/adapters.py index 342a2622..1dc4328f 100644 --- a/topobenchmark/data/utils/adapters.py +++ b/topobenchmark/data/utils/adapters.py @@ -85,14 +85,14 @@ def adapt(self, domain: torch_geometric.data.Data) -> nx.Graph: if self.preserve_edge_attr and self._data_has_edge_attr(domain): # In case edge features are given, assign features to every edge - edge_index, edge_attr = ( - domain.edge_index, - ( - domain.edge_attr - if is_undirected(domain.edge_index, domain.edge_attr) - else to_undirected(domain.edge_index, domain.edge_attr) - ), - ) + # TODO: confirm this is the desired behavior + if is_undirected(domain.edge_index, domain.edge_attr): + edge_index, edge_attr = (domain.edge_index, domain.edge_attr) + else: + edge_index, edge_attr = to_undirected( + domain.edge_index, domain.edge_attr + ) + edges = [ (i.item(), j.item(), dict(features=edge_attr[edge_idx], dim=1)) for edge_idx, (i, j) in enumerate( From bd243871552d45f0449800bd55891e563017fe10 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lu=C3=ADs=20F=2E=20Pereira?= Date: Wed, 15 Jan 2025 19:14:19 -0800 Subject: [PATCH 19/28] Fix failing data manipulation test (only setup) --- .../test_SimplicialCurvature.py | 45 ++++++++++++------- 1 file changed, 29 insertions(+), 16 deletions(-) diff --git a/test/transforms/data_manipulations/test_SimplicialCurvature.py b/test/transforms/data_manipulations/test_SimplicialCurvature.py index e4cb517b..e8199beb 100644 --- a/test/transforms/data_manipulations/test_SimplicialCurvature.py +++ b/test/transforms/data_manipulations/test_SimplicialCurvature.py @@ -2,8 +2,19 @@ import torch from torch_geometric.data import Data -from topobenchmark.transforms.data_manipulations import CalculateSimplicialCurvature -from topobenchmark.transforms.liftings.graph2simplicial import SimplicialCliqueLifting + +from topobenchmark.data.utils import ( + Complex2Dict, + Data2NxGraph, + TnxComplex2Complex, +) +from topobenchmark.transforms.data_manipulations import ( + CalculateSimplicialCurvature, +) +from topobenchmark.transforms.liftings import ( + LiftingTransform, + SimplicialCliqueLifting, +) class TestSimplicialCurvature: @@ -11,29 +22,31 @@ class TestSimplicialCurvature: def test_simplicial_curvature(self, simple_graph_1): """Test simplicial curvature calculation. - + Parameters ---------- simple_graph_1 : torch_geometric.data.Data A simple graph fixture. """ simplicial_curvature = CalculateSimplicialCurvature() - lifting_unsigned = SimplicialCliqueLifting( - complex_dim=3, signed=False + + lifting_unsigned = LiftingTransform( + lifting=SimplicialCliqueLifting(complex_dim=3), + data2domain=Data2NxGraph(), + domain2domain=TnxComplex2Complex(signed=False), + domain2dict=Complex2Dict(), ) + data = lifting_unsigned(simple_graph_1) - data['0_cell_degrees'] = torch.unsqueeze( - torch.sum(data['incidence_1'], dim=1).to_dense(), - dim=1 + data["0_cell_degrees"] = torch.unsqueeze( + torch.sum(data["incidence_1"], dim=1).to_dense(), dim=1 ) - data['1_cell_degrees'] = torch.unsqueeze( - torch.sum(data['incidence_2'], dim=1).to_dense(), - dim=1 + data["1_cell_degrees"] = torch.unsqueeze( + torch.sum(data["incidence_2"], dim=1).to_dense(), dim=1 ) - data['2_cell_degrees'] = torch.unsqueeze( - torch.sum(data['incidence_3'], dim=1).to_dense(), - dim=1 + data["2_cell_degrees"] = torch.unsqueeze( + torch.sum(data["incidence_3"], dim=1).to_dense(), dim=1 ) - + res = simplicial_curvature(data) - assert isinstance(res, Data) \ No newline at end of file + assert isinstance(res, Data) From a7a87553df3b7c1e82b2f94c12c582cc440d16e0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lu=C3=ADs=20F=2E=20Pereira?= Date: Thu, 16 Jan 2025 11:21:59 -0800 Subject: [PATCH 20/28] Fix TnxComplex2Complex adapter to handle CellComplex features --- topobenchmark/data/utils/adapters.py | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/topobenchmark/data/utils/adapters.py b/topobenchmark/data/utils/adapters.py index 1dc4328f..80295e69 100644 --- a/topobenchmark/data/utils/adapters.py +++ b/topobenchmark/data/utils/adapters.py @@ -5,6 +5,7 @@ import torch import torch_geometric from topomodelx.utils.sparse import from_sparse +from toponetx.classes import CellComplex, SimplicialComplex from torch_geometric.utils.undirected import is_undirected, to_undirected from topobenchmark.data.utils.domain import Complex @@ -123,10 +124,6 @@ class TnxComplex2Complex(Adapter): Parameters ---------- - complex_dim : int - Dimension of the (sub)complex. - If ``None``, adapts the (full) complex. - If greater than dimension of complex, pads with empty matrices. neighborhoods : list, optional List of neighborhoods of interest. signed : bool, optional @@ -216,17 +213,18 @@ def adapt(self, domain): if neighborhoods is not None: data = select_neighborhoods_of_interest(data, neighborhoods) - # TODO: simplex specific? - # TODO: how to do this for other? - if self.transfer_features and hasattr( - domain, "get_simplex_attributes" - ): + if self.transfer_features: + if isinstance(domain, SimplicialComplex): + get_features = domain.get_simplex_attributes + elif isinstance(domain, CellComplex): + get_features = domain.get_cell_attributes + else: + raise ValueError("Can't transfer features.") + # TODO: confirm features are in the right order; update this data["features"] = [] for rank in range(dim + 1): - rank_features_dict = domain.get_simplex_attributes( - "features", rank - ) + rank_features_dict = get_features("features", rank) if rank_features_dict: rank_features = torch.stack( list(rank_features_dict.values()) From 2ccbea30d59686f69199eb911eb373811f28c5d3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lu=C3=ADs=20F=2E=20Pereira?= Date: Thu, 16 Jan 2025 11:23:10 -0800 Subject: [PATCH 21/28] Add syntax sugar to instantiate graph2complex/simplicial lifting transforms --- topobenchmark/transforms/liftings/__init__.py | 6 +- topobenchmark/transforms/liftings/base.py | 55 ++++++++++++++++++- 2 files changed, 57 insertions(+), 4 deletions(-) diff --git a/topobenchmark/transforms/liftings/__init__.py b/topobenchmark/transforms/liftings/__init__.py index 513f5035..2c759ac3 100755 --- a/topobenchmark/transforms/liftings/__init__.py +++ b/topobenchmark/transforms/liftings/__init__.py @@ -1,6 +1,10 @@ """This module implements the liftings for the topological transforms.""" -from .base import LiftingTransform # noqa: F401 +from .base import ( # noqa: F401 + Graph2ComplexLiftingTransform, + Graph2SimplicialLiftingTransform, + LiftingTransform, +) from .graph2cell import GRAPH2CELL_LIFTINGS from .graph2hypergraph import GRAPH2HYPERGRAPH_LIFTINGS from .graph2simplicial import GRAPH2SIMPLICIAL_LIFTINGS diff --git a/topobenchmark/transforms/liftings/base.py b/topobenchmark/transforms/liftings/base.py index d9fb0593..3637f564 100644 --- a/topobenchmark/transforms/liftings/base.py +++ b/topobenchmark/transforms/liftings/base.py @@ -4,7 +4,12 @@ import torch_geometric -from topobenchmark.data.utils import IdentityAdapter +from topobenchmark.data.utils import ( + Complex2Dict, + Data2NxGraph, + IdentityAdapter, + TnxComplex2Complex, +) from topobenchmark.transforms.feature_liftings.identity import Identity @@ -13,14 +18,14 @@ class LiftingTransform(torch_geometric.transforms.BaseTransform): Parameters ---------- + lifting : LiftingMap + Lifting map. data2domain : Converter Conversion between ``torch_geometric.Data`` into domain for consumption by lifting. domain2dict : Converter Conversion between output domain of feature lifting and ``torch_geometric.Data``. - lifting : LiftingMap - Lifting map. domain2domain : Converter Conversion between output domain of lifting and input domain for feature lifting. @@ -92,6 +97,50 @@ def forward( ) +class Graph2ComplexLiftingTransform(LiftingTransform): + """Graph to complex lifting transform. + + Parameters + ---------- + lifting : LiftingMap + Lifting map. + feature_lifting : FeatureLiftingMap + Feature lifting map. + preserve_edge_attr : bool + Whether to preserve edge attributes. + neighborhoods : list, optional + List of neighborhoods of interest. + signed : bool, optional + If True, returns signed connectivity matrices. + transfer_features : bool, optional + Whether to transfer features. + """ + + def __init__( + self, + lifting, + feature_lifting="ProjectionSum", + preserve_edge_attr=False, + neighborhoods=None, + signed=False, + transfer_features=True, + ): + super().__init__( + lifting, + feature_lifting=feature_lifting, + data2domain=Data2NxGraph(preserve_edge_attr), + domain2domain=TnxComplex2Complex( + neighborhoods=neighborhoods, + signed=signed, + transfer_features=transfer_features, + ), + domain2dict=Complex2Dict(), + ) + + +Graph2SimplicialLiftingTransform = Graph2ComplexLiftingTransform + + class LiftingMap(abc.ABC): """Lifting map. From 7682feffad4b73fcc370266ce59cd7c6abc87d44 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lu=C3=ADs=20F=2E=20Pereira?= Date: Thu, 16 Jan 2025 11:24:25 -0800 Subject: [PATCH 22/28] Make imports shorter and use newly added syntax sugar --- test/conftest.py | 4 ++-- .../test_SimplicialCurvature.py | 14 +++-------- .../feature_liftings/test_Concatenation.py | 15 ++++-------- .../feature_liftings/test_ProjectionSum.py | 12 ++-------- .../feature_liftings/test_SetLifting.py | 14 +++-------- .../liftings/cell/test_CellCyclesLifting.py | 13 ++++------- .../hypergraph/test_HypergraphKHopLifting.py | 4 ++-- ...test_HypergraphKNearestNeighborsLifting.py | 4 +--- .../test_SimplicialCliqueLifting.py | 23 +++++-------------- .../test_SimplicialNeighborhoodLifting.py | 23 +++++-------------- topobenchmark/data/utils/__init__.py | 2 +- 11 files changed, 35 insertions(+), 93 deletions(-) diff --git a/test/conftest.py b/test/conftest.py index 753d63b2..9a70c6a1 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -5,8 +5,8 @@ import torch import torch_geometric -from topobenchmark.transforms.liftings.graph2cell.cycle import CellCycleLifting -from topobenchmark.transforms.liftings.graph2simplicial.clique import ( +from topobenchmark.transforms.liftings import ( + CellCycleLifting, SimplicialCliqueLifting, ) diff --git a/test/transforms/data_manipulations/test_SimplicialCurvature.py b/test/transforms/data_manipulations/test_SimplicialCurvature.py index e8199beb..e90d4e68 100644 --- a/test/transforms/data_manipulations/test_SimplicialCurvature.py +++ b/test/transforms/data_manipulations/test_SimplicialCurvature.py @@ -3,16 +3,11 @@ import torch from torch_geometric.data import Data -from topobenchmark.data.utils import ( - Complex2Dict, - Data2NxGraph, - TnxComplex2Complex, -) from topobenchmark.transforms.data_manipulations import ( CalculateSimplicialCurvature, ) from topobenchmark.transforms.liftings import ( - LiftingTransform, + Graph2SimplicialLiftingTransform, SimplicialCliqueLifting, ) @@ -30,11 +25,8 @@ def test_simplicial_curvature(self, simple_graph_1): """ simplicial_curvature = CalculateSimplicialCurvature() - lifting_unsigned = LiftingTransform( - lifting=SimplicialCliqueLifting(complex_dim=3), - data2domain=Data2NxGraph(), - domain2domain=TnxComplex2Complex(signed=False), - domain2dict=Complex2Dict(), + lifting_unsigned = Graph2SimplicialLiftingTransform( + lifting=SimplicialCliqueLifting(complex_dim=3) ) data = lifting_unsigned(simple_graph_1) diff --git a/test/transforms/feature_liftings/test_Concatenation.py b/test/transforms/feature_liftings/test_Concatenation.py index aff3a2c1..9474e8da 100644 --- a/test/transforms/feature_liftings/test_Concatenation.py +++ b/test/transforms/feature_liftings/test_Concatenation.py @@ -2,13 +2,8 @@ import torch -from topobenchmark.data.utils import ( - Complex2Dict, - Data2NxGraph, - TnxComplex2Complex, -) from topobenchmark.transforms.liftings import ( - LiftingTransform, + Graph2SimplicialLiftingTransform, SimplicialCliqueLifting, ) @@ -19,12 +14,10 @@ class TestConcatenation: def setup_method(self): """Set up the test.""" # Initialize a lifting class - self.lifting = LiftingTransform( - SimplicialCliqueLifting(complex_dim=3), + + self.lifting = Graph2SimplicialLiftingTransform( + lifting=SimplicialCliqueLifting(complex_dim=3), feature_lifting="Concatenation", - data2domain=Data2NxGraph(), - domain2domain=TnxComplex2Complex(signed=False), - domain2dict=Complex2Dict(), ) def test_lift_features(self, simple_graph_0, simple_graph_1): diff --git a/test/transforms/feature_liftings/test_ProjectionSum.py b/test/transforms/feature_liftings/test_ProjectionSum.py index b14ea5e8..a6ad8cdf 100644 --- a/test/transforms/feature_liftings/test_ProjectionSum.py +++ b/test/transforms/feature_liftings/test_ProjectionSum.py @@ -2,13 +2,8 @@ import torch -from topobenchmark.data.utils import ( - Complex2Dict, - Data2NxGraph, - TnxComplex2Complex, -) from topobenchmark.transforms.liftings import ( - LiftingTransform, + Graph2SimplicialLiftingTransform, SimplicialCliqueLifting, ) @@ -19,12 +14,9 @@ class TestProjectionSum: def setup_method(self): """Set up the test.""" # Initialize a lifting class - self.lifting = LiftingTransform( + self.lifting = Graph2SimplicialLiftingTransform( lifting=SimplicialCliqueLifting(complex_dim=3), feature_lifting="ProjectionSum", - data2domain=Data2NxGraph(), - domain2domain=TnxComplex2Complex(), - domain2dict=Complex2Dict(), ) def test_lift_features(self, simple_graph_1): diff --git a/test/transforms/feature_liftings/test_SetLifting.py b/test/transforms/feature_liftings/test_SetLifting.py index bf0c621f..584f9724 100644 --- a/test/transforms/feature_liftings/test_SetLifting.py +++ b/test/transforms/feature_liftings/test_SetLifting.py @@ -2,13 +2,8 @@ import torch -from topobenchmark.data.utils import ( - Complex2Dict, - Data2NxGraph, - TnxComplex2Complex, -) from topobenchmark.transforms.liftings import ( - LiftingTransform, + Graph2SimplicialLiftingTransform, SimplicialCliqueLifting, ) @@ -19,12 +14,9 @@ class TestSetLifting: def setup_method(self): """Set up the test.""" # Initialize a lifting class - self.lifting = LiftingTransform( - SimplicialCliqueLifting(complex_dim=3), + self.lifting = Graph2SimplicialLiftingTransform( + lifting=SimplicialCliqueLifting(complex_dim=3), feature_lifting="Set", - data2domain=Data2NxGraph(), - domain2domain=TnxComplex2Complex(signed=False), - domain2dict=Complex2Dict(), ) def test_lift_features(self, simple_graph_1): diff --git a/test/transforms/liftings/cell/test_CellCyclesLifting.py b/test/transforms/liftings/cell/test_CellCyclesLifting.py index c574992e..7235b20f 100644 --- a/test/transforms/liftings/cell/test_CellCyclesLifting.py +++ b/test/transforms/liftings/cell/test_CellCyclesLifting.py @@ -2,9 +2,10 @@ import torch -from topobenchmark.data.utils import Data2NxGraph, TnxComplex2Dict -from topobenchmark.transforms.liftings.base import LiftingTransform -from topobenchmark.transforms.liftings.graph2cell.cycle import CellCycleLifting +from topobenchmark.transforms.liftings import ( + CellCycleLifting, + Graph2ComplexLiftingTransform, +) class TestCellCycleLifting: @@ -12,11 +13,7 @@ class TestCellCycleLifting: def setup_method(self): # Initialise the CellCycleLifting class - self.lifting = LiftingTransform( - CellCycleLifting(), - data2domain=Data2NxGraph(), - domain2dict=TnxComplex2Dict(), - ) + self.lifting = Graph2ComplexLiftingTransform(CellCycleLifting()) def test_lift_topology(self, simple_graph_1): # Test the lift_topology method diff --git a/test/transforms/liftings/hypergraph/test_HypergraphKHopLifting.py b/test/transforms/liftings/hypergraph/test_HypergraphKHopLifting.py index 3fcc7ebb..8fd1b75b 100644 --- a/test/transforms/liftings/hypergraph/test_HypergraphKHopLifting.py +++ b/test/transforms/liftings/hypergraph/test_HypergraphKHopLifting.py @@ -2,9 +2,9 @@ import torch -from topobenchmark.transforms.liftings.base import LiftingTransform -from topobenchmark.transforms.liftings.graph2hypergraph.khop import ( +from topobenchmark.transforms.liftings import ( HypergraphKHopLifting, + LiftingTransform, ) diff --git a/test/transforms/liftings/hypergraph/test_HypergraphKNearestNeighborsLifting.py b/test/transforms/liftings/hypergraph/test_HypergraphKNearestNeighborsLifting.py index 069d7a3c..23dc5d35 100644 --- a/test/transforms/liftings/hypergraph/test_HypergraphKNearestNeighborsLifting.py +++ b/test/transforms/liftings/hypergraph/test_HypergraphKNearestNeighborsLifting.py @@ -4,9 +4,7 @@ import torch from torch_geometric.data import Data -from topobenchmark.transforms.liftings.graph2hypergraph.knn import ( - HypergraphKNNLifting, -) +from topobenchmark.transforms.liftings import HypergraphKNNLifting class TestHypergraphKNNLifting: diff --git a/test/transforms/liftings/simplicial/test_SimplicialCliqueLifting.py b/test/transforms/liftings/simplicial/test_SimplicialCliqueLifting.py index fa36e072..a2c32ebf 100644 --- a/test/transforms/liftings/simplicial/test_SimplicialCliqueLifting.py +++ b/test/transforms/liftings/simplicial/test_SimplicialCliqueLifting.py @@ -2,16 +2,11 @@ import torch -from topobenchmark.data.utils import ( - Complex2Dict, - Data2NxGraph, - TnxComplex2Complex, -) from topobenchmark.transforms.feature_liftings.projection_sum import ( ProjectionSum, ) -from topobenchmark.transforms.liftings.base import LiftingTransform -from topobenchmark.transforms.liftings.graph2simplicial.clique import ( +from topobenchmark.transforms.liftings import ( + Graph2SimplicialLiftingTransform, SimplicialCliqueLifting, ) @@ -21,25 +16,19 @@ class TestSimplicialCliqueLifting: def setup_method(self): # Initialise the SimplicialCliqueLifting class - data2graph = Data2NxGraph() lifting_map = SimplicialCliqueLifting(complex_dim=3) feature_lifting = ProjectionSum() - domain2dict = Complex2Dict() - self.lifting_signed = LiftingTransform( + self.lifting_signed = Graph2SimplicialLiftingTransform( lifting=lifting_map, feature_lifting=feature_lifting, - data2domain=data2graph, - domain2domain=TnxComplex2Complex(signed=True), - domain2dict=domain2dict, + signed=True, ) - self.lifting_unsigned = LiftingTransform( + self.lifting_unsigned = Graph2SimplicialLiftingTransform( lifting=lifting_map, feature_lifting=feature_lifting, - data2domain=data2graph, - domain2domain=TnxComplex2Complex(signed=False), - domain2dict=domain2dict, + signed=False, ) def test_lift_topology(self, simple_graph_1): diff --git a/test/transforms/liftings/simplicial/test_SimplicialNeighborhoodLifting.py b/test/transforms/liftings/simplicial/test_SimplicialNeighborhoodLifting.py index e21b8f99..6a81d9f2 100644 --- a/test/transforms/liftings/simplicial/test_SimplicialNeighborhoodLifting.py +++ b/test/transforms/liftings/simplicial/test_SimplicialNeighborhoodLifting.py @@ -2,16 +2,11 @@ import torch -from topobenchmark.data.utils import ( - Complex2Dict, - Data2NxGraph, - TnxComplex2Complex, -) from topobenchmark.transforms.feature_liftings.projection_sum import ( ProjectionSum, ) -from topobenchmark.transforms.liftings.base import LiftingTransform -from topobenchmark.transforms.liftings.graph2simplicial.khop import ( +from topobenchmark.transforms.liftings import ( + Graph2SimplicialLiftingTransform, SimplicialKHopLifting, ) @@ -23,25 +18,19 @@ class TestSimplicialKHopLifting: def setup_method(self): # Initialise the SimplicialKHopLifting class - data2graph = Data2NxGraph() feature_lifting = ProjectionSum() - domain2dict = Complex2Dict() lifting_map = SimplicialKHopLifting(complex_dim=3) - self.lifting_signed = LiftingTransform( + self.lifting_signed = Graph2SimplicialLiftingTransform( lifting=lifting_map, feature_lifting=feature_lifting, - data2domain=data2graph, - domain2domain=TnxComplex2Complex(signed=True), - domain2dict=domain2dict, + signed=True, ) - self.lifting_unsigned = LiftingTransform( + self.lifting_unsigned = Graph2SimplicialLiftingTransform( lifting=lifting_map, feature_lifting=feature_lifting, - data2domain=data2graph, - domain2domain=TnxComplex2Complex(signed=False), - domain2dict=domain2dict, + signed=False, ) def test_lift_topology(self, simple_graph_1): diff --git a/topobenchmark/data/utils/__init__.py b/topobenchmark/data/utils/__init__.py index d7010c2b..34fc79f3 100644 --- a/topobenchmark/data/utils/__init__.py +++ b/topobenchmark/data/utils/__init__.py @@ -1,7 +1,7 @@ """Init file for data/utils module.""" from .adapters import * -from .domain import Complex +from .domain import Complex # noqa: F401 from .utils import ( ensure_serializable, # noqa: F401 generate_zero_sparse_connectivity, # noqa: F401 From f3e3f88b5b503ef84570bc3b7fcb384a405ace31 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lu=C3=ADs=20F=2E=20Pereira?= Date: Thu, 16 Jan 2025 15:31:25 -0800 Subject: [PATCH 23/28] Update domain to accomodate hypergraph data --- .../liftings/cell/test_CellCyclesLifting.py | 4 +- .../hypergraph/test_HypergraphKHopLifting.py | 12 +- topobenchmark/data/utils/__init__.py | 2 +- topobenchmark/data/utils/adapters.py | 39 ++++-- topobenchmark/data/utils/domain.py | 113 +++++++++++------- .../feature_liftings/concatenation.py | 18 +-- .../feature_liftings/projection_sum.py | 16 +-- .../transforms/feature_liftings/set.py | 16 +-- topobenchmark/transforms/liftings/__init__.py | 2 + topobenchmark/transforms/liftings/base.py | 29 +++-- .../liftings/graph2hypergraph/khop.py | 13 +- 11 files changed, 166 insertions(+), 98 deletions(-) diff --git a/test/transforms/liftings/cell/test_CellCyclesLifting.py b/test/transforms/liftings/cell/test_CellCyclesLifting.py index 7235b20f..706e1f9d 100644 --- a/test/transforms/liftings/cell/test_CellCyclesLifting.py +++ b/test/transforms/liftings/cell/test_CellCyclesLifting.py @@ -4,7 +4,7 @@ from topobenchmark.transforms.liftings import ( CellCycleLifting, - Graph2ComplexLiftingTransform, + Graph2CellLiftingTransform, ) @@ -13,7 +13,7 @@ class TestCellCycleLifting: def setup_method(self): # Initialise the CellCycleLifting class - self.lifting = Graph2ComplexLiftingTransform(CellCycleLifting()) + self.lifting = Graph2CellLiftingTransform(CellCycleLifting()) def test_lift_topology(self, simple_graph_1): # Test the lift_topology method diff --git a/test/transforms/liftings/hypergraph/test_HypergraphKHopLifting.py b/test/transforms/liftings/hypergraph/test_HypergraphKHopLifting.py index 8fd1b75b..68326f11 100644 --- a/test/transforms/liftings/hypergraph/test_HypergraphKHopLifting.py +++ b/test/transforms/liftings/hypergraph/test_HypergraphKHopLifting.py @@ -3,8 +3,8 @@ import torch from topobenchmark.transforms.liftings import ( + Graph2HypergraphLiftingTransform, HypergraphKHopLifting, - LiftingTransform, ) @@ -14,15 +14,19 @@ class TestHypergraphKHopLifting: def setup_method(self): """Setup the test.""" # Initialise the HypergraphKHopLifting class - self.lifting_k1 = LiftingTransform(HypergraphKHopLifting(k_value=1)) - self.lifting_k2 = LiftingTransform(HypergraphKHopLifting(k_value=2)) + self.lifting_k1 = Graph2HypergraphLiftingTransform( + HypergraphKHopLifting(k_value=1) + ) + self.lifting_k2 = Graph2HypergraphLiftingTransform( + HypergraphKHopLifting(k_value=2) + ) # TODO: delete? # NB: `preserve_edge_attr` is never used? therefore they're equivalent # self.lifting_edge_attr = HypergraphKHopLifting( # k_value=1, preserve_edge_attr=True # ) - self.lifting_edge_attr = LiftingTransform( + self.lifting_edge_attr = Graph2HypergraphLiftingTransform( HypergraphKHopLifting(k_value=1) ) diff --git a/topobenchmark/data/utils/__init__.py b/topobenchmark/data/utils/__init__.py index 34fc79f3..de796c1d 100644 --- a/topobenchmark/data/utils/__init__.py +++ b/topobenchmark/data/utils/__init__.py @@ -1,7 +1,7 @@ """Init file for data/utils module.""" from .adapters import * -from .domain import Complex # noqa: F401 +from .domain import ComplexData, HypergraphData # noqa: F401 from .utils import ( ensure_serializable, # noqa: F401 generate_zero_sparse_connectivity, # noqa: F401 diff --git a/topobenchmark/data/utils/adapters.py b/topobenchmark/data/utils/adapters.py index 80295e69..b049d49c 100644 --- a/topobenchmark/data/utils/adapters.py +++ b/topobenchmark/data/utils/adapters.py @@ -8,7 +8,7 @@ from toponetx.classes import CellComplex, SimplicialComplex from torch_geometric.utils.undirected import is_undirected, to_undirected -from topobenchmark.data.utils.domain import Complex +from topobenchmark.data.utils.domain import ComplexData from topobenchmark.data.utils.utils import ( generate_zero_sparse_connectivity, select_neighborhoods_of_interest, @@ -115,7 +115,7 @@ def adapt(self, domain: torch_geometric.data.Data) -> nx.Graph: return graph -class TnxComplex2Complex(Adapter): +class TnxComplex2ComplexData(Adapter): """toponetx.Complex to Complex adaptation. NB: order of features plays a crucial role, as ``Complex`` @@ -236,18 +236,18 @@ def adapt(self, domain): for _ in range(dim + 1, practical_dim + 1): data["features"].append(None) - return Complex(**data) + return ComplexData(**data) -class Complex2Dict(Adapter): - """Complex to dict adaptation.""" +class ComplexData2Dict(Adapter): + """ComplexData to dict adaptation.""" def adapt(self, domain): """Adapt Complex to dict. Parameters ---------- - domain : toponetx.Complex + domain : ComplexData Returns ------- @@ -277,6 +277,29 @@ def adapt(self, domain): return data +class HypergraphData2Dict(Adapter): + """HypergraphData to dict adaptation.""" + + def adapt(self, domain): + """Adapt HypergraphData to dict. + + Parameters + ---------- + domain : HypergraphData + + Returns + ------- + dict + """ + hyperedges_key = domain.keys()[-1] + return { + "incidence_hyperedges": domain.incidence[hyperedges_key], + "num_hyperedges": domain.num_hyperedges, + "x_0": domain.features[0], + "x_hyperedges": domain.features[hyperedges_key], + } + + class AdapterComposition(Adapter): def __init__(self, adapters): super().__init__() @@ -309,10 +332,10 @@ def __init__( signed=False, transfer_features=True, ): - tnxcomplex2complex = TnxComplex2Complex( + tnxcomplex2complex = TnxComplex2ComplexData( neighborhoods=neighborhoods, signed=signed, transfer_features=transfer_features, ) - complex2dict = Complex2Dict() + complex2dict = ComplexData2Dict() super().__init__(adapters=(tnxcomplex2complex, complex2dict)) diff --git a/topobenchmark/data/utils/domain.py b/topobenchmark/data/utils/domain.py index 8bf4f1d7..57790162 100644 --- a/topobenchmark/data/utils/domain.py +++ b/topobenchmark/data/utils/domain.py @@ -1,4 +1,44 @@ -class Complex: +import abc + + +class Data(abc.ABC): + def __init__(self, incidence, features): + self.incidence = incidence + self.features = features + + @abc.abstractmethod + def keys(self): + pass + + def update_features(self, rank, values): + """Update features. + + Parameters + ---------- + rank : int + Rank of simplices the features belong to. + values : array-like + New features for the rank-simplices. + """ + self.features[rank] = values + + @property + def shape(self): + """Shape of the complex. + + Returns + ------- + list[int] + """ + return [ + None + if self.incidence[key] is None + else self.incidence[key].shape[-1] + for key in self.keys() + ] + + +class ComplexData(Data): def __init__( self, incidence, @@ -9,10 +49,6 @@ def __init__( hodge_laplacian, features=None, ): - # TODO: allow None with nice error message if callable? - - # TODO: make this private? do not allow for changes in these values? - self.incidence = incidence self.down_laplacian = down_laplacian self.up_laplacian = up_laplacian self.adjacency = adjacency @@ -20,53 +56,42 @@ def __init__( self.hodge_laplacian = hodge_laplacian if features is None: - features = [None for _ in range(len(self.incidence))] + features = [None for _ in range(len(incidence))] else: - for rank, dim in enumerate(self.shape): + for rank, incidence_ in enumerate(incidence): # TODO: make error message more informative if ( features[rank] is not None - and features[rank].shape[0] != dim + and features[rank].shape[0] != incidence_.shape[-1] ): raise ValueError("Features have wrong shape.") - self.features = features + super().__init__(incidence, features) - @property - def shape(self): - """Shape of the complex. + def keys(self): + return list(range(len(self.incidence))) - Returns - ------- - list[int] - """ - return [incidence.shape[-1] for incidence in self.incidence] - - @property - def max_rank(self): - """Maximum rank of the complex. - - NB: may differ from mathematical definition due to empty - matrices. - - Returns - ------- - int - """ - return len(self.incidence) - def update_features(self, rank, values): - """Update features. - - Parameters - ---------- - rank : int - Rank of simplices the features belong to. - values : array-like - New features for the rank-simplices. - """ - self.features[rank] = values +class HypergraphData(Data): + def __init__( + self, + incidence_hyperedges, + num_hyperedges, + incidence_0=None, + x_0=None, + x_hyperedges=None, + ): + self._hyperedges_key = 1 + incidence = { + 0: incidence_0, + self._hyperedges_key: incidence_hyperedges, + } + features = { + 0: x_0, + self._hyperedges_key: x_hyperedges, + } + super().__init__(incidence, features) + self.num_hyperedges = num_hyperedges - def reset_features(self): - """Reset features.""" - self.features = [None for _ in self.features] + def keys(self): + return [0, self._hyperedges_key] diff --git a/topobenchmark/transforms/feature_liftings/concatenation.py b/topobenchmark/transforms/feature_liftings/concatenation.py index b26509d9..44e3b192 100644 --- a/topobenchmark/transforms/feature_liftings/concatenation.py +++ b/topobenchmark/transforms/feature_liftings/concatenation.py @@ -24,14 +24,13 @@ def lift_features(self, domain): Complex Domain with the lifted features. """ - for rank in range(domain.max_rank - 1): - if domain.features[rank + 1] is not None: + for key, next_key in zip( + domain.keys(), domain.keys()[1:], strict=False + ): + if domain.features[next_key] is not None: continue - # TODO: different if hyperedges? - idx_to_project = rank - - incidence = domain.incidence[rank + 1] + incidence = domain.incidence[next_key] _, n = incidence.shape if n != 0: @@ -43,11 +42,12 @@ def lift_features(self, domain): idxs_list.append(torch.sort(idxs_for_feature)[0]) idxs = torch.stack(idxs_list, dim=0) - values = domain.features[idx_to_project][idxs].view(n, -1) + values = domain.features[key][idxs].view(n, -1) else: - m = domain.features[rank].shape[1] * (rank + 2) + # NB: only works if key represents rank + m = domain.features[key].shape[1] * (next_key + 1) values = torch.zeros([0, m]) - domain.update_features(rank + 1, values) + domain.update_features(next_key, values) return domain diff --git a/topobenchmark/transforms/feature_liftings/projection_sum.py b/topobenchmark/transforms/feature_liftings/projection_sum.py index a756fd0e..757234a7 100644 --- a/topobenchmark/transforms/feature_liftings/projection_sum.py +++ b/topobenchmark/transforms/feature_liftings/projection_sum.py @@ -13,23 +13,25 @@ def lift_features(self, domain): Parameters ---------- - data : Complex + data : Data The input data to be lifted. Returns ------- - Complex + Data Domain with the lifted features. """ - for rank in range(domain.max_rank - 1): - if domain.features[rank + 1] is not None: + for key, next_key in zip( + domain.keys(), domain.keys()[1:], strict=False + ): + if domain.features[next_key] is not None: continue domain.update_features( - rank + 1, + next_key, torch.matmul( - torch.abs(domain.incidence[rank + 1].t()), - domain.features[rank], + torch.abs(domain.incidence[next_key].t()), + domain.features[key], ), ) diff --git a/topobenchmark/transforms/feature_liftings/set.py b/topobenchmark/transforms/feature_liftings/set.py index 1886e25b..54ac1b9d 100644 --- a/topobenchmark/transforms/feature_liftings/set.py +++ b/topobenchmark/transforms/feature_liftings/set.py @@ -24,11 +24,13 @@ def lift_features(self, domain): Complex Domain with the lifted features. """ - for rank in range(domain.max_rank - 1): - if domain.features[rank + 1] is not None: + for key, next_key in zip( + domain.keys(), domain.keys()[1:], strict=False + ): + if domain.features[next_key] is not None: continue - incidence = domain.incidence[rank + 1] + incidence = domain.incidence[next_key] _, n = incidence.shape if n != 0: @@ -40,14 +42,12 @@ def lift_features(self, domain): idxs_list.append(torch.sort(idxs_for_feature)[0]) idxs = torch.stack(idxs_list, dim=0) - if rank == 0: + if key == 0: values = idxs else: values = torch.sort( torch.unique( - domain.features[rank][idxs].view( - idxs.shape[0], -1 - ), + domain.features[key][idxs].view(idxs.shape[0], -1), dim=1, ), dim=1, @@ -55,6 +55,6 @@ def lift_features(self, domain): else: values = torch.tensor([]) - domain.update_features(rank + 1, values) + domain.update_features(next_key, values) return domain diff --git a/topobenchmark/transforms/liftings/__init__.py b/topobenchmark/transforms/liftings/__init__.py index 2c759ac3..322c43a1 100755 --- a/topobenchmark/transforms/liftings/__init__.py +++ b/topobenchmark/transforms/liftings/__init__.py @@ -1,7 +1,9 @@ """This module implements the liftings for the topological transforms.""" from .base import ( # noqa: F401 + Graph2CellLiftingTransform, Graph2ComplexLiftingTransform, + Graph2HypergraphLiftingTransform, Graph2SimplicialLiftingTransform, LiftingTransform, ) diff --git a/topobenchmark/transforms/liftings/base.py b/topobenchmark/transforms/liftings/base.py index 3637f564..13f1f443 100644 --- a/topobenchmark/transforms/liftings/base.py +++ b/topobenchmark/transforms/liftings/base.py @@ -5,12 +5,12 @@ import torch_geometric from topobenchmark.data.utils import ( - Complex2Dict, + ComplexData2Dict, Data2NxGraph, + HypergraphData2Dict, IdentityAdapter, - TnxComplex2Complex, + TnxComplex2ComplexData, ) -from topobenchmark.transforms.feature_liftings.identity import Identity class LiftingTransform(torch_geometric.transforms.BaseTransform): @@ -39,11 +39,8 @@ def __init__( data2domain=None, domain2dict=None, domain2domain=None, - feature_lifting=None, + feature_lifting="ProjectionSum", ): - if feature_lifting is None: - feature_lifting = Identity() - if data2domain is None: data2domain = IdentityAdapter() @@ -129,16 +126,30 @@ def __init__( lifting, feature_lifting=feature_lifting, data2domain=Data2NxGraph(preserve_edge_attr), - domain2domain=TnxComplex2Complex( + domain2domain=TnxComplex2ComplexData( neighborhoods=neighborhoods, signed=signed, transfer_features=transfer_features, ), - domain2dict=Complex2Dict(), + domain2dict=ComplexData2Dict(), ) Graph2SimplicialLiftingTransform = Graph2ComplexLiftingTransform +Graph2CellLiftingTransform = Graph2ComplexLiftingTransform + + +class Graph2HypergraphLiftingTransform(LiftingTransform): + def __init__( + self, + lifting, + feature_lifting="ProjectionSum", + ): + super().__init__( + lifting, + feature_lifting=feature_lifting, + domain2dict=HypergraphData2Dict(), + ) class LiftingMap(abc.ABC): diff --git a/topobenchmark/transforms/liftings/graph2hypergraph/khop.py b/topobenchmark/transforms/liftings/graph2hypergraph/khop.py index f8997e31..7c56006c 100755 --- a/topobenchmark/transforms/liftings/graph2hypergraph/khop.py +++ b/topobenchmark/transforms/liftings/graph2hypergraph/khop.py @@ -3,6 +3,7 @@ import torch import torch_geometric +from topobenchmark.data.utils import HypergraphData from topobenchmark.transforms.liftings.base import LiftingMap @@ -36,7 +37,7 @@ def lift(self, data: torch_geometric.data.Data) -> dict: Returns ------- - dict + HypergraphData The lifted topology. """ # Check if data has instance x: @@ -72,8 +73,8 @@ def lift(self, data: torch_geometric.data.Data) -> dict: num_hyperedges = incidence_1.shape[1] incidence_1 = torch.Tensor(incidence_1).to_sparse_coo() - return { - "incidence_hyperedges": incidence_1, - "num_hyperedges": num_hyperedges, - "x_0": data.x, - } + return HypergraphData( + incidence_hyperedges=incidence_1, + num_hyperedges=num_hyperedges, + x_0=data.x, + ) From f4010fad4856a3423d59354ea4a2adf31a455d8f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lu=C3=ADs=20F=2E=20Pereira?= Date: Thu, 16 Jan 2025 15:31:50 -0800 Subject: [PATCH 24/28] Update data_transform to handle new liftings design --- topobenchmark/transforms/data_transform.py | 69 +++++++++++++++++++--- 1 file changed, 61 insertions(+), 8 deletions(-) diff --git a/topobenchmark/transforms/data_transform.py b/topobenchmark/transforms/data_transform.py index da9e883b..8106f829 100755 --- a/topobenchmark/transforms/data_transform.py +++ b/topobenchmark/transforms/data_transform.py @@ -1,8 +1,54 @@ """DataTransform class.""" +import inspect + import torch_geometric -from topobenchmark.transforms import TRANSFORMS +from topobenchmark.transforms import LIFTINGS, TRANSFORMS +from topobenchmark.transforms.liftings import ( + GRAPH2CELL_LIFTINGS, + GRAPH2HYPERGRAPH_LIFTINGS, + GRAPH2SIMPLICIAL_LIFTINGS, + Graph2CellLiftingTransform, + Graph2HypergraphLiftingTransform, + Graph2SimplicialLiftingTransform, + LiftingTransform, +) + +_map_lifting_types = { + "graph2cell": (GRAPH2CELL_LIFTINGS, Graph2CellLiftingTransform), + "graph2hypergraph": ( + GRAPH2HYPERGRAPH_LIFTINGS, + Graph2HypergraphLiftingTransform, + ), + "graph2simplicial": ( + GRAPH2SIMPLICIAL_LIFTINGS, + Graph2SimplicialLiftingTransform, + ), +} + + +def _map_lifting_name(lifting_name): + for liftings_dict, Transform in _map_lifting_types.values(): + if lifting_name in liftings_dict: + return Transform + + return LiftingTransform + + +def _route_lifting_kwargs(kwargs, LiftingMap): + lifting_map_sign = inspect.signature(LiftingMap) + + lifting_map_kwargs = {} + transform_kwargs = {} + + for key, value in kwargs.items(): + if key in lifting_map_sign.parameters: + lifting_map_kwargs[key] = value + else: + transform_kwargs[key] = value + + return lifting_map_kwargs, transform_kwargs class DataTransform(torch_geometric.transforms.BaseTransform): @@ -19,14 +65,21 @@ class DataTransform(torch_geometric.transforms.BaseTransform): def __init__(self, transform_name, **kwargs): super().__init__() - kwargs["transform_name"] = transform_name - self.parameters = kwargs + if transform_name not in LIFTINGS: + kwargs["transform_name"] = transform_name + transform = TRANSFORMS[transform_name](**kwargs) + else: + LiftingMap_ = TRANSFORMS[transform_name] + Transform = _map_lifting_name(transform_name) + lifting_map_kwargs, transform_kwargs = _route_lifting_kwargs( + kwargs, LiftingMap_ + ) + + lifting_map = LiftingMap_(**lifting_map_kwargs) + transform = Transform(lifting_map, **transform_kwargs) - self.transform = ( - TRANSFORMS[transform_name](**kwargs) - if transform_name is not None - else None - ) + self.parameters = kwargs + self.transform = transform def forward( self, data: torch_geometric.data.Data From 27962fc1e7c75c6a4b1cc99a951451c09b4c3d6f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lu=C3=ADs=20F=2E=20Pereira?= Date: Thu, 16 Jan 2025 15:49:23 -0800 Subject: [PATCH 25/28] Fix failing tests --- test/conftest.py | 8 ++- test/nn/backbones/simplicial/test_sccnn.py | 59 +++++++++++------- test/nn/wrappers/cell/test_cell_wrappers.py | 46 ++++++-------- .../wrappers/simplicial/test_SCCNNWrapper.py | 62 ++++++++++--------- topobenchmark/transforms/data_transform.py | 7 ++- 5 files changed, 97 insertions(+), 85 deletions(-) diff --git a/test/conftest.py b/test/conftest.py index 9a70c6a1..d8cf94d0 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -7,6 +7,8 @@ from topobenchmark.transforms.liftings import ( CellCycleLifting, + Graph2CellLiftingTransform, + Graph2SimplicialLiftingTransform, SimplicialCliqueLifting, ) @@ -148,7 +150,9 @@ def sg1_clique_lifted(simple_graph_1): torch_geometric.data.Data A simple graph data object with a clique lifting. """ - lifting_signed = SimplicialCliqueLifting(complex_dim=3, signed=True) + lifting_signed = Graph2SimplicialLiftingTransform( + SimplicialCliqueLifting(complex_dim=3), signed=True + ) data = lifting_signed(simple_graph_1) data.batch_0 = "null" return data @@ -168,7 +172,7 @@ def sg1_cell_lifted(simple_graph_1): torch_geometric.data.Data A simple graph data object with a cell lifting. """ - lifting = CellCycleLifting() + lifting = Graph2CellLiftingTransform(CellCycleLifting()) data = lifting(simple_graph_1) data.batch_0 = "null" return data diff --git a/test/nn/backbones/simplicial/test_sccnn.py b/test/nn/backbones/simplicial/test_sccnn.py index 19e2b774..09e86342 100644 --- a/test/nn/backbones/simplicial/test_sccnn.py +++ b/test/nn/backbones/simplicial/test_sccnn.py @@ -1,38 +1,53 @@ """Unit tests for SCCNN""" -import torch -from torch_geometric.utils import get_laplacian -from ...._utils.nn_module_auto_test import NNModuleAutoTest from topobenchmark.nn.backbones.simplicial import SCCNNCustom -from topobenchmark.transforms.liftings.graph2simplicial import ( +from topobenchmark.transforms.liftings import ( + Graph2SimplicialLiftingTransform, SimplicialCliqueLifting, ) +from ...._utils.nn_module_auto_test import NNModuleAutoTest + def test_SCCNNCustom(simple_graph_1): - lifting_signed = SimplicialCliqueLifting( - complex_dim=3, signed=True - ) + lifting_signed = Graph2SimplicialLiftingTransform( + SimplicialCliqueLifting(complex_dim=3), signed=True + ) data = lifting_signed(simple_graph_1) out_dim = 4 conv_order = 1 sc_order = 3 laplacian_all = ( - data.hodge_laplacian_0, - data.down_laplacian_1, - data.up_laplacian_1, - data.down_laplacian_2, - data.up_laplacian_2, - ) + data.hodge_laplacian_0, + data.down_laplacian_1, + data.up_laplacian_1, + data.down_laplacian_2, + data.up_laplacian_2, + ) incidence_all = (data.incidence_1, data.incidence_2) - expected_shapes = [(data.x.shape[0], out_dim), (data.x_1.shape[0], out_dim), (data.x_2.shape[0], out_dim)] + expected_shapes = [ + (data.x.shape[0], out_dim), + (data.x_1.shape[0], out_dim), + (data.x_2.shape[0], out_dim), + ] - auto_test = NNModuleAutoTest([ - { - "module" : SCCNNCustom, - "init": ((data.x.shape[1], data.x_1.shape[1], data.x_2.shape[1]), (out_dim, out_dim, out_dim), conv_order, sc_order), - "forward": ((data.x, data.x_1, data.x_2), laplacian_all, incidence_all), - "assert_shape": expected_shapes - }, - ]) + auto_test = NNModuleAutoTest( + [ + { + "module": SCCNNCustom, + "init": ( + (data.x.shape[1], data.x_1.shape[1], data.x_2.shape[1]), + (out_dim, out_dim, out_dim), + conv_order, + sc_order, + ), + "forward": ( + (data.x, data.x_1, data.x_2), + laplacian_all, + incidence_all, + ), + "assert_shape": expected_shapes, + }, + ] + ) auto_test.run() diff --git a/test/nn/wrappers/cell/test_cell_wrappers.py b/test/nn/wrappers/cell/test_cell_wrappers.py index 45b69888..fb551a67 100644 --- a/test/nn/wrappers/cell/test_cell_wrappers.py +++ b/test/nn/wrappers/cell/test_cell_wrappers.py @@ -1,23 +1,14 @@ """Unit tests for cell model wrappers""" -import torch -from torch_geometric.utils import get_laplacian -from ...._utils.nn_module_auto_test import NNModuleAutoTest -from ...._utils.flow_mocker import FlowMocker -from unittest.mock import MagicMock +from topomodelx.nn.cell.ccxn import CCXN +from topomodelx.nn.cell.cwn import CWN +from topobenchmark.nn.backbones.cell.cccn import CCCN from topobenchmark.nn.wrappers import ( - AbstractWrapper, CCCNWrapper, - CANWrapper, CCXNWrapper, - CWNWrapper + CWNWrapper, ) -from topomodelx.nn.cell.can import CAN -from topomodelx.nn.cell.ccxn import CCXN -from topomodelx.nn.cell.cwn import CWN -from topobenchmark.nn.backbones.cell.cccn import CCCN -from unittest.mock import MagicMock class TestCellWrappers: @@ -27,11 +18,9 @@ def test_CCCNWrapper(self, sg1_clique_lifted): num_cell_dimensions = 2 wrapper = CCCNWrapper( - CCCN( - data.x_1.shape[1] - ), - out_channels=out_channels, - num_cell_dimensions=num_cell_dimensions + CCCN(data.x_1.shape[1]), + out_channels=out_channels, + num_cell_dimensions=num_cell_dimensions, ) out = wrapper(data) @@ -44,11 +33,9 @@ def test_CCXNWrapper(self, sg1_cell_lifted): num_cell_dimensions = 2 wrapper = CCXNWrapper( - CCXN( - data.x_0.shape[1], data.x_1.shape[1], out_channels - ), - out_channels=out_channels, - num_cell_dimensions=num_cell_dimensions + CCXN(data.x_0.shape[1], data.x_1.shape[1], out_channels), + out_channels=out_channels, + num_cell_dimensions=num_cell_dimensions, ) out = wrapper(data) @@ -63,13 +50,16 @@ def test_CWNWrapper(self, sg1_cell_lifted): wrapper = CWNWrapper( CWN( - data.x_0.shape[1], data.x_1.shape[1], data.x_2.shape[1], hid_channels, 2 - ), - out_channels=out_channels, - num_cell_dimensions=num_cell_dimensions + data.x_0.shape[1], + data.x_1.shape[1], + data.x_2.shape[1], + hid_channels, + 2, + ), + out_channels=out_channels, + num_cell_dimensions=num_cell_dimensions, ) out = wrapper(data) for key in ["labels", "batch_0", "x_0", "x_1", "x_2"]: assert key in out - diff --git a/test/nn/wrappers/simplicial/test_SCCNNWrapper.py b/test/nn/wrappers/simplicial/test_SCCNNWrapper.py index f3614a7b..bc3e1807 100644 --- a/test/nn/wrappers/simplicial/test_SCCNNWrapper.py +++ b/test/nn/wrappers/simplicial/test_SCCNNWrapper.py @@ -1,26 +1,24 @@ """Unit tests for simplicial model wrappers""" -import torch -from torch_geometric.utils import get_laplacian -from ...._utils.nn_module_auto_test import NNModuleAutoTest -from ...._utils.flow_mocker import FlowMocker -from topobenchmark.nn.backbones.simplicial import SCCNNCustom from topomodelx.nn.simplicial.san import SAN -from topomodelx.nn.simplicial.scn2 import SCN2 from topomodelx.nn.simplicial.sccn import SCCN +from topomodelx.nn.simplicial.scn2 import SCN2 + +from topobenchmark.nn.backbones.simplicial import SCCNNCustom from topobenchmark.nn.wrappers import ( - SCCNWrapper, - SCCNNWrapper, SANWrapper, - SCNWrapper + SCCNNWrapper, + SCCNWrapper, + SCNWrapper, ) + class TestSimplicialWrappers: """Test simplicial model wrappers.""" def test_SCCNNWrapper(self, sg1_clique_lifted): """Test SCCNNWrapper. - + Parameters ---------- sg1_clique_lifted : torch_geometric.data.Data @@ -30,12 +28,17 @@ def test_SCCNNWrapper(self, sg1_clique_lifted): out_dim = 4 conv_order = 1 sc_order = 3 - init_args = (data.x_0.shape[1], data.x_1.shape[1], data.x_2.shape[1]), (out_dim, out_dim, out_dim), conv_order, sc_order + init_args = ( + (data.x_0.shape[1], data.x_1.shape[1], data.x_2.shape[1]), + (out_dim, out_dim, out_dim), + conv_order, + sc_order, + ) wrapper = SCCNNWrapper( - SCCNNCustom(*init_args), - out_channels=out_dim, - num_cell_dimensions=3 + SCCNNCustom(*init_args), + out_channels=out_dim, + num_cell_dimensions=3, ) out = wrapper(data) # Assert keys in output @@ -44,20 +47,20 @@ def test_SCCNNWrapper(self, sg1_clique_lifted): def test_SANWarpper(self, sg1_clique_lifted): """Test SANWarpper. - + Parameters ---------- sg1_clique_lifted : torch_geometric.data.Data - A fixture of simple graph 1 lifted with SimlicialCliqueLifting + A fixture of simple graph 1 lifted with SimlicialCliqueLifting """ data = sg1_clique_lifted out_dim = data.x_0.shape[1] hidden_channels = data.x_0.shape[1] wrapper = SANWrapper( - SAN(data.x_0.shape[1], hidden_channels), - out_channels=out_dim, - num_cell_dimensions=3 + SAN(data.x_0.shape[1], hidden_channels), + out_channels=out_dim, + num_cell_dimensions=3, ) out = wrapper(data) # Assert keys in output @@ -66,19 +69,19 @@ def test_SANWarpper(self, sg1_clique_lifted): def test_SCNWrapper(self, sg1_clique_lifted): """Test SCNWrapper. - + Parameters ---------- sg1_clique_lifted : torch_geometric.data.Data - A fixture of simple graph 1 lifted with SimlicialCliqueLifting + A fixture of simple graph 1 lifted with SimlicialCliqueLifting """ data = sg1_clique_lifted out_dim = data.x_0.shape[1] wrapper = SCNWrapper( - SCN2(data.x_0.shape[1], data.x_1.shape[1], data.x_2.shape[1]), - out_channels=out_dim, - num_cell_dimensions=3 + SCN2(data.x_0.shape[1], data.x_1.shape[1], data.x_2.shape[1]), + out_channels=out_dim, + num_cell_dimensions=3, ) out = wrapper(data) # Assert keys in output @@ -87,23 +90,22 @@ def test_SCNWrapper(self, sg1_clique_lifted): def test_SCCNWrapper(self, sg1_clique_lifted): """Test SCCNWrapper. - + Parameters ---------- sg1_clique_lifted : torch_geometric.data.Data - A fixture of simple graph 1 lifted with SimlicialCliqueLifting + A fixture of simple graph 1 lifted with SimlicialCliqueLifting """ data = sg1_clique_lifted out_dim = data.x_0.shape[1] max_rank = 2 wrapper = SCCNWrapper( - SCCN(data.x_0.shape[1], max_rank), - out_channels=out_dim, - num_cell_dimensions=3 + SCCN(data.x_0.shape[1], max_rank), + out_channels=out_dim, + num_cell_dimensions=3, ) out = wrapper(data) # Assert keys in output for key in ["labels", "batch_0", "x_0", "x_1", "x_2"]: assert key in out - diff --git a/topobenchmark/transforms/data_transform.py b/topobenchmark/transforms/data_transform.py index 8106f829..af48dc88 100755 --- a/topobenchmark/transforms/data_transform.py +++ b/topobenchmark/transforms/data_transform.py @@ -36,8 +36,9 @@ def _map_lifting_name(lifting_name): return LiftingTransform -def _route_lifting_kwargs(kwargs, LiftingMap): +def _route_lifting_kwargs(kwargs, LiftingMap, Transform): lifting_map_sign = inspect.signature(LiftingMap) + transform_sign = inspect.signature(Transform) lifting_map_kwargs = {} transform_kwargs = {} @@ -45,7 +46,7 @@ def _route_lifting_kwargs(kwargs, LiftingMap): for key, value in kwargs.items(): if key in lifting_map_sign.parameters: lifting_map_kwargs[key] = value - else: + elif key in transform_sign.parameters: transform_kwargs[key] = value return lifting_map_kwargs, transform_kwargs @@ -72,7 +73,7 @@ def __init__(self, transform_name, **kwargs): LiftingMap_ = TRANSFORMS[transform_name] Transform = _map_lifting_name(transform_name) lifting_map_kwargs, transform_kwargs = _route_lifting_kwargs( - kwargs, LiftingMap_ + kwargs, LiftingMap_, Transform ) lifting_map = LiftingMap_(**lifting_map_kwargs) From 1f5ad563eba1c84100e8c7561244625cb79b9fa2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lu=C3=ADs=20F=2E=20Pereira?= Date: Thu, 16 Jan 2025 16:26:15 -0800 Subject: [PATCH 26/28] Fix tutorial_lifting --- topobenchmark/transforms/__init__.py | 24 +++- topobenchmark/transforms/data_transform.py | 31 ++--- topobenchmark/transforms/liftings/__init__.py | 1 + .../liftings/graph2simplicial/clique.py | 2 +- tutorials/tutorial_lifting.ipynb | 124 ++++++++++-------- 5 files changed, 110 insertions(+), 72 deletions(-) diff --git a/topobenchmark/transforms/__init__.py b/topobenchmark/transforms/__init__.py index 62f8d85e..20840dfe 100755 --- a/topobenchmark/transforms/__init__.py +++ b/topobenchmark/transforms/__init__.py @@ -2,10 +2,32 @@ from .data_manipulations import DATA_MANIPULATIONS from .feature_liftings import FEATURE_LIFTINGS -from .liftings import LIFTINGS +from .liftings import ( + GRAPH2CELL_LIFTINGS, + GRAPH2HYPERGRAPH_LIFTINGS, + GRAPH2SIMPLICIAL_LIFTINGS, + LIFTINGS, +) TRANSFORMS = { **LIFTINGS, **FEATURE_LIFTINGS, **DATA_MANIPULATIONS, } + + +_map_lifting_type_to_dict = { + "graph2cell": GRAPH2CELL_LIFTINGS, + "graph2hypergraph": GRAPH2HYPERGRAPH_LIFTINGS, + "graph2simplicial": GRAPH2SIMPLICIAL_LIFTINGS, +} + + +def add_lifting_map(LiftingMap, lifting_type, name=None): + if name is None: + name = LiftingMap.__name__ + + liftings_dict = _map_lifting_type_to_dict[lifting_type] + + for dict_ in (liftings_dict, LIFTINGS, TRANSFORMS): + dict_[name] = LiftingMap diff --git a/topobenchmark/transforms/data_transform.py b/topobenchmark/transforms/data_transform.py index af48dc88..c1cda424 100755 --- a/topobenchmark/transforms/data_transform.py +++ b/topobenchmark/transforms/data_transform.py @@ -4,34 +4,29 @@ import torch_geometric -from topobenchmark.transforms import LIFTINGS, TRANSFORMS +from topobenchmark.transforms import ( + LIFTINGS, + TRANSFORMS, + _map_lifting_type_to_dict, +) from topobenchmark.transforms.liftings import ( - GRAPH2CELL_LIFTINGS, - GRAPH2HYPERGRAPH_LIFTINGS, - GRAPH2SIMPLICIAL_LIFTINGS, Graph2CellLiftingTransform, Graph2HypergraphLiftingTransform, Graph2SimplicialLiftingTransform, LiftingTransform, ) -_map_lifting_types = { - "graph2cell": (GRAPH2CELL_LIFTINGS, Graph2CellLiftingTransform), - "graph2hypergraph": ( - GRAPH2HYPERGRAPH_LIFTINGS, - Graph2HypergraphLiftingTransform, - ), - "graph2simplicial": ( - GRAPH2SIMPLICIAL_LIFTINGS, - Graph2SimplicialLiftingTransform, - ), +_map_lifting_type_to_transform = { + "graph2cell": Graph2CellLiftingTransform, + "graph2hypergraph": Graph2HypergraphLiftingTransform, + "graph2simplicial": Graph2SimplicialLiftingTransform, } -def _map_lifting_name(lifting_name): - for liftings_dict, Transform in _map_lifting_types.values(): +def _map_lifting_to_transform(lifting_name): + for key, liftings_dict in _map_lifting_type_to_dict.items(): if lifting_name in liftings_dict: - return Transform + return _map_lifting_type_to_transform[key] return LiftingTransform @@ -71,7 +66,7 @@ def __init__(self, transform_name, **kwargs): transform = TRANSFORMS[transform_name](**kwargs) else: LiftingMap_ = TRANSFORMS[transform_name] - Transform = _map_lifting_name(transform_name) + Transform = _map_lifting_to_transform(transform_name) lifting_map_kwargs, transform_kwargs = _route_lifting_kwargs( kwargs, LiftingMap_, Transform ) diff --git a/topobenchmark/transforms/liftings/__init__.py b/topobenchmark/transforms/liftings/__init__.py index 322c43a1..10e1e3c1 100755 --- a/topobenchmark/transforms/liftings/__init__.py +++ b/topobenchmark/transforms/liftings/__init__.py @@ -5,6 +5,7 @@ Graph2ComplexLiftingTransform, Graph2HypergraphLiftingTransform, Graph2SimplicialLiftingTransform, + LiftingMap, LiftingTransform, ) from .graph2cell import GRAPH2CELL_LIFTINGS diff --git a/topobenchmark/transforms/liftings/graph2simplicial/clique.py b/topobenchmark/transforms/liftings/graph2simplicial/clique.py index 04baa1ef..41047a62 100755 --- a/topobenchmark/transforms/liftings/graph2simplicial/clique.py +++ b/topobenchmark/transforms/liftings/graph2simplicial/clique.py @@ -50,7 +50,7 @@ def lift(self, domain): for set_k_simplices in simplices: simplicial_complex.add_simplices_from(list(set_k_simplices)) - # because Complex pads unexisting dimensions with empty matrices + # because ComplexData pads unexisting dimensions with empty matrices simplicial_complex.practical_dim = self.complex_dim return simplicial_complex diff --git a/tutorials/tutorial_lifting.ipynb b/tutorials/tutorial_lifting.ipynb index d1a77003..af533a1b 100644 --- a/tutorials/tutorial_lifting.ipynb +++ b/tutorials/tutorial_lifting.ipynb @@ -56,8 +56,6 @@ "\n", "import lightning as pl\n", "import networkx as nx\n", - "import hydra\n", - "import torch_geometric\n", "from omegaconf import OmegaConf\n", "from topomodelx.nn.simplicial.scn2 import SCN2\n", "from toponetx.classes import SimplicialComplex\n", @@ -72,8 +70,8 @@ "from topobenchmark.nn.readouts import PropagateSignalDown\n", "from topobenchmark.nn.wrappers.simplicial import SCNWrapper\n", "from topobenchmark.optimizer import TBOptimizer\n", - "from topobenchmark.transforms.liftings.graph2simplicial import (\n", - " Graph2SimplicialLifting,\n", + "from topobenchmark.transforms.liftings import (\n", + " LiftingMap,\n", ")" ] }, @@ -101,14 +99,17 @@ " \"data_domain\": \"graph\",\n", " \"data_type\": \"TUDataset\",\n", " \"data_name\": \"MUTAG\",\n", - " \"data_dir\": \"./data/MUTAG/\"}\n", + " \"data_dir\": \"./data/MUTAG/\",\n", + "}\n", "\n", "\n", - "transform_config = { \"clique_lifting\":\n", - " {\"_target_\": \"__main__.SimplicialCliquesLEQLifting\",\n", - " \"transform_name\": \"SimplicialCliquesLEQLifting\",\n", - " \"transform_type\": \"lifting\",\n", - " \"complex_dim\": 3,}\n", + "transform_config = {\n", + " \"clique_lifting\": {\n", + " \"_target_\": \"topobenchmark.transforms.data_transform.DataTransform\",\n", + " \"transform_name\": \"SimplicialCliquesLEQLifting\",\n", + " \"transform_type\": \"lifting\",\n", + " \"complex_dim\": 3,\n", + " }\n", "}\n", "\n", "split_config = {\n", @@ -138,21 +139,19 @@ "}\n", "\n", "loss_config = {\n", - " \"dataset_loss\": \n", - " {\n", - " \"task\": \"classification\", \n", - " \"loss_type\": \"cross_entropy\"\n", - " }\n", + " \"dataset_loss\": {\"task\": \"classification\", \"loss_type\": \"cross_entropy\"}\n", "}\n", "\n", - "evaluator_config = {\"task\": \"classification\",\n", - " \"num_classes\": out_channels,\n", - " \"metrics\": [\"accuracy\", \"precision\", \"recall\"]}\n", + "evaluator_config = {\n", + " \"task\": \"classification\",\n", + " \"num_classes\": out_channels,\n", + " \"metrics\": [\"accuracy\", \"precision\", \"recall\"],\n", + "}\n", "\n", - "optimizer_config = {\"optimizer_id\": \"Adam\",\n", - " \"parameters\":\n", - " {\"lr\": 0.001,\"weight_decay\": 0.0005}\n", - " }\n", + "optimizer_config = {\n", + " \"optimizer_id\": \"Adam\",\n", + " \"parameters\": {\"lr\": 0.001, \"weight_decay\": 0.0005},\n", + "}\n", "\n", "\n", "loader_config = OmegaConf.create(loader_config)\n", @@ -174,6 +173,7 @@ "def wrapper(**factory_kwargs):\n", " def factory(backbone):\n", " return SCNWrapper(backbone, **factory_kwargs)\n", + "\n", " return factory" ] }, @@ -197,16 +197,15 @@ "metadata": {}, "outputs": [], "source": [ - "class SimplicialCliquesLEQLifting(Graph2SimplicialLifting):\n", + "class SimplicialCliquesLEQLifting(LiftingMap):\n", " r\"\"\"Lifts graphs to simplicial complex domain by identifying the cliques as k-simplices. Only the cliques with size smaller or equal to the max complex dimension are considered.\n", - " \n", - " Args:\n", - " kwargs (optional): Additional arguments for the class.\n", " \"\"\"\n", - " def __init__(self, **kwargs):\n", - " super().__init__(**kwargs)\n", + " def __init__(self, complex_dim=2):\n", + " super().__init__()\n", + " self.complex_dim = complex_dim\n", + "\n", "\n", - " def lift_topology(self, data: torch_geometric.data.Data) -> dict:\n", + " def lift(self, domain) -> dict:\n", " r\"\"\"Lifts the topology of a graph to a simplicial complex by identifying the cliques as k-simplices. Only the cliques with size smaller or equal to the max complex dimension are considered.\n", "\n", " Args:\n", @@ -214,11 +213,14 @@ " Returns:\n", " dict: The lifted topology.\n", " \"\"\"\n", - " graph = self._generate_graph_from_data(data)\n", + " graph = domain\n", + "\n", " simplicial_complex = SimplicialComplex(graph)\n", " cliques = nx.find_cliques(graph)\n", - " \n", - " simplices: list[set[tuple[Any, ...]]] = [set() for _ in range(2, self.complex_dim + 1)]\n", + "\n", + " simplices: list[set[tuple[Any, ...]]] = [\n", + " set() for _ in range(2, self.complex_dim + 1)\n", + " ]\n", " for clique in cliques:\n", " if len(clique) <= self.complex_dim + 1:\n", " for i in range(2, self.complex_dim + 1):\n", @@ -227,8 +229,11 @@ "\n", " for set_k_simplices in simplices:\n", " simplicial_complex.add_simplices_from(list(set_k_simplices))\n", + " \n", + " # because ComplexData pads unexisting dimensions with empty matrices\n", + " simplicial_complex.practical_dim = self.complex_dim\n", "\n", - " return self._get_lifted_topology(simplicial_complex, graph)\n" + " return simplicial_complex" ] }, { @@ -251,9 +256,9 @@ "metadata": {}, "outputs": [], "source": [ - "from topobenchmark.transforms import TRANSFORMS\n", + "from topobenchmark.transforms import add_lifting_map\n", "\n", - "TRANSFORMS[\"SimplicialCliquesLEQLifting\"] = SimplicialCliquesLEQLifting" + "add_lifting_map(SimplicialCliquesLEQLifting, \"graph2simplicial\")" ] }, { @@ -275,8 +280,12 @@ "dataset, dataset_dir = graph_loader.load()\n", "\n", "preprocessor = PreProcessor(dataset, dataset_dir, transform_config)\n", - "dataset_train, dataset_val, dataset_test = preprocessor.load_dataset_splits(split_config)\n", - "datamodule = TBDataloader(dataset_train, dataset_val, dataset_test, batch_size=32)" + "dataset_train, dataset_val, dataset_test = preprocessor.load_dataset_splits(\n", + " split_config\n", + ")\n", + "datamodule = TBDataloader(\n", + " dataset_train, dataset_val, dataset_test, batch_size=32\n", + ")" ] }, { @@ -299,12 +308,19 @@ "metadata": {}, "outputs": [], "source": [ - "backbone = SCN2(in_channels_0=dim_hidden,in_channels_1=dim_hidden,in_channels_2=dim_hidden)\n", + "backbone = SCN2(\n", + " in_channels_0=dim_hidden,\n", + " in_channels_1=dim_hidden,\n", + " in_channels_2=dim_hidden,\n", + ")\n", "backbone_wrapper = wrapper(**wrapper_config)\n", "\n", "readout = PropagateSignalDown(**readout_config)\n", "loss = TBLoss(**loss_config)\n", - "feature_encoder = AllCellFeatureEncoder(in_channels=[in_channels, in_channels, in_channels], out_channels=dim_hidden)\n", + "feature_encoder = AllCellFeatureEncoder(\n", + " in_channels=[in_channels, in_channels, in_channels],\n", + " out_channels=dim_hidden,\n", + ")\n", "\n", "evaluator = TBEvaluator(**evaluator_config)\n", "optimizer = TBOptimizer(**optimizer_config)" @@ -316,14 +332,16 @@ "metadata": {}, "outputs": [], "source": [ - "model = TBModel(backbone=backbone,\n", - " backbone_wrapper=backbone_wrapper,\n", - " readout=readout,\n", - " loss=loss,\n", - " feature_encoder=feature_encoder,\n", - " evaluator=evaluator,\n", - " optimizer=optimizer,\n", - " compile=False,)" + "model = TBModel(\n", + " backbone=backbone,\n", + " backbone_wrapper=backbone_wrapper,\n", + " readout=readout,\n", + " loss=loss,\n", + " feature_encoder=feature_encoder,\n", + " evaluator=evaluator,\n", + " optimizer=optimizer,\n", + " compile=False,\n", + ")" ] }, { @@ -386,7 +404,9 @@ ], "source": [ "# Increase the number of epochs to get better results\n", - "trainer = pl.Trainer(max_epochs=50, accelerator=\"cpu\", enable_progress_bar=False)\n", + "trainer = pl.Trainer(\n", + " max_epochs=50, accelerator=\"cpu\", enable_progress_bar=False\n", + ")\n", "\n", "trainer.fit(model, datamodule)\n", "train_metrics = trainer.callback_metrics" @@ -415,9 +435,9 @@ } ], "source": [ - "print(' Training metrics\\n', '-'*26)\n", + "print(\" Training metrics\\n\", \"-\" * 26)\n", "for key in train_metrics:\n", - " print('{:<21s} {:>5.4f}'.format(key+':', train_metrics[key].item()))" + " print(\"{:<21s} {:>5.4f}\".format(key + \":\", train_metrics[key].item()))" ] }, { @@ -505,9 +525,9 @@ } ], "source": [ - "print(' Testing metrics\\n', '-'*25)\n", + "print(\" Testing metrics\\n\", \"-\" * 25)\n", "for key in test_metrics:\n", - " print('{:<20s} {:>5.4f}'.format(key+':', test_metrics[key].item()))" + " print(\"{:<20s} {:>5.4f}\".format(key + \":\", test_metrics[key].item()))" ] }, { From 81df9ac6e05a75d633079dbc03a5d3a094e6534e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lu=C3=ADs=20F=2E=20Pereira?= Date: Thu, 16 Jan 2025 19:48:35 -0800 Subject: [PATCH 27/28] Remove use of lambda func --- topobenchmark/transforms/_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/topobenchmark/transforms/_utils.py b/topobenchmark/transforms/_utils.py index f14d156e..c2e0750c 100644 --- a/topobenchmark/transforms/_utils.py +++ b/topobenchmark/transforms/_utils.py @@ -19,7 +19,9 @@ def discover_objs(package_path, condition=None): Dictionary mapping class names to their corresponding class objects. """ if condition is None: - condition = lambda name, obj: True + + def condition(name, obj): + return True objs = {} From 4f10f7b6125895d11e9dc3f28aa369b789bae139 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lu=C3=ADs=20F=2E=20Pereira?= Date: Tue, 21 Jan 2025 18:05:39 -0800 Subject: [PATCH 28/28] Bump codecov to v5 --- .github/workflows/test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 20ebf7d8..e6dbd429 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -44,7 +44,7 @@ jobs: pytest --cov --cov-report=xml:coverage.xml test/ - name: Upload coverage reports to Codecov - uses: codecov/codecov-action@v4.0.1 + uses: codecov/codecov-action@v5 with: token: ${{ secrets.CODECOV_TOKEN }} file: coverage.xml