diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 20ebf7d8..e6dbd429 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -44,7 +44,7 @@ jobs: pytest --cov --cov-report=xml:coverage.xml test/ - name: Upload coverage reports to Codecov - uses: codecov/codecov-action@v4.0.1 + uses: codecov/codecov-action@v5 with: token: ${{ secrets.CODECOV_TOKEN }} file: coverage.xml diff --git a/test/conftest.py b/test/conftest.py index c84a1b72..d8cf94d0 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -1,25 +1,27 @@ """Configuration file for pytest.""" + import networkx as nx import pytest import torch import torch_geometric -from topobenchmark.transforms.liftings.graph2simplicial import ( - SimplicialCliqueLifting -) -from topobenchmark.transforms.liftings.graph2cell import ( - CellCycleLifting + +from topobenchmark.transforms.liftings import ( + CellCycleLifting, + Graph2CellLiftingTransform, + Graph2SimplicialLiftingTransform, + SimplicialCliqueLifting, ) @pytest.fixture def mocker_fixture(mocker): """Return pytest mocker, used when one want to use mocker in setup_method. - + Parameters ---------- mocker : pytest_mock.plugin.MockerFixture A pytest mocker. - + Returns ------- pytest_mock.plugin.MockerFixture @@ -31,7 +33,7 @@ def mocker_fixture(mocker): @pytest.fixture def simple_graph_0(): """Create a manual graph for testing purposes. - + Returns ------- torch_geometric.data.Data @@ -74,10 +76,11 @@ def simple_graph_0(): ) return data + @pytest.fixture def simple_graph_1(): """Create a manual graph for testing purposes. - + Returns ------- torch_geometric.data.Data @@ -133,43 +136,43 @@ def simple_graph_1(): return data - @pytest.fixture def sg1_clique_lifted(simple_graph_1): """Return a simple graph with a clique lifting. - + Parameters ---------- simple_graph_1 : torch_geometric.data.Data A simple graph data object. - + Returns ------- torch_geometric.data.Data A simple graph data object with a clique lifting. """ - lifting_signed = SimplicialCliqueLifting( - complex_dim=3, signed=True - ) + lifting_signed = Graph2SimplicialLiftingTransform( + SimplicialCliqueLifting(complex_dim=3), signed=True + ) data = lifting_signed(simple_graph_1) data.batch_0 = "null" return data + @pytest.fixture def sg1_cell_lifted(simple_graph_1): """Return a simple graph with a cell lifting. - + Parameters ---------- simple_graph_1 : torch_geometric.data.Data A simple graph data object. - + Returns ------- torch_geometric.data.Data A simple graph data object with a cell lifting. """ - lifting = CellCycleLifting() + lifting = Graph2CellLiftingTransform(CellCycleLifting()) data = lifting(simple_graph_1) data.batch_0 = "null" return data @@ -178,7 +181,7 @@ def sg1_cell_lifted(simple_graph_1): @pytest.fixture def simple_graph_2(): """Create a manual graph for testing purposes. - + Returns ------- torch_geometric.data.Data @@ -244,7 +247,7 @@ def simple_graph_2(): @pytest.fixture def random_graph_input(): """Create a random graph for testing purposes. - + Returns ------- torch.Tensor @@ -261,13 +264,12 @@ def random_graph_input(): num_nodes = 8 d_feat = 12 x = torch.randn(num_nodes, 12) - edges_1 = torch.randint(0, num_nodes, (2, num_nodes*2)) - edges_2 = torch.randint(0, num_nodes, (2, num_nodes*2)) - + edges_1 = torch.randint(0, num_nodes, (2, num_nodes * 2)) + edges_2 = torch.randint(0, num_nodes, (2, num_nodes * 2)) + d_feat_1, d_feat_2 = 5, 17 - x_1 = torch.randn(num_nodes*2, d_feat_1) - x_2 = torch.randn(num_nodes*2, d_feat_2) + x_1 = torch.randn(num_nodes * 2, d_feat_1) + x_2 = torch.randn(num_nodes * 2, d_feat_2) return x, x_1, x_2, edges_1, edges_2 - diff --git a/test/nn/backbones/simplicial/test_sccnn.py b/test/nn/backbones/simplicial/test_sccnn.py index 4f4c5f68..7546713c 100644 --- a/test/nn/backbones/simplicial/test_sccnn.py +++ b/test/nn/backbones/simplicial/test_sccnn.py @@ -2,40 +2,57 @@ import pytest import torch -from torch_geometric.utils import get_laplacian -from ...._utils.nn_module_auto_test import NNModuleAutoTest + from topobenchmark.nn.backbones.simplicial import SCCNNCustom -from topobenchmark.transforms.liftings.graph2simplicial import ( +from topobenchmark.transforms.liftings import ( + Graph2SimplicialLiftingTransform, SimplicialCliqueLifting, ) +from ...._utils.nn_module_auto_test import NNModuleAutoTest + def test_SCCNNCustom(simple_graph_1): - lifting_signed = SimplicialCliqueLifting( - complex_dim=3, signed=True - ) + lifting_signed = Graph2SimplicialLiftingTransform( + SimplicialCliqueLifting(complex_dim=3), signed=True + ) data = lifting_signed(simple_graph_1) out_dim = 4 conv_order = 1 sc_order = 3 laplacian_all = ( - data.hodge_laplacian_0, - data.down_laplacian_1, - data.up_laplacian_1, - data.down_laplacian_2, - data.up_laplacian_2, - ) + data.hodge_laplacian_0, + data.down_laplacian_1, + data.up_laplacian_1, + data.down_laplacian_2, + data.up_laplacian_2, + ) incidence_all = (data.incidence_1, data.incidence_2) - expected_shapes = [(data.x.shape[0], out_dim), (data.x_1.shape[0], out_dim), (data.x_2.shape[0], out_dim)] - - auto_test = NNModuleAutoTest([ - { - "module" : SCCNNCustom, - "init": ((data.x.shape[1], data.x_1.shape[1], data.x_2.shape[1]), (out_dim, out_dim, out_dim), conv_order, sc_order), - "forward": ((data.x, data.x_1, data.x_2), laplacian_all, incidence_all), - "assert_shape": expected_shapes - }, - ]) + expected_shapes = [ + (data.x.shape[0], out_dim), + (data.x_1.shape[0], out_dim), + (data.x_2.shape[0], out_dim), + ] + + auto_test = NNModuleAutoTest( + [ + { + "module": SCCNNCustom, + "init": ( + (data.x.shape[1], data.x_1.shape[1], data.x_2.shape[1]), + (out_dim, out_dim, out_dim), + conv_order, + sc_order, + ), + "forward": ( + (data.x, data.x_1, data.x_2), + laplacian_all, + incidence_all, + ), + "assert_shape": expected_shapes, + }, + ] + ) auto_test.run() @@ -46,189 +63,202 @@ def create_sample_data(): x = torch.randn(num_nodes, 3) # 3 node features x_1 = torch.randn(8, 4) # 8 edges with 4 features x_2 = torch.randn(6, 5) # 6 faces with 5 features - + # Create sample Laplacians and incidence matrices hodge_laplacian_0 = torch.sparse_coo_tensor(size=(num_nodes, num_nodes)) down_laplacian_1 = torch.sparse_coo_tensor(size=(8, 8)) up_laplacian_1 = torch.sparse_coo_tensor(size=(8, 8)) down_laplacian_2 = torch.sparse_coo_tensor(size=(6, 6)) up_laplacian_2 = torch.sparse_coo_tensor(size=(6, 6)) - + incidence_1 = torch.sparse_coo_tensor(size=(num_nodes, 8)) incidence_2 = torch.sparse_coo_tensor(size=(8, 6)) - + return { - 'x': x, - 'x_1': x_1, - 'x_2': x_2, - 'laplacian_all': (hodge_laplacian_0, down_laplacian_1, up_laplacian_1, down_laplacian_2, up_laplacian_2), - 'incidence_all': (incidence_1, incidence_2) + "x": x, + "x_1": x_1, + "x_2": x_2, + "laplacian_all": ( + hodge_laplacian_0, + down_laplacian_1, + up_laplacian_1, + down_laplacian_2, + up_laplacian_2, + ), + "incidence_all": (incidence_1, incidence_2), } + def test_sccnn_basic_initialization(): """Test basic initialization of SCCNNCustom.""" in_channels = (3, 4, 5) hidden_channels = (6, 6, 6) - + # Test basic initialization model = SCCNNCustom( in_channels_all=in_channels, hidden_channels_all=hidden_channels, conv_order=2, - sc_order=3 + sc_order=3, ) assert model is not None - + # Verify layer structure assert len(model.layers) == 2 # Default n_layers is 2 - assert hasattr(model, 'in_linear_0') - assert hasattr(model, 'in_linear_1') - assert hasattr(model, 'in_linear_2') + assert hasattr(model, "in_linear_0") + assert hasattr(model, "in_linear_1") + assert hasattr(model, "in_linear_2") + def test_update_functions(): """Test different update functions in the SCCNN.""" in_channels = (3, 4, 5) hidden_channels = (6, 6, 6) - + # Test sigmoid update function model = SCCNNCustom( in_channels_all=in_channels, hidden_channels_all=hidden_channels, conv_order=2, sc_order=3, - update_func="sigmoid" + update_func="sigmoid", ) assert model is not None - + # Test ReLU update function model = SCCNNCustom( in_channels_all=in_channels, hidden_channels_all=hidden_channels, conv_order=2, sc_order=3, - update_func="relu" + update_func="relu", ) assert model is not None + def test_aggr_norm(create_sample_data): """Test aggregation normalization functionality.""" data = create_sample_data - + model = SCCNNCustom( in_channels_all=(3, 4, 5), hidden_channels_all=(6, 6, 6), conv_order=2, sc_order=3, - aggr_norm=True + aggr_norm=True, ) - + # Forward pass with aggregation normalization output = model( - (data['x'], data['x_1'], data['x_2']), - data['laplacian_all'], - data['incidence_all'] + (data["x"], data["x_1"], data["x_2"]), + data["laplacian_all"], + data["incidence_all"], ) - + assert len(output) == 3 assert all(torch.isfinite(out).all() for out in output) + def test_different_conv_orders(): """Test SCCNN with different convolution orders.""" in_channels = (3, 4, 5) hidden_channels = (6, 6, 6) - + # Test with conv_order = 1 model1 = SCCNNCustom( in_channels_all=in_channels, hidden_channels_all=hidden_channels, conv_order=1, - sc_order=3 + sc_order=3, ) assert model1 is not None - + # Test with conv_order = 3 model2 = SCCNNCustom( in_channels_all=in_channels, hidden_channels_all=hidden_channels, conv_order=3, - sc_order=3 + sc_order=3, ) assert model2 is not None - + # Test invalid conv_order with pytest.raises(AssertionError): model = SCCNNCustom( in_channels_all=in_channels, hidden_channels_all=hidden_channels, conv_order=0, - sc_order=3 + sc_order=3, ) + def test_different_sc_orders(): """Test SCCNN with different simplicial complex orders.""" in_channels = (3, 4, 5) hidden_channels = (6, 6, 6) - + # Test with sc_order = 2 model1 = SCCNNCustom( in_channels_all=in_channels, hidden_channels_all=hidden_channels, conv_order=2, - sc_order=2 + sc_order=2, ) assert model1 is not None - + # Test with sc_order > 2 model2 = SCCNNCustom( in_channels_all=in_channels, hidden_channels_all=hidden_channels, conv_order=2, - sc_order=3 + sc_order=3, ) assert model2 is not None + def test_forward_shapes(create_sample_data): """Test output shapes for different input configurations.""" data = create_sample_data - + model = SCCNNCustom( in_channels_all=(3, 4, 5), hidden_channels_all=(6, 6, 6), conv_order=2, - sc_order=3 + sc_order=3, ) - + output = model( - (data['x'], data['x_1'], data['x_2']), - data['laplacian_all'], - data['incidence_all'] + (data["x"], data["x_1"], data["x_2"]), + data["laplacian_all"], + data["incidence_all"], ) - - assert output[0].shape == (data['x'].shape[0], 6) - assert output[1].shape == (data['x_1'].shape[0], 6) - assert output[2].shape == (data['x_2'].shape[0], 6) + + assert output[0].shape == (data["x"].shape[0], 6) + assert output[1].shape == (data["x_1"].shape[0], 6) + assert output[2].shape == (data["x_2"].shape[0], 6) + def test_n_layers(): """Test SCCNN with different numbers of layers.""" in_channels = (3, 4, 5) hidden_channels = (6, 6, 6) - + # Test with 1 layer model1 = SCCNNCustom( in_channels_all=in_channels, hidden_channels_all=hidden_channels, conv_order=2, sc_order=3, - n_layers=1 + n_layers=1, ) assert len(model1.layers) == 1 - + # Test with 3 layers model2 = SCCNNCustom( in_channels_all=in_channels, hidden_channels_all=hidden_channels, conv_order=2, sc_order=3, - n_layers=3 + n_layers=3, ) - assert len(model2.layers) == 3 \ No newline at end of file + assert len(model2.layers) == 3 diff --git a/test/nn/wrappers/cell/test_cell_wrappers.py b/test/nn/wrappers/cell/test_cell_wrappers.py index 45b69888..fb551a67 100644 --- a/test/nn/wrappers/cell/test_cell_wrappers.py +++ b/test/nn/wrappers/cell/test_cell_wrappers.py @@ -1,23 +1,14 @@ """Unit tests for cell model wrappers""" -import torch -from torch_geometric.utils import get_laplacian -from ...._utils.nn_module_auto_test import NNModuleAutoTest -from ...._utils.flow_mocker import FlowMocker -from unittest.mock import MagicMock +from topomodelx.nn.cell.ccxn import CCXN +from topomodelx.nn.cell.cwn import CWN +from topobenchmark.nn.backbones.cell.cccn import CCCN from topobenchmark.nn.wrappers import ( - AbstractWrapper, CCCNWrapper, - CANWrapper, CCXNWrapper, - CWNWrapper + CWNWrapper, ) -from topomodelx.nn.cell.can import CAN -from topomodelx.nn.cell.ccxn import CCXN -from topomodelx.nn.cell.cwn import CWN -from topobenchmark.nn.backbones.cell.cccn import CCCN -from unittest.mock import MagicMock class TestCellWrappers: @@ -27,11 +18,9 @@ def test_CCCNWrapper(self, sg1_clique_lifted): num_cell_dimensions = 2 wrapper = CCCNWrapper( - CCCN( - data.x_1.shape[1] - ), - out_channels=out_channels, - num_cell_dimensions=num_cell_dimensions + CCCN(data.x_1.shape[1]), + out_channels=out_channels, + num_cell_dimensions=num_cell_dimensions, ) out = wrapper(data) @@ -44,11 +33,9 @@ def test_CCXNWrapper(self, sg1_cell_lifted): num_cell_dimensions = 2 wrapper = CCXNWrapper( - CCXN( - data.x_0.shape[1], data.x_1.shape[1], out_channels - ), - out_channels=out_channels, - num_cell_dimensions=num_cell_dimensions + CCXN(data.x_0.shape[1], data.x_1.shape[1], out_channels), + out_channels=out_channels, + num_cell_dimensions=num_cell_dimensions, ) out = wrapper(data) @@ -63,13 +50,16 @@ def test_CWNWrapper(self, sg1_cell_lifted): wrapper = CWNWrapper( CWN( - data.x_0.shape[1], data.x_1.shape[1], data.x_2.shape[1], hid_channels, 2 - ), - out_channels=out_channels, - num_cell_dimensions=num_cell_dimensions + data.x_0.shape[1], + data.x_1.shape[1], + data.x_2.shape[1], + hid_channels, + 2, + ), + out_channels=out_channels, + num_cell_dimensions=num_cell_dimensions, ) out = wrapper(data) for key in ["labels", "batch_0", "x_0", "x_1", "x_2"]: assert key in out - diff --git a/test/nn/wrappers/simplicial/test_SCCNNWrapper.py b/test/nn/wrappers/simplicial/test_SCCNNWrapper.py index f3614a7b..bc3e1807 100644 --- a/test/nn/wrappers/simplicial/test_SCCNNWrapper.py +++ b/test/nn/wrappers/simplicial/test_SCCNNWrapper.py @@ -1,26 +1,24 @@ """Unit tests for simplicial model wrappers""" -import torch -from torch_geometric.utils import get_laplacian -from ...._utils.nn_module_auto_test import NNModuleAutoTest -from ...._utils.flow_mocker import FlowMocker -from topobenchmark.nn.backbones.simplicial import SCCNNCustom from topomodelx.nn.simplicial.san import SAN -from topomodelx.nn.simplicial.scn2 import SCN2 from topomodelx.nn.simplicial.sccn import SCCN +from topomodelx.nn.simplicial.scn2 import SCN2 + +from topobenchmark.nn.backbones.simplicial import SCCNNCustom from topobenchmark.nn.wrappers import ( - SCCNWrapper, - SCCNNWrapper, SANWrapper, - SCNWrapper + SCCNNWrapper, + SCCNWrapper, + SCNWrapper, ) + class TestSimplicialWrappers: """Test simplicial model wrappers.""" def test_SCCNNWrapper(self, sg1_clique_lifted): """Test SCCNNWrapper. - + Parameters ---------- sg1_clique_lifted : torch_geometric.data.Data @@ -30,12 +28,17 @@ def test_SCCNNWrapper(self, sg1_clique_lifted): out_dim = 4 conv_order = 1 sc_order = 3 - init_args = (data.x_0.shape[1], data.x_1.shape[1], data.x_2.shape[1]), (out_dim, out_dim, out_dim), conv_order, sc_order + init_args = ( + (data.x_0.shape[1], data.x_1.shape[1], data.x_2.shape[1]), + (out_dim, out_dim, out_dim), + conv_order, + sc_order, + ) wrapper = SCCNNWrapper( - SCCNNCustom(*init_args), - out_channels=out_dim, - num_cell_dimensions=3 + SCCNNCustom(*init_args), + out_channels=out_dim, + num_cell_dimensions=3, ) out = wrapper(data) # Assert keys in output @@ -44,20 +47,20 @@ def test_SCCNNWrapper(self, sg1_clique_lifted): def test_SANWarpper(self, sg1_clique_lifted): """Test SANWarpper. - + Parameters ---------- sg1_clique_lifted : torch_geometric.data.Data - A fixture of simple graph 1 lifted with SimlicialCliqueLifting + A fixture of simple graph 1 lifted with SimlicialCliqueLifting """ data = sg1_clique_lifted out_dim = data.x_0.shape[1] hidden_channels = data.x_0.shape[1] wrapper = SANWrapper( - SAN(data.x_0.shape[1], hidden_channels), - out_channels=out_dim, - num_cell_dimensions=3 + SAN(data.x_0.shape[1], hidden_channels), + out_channels=out_dim, + num_cell_dimensions=3, ) out = wrapper(data) # Assert keys in output @@ -66,19 +69,19 @@ def test_SANWarpper(self, sg1_clique_lifted): def test_SCNWrapper(self, sg1_clique_lifted): """Test SCNWrapper. - + Parameters ---------- sg1_clique_lifted : torch_geometric.data.Data - A fixture of simple graph 1 lifted with SimlicialCliqueLifting + A fixture of simple graph 1 lifted with SimlicialCliqueLifting """ data = sg1_clique_lifted out_dim = data.x_0.shape[1] wrapper = SCNWrapper( - SCN2(data.x_0.shape[1], data.x_1.shape[1], data.x_2.shape[1]), - out_channels=out_dim, - num_cell_dimensions=3 + SCN2(data.x_0.shape[1], data.x_1.shape[1], data.x_2.shape[1]), + out_channels=out_dim, + num_cell_dimensions=3, ) out = wrapper(data) # Assert keys in output @@ -87,23 +90,22 @@ def test_SCNWrapper(self, sg1_clique_lifted): def test_SCCNWrapper(self, sg1_clique_lifted): """Test SCCNWrapper. - + Parameters ---------- sg1_clique_lifted : torch_geometric.data.Data - A fixture of simple graph 1 lifted with SimlicialCliqueLifting + A fixture of simple graph 1 lifted with SimlicialCliqueLifting """ data = sg1_clique_lifted out_dim = data.x_0.shape[1] max_rank = 2 wrapper = SCCNWrapper( - SCCN(data.x_0.shape[1], max_rank), - out_channels=out_dim, - num_cell_dimensions=3 + SCCN(data.x_0.shape[1], max_rank), + out_channels=out_dim, + num_cell_dimensions=3, ) out = wrapper(data) # Assert keys in output for key in ["labels", "batch_0", "x_0", "x_1", "x_2"]: assert key in out - diff --git a/test/transforms/data_manipulations/test_SimplicialCurvature.py b/test/transforms/data_manipulations/test_SimplicialCurvature.py index e4cb517b..e90d4e68 100644 --- a/test/transforms/data_manipulations/test_SimplicialCurvature.py +++ b/test/transforms/data_manipulations/test_SimplicialCurvature.py @@ -2,8 +2,14 @@ import torch from torch_geometric.data import Data -from topobenchmark.transforms.data_manipulations import CalculateSimplicialCurvature -from topobenchmark.transforms.liftings.graph2simplicial import SimplicialCliqueLifting + +from topobenchmark.transforms.data_manipulations import ( + CalculateSimplicialCurvature, +) +from topobenchmark.transforms.liftings import ( + Graph2SimplicialLiftingTransform, + SimplicialCliqueLifting, +) class TestSimplicialCurvature: @@ -11,29 +17,28 @@ class TestSimplicialCurvature: def test_simplicial_curvature(self, simple_graph_1): """Test simplicial curvature calculation. - + Parameters ---------- simple_graph_1 : torch_geometric.data.Data A simple graph fixture. """ simplicial_curvature = CalculateSimplicialCurvature() - lifting_unsigned = SimplicialCliqueLifting( - complex_dim=3, signed=False + + lifting_unsigned = Graph2SimplicialLiftingTransform( + lifting=SimplicialCliqueLifting(complex_dim=3) ) + data = lifting_unsigned(simple_graph_1) - data['0_cell_degrees'] = torch.unsqueeze( - torch.sum(data['incidence_1'], dim=1).to_dense(), - dim=1 + data["0_cell_degrees"] = torch.unsqueeze( + torch.sum(data["incidence_1"], dim=1).to_dense(), dim=1 ) - data['1_cell_degrees'] = torch.unsqueeze( - torch.sum(data['incidence_2'], dim=1).to_dense(), - dim=1 + data["1_cell_degrees"] = torch.unsqueeze( + torch.sum(data["incidence_2"], dim=1).to_dense(), dim=1 ) - data['2_cell_degrees'] = torch.unsqueeze( - torch.sum(data['incidence_3'], dim=1).to_dense(), - dim=1 + data["2_cell_degrees"] = torch.unsqueeze( + torch.sum(data["incidence_3"], dim=1).to_dense(), dim=1 ) - + res = simplicial_curvature(data) - assert isinstance(res, Data) \ No newline at end of file + assert isinstance(res, Data) diff --git a/test/transforms/feature_liftings/test_Concatenation.py b/test/transforms/feature_liftings/test_Concatenation.py index a8f83d78..9474e8da 100644 --- a/test/transforms/feature_liftings/test_Concatenation.py +++ b/test/transforms/feature_liftings/test_Concatenation.py @@ -2,24 +2,27 @@ import torch -from topobenchmark.transforms.liftings.graph2simplicial import ( +from topobenchmark.transforms.liftings import ( + Graph2SimplicialLiftingTransform, SimplicialCliqueLifting, ) -class TestConcatention: +class TestConcatenation: """Test the Concatention feature lifting class.""" def setup_method(self): """Set up the test.""" # Initialize a lifting class - self.lifting = SimplicialCliqueLifting( - feature_lifting="Concatenation", complex_dim=3 + + self.lifting = Graph2SimplicialLiftingTransform( + lifting=SimplicialCliqueLifting(complex_dim=3), + feature_lifting="Concatenation", ) def test_lift_features(self, simple_graph_0, simple_graph_1): """Test the lift_features method. - + Parameters ---------- simple_graph_0 : torch_geometric.data.Data @@ -27,12 +30,12 @@ def test_lift_features(self, simple_graph_0, simple_graph_1): simple_graph_1 : torch_geometric.data.Data A simple graph data object. """ - + data = simple_graph_0 # Test the lift_features method lifted_data = self.lifting.forward(data.clone()) assert lifted_data.x_2.shape == torch.Size([0, 6]) - + data = simple_graph_1 # Test the lift_features method lifted_data = self.lifting.forward(data.clone()) diff --git a/test/transforms/feature_liftings/test_ProjectionSum.py b/test/transforms/feature_liftings/test_ProjectionSum.py index 935a5148..a6ad8cdf 100644 --- a/test/transforms/feature_liftings/test_ProjectionSum.py +++ b/test/transforms/feature_liftings/test_ProjectionSum.py @@ -2,7 +2,8 @@ import torch -from topobenchmark.transforms.liftings.graph2simplicial import ( +from topobenchmark.transforms.liftings import ( + Graph2SimplicialLiftingTransform, SimplicialCliqueLifting, ) @@ -13,13 +14,14 @@ class TestProjectionSum: def setup_method(self): """Set up the test.""" # Initialize a lifting class - self.lifting = SimplicialCliqueLifting( - feature_lifting="ProjectionSum", complex_dim=3 + self.lifting = Graph2SimplicialLiftingTransform( + lifting=SimplicialCliqueLifting(complex_dim=3), + feature_lifting="ProjectionSum", ) def test_lift_features(self, simple_graph_1): """Test the lift_features method. - + Parameters ---------- simple_graph_1 : torch_geometric.data.Data @@ -31,38 +33,27 @@ def test_lift_features(self, simple_graph_1): expected_x1 = torch.tensor( [ - [ 6.], - [ 11.], - [ 101.], - [5001.], - [ 15.], - [ 105.], - [ 60.], - [ 110.], - [ 510.], - [5010.], - [1050.], - [1500.], - [5500.] + [6.0], + [11.0], + [101.0], + [5001.0], + [15.0], + [105.0], + [60.0], + [110.0], + [510.0], + [5010.0], + [1050.0], + [1500.0], + [5500.0], ] ) expected_x2 = torch.tensor( - [ - [ 32.], - [ 212.], - [ 222.], - [10022.], - [ 230.], - [11020.] - ] + [[32.0], [212.0], [222.0], [10022.0], [230.0], [11020.0]] ) - expected_x3 = torch.tensor( - [ - [696.] - ] - ) + expected_x3 = torch.tensor([[696.0]]) assert ( expected_x1 == lifted_data.x_1 diff --git a/test/transforms/feature_liftings/test_SetLifting.py b/test/transforms/feature_liftings/test_SetLifting.py index 9b71816f..584f9724 100644 --- a/test/transforms/feature_liftings/test_SetLifting.py +++ b/test/transforms/feature_liftings/test_SetLifting.py @@ -2,7 +2,8 @@ import torch -from topobenchmark.transforms.liftings.graph2simplicial import ( +from topobenchmark.transforms.liftings import ( + Graph2SimplicialLiftingTransform, SimplicialCliqueLifting, ) @@ -13,13 +14,14 @@ class TestSetLifting: def setup_method(self): """Set up the test.""" # Initialize a lifting class - self.lifting = SimplicialCliqueLifting( - feature_lifting="Set", complex_dim=3 + self.lifting = Graph2SimplicialLiftingTransform( + lifting=SimplicialCliqueLifting(complex_dim=3), + feature_lifting="Set", ) def test_lift_features(self, simple_graph_1): """Test the lift_features method. - + Parameters ---------- simple_graph_1 : torch_geometric.data.Data diff --git a/test/transforms/liftings/cell/test_CellCyclesLifting.py b/test/transforms/liftings/cell/test_CellCyclesLifting.py index 54fd276f..706e1f9d 100644 --- a/test/transforms/liftings/cell/test_CellCyclesLifting.py +++ b/test/transforms/liftings/cell/test_CellCyclesLifting.py @@ -2,7 +2,10 @@ import torch -from topobenchmark.transforms.liftings.graph2cell import CellCycleLifting +from topobenchmark.transforms.liftings import ( + CellCycleLifting, + Graph2CellLiftingTransform, +) class TestCellCycleLifting: @@ -10,7 +13,7 @@ class TestCellCycleLifting: def setup_method(self): # Initialise the CellCycleLifting class - self.lifting = CellCycleLifting() + self.lifting = Graph2CellLiftingTransform(CellCycleLifting()) def test_lift_topology(self, simple_graph_1): # Test the lift_topology method diff --git a/test/transforms/liftings/hypergraph/test_HypergraphKHopLifting.py b/test/transforms/liftings/hypergraph/test_HypergraphKHopLifting.py index 13285fc1..68326f11 100644 --- a/test/transforms/liftings/hypergraph/test_HypergraphKHopLifting.py +++ b/test/transforms/liftings/hypergraph/test_HypergraphKHopLifting.py @@ -2,7 +2,8 @@ import torch -from topobenchmark.transforms.liftings.graph2hypergraph import ( +from topobenchmark.transforms.liftings import ( + Graph2HypergraphLiftingTransform, HypergraphKHopLifting, ) @@ -11,15 +12,27 @@ class TestHypergraphKHopLifting: """Test the HypergraphKHopLifting class.""" def setup_method(self): - """ Setup the test.""" + """Setup the test.""" # Initialise the HypergraphKHopLifting class - self.lifting_k1 = HypergraphKHopLifting(k_value=1) - self.lifting_k2 = HypergraphKHopLifting(k_value=2) - self.lifting_edge_attr = HypergraphKHopLifting(k_value=1, preserve_edge_attr=True) + self.lifting_k1 = Graph2HypergraphLiftingTransform( + HypergraphKHopLifting(k_value=1) + ) + self.lifting_k2 = Graph2HypergraphLiftingTransform( + HypergraphKHopLifting(k_value=2) + ) + + # TODO: delete? + # NB: `preserve_edge_attr` is never used? therefore they're equivalent + # self.lifting_edge_attr = HypergraphKHopLifting( + # k_value=1, preserve_edge_attr=True + # ) + self.lifting_edge_attr = Graph2HypergraphLiftingTransform( + HypergraphKHopLifting(k_value=1) + ) def test_lift_topology(self, simple_graph_2): - """ Test the lift_topology method. - + """Test the lift_topology method. + Parameters ---------- simple_graph_2 : Data @@ -78,10 +91,18 @@ def test_lift_topology(self, simple_graph_2): assert ( expected_n_hyperedges == lifted_data_k2.num_hyperedges ), "Something is wrong with the number of hyperedges (k=2)." - + self.data_edge_attr = simple_graph_2 - edge_attributes = torch.rand((self.data_edge_attr.edge_index.shape[1], 2)) + edge_attributes = torch.rand( + (self.data_edge_attr.edge_index.shape[1], 2) + ) self.data_edge_attr.edge_attr = edge_attributes - lifted_data_edge_attr = self.lifting_edge_attr.forward(self.data_edge_attr.clone()) - assert lifted_data_edge_attr.edge_attr is not None, "Edge attributes are not preserved." - assert torch.all(edge_attributes == lifted_data_edge_attr.edge_attr), "Edge attributes are not preserved correctly." + lifted_data_edge_attr = self.lifting_edge_attr.forward( + self.data_edge_attr.clone() + ) + assert ( + lifted_data_edge_attr.edge_attr is not None + ), "Edge attributes are not preserved." + assert torch.all( + edge_attributes == lifted_data_edge_attr.edge_attr + ), "Edge attributes are not preserved correctly." diff --git a/test/transforms/liftings/hypergraph/test_HypergraphKNearestNeighborsLifting.py b/test/transforms/liftings/hypergraph/test_HypergraphKNearestNeighborsLifting.py index 7e9d1216..23dc5d35 100644 --- a/test/transforms/liftings/hypergraph/test_HypergraphKNearestNeighborsLifting.py +++ b/test/transforms/liftings/hypergraph/test_HypergraphKNearestNeighborsLifting.py @@ -3,9 +3,8 @@ import pytest import torch from torch_geometric.data import Data -from topobenchmark.transforms.liftings.graph2hypergraph import ( - HypergraphKNNLifting, -) + +from topobenchmark.transforms.liftings import HypergraphKNNLifting class TestHypergraphKNNLifting: @@ -13,7 +12,7 @@ class TestHypergraphKNNLifting: def setup_method(self): """Set up test fixtures before each test method. - + Creates instances of HypergraphKNNLifting with different k values and loop settings. """ @@ -23,88 +22,94 @@ def setup_method(self): def test_initialization(self): """Test initialization with different parameters.""" + # TODO: overkill, delete? + # Test default parameters lifting_default = HypergraphKNNLifting() - assert lifting_default.k == 1 - assert lifting_default.loop is True + assert lifting_default.transform.k == 1 + assert lifting_default.transform.loop is True # Test custom parameters lifting_custom = HypergraphKNNLifting(k_value=5, loop=False) - assert lifting_custom.k == 5 - assert lifting_custom.loop is False + assert lifting_custom.transform.k == 5 + assert lifting_custom.transform.loop is False def test_lift_topology_k2(self, simple_graph_2): """Test the lift_topology method with k=2. - + Parameters ---------- simple_graph_2 : torch_geometric.data.Data A simple graph fixture with 9 nodes arranged in a line pattern. """ - lifted_data_k2 = self.lifting_k2.lift_topology(simple_graph_2.clone()) + lifted_data_k2 = self.lifting_k2.lift(simple_graph_2.clone()) expected_n_hyperedges = 9 - expected_incidence_1 = torch.tensor([ - [1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0], - ]) + expected_incidence_1 = torch.tensor( + [ + [1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0], + ] + ) assert torch.equal( lifted_data_k2["incidence_hyperedges"].to_dense(), - expected_incidence_1 + expected_incidence_1, ), "Incorrect incidence_hyperedges for k=2" - + assert lifted_data_k2["num_hyperedges"] == expected_n_hyperedges assert torch.equal(lifted_data_k2["x_0"], simple_graph_2.x) def test_lift_topology_k3(self, simple_graph_2): """Test the lift_topology method with k=3. - + Parameters ---------- simple_graph_2 : torch_geometric.data.Data A simple graph fixture with 9 nodes arranged in a line pattern. """ - lifted_data_k3 = self.lifting_k3.lift_topology(simple_graph_2.clone()) + lifted_data_k3 = self.lifting_k3.lift(simple_graph_2.clone()) expected_n_hyperedges = 9 - expected_incidence_1 = torch.tensor([ - [1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0], - ]) + expected_incidence_1 = torch.tensor( + [ + [1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0], + ] + ) assert torch.equal( lifted_data_k3["incidence_hyperedges"].to_dense(), - expected_incidence_1 + expected_incidence_1, ), "Incorrect incidence_hyperedges for k=3" - + assert lifted_data_k3["num_hyperedges"] == expected_n_hyperedges assert torch.equal(lifted_data_k3["x_0"], simple_graph_2.x) def test_lift_topology_no_loop(self, simple_graph_2): """Test the lift_topology method with loop=False. - + Parameters ---------- simple_graph_2 : torch_geometric.data.Data A simple graph fixture with 9 nodes arranged in a line pattern. """ - lifted_data = self.lifting_no_loop.lift_topology(simple_graph_2.clone()) - + lifted_data = self.lifting_no_loop.lift(simple_graph_2.clone()) + # Verify no self-loops in the incidence matrix incidence_matrix = lifted_data["incidence_hyperedges"].to_dense() diagonal = torch.diag(incidence_matrix) @@ -115,11 +120,11 @@ def test_lift_topology_with_equal_features(self): # Create a graph where some nodes have identical features data = Data( x=torch.tensor([[1.0], [1.0], [2.0], [2.0]]), - edge_index=torch.tensor([[0, 1, 2, 3], [1, 0, 3, 2]]) + edge_index=torch.tensor([[0, 1, 2, 3], [1, 0, 3, 2]]), ) - - lifted_data = self.lifting_k2.lift_topology(data) - + + lifted_data = self.lifting_k2.lift(data) + # Verify the shape of the output assert lifted_data["incidence_hyperedges"].size() == (4, 4) assert lifted_data["num_hyperedges"] == 4 @@ -128,7 +133,7 @@ def test_lift_topology_with_equal_features(self): @pytest.mark.parametrize("k_value", [1, 2, 3, 4]) def test_different_k_values(self, k_value, simple_graph_2): """Test lift_topology with different k values. - + Parameters ---------- k_value : int @@ -137,29 +142,30 @@ def test_different_k_values(self, k_value, simple_graph_2): A simple graph fixture with 9 nodes arranged in a line pattern. """ lifting = HypergraphKNNLifting(k_value=k_value, loop=True) - lifted_data = lifting.lift_topology(simple_graph_2.clone()) - + lifted_data = lifting.lift(simple_graph_2.clone()) + # Verify basic properties assert lifted_data["num_hyperedges"] == simple_graph_2.x.size(0) incidence_matrix = lifted_data["incidence_hyperedges"].to_dense() - + # Check that each node is connected to at most k nodes - assert torch.all(incidence_matrix.sum(dim=1) <= k_value), \ - f"Some nodes are connected to more than {k_value} neighbors" + assert torch.all( + incidence_matrix.sum(dim=1) <= k_value + ), f"Some nodes are connected to more than {k_value} neighbors" def test_invalid_inputs(self): """Test handling of invalid inputs and edge cases.""" # Test with no x attribute (this should raise AttributeError) data_no_x = Data(edge_index=torch.tensor([[0, 1], [1, 0]])) with pytest.raises(AttributeError): - self.lifting_k2.lift_topology(data_no_x) + self.lifting_k2.lift(data_no_x) # Test single node case (edge case that should work) single_node_data = Data( x=torch.tensor([[1.0]], dtype=torch.float), - edge_index=torch.tensor([[0], [0]]) + edge_index=torch.tensor([[0], [0]]), ) - lifted_single = self.lifting_k2.lift_topology(single_node_data) + lifted_single = self.lifting_k2.lift(single_node_data) assert lifted_single["num_hyperedges"] == 1 assert lifted_single["incidence_hyperedges"].size() == (1, 1) assert torch.equal(lifted_single["x_0"], single_node_data.x) @@ -167,32 +173,30 @@ def test_invalid_inputs(self): # Test with identical nodes (edge case that should work) identical_nodes_data = Data( x=torch.tensor([[1.0], [1.0]], dtype=torch.float), - edge_index=torch.tensor([[0, 1], [1, 0]]) + edge_index=torch.tensor([[0, 1], [1, 0]]), ) - lifted_identical = self.lifting_k2.lift_topology(identical_nodes_data) + lifted_identical = self.lifting_k2.lift(identical_nodes_data) assert lifted_identical["num_hyperedges"] == 2 assert lifted_identical["incidence_hyperedges"].size() == (2, 2) assert torch.equal(lifted_identical["x_0"], identical_nodes_data.x) # Test with missing edge_index (this should work as KNNGraph will create edges) - data_no_edges = Data( - x=torch.tensor([[1.0], [2.0]], dtype=torch.float) - ) - lifted_no_edges = self.lifting_k2.lift_topology(data_no_edges) + data_no_edges = Data(x=torch.tensor([[1.0], [2.0]], dtype=torch.float)) + lifted_no_edges = self.lifting_k2.lift(data_no_edges) assert lifted_no_edges["num_hyperedges"] == 2 assert lifted_no_edges["incidence_hyperedges"].size() == (2, 2) assert torch.equal(lifted_no_edges["x_0"], data_no_edges.x) # Test with no data (should raise AttributeError) with pytest.raises(AttributeError): - self.lifting_k2.lift_topology(None) + self.lifting_k2.lift(None) # Test with empty tensor for x (should work but result in empty outputs) empty_data = Data( x=torch.tensor([], dtype=torch.float).reshape(0, 1), - edge_index=torch.tensor([], dtype=torch.long).reshape(2, 0) + edge_index=torch.tensor([], dtype=torch.long).reshape(2, 0), ) - lifted_empty = self.lifting_k2.lift_topology(empty_data) + lifted_empty = self.lifting_k2.lift(empty_data) assert lifted_empty["num_hyperedges"] == 0 assert lifted_empty["incidence_hyperedges"].size(0) == 0 @@ -203,13 +207,17 @@ def test_invalid_initialization(self): HypergraphKNNLifting(k_value=1.5) # Test with zero k_value - with pytest.raises(ValueError, match="k_value must be greater than or equal to 1"): + with pytest.raises( + ValueError, match="k_value must be greater than or equal to 1" + ): HypergraphKNNLifting(k_value=0) # Test with negative k_value - with pytest.raises(ValueError, match="k_value must be greater than or equal to 1"): + with pytest.raises( + ValueError, match="k_value must be greater than or equal to 1" + ): HypergraphKNNLifting(k_value=-1) # Test with non-boolean loop with pytest.raises(TypeError, match="loop must be a boolean"): - HypergraphKNNLifting(k_value=1, loop="True") \ No newline at end of file + HypergraphKNNLifting(k_value=1, loop="True") diff --git a/test/transforms/liftings/simplicial/test_SimplicialCliqueLifting.py b/test/transforms/liftings/simplicial/test_SimplicialCliqueLifting.py index 9cd80058..a2c32ebf 100644 --- a/test/transforms/liftings/simplicial/test_SimplicialCliqueLifting.py +++ b/test/transforms/liftings/simplicial/test_SimplicialCliqueLifting.py @@ -2,7 +2,11 @@ import torch -from topobenchmark.transforms.liftings.graph2simplicial import ( +from topobenchmark.transforms.feature_liftings.projection_sum import ( + ProjectionSum, +) +from topobenchmark.transforms.liftings import ( + Graph2SimplicialLiftingTransform, SimplicialCliqueLifting, ) @@ -12,11 +16,19 @@ class TestSimplicialCliqueLifting: def setup_method(self): # Initialise the SimplicialCliqueLifting class - self.lifting_signed = SimplicialCliqueLifting( - complex_dim=3, signed=True + + lifting_map = SimplicialCliqueLifting(complex_dim=3) + feature_lifting = ProjectionSum() + + self.lifting_signed = Graph2SimplicialLiftingTransform( + lifting=lifting_map, + feature_lifting=feature_lifting, + signed=True, ) - self.lifting_unsigned = SimplicialCliqueLifting( - complex_dim=3, signed=False + self.lifting_unsigned = Graph2SimplicialLiftingTransform( + lifting=lifting_map, + feature_lifting=feature_lifting, + signed=False, ) def test_lift_topology(self, simple_graph_1): @@ -204,6 +216,8 @@ def test_lift_topology(self, simple_graph_1): def test_lifted_features_signed(self, simple_graph_1): """Test the lift_features method in signed incidence cases.""" + # TODO: can be removed/moved; part of projection sum + self.data = simple_graph_1 # Test the lift_features method for signed case lifted_data = self.lifting_signed.forward(self.data) @@ -246,6 +260,8 @@ def test_lifted_features_signed(self, simple_graph_1): def test_lifted_features_unsigned(self, simple_graph_1): """Test the lift_features method in unsigned incidence cases.""" + # TODO: redundant. can be moved/removed + self.data = simple_graph_1 # Test the lift_features method for unsigned case lifted_data = self.lifting_unsigned.forward(self.data) diff --git a/test/transforms/liftings/simplicial/test_SimplicialNeighborhoodLifting.py b/test/transforms/liftings/simplicial/test_SimplicialNeighborhoodLifting.py index 5a03f67e..6a81d9f2 100644 --- a/test/transforms/liftings/simplicial/test_SimplicialNeighborhoodLifting.py +++ b/test/transforms/liftings/simplicial/test_SimplicialNeighborhoodLifting.py @@ -2,19 +2,35 @@ import torch -from topobenchmark.transforms.liftings.graph2simplicial import ( +from topobenchmark.transforms.feature_liftings.projection_sum import ( + ProjectionSum, +) +from topobenchmark.transforms.liftings import ( + Graph2SimplicialLiftingTransform, SimplicialKHopLifting, ) +# TODO: rename for consistency? + class TestSimplicialKHopLifting: """Test the SimplicialKHopLifting class.""" def setup_method(self): # Initialise the SimplicialKHopLifting class - self.lifting_signed = SimplicialKHopLifting(complex_dim=3, signed=True) - self.lifting_unsigned = SimplicialKHopLifting( - complex_dim=3, signed=False + feature_lifting = ProjectionSum() + + lifting_map = SimplicialKHopLifting(complex_dim=3) + + self.lifting_signed = Graph2SimplicialLiftingTransform( + lifting=lifting_map, + feature_lifting=feature_lifting, + signed=True, + ) + self.lifting_unsigned = Graph2SimplicialLiftingTransform( + lifting=lifting_map, + feature_lifting=feature_lifting, + signed=False, ) def test_lift_topology(self, simple_graph_1): diff --git a/test/transforms/liftings/test_AbstractLifting.py b/test/transforms/liftings/test_AbstractLifting.py deleted file mode 100644 index 49167cb1..00000000 --- a/test/transforms/liftings/test_AbstractLifting.py +++ /dev/null @@ -1,53 +0,0 @@ -"""Test AbstractLifting module.""" - -import pytest -import torch -from torch_geometric.data import Data -from topobenchmark.transforms.liftings import AbstractLifting - -class TestAbstractLifting: - """Test the AbstractLifting class.""" - - def setup_method(self): - """Set up test fixtures for each test method. - - Creates a concrete subclass of AbstractLifting for testing purposes. - """ - class ConcreteLifting(AbstractLifting): - """Concrete implementation of AbstractLifting for testing.""" - - def lift_topology(self, data): - """Implementation of abstract method that calls parent's method. - - Parameters - ---------- - data : torch_geometric.data.Data - The input data to be lifted. - - Returns - ------- - dict - Empty dictionary as this is just for testing. - - Raises - ------ - NotImplementedError - Always raises this error as it calls the parent's abstract method. - """ - return super().lift_topology(data) - - self.lifting = ConcreteLifting(feature_lifting=None) - - def test_lift_topology_raises_not_implemented(self): - """Test that the abstract lift_topology method raises NotImplementedError. - - Verifies that calling lift_topology on an abstract class implementation - raises NotImplementedError as expected. - """ - dummy_data = Data( - x=torch.tensor([[1.0], [2.0]]), - edge_index=torch.tensor([[0, 1], [1, 0]]) - ) - - with pytest.raises(NotImplementedError): - self.lifting.lift_topology(dummy_data) \ No newline at end of file diff --git a/test/transforms/liftings/test_GraphLifting.py b/test/transforms/liftings/test_GraphLifting.py index c7acf454..546956c9 100644 --- a/test/transforms/liftings/test_GraphLifting.py +++ b/test/transforms/liftings/test_GraphLifting.py @@ -1,21 +1,42 @@ """Test the GraphLifting class.""" -import pytest + import torch +import torch_geometric from torch_geometric.data import Data -from topobenchmark.transforms.liftings import GraphLifting +from topobenchmark.transforms.feature_liftings.projection_sum import ( + ProjectionSum, +) +from topobenchmark.transforms.liftings.base import LiftingMap, LiftingTransform + + +def _data_has_edge_attr(data: torch_geometric.data.Data) -> bool: + r"""Check if the input data object has edge attributes. + + Parameters + ---------- + data : torch_geometric.data.Data + The input data. + + Returns + ------- + bool + Whether the data object has edge attributes. + """ + return hasattr(data, "edge_attr") and data.edge_attr is not None -class ConcreteGraphLifting(GraphLifting): + +class ConcreteGraphLifting(LiftingMap): """Concrete implementation of GraphLifting for testing.""" - - def lift_topology(self, data): + + def lift(self, data): """Implement the abstract lift_topology method. - + Parameters ---------- data : torch_geometric.data.Data The input data to be lifted. - + Returns ------- dict @@ -26,86 +47,70 @@ def lift_topology(self, data): class TestGraphLifting: """Test the GraphLifting class.""" - + def setup_method(self): """Set up test fixtures before each test method. - + Creates an instance of ConcreteGraphLifting with default parameters. """ - self.lifting = ConcreteGraphLifting( - feature_lifting="ProjectionSum", - preserve_edge_attr=False + self.lifting = LiftingTransform( + ConcreteGraphLifting(), feature_lifting=ProjectionSum() ) def test_data_has_edge_attr(self): """Test _data_has_edge_attr method with different data configurations.""" - + # Test case 1: Data with edge attributes data_with_edge_attr = Data( x=torch.tensor([[1.0], [2.0]]), edge_index=torch.tensor([[0, 1], [1, 0]]), - edge_attr=torch.tensor([[1.0], [1.0]]) + edge_attr=torch.tensor([[1.0], [1.0]]), ) - assert self.lifting._data_has_edge_attr(data_with_edge_attr) is True + assert _data_has_edge_attr(data_with_edge_attr) is True # Test case 2: Data without edge attributes data_without_edge_attr = Data( x=torch.tensor([[1.0], [2.0]]), - edge_index=torch.tensor([[0, 1], [1, 0]]) + edge_index=torch.tensor([[0, 1], [1, 0]]), ) - assert self.lifting._data_has_edge_attr(data_without_edge_attr) is False + assert _data_has_edge_attr(data_without_edge_attr) is False # Test case 3: Data with edge_attr set to None data_with_none_edge_attr = Data( x=torch.tensor([[1.0], [2.0]]), edge_index=torch.tensor([[0, 1], [1, 0]]), - edge_attr=None + edge_attr=None, ) - assert self.lifting._data_has_edge_attr(data_with_none_edge_attr) is False + assert _data_has_edge_attr(data_with_none_edge_attr) is False def test_data_has_edge_attr_empty_data(self): """Test _data_has_edge_attr method with empty data object.""" empty_data = Data() - assert self.lifting._data_has_edge_attr(empty_data) is False + assert _data_has_edge_attr(empty_data) is False def test_data_has_edge_attr_different_edge_formats(self): """Test _data_has_edge_attr method with different edge attribute formats.""" - + # Test with float edge attributes data_float_attr = Data( x=torch.tensor([[1.0], [2.0]]), edge_index=torch.tensor([[0, 1], [1, 0]]), - edge_attr=torch.tensor([[0.5], [0.5]]) + edge_attr=torch.tensor([[0.5], [0.5]]), ) - assert self.lifting._data_has_edge_attr(data_float_attr) is True + assert _data_has_edge_attr(data_float_attr) is True # Test with integer edge attributes data_int_attr = Data( x=torch.tensor([[1.0], [2.0]]), edge_index=torch.tensor([[0, 1], [1, 0]]), - edge_attr=torch.tensor([[1], [1]], dtype=torch.long) + edge_attr=torch.tensor([[1], [1]], dtype=torch.long), ) - assert self.lifting._data_has_edge_attr(data_int_attr) is True + assert _data_has_edge_attr(data_int_attr) is True # Test with multi-dimensional edge attributes data_multidim_attr = Data( x=torch.tensor([[1.0], [2.0]]), edge_index=torch.tensor([[0, 1], [1, 0]]), - edge_attr=torch.tensor([[1.0, 2.0], [2.0, 1.0]]) - ) - assert self.lifting._data_has_edge_attr(data_multidim_attr) is True - - @pytest.mark.parametrize("preserve_edge_attr", [True, False]) - def test_init_preserve_edge_attr(self, preserve_edge_attr): - """Test initialization with different preserve_edge_attr values. - - Parameters - ---------- - preserve_edge_attr : bool - Boolean value to test initialization with True and False values. - """ - lifting = ConcreteGraphLifting( - feature_lifting="ProjectionSum", - preserve_edge_attr=preserve_edge_attr + edge_attr=torch.tensor([[1.0, 2.0], [2.0, 1.0]]), ) - assert lifting.preserve_edge_attr == preserve_edge_attr \ No newline at end of file + assert _data_has_edge_attr(data_multidim_attr) is True diff --git a/topobenchmark/data/utils/__init__.py b/topobenchmark/data/utils/__init__.py index 70126253..de796c1d 100644 --- a/topobenchmark/data/utils/__init__.py +++ b/topobenchmark/data/utils/__init__.py @@ -1,5 +1,7 @@ """Init file for data/utils module.""" +from .adapters import * +from .domain import ComplexData, HypergraphData # noqa: F401 from .utils import ( ensure_serializable, # noqa: F401 generate_zero_sparse_connectivity, # noqa: F401 diff --git a/topobenchmark/data/utils/adapters.py b/topobenchmark/data/utils/adapters.py new file mode 100644 index 00000000..b049d49c --- /dev/null +++ b/topobenchmark/data/utils/adapters.py @@ -0,0 +1,341 @@ +import abc + +import networkx as nx +import numpy as np +import torch +import torch_geometric +from topomodelx.utils.sparse import from_sparse +from toponetx.classes import CellComplex, SimplicialComplex +from torch_geometric.utils.undirected import is_undirected, to_undirected + +from topobenchmark.data.utils.domain import ComplexData +from topobenchmark.data.utils.utils import ( + generate_zero_sparse_connectivity, + select_neighborhoods_of_interest, +) + + +class Adapter(abc.ABC): + """Adapt between data structures representing the same domain.""" + + def __call__(self, domain): + """Adapt domain's data structure.""" + return self.adapt(domain) + + @abc.abstractmethod + def adapt(self, domain): + """Adapt domain's data structure.""" + + +class IdentityAdapter(Adapter): + """Identity adaptation. + + Retrieves same data structure for domain. + """ + + def adapt(self, domain): + """Adapt domain.""" + return domain + + +class Data2NxGraph(Adapter): + """Data to nx.Graph adaptation. + + Parameters + ---------- + preserve_edge_attr : bool + Whether to preserve edge attributes. + """ + + def __init__(self, preserve_edge_attr=False): + self.preserve_edge_attr = preserve_edge_attr + + def _data_has_edge_attr(self, data: torch_geometric.data.Data) -> bool: + r"""Check if the input data object has edge attributes. + + Parameters + ---------- + data : torch_geometric.data.Data + The input data. + + Returns + ------- + bool + Whether the data object has edge attributes. + """ + return hasattr(data, "edge_attr") and data.edge_attr is not None + + def adapt(self, domain: torch_geometric.data.Data) -> nx.Graph: + r"""Generate a NetworkX graph from the input data object. + + Parameters + ---------- + domain : torch_geometric.data.Data + The input data. + + Returns + ------- + nx.Graph + The generated NetworkX graph. + """ + # Check if data object have edge_attr, return list of tuples as [(node_id, {'features':data}, 'dim':1)] or ?? + nodes = [ + (n, dict(features=domain.x[n], dim=0)) + for n in range(domain.x.shape[0]) + ] + + if self.preserve_edge_attr and self._data_has_edge_attr(domain): + # In case edge features are given, assign features to every edge + # TODO: confirm this is the desired behavior + if is_undirected(domain.edge_index, domain.edge_attr): + edge_index, edge_attr = (domain.edge_index, domain.edge_attr) + else: + edge_index, edge_attr = to_undirected( + domain.edge_index, domain.edge_attr + ) + + edges = [ + (i.item(), j.item(), dict(features=edge_attr[edge_idx], dim=1)) + for edge_idx, (i, j) in enumerate( + zip(edge_index[0], edge_index[1], strict=False) + ) + ] + + else: + # If edge_attr is not present, return list list of edges + edges = [ + (i.item(), j.item(), {}) + for i, j in zip( + domain.edge_index[0], domain.edge_index[1], strict=False + ) + ] + graph = nx.Graph() + graph.add_nodes_from(nodes) + graph.add_edges_from(edges) + return graph + + +class TnxComplex2ComplexData(Adapter): + """toponetx.Complex to Complex adaptation. + + NB: order of features plays a crucial role, as ``Complex`` + simply stores them as lists (i.e. the reference to the indices + of the simplex are lost). + + Parameters + ---------- + neighborhoods : list, optional + List of neighborhoods of interest. + signed : bool, optional + If True, returns signed connectivity matrices. + transfer_features : bool, optional + Whether to transfer features. + """ + + def __init__( + self, + neighborhoods=None, + signed=False, + transfer_features=True, + ): + super().__init__() + self.neighborhoods = neighborhoods + self.signed = signed + self.transfer_features = transfer_features + + def adapt(self, domain): + """Adapt toponetx.Complex to Complex. + + Parameters + ---------- + domain : toponetx.Complex + + Returns + ------- + Complex + """ + # NB: just a slightly rewriting of get_complex_connectivity + + practical_dim = ( + domain.practical_dim + if hasattr(domain, "practical_dim") + else domain.dim + ) + dim = domain.dim + + signed = self.signed + neighborhoods = self.neighborhoods + + connectivity_infos = [ + "incidence", + "down_laplacian", + "up_laplacian", + "adjacency", + "coadjacency", + "hodge_laplacian", + ] + + practical_shape = list( + np.pad( + list(domain.shape), (0, practical_dim + 1 - len(domain.shape)) + ) + ) + data = { + connectivity_info: [] for connectivity_info in connectivity_infos + } + for rank in range(practical_dim + 1): + for connectivity_info in connectivity_infos: + try: + data[connectivity_info].append( + from_sparse( + getattr(domain, f"{connectivity_info}_matrix")( + rank=rank, signed=signed + ) + ) + ) + except ValueError: + if connectivity_info == "incidence": + data[connectivity_info].append( + generate_zero_sparse_connectivity( + m=practical_shape[rank - 1], + n=practical_shape[rank], + ) + ) + else: + data[connectivity_info].append( + generate_zero_sparse_connectivity( + m=practical_shape[rank], + n=practical_shape[rank], + ) + ) + + # TODO: handle this + if neighborhoods is not None: + data = select_neighborhoods_of_interest(data, neighborhoods) + + if self.transfer_features: + if isinstance(domain, SimplicialComplex): + get_features = domain.get_simplex_attributes + elif isinstance(domain, CellComplex): + get_features = domain.get_cell_attributes + else: + raise ValueError("Can't transfer features.") + + # TODO: confirm features are in the right order; update this + data["features"] = [] + for rank in range(dim + 1): + rank_features_dict = get_features("features", rank) + if rank_features_dict: + rank_features = torch.stack( + list(rank_features_dict.values()) + ) + else: + rank_features = None + data["features"].append(rank_features) + + for _ in range(dim + 1, practical_dim + 1): + data["features"].append(None) + + return ComplexData(**data) + + +class ComplexData2Dict(Adapter): + """ComplexData to dict adaptation.""" + + def adapt(self, domain): + """Adapt Complex to dict. + + Parameters + ---------- + domain : ComplexData + + Returns + ------- + dict + """ + data = {} + connectivity_infos = [ + "incidence", + "down_laplacian", + "up_laplacian", + "adjacency", + "coadjacency", + "hodge_laplacian", + ] + for connectivity_info in connectivity_infos: + info = getattr(domain, connectivity_info) + for rank, rank_info in enumerate(info): + data[f"{connectivity_info}_{rank}"] = rank_info + + # TODO: handle neighborhoods + data["shape"] = domain.shape + + for index, values in enumerate(domain.features): + if values is not None: + data[f"x_{index}"] = values + + return data + + +class HypergraphData2Dict(Adapter): + """HypergraphData to dict adaptation.""" + + def adapt(self, domain): + """Adapt HypergraphData to dict. + + Parameters + ---------- + domain : HypergraphData + + Returns + ------- + dict + """ + hyperedges_key = domain.keys()[-1] + return { + "incidence_hyperedges": domain.incidence[hyperedges_key], + "num_hyperedges": domain.num_hyperedges, + "x_0": domain.features[0], + "x_hyperedges": domain.features[hyperedges_key], + } + + +class AdapterComposition(Adapter): + def __init__(self, adapters): + super().__init__() + self.adapters = adapters + + def adapt(self, domain): + """Adapt domain""" + for adapter in self.adapters: + domain = adapter(domain) + + return domain + + +class TnxComplex2Dict(AdapterComposition): + """toponetx.Complex to dict adaptation. + + Parameters + ---------- + neighborhoods : list, optional + List of neighborhoods of interest. + signed : bool, optional + If True, returns signed connectivity matrices. + transfer_features : bool, optional + Whether to transfer features. + """ + + def __init__( + self, + neighborhoods=None, + signed=False, + transfer_features=True, + ): + tnxcomplex2complex = TnxComplex2ComplexData( + neighborhoods=neighborhoods, + signed=signed, + transfer_features=transfer_features, + ) + complex2dict = ComplexData2Dict() + super().__init__(adapters=(tnxcomplex2complex, complex2dict)) diff --git a/topobenchmark/data/utils/domain.py b/topobenchmark/data/utils/domain.py new file mode 100644 index 00000000..57790162 --- /dev/null +++ b/topobenchmark/data/utils/domain.py @@ -0,0 +1,97 @@ +import abc + + +class Data(abc.ABC): + def __init__(self, incidence, features): + self.incidence = incidence + self.features = features + + @abc.abstractmethod + def keys(self): + pass + + def update_features(self, rank, values): + """Update features. + + Parameters + ---------- + rank : int + Rank of simplices the features belong to. + values : array-like + New features for the rank-simplices. + """ + self.features[rank] = values + + @property + def shape(self): + """Shape of the complex. + + Returns + ------- + list[int] + """ + return [ + None + if self.incidence[key] is None + else self.incidence[key].shape[-1] + for key in self.keys() + ] + + +class ComplexData(Data): + def __init__( + self, + incidence, + down_laplacian, + up_laplacian, + adjacency, + coadjacency, + hodge_laplacian, + features=None, + ): + self.down_laplacian = down_laplacian + self.up_laplacian = up_laplacian + self.adjacency = adjacency + self.coadjacency = coadjacency + self.hodge_laplacian = hodge_laplacian + + if features is None: + features = [None for _ in range(len(incidence))] + else: + for rank, incidence_ in enumerate(incidence): + # TODO: make error message more informative + if ( + features[rank] is not None + and features[rank].shape[0] != incidence_.shape[-1] + ): + raise ValueError("Features have wrong shape.") + + super().__init__(incidence, features) + + def keys(self): + return list(range(len(self.incidence))) + + +class HypergraphData(Data): + def __init__( + self, + incidence_hyperedges, + num_hyperedges, + incidence_0=None, + x_0=None, + x_hyperedges=None, + ): + self._hyperedges_key = 1 + incidence = { + 0: incidence_0, + self._hyperedges_key: incidence_hyperedges, + } + features = { + 0: x_0, + self._hyperedges_key: x_hyperedges, + } + super().__init__(incidence, features) + self.num_hyperedges = num_hyperedges + + def keys(self): + return [0, self._hyperedges_key] diff --git a/topobenchmark/transforms/__init__.py b/topobenchmark/transforms/__init__.py index 3f568814..20840dfe 100755 --- a/topobenchmark/transforms/__init__.py +++ b/topobenchmark/transforms/__init__.py @@ -1,32 +1,33 @@ """This module contains the transforms for the topobenchmark package.""" -from typing import Any - -from topobenchmark.transforms.data_manipulations import DATA_MANIPULATIONS -from topobenchmark.transforms.feature_liftings import FEATURE_LIFTINGS -from topobenchmark.transforms.liftings.graph2cell import GRAPH2CELL_LIFTINGS -from topobenchmark.transforms.liftings.graph2hypergraph import ( +from .data_manipulations import DATA_MANIPULATIONS +from .feature_liftings import FEATURE_LIFTINGS +from .liftings import ( + GRAPH2CELL_LIFTINGS, GRAPH2HYPERGRAPH_LIFTINGS, -) -from topobenchmark.transforms.liftings.graph2simplicial import ( GRAPH2SIMPLICIAL_LIFTINGS, + LIFTINGS, ) -LIFTINGS = { - **GRAPH2CELL_LIFTINGS, - **GRAPH2HYPERGRAPH_LIFTINGS, - **GRAPH2SIMPLICIAL_LIFTINGS, -} - -TRANSFORMS: dict[Any, Any] = { +TRANSFORMS = { **LIFTINGS, **FEATURE_LIFTINGS, **DATA_MANIPULATIONS, } -__all__ = [ - "DATA_MANIPULATIONS", - "FEATURE_LIFTINGS", - "LIFTINGS", - "TRANSFORMS", -] + +_map_lifting_type_to_dict = { + "graph2cell": GRAPH2CELL_LIFTINGS, + "graph2hypergraph": GRAPH2HYPERGRAPH_LIFTINGS, + "graph2simplicial": GRAPH2SIMPLICIAL_LIFTINGS, +} + + +def add_lifting_map(LiftingMap, lifting_type, name=None): + if name is None: + name = LiftingMap.__name__ + + liftings_dict = _map_lifting_type_to_dict[lifting_type] + + for dict_ in (liftings_dict, LIFTINGS, TRANSFORMS): + dict_[name] = LiftingMap diff --git a/topobenchmark/transforms/_utils.py b/topobenchmark/transforms/_utils.py new file mode 100644 index 00000000..c2e0750c --- /dev/null +++ b/topobenchmark/transforms/_utils.py @@ -0,0 +1,55 @@ +import inspect +from importlib import util +from pathlib import Path + + +def discover_objs(package_path, condition=None): + """Dynamically discover all manipulation classes in the package. + + Parameters + ---------- + package_path : str + Path to the package's __init__.py file. + condition : callable + `(name, obj) -> bool` + + Returns + ------- + dict[str, type] + Dictionary mapping class names to their corresponding class objects. + """ + if condition is None: + + def condition(name, obj): + return True + + objs = {} + + # Get the directory containing the manipulation modules + package_dir = Path(package_path).parent + + # Iterate through all .py files in the directory + for file_path in package_dir.glob("*.py"): + if file_path.stem == "__init__": + continue + + # Import the module + module_name = f"{Path(package_path).stem}.{file_path.stem}" + spec = util.spec_from_file_location(module_name, file_path) + if spec and spec.loader: + module = util.module_from_spec(spec) + spec.loader.exec_module(module) + + # Find all manipulation classes in the module + for name, obj in inspect.getmembers(module): + if ( + not inspect.isclass(obj) + or name.startswith("_") + or obj.__module__ != module.__name__ + ): + continue + + if condition(name, obj): + objs[name] = obj + + return objs diff --git a/topobenchmark/transforms/data_manipulations/__init__.py b/topobenchmark/transforms/data_manipulations/__init__.py index a17e506d..314d5fa6 100644 --- a/topobenchmark/transforms/data_manipulations/__init__.py +++ b/topobenchmark/transforms/data_manipulations/__init__.py @@ -1,86 +1,7 @@ """Data manipulations module with automated exports.""" -import inspect -from importlib import util -from pathlib import Path -from typing import Any +from topobenchmark.transforms._utils import discover_objs +DATA_MANIPULATIONS = discover_objs(__file__) -class ModuleExportsManager: - """Manages automatic discovery and registration of data manipulation classes.""" - - @staticmethod - def is_manipulation_class(obj: Any) -> bool: - """Check if an object is a valid manipulation class. - - Parameters - ---------- - obj : Any - The object to check if it's a valid manipulation class. - - Returns - ------- - bool - True if the object is a valid manipulation class (non-private class - defined in __main__), False otherwise. - """ - return ( - inspect.isclass(obj) - and obj.__module__ == "__main__" - and not obj.__name__.startswith("_") - ) - - @classmethod - def discover_manipulations(cls, package_path: str) -> dict[str, type]: - """Dynamically discover all manipulation classes in the package. - - Parameters - ---------- - package_path : str - Path to the package's __init__.py file. - - Returns - ------- - dict[str, type] - Dictionary mapping class names to their corresponding class objects. - """ - manipulations = {} - - # Get the directory containing the manipulation modules - package_dir = Path(package_path).parent - - # Iterate through all .py files in the directory - for file_path in package_dir.glob("*.py"): - if file_path.stem == "__init__": - continue - - # Import the module - module_name = f"{Path(package_path).stem}.{file_path.stem}" - spec = util.spec_from_file_location(module_name, file_path) - if spec and spec.loader: - module = util.module_from_spec(spec) - spec.loader.exec_module(module) - - # Find all manipulation classes in the module - for name, obj in inspect.getmembers(module): - if ( - inspect.isclass(obj) - and obj.__module__ == module.__name__ - and not name.startswith("_") - ): - manipulations[name] = obj # noqa: PERF403 - - return manipulations - - -# Create the exports manager -manager = ModuleExportsManager() - -# Automatically discover and populate DATA_MANIPULATIONS -DATA_MANIPULATIONS = manager.discover_manipulations(__file__) - -# Automatically generate __all__ -__all__ = [*DATA_MANIPULATIONS.keys(), "DATA_MANIPULATIONS"] - -# For backwards compatibility, also create individual imports locals().update(DATA_MANIPULATIONS) diff --git a/topobenchmark/transforms/data_transform.py b/topobenchmark/transforms/data_transform.py index da9e883b..c1cda424 100755 --- a/topobenchmark/transforms/data_transform.py +++ b/topobenchmark/transforms/data_transform.py @@ -1,8 +1,50 @@ """DataTransform class.""" +import inspect + import torch_geometric -from topobenchmark.transforms import TRANSFORMS +from topobenchmark.transforms import ( + LIFTINGS, + TRANSFORMS, + _map_lifting_type_to_dict, +) +from topobenchmark.transforms.liftings import ( + Graph2CellLiftingTransform, + Graph2HypergraphLiftingTransform, + Graph2SimplicialLiftingTransform, + LiftingTransform, +) + +_map_lifting_type_to_transform = { + "graph2cell": Graph2CellLiftingTransform, + "graph2hypergraph": Graph2HypergraphLiftingTransform, + "graph2simplicial": Graph2SimplicialLiftingTransform, +} + + +def _map_lifting_to_transform(lifting_name): + for key, liftings_dict in _map_lifting_type_to_dict.items(): + if lifting_name in liftings_dict: + return _map_lifting_type_to_transform[key] + + return LiftingTransform + + +def _route_lifting_kwargs(kwargs, LiftingMap, Transform): + lifting_map_sign = inspect.signature(LiftingMap) + transform_sign = inspect.signature(Transform) + + lifting_map_kwargs = {} + transform_kwargs = {} + + for key, value in kwargs.items(): + if key in lifting_map_sign.parameters: + lifting_map_kwargs[key] = value + elif key in transform_sign.parameters: + transform_kwargs[key] = value + + return lifting_map_kwargs, transform_kwargs class DataTransform(torch_geometric.transforms.BaseTransform): @@ -19,14 +61,21 @@ class DataTransform(torch_geometric.transforms.BaseTransform): def __init__(self, transform_name, **kwargs): super().__init__() - kwargs["transform_name"] = transform_name - self.parameters = kwargs + if transform_name not in LIFTINGS: + kwargs["transform_name"] = transform_name + transform = TRANSFORMS[transform_name](**kwargs) + else: + LiftingMap_ = TRANSFORMS[transform_name] + Transform = _map_lifting_to_transform(transform_name) + lifting_map_kwargs, transform_kwargs = _route_lifting_kwargs( + kwargs, LiftingMap_, Transform + ) + + lifting_map = LiftingMap_(**lifting_map_kwargs) + transform = Transform(lifting_map, **transform_kwargs) - self.transform = ( - TRANSFORMS[transform_name](**kwargs) - if transform_name is not None - else None - ) + self.parameters = kwargs + self.transform = transform def forward( self, data: torch_geometric.data.Data diff --git a/topobenchmark/transforms/feature_liftings/__init__.py b/topobenchmark/transforms/feature_liftings/__init__.py index ec4f763c..6e047683 100644 --- a/topobenchmark/transforms/feature_liftings/__init__.py +++ b/topobenchmark/transforms/feature_liftings/__init__.py @@ -1,104 +1,12 @@ """Feature lifting transforms with automated exports.""" -import inspect -from importlib import util -from pathlib import Path -from typing import Any +from topobenchmark.transforms._utils import discover_objs -from .identity import Identity # Import Identity for special case +from .base import FeatureLiftingMap - -class ModuleExportsManager: - """Manages automatic discovery and registration of feature lifting classes.""" - - @staticmethod - def is_lifting_class(obj: Any) -> bool: - """Check if an object is a valid lifting class. - - Parameters - ---------- - obj : Any - The object to check if it's a valid lifting class. - - Returns - ------- - bool - True if the object is a valid lifting class (non-private class - defined in __main__), False otherwise. - """ - return ( - inspect.isclass(obj) - and obj.__module__ == "__main__" - and not obj.__name__.startswith("_") - ) - - @classmethod - def discover_liftings( - cls, package_path: str, special_cases: dict[Any, type] | None = None - ) -> dict[str, type]: - """Dynamically discover all lifting classes in the package. - - Parameters - ---------- - package_path : str - Path to the package's __init__.py file. - special_cases : Optional[dict[Any, type]] - Dictionary of special case mappings (e.g., {None: Identity}), - by default None. - - Returns - ------- - dict[str, type] - Dictionary mapping class names to their corresponding class objects, - including any special cases if provided. - """ - liftings = {} - - # Get the directory containing the lifting modules - package_dir = Path(package_path).parent - - # Iterate through all .py files in the directory - for file_path in package_dir.glob("*.py"): - if file_path.stem == "__init__": - continue - - # Import the module - module_name = f"{Path(package_path).stem}.{file_path.stem}" - spec = util.spec_from_file_location(module_name, file_path) - if spec and spec.loader: - module = util.module_from_spec(spec) - spec.loader.exec_module(module) - - # Find all lifting classes in the module - for name, obj in inspect.getmembers(module): - if ( - inspect.isclass(obj) - and obj.__module__ == module.__name__ - and not name.startswith("_") - ): - liftings[name] = obj # noqa: PERF403 - - # Add special cases if provided - if special_cases: - liftings.update(special_cases) - - return liftings - - -# Create the exports manager -manager = ModuleExportsManager() - -# Automatically discover and populate FEATURE_LIFTINGS with special case for None -FEATURE_LIFTINGS = manager.discover_liftings( - __file__, special_cases={None: Identity} +FEATURE_LIFTINGS = discover_objs( + __file__, + condition=lambda name, obj: issubclass(obj, FeatureLiftingMap), ) -# Automatically generate __all__ (excluding None key) -__all__ = [name for name in FEATURE_LIFTINGS if isinstance(name, str)] + [ - "FEATURE_LIFTINGS" -] - -# For backwards compatibility, create individual imports (excluding None key) -locals().update( - {k: v for k, v in FEATURE_LIFTINGS.items() if isinstance(k, str)} -) +locals().update(FEATURE_LIFTINGS) diff --git a/topobenchmark/transforms/feature_liftings/base.py b/topobenchmark/transforms/feature_liftings/base.py new file mode 100644 index 00000000..c5969398 --- /dev/null +++ b/topobenchmark/transforms/feature_liftings/base.py @@ -0,0 +1,13 @@ +import abc + + +class FeatureLiftingMap(abc.ABC): + """Feature lifting map.""" + + def __call__(self, domain): + """Lift features of a domain.""" + return self.lift_features(domain) + + @abc.abstractmethod + def lift_features(self, domain): + """Lift features of a domain.""" diff --git a/topobenchmark/transforms/feature_liftings/concatenation.py b/topobenchmark/transforms/feature_liftings/concatenation.py index 5a69f46d..44e3b192 100644 --- a/topobenchmark/transforms/feature_liftings/concatenation.py +++ b/topobenchmark/transforms/feature_liftings/concatenation.py @@ -1,83 +1,53 @@ """Concatenation feature lifting.""" import torch -import torch_geometric +from topobenchmark.transforms.feature_liftings.base import FeatureLiftingMap -class Concatenation(torch_geometric.transforms.BaseTransform): - r"""Lift r-cell features to r+1-cells by concatenation. - Parameters - ---------- - **kwargs : optional - Additional arguments for the class. - """ - - def __init__(self, **kwargs): - super().__init__() +class Concatenation(FeatureLiftingMap): + """Lift r-cell features to r+1-cells by concatenation.""" def __repr__(self) -> str: return f"{self.__class__.__name__}()" - def lift_features( - self, data: torch_geometric.data.Data | dict - ) -> torch_geometric.data.Data | dict: + def lift_features(self, domain): r"""Concatenate r-cell features to obtain r+1-cell features. Parameters ---------- - data : torch_geometric.data.Data | dict - The input data to be lifted. - - Returns - ------- - torch_geometric.data.Data | dict - The lifted data. - """ - keys = sorted( - [ - key.split("_")[1] - for key in data - if "incidence" in key and "-" not in key - ] - ) - for elem in keys: - if f"x_{elem}" not in data: - idx_to_project = 0 if elem == "hyperedges" else int(elem) - 1 - incidence = data["incidence_" + elem] - _, n = incidence.shape - - if n != 0: - idxs_list = [] - for n_feature in range(n): - idxs_for_feature = incidence.indices()[ - 0, incidence.indices()[1, :] == n_feature - ] - idxs_list.append(torch.sort(idxs_for_feature)[0]) - - idxs = torch.stack(idxs_list, dim=0) - values = data[f"x_{idx_to_project}"][idxs].view(n, -1) - else: - m = data[f"x_{int(elem)-1}"].shape[1] * (int(elem) + 1) - values = torch.zeros([0, m]) - - data["x_" + elem] = values - return data - - def forward( - self, data: torch_geometric.data.Data | dict - ) -> torch_geometric.data.Data | dict: - r"""Apply the lifting to the input data. - - Parameters - ---------- - data : torch_geometric.data.Data | dict + data : Complex The input data to be lifted. Returns ------- - torch_geometric.data.Data | dict - The lifted data. + Complex + Domain with the lifted features. """ - data = self.lift_features(data) - return data + for key, next_key in zip( + domain.keys(), domain.keys()[1:], strict=False + ): + if domain.features[next_key] is not None: + continue + + incidence = domain.incidence[next_key] + _, n = incidence.shape + + if n != 0: + idxs_list = [] + for n_feature in range(n): + idxs_for_feature = incidence.indices()[ + 0, incidence.indices()[1, :] == n_feature + ] + idxs_list.append(torch.sort(idxs_for_feature)[0]) + + idxs = torch.stack(idxs_list, dim=0) + values = domain.features[key][idxs].view(n, -1) + else: + # NB: only works if key represents rank + m = domain.features[key].shape[1] * (next_key + 1) + values = torch.zeros([0, m]) + + domain.update_features(next_key, values) + + return domain diff --git a/topobenchmark/transforms/feature_liftings/identity.py b/topobenchmark/transforms/feature_liftings/identity.py index 93806f1d..e640bd06 100644 --- a/topobenchmark/transforms/feature_liftings/identity.py +++ b/topobenchmark/transforms/feature_liftings/identity.py @@ -1,36 +1,13 @@ """Identity transform that does nothing to the input data.""" -import torch_geometric +from topobenchmark.transforms.feature_liftings.base import FeatureLiftingMap -class Identity(torch_geometric.transforms.BaseTransform): - r"""An identity transform that does nothing to the input data. +class Identity(FeatureLiftingMap): + """Identity feature lifting map.""" - Parameters - ---------- - **kwargs : optional - Parameters for the base transform. - """ + # TODO: rename to IdentityFeatureLifting - def __init__(self, **kwargs): - super().__init__() - self.type = "domain2domain" - self.parameters = kwargs - - def __repr__(self) -> str: - return f"{self.__class__.__name__}(type={self.type!r}, parameters={self.parameters!r})" - - def forward(self, data: torch_geometric.data.Data): - r"""Apply the transform to the input data. - - Parameters - ---------- - data : torch_geometric.data.Data - The input data. - - Returns - ------- - torch_geometric.data.Data - The same data. - """ - return data + def lift_features(self, domain): + """Lift features of a domain using identity map.""" + return domain diff --git a/topobenchmark/transforms/feature_liftings/projection_sum.py b/topobenchmark/transforms/feature_liftings/projection_sum.py index 3cce03eb..757234a7 100644 --- a/topobenchmark/transforms/feature_liftings/projection_sum.py +++ b/topobenchmark/transforms/feature_liftings/projection_sum.py @@ -1,69 +1,38 @@ """ProjectionSum class.""" import torch -import torch_geometric +from topobenchmark.transforms.feature_liftings.base import FeatureLiftingMap -class ProjectionSum(torch_geometric.transforms.BaseTransform): - r"""Lift r-cell features to r+1-cells by projection. - Parameters - ---------- - **kwargs : optional - Additional arguments for the class. - """ +class ProjectionSum(FeatureLiftingMap): + r"""Lift r-cell features to r+1-cells by projection.""" - def __init__(self, **kwargs): - super().__init__() - - def __repr__(self) -> str: - return f"{self.__class__.__name__}()" - - def lift_features( - self, data: torch_geometric.data.Data | dict - ) -> torch_geometric.data.Data | dict: + def lift_features(self, domain): r"""Project r-cell features of a graph to r+1-cell structures. Parameters ---------- - data : torch_geometric.data.Data | dict - The input data to be lifted. - - Returns - ------- - torch_geometric.data.Data | dict - The data with the lifted features. - """ - keys = sorted( - [ - key.split("_")[1] - for key in data - if ("incidence" in key and "-" not in key) - ] - ) - for elem in keys: - if f"x_{elem}" not in data: - idx_to_project = 0 if elem == "hyperedges" else int(elem) - 1 - data["x_" + elem] = torch.matmul( - abs(data["incidence_" + elem].t()), - data[f"x_{idx_to_project}"], - ) - return data - - def forward( - self, data: torch_geometric.data.Data | dict - ) -> torch_geometric.data.Data | dict: - r"""Apply the lifting to the input data. - - Parameters - ---------- - data : torch_geometric.data.Data | dict + data : Data The input data to be lifted. Returns ------- - torch_geometric.data.Data | dict - The lifted data. + Data + Domain with the lifted features. """ - data = self.lift_features(data) - return data + for key, next_key in zip( + domain.keys(), domain.keys()[1:], strict=False + ): + if domain.features[next_key] is not None: + continue + + domain.update_features( + next_key, + torch.matmul( + torch.abs(domain.incidence[next_key].t()), + domain.features[key], + ), + ) + + return domain diff --git a/topobenchmark/transforms/feature_liftings/set.py b/topobenchmark/transforms/feature_liftings/set.py index 28ccd0cc..54ac1b9d 100644 --- a/topobenchmark/transforms/feature_liftings/set.py +++ b/topobenchmark/transforms/feature_liftings/set.py @@ -1,89 +1,60 @@ """Set lifting for r-cell features to r+1-cell features.""" import torch -import torch_geometric +from topobenchmark.transforms.feature_liftings.base import FeatureLiftingMap -class Set(torch_geometric.transforms.BaseTransform): - r"""Lift r-cell features to r+1-cells by set operations. - Parameters - ---------- - **kwargs : optional - Additional arguments for the class. - """ - - def __init__(self, **kwargs): - super().__init__() +class Set(FeatureLiftingMap): + """Lift r-cell features to r+1-cells by set operations.""" def __repr__(self) -> str: return f"{self.__class__.__name__}()" - def lift_features( - self, data: torch_geometric.data.Data | dict - ) -> torch_geometric.data.Data | dict: + def lift_features(self, domain): r"""Concatenate r-cell features to r+1-cell structures. Parameters ---------- - data : torch_geometric.data.Data | dict + data : Complex The input data to be lifted. Returns ------- - torch_geometric.data.Data | dict - The lifted data. + Complex + Domain with the lifted features. """ - keys = sorted( - [key.split("_")[1] for key in data if "incidence" in key] - ) - for elem in keys: - if f"x_{elem}" not in data: - # idx_to_project = 0 if elem == "hyperedges" else int(elem) - 1 - incidence = data["incidence_" + elem] - _, n = incidence.shape - - if n != 0: - idxs_list = [] - for n_feature in range(n): - idxs_for_feature = incidence.indices()[ - 0, incidence.indices()[1, :] == n_feature - ] - idxs_list.append(torch.sort(idxs_for_feature)[0]) - - idxs = torch.stack(idxs_list, dim=0) - if elem == "1": - values = idxs - else: - values = torch.sort( - torch.unique( - data["x_" + str(int(elem) - 1)][idxs].view( - idxs.shape[0], -1 - ), - dim=1, - ), - dim=1, - )[0] + for key, next_key in zip( + domain.keys(), domain.keys()[1:], strict=False + ): + if domain.features[next_key] is not None: + continue + + incidence = domain.incidence[next_key] + _, n = incidence.shape + + if n != 0: + idxs_list = [] + for n_feature in range(n): + idxs_for_feature = incidence.indices()[ + 0, incidence.indices()[1, :] == n_feature + ] + idxs_list.append(torch.sort(idxs_for_feature)[0]) + + idxs = torch.stack(idxs_list, dim=0) + if key == 0: + values = idxs else: - values = torch.tensor([]) - - data["x_" + elem] = values - return data + values = torch.sort( + torch.unique( + domain.features[key][idxs].view(idxs.shape[0], -1), + dim=1, + ), + dim=1, + )[0] + else: + values = torch.tensor([]) - def forward( - self, data: torch_geometric.data.Data | dict - ) -> torch_geometric.data.Data | dict: - r"""Apply the lifting to the input data. + domain.update_features(next_key, values) - Parameters - ---------- - data : torch_geometric.data.Data | dict - The input data to be lifted. - - Returns - ------- - torch_geometric.data.Data | dict - The lifted data. - """ - data = self.lift_features(data) - return data + return domain diff --git a/topobenchmark/transforms/liftings/__init__.py b/topobenchmark/transforms/liftings/__init__.py index 4692ceaf..10e1e3c1 100755 --- a/topobenchmark/transforms/liftings/__init__.py +++ b/topobenchmark/transforms/liftings/__init__.py @@ -1,21 +1,22 @@ """This module implements the liftings for the topological transforms.""" -from .base import AbstractLifting -from .liftings import ( - CellComplexLifting, - CombinatorialLifting, - GraphLifting, - HypergraphLifting, - PointCloudLifting, - SimplicialLifting, +from .base import ( # noqa: F401 + Graph2CellLiftingTransform, + Graph2ComplexLiftingTransform, + Graph2HypergraphLiftingTransform, + Graph2SimplicialLiftingTransform, + LiftingMap, + LiftingTransform, ) +from .graph2cell import GRAPH2CELL_LIFTINGS +from .graph2hypergraph import GRAPH2HYPERGRAPH_LIFTINGS +from .graph2simplicial import GRAPH2SIMPLICIAL_LIFTINGS -__all__ = [ - "AbstractLifting", - "CellComplexLifting", - "CombinatorialLifting", - "GraphLifting", - "HypergraphLifting", - "PointCloudLifting", - "SimplicialLifting", -] +LIFTINGS = { + **GRAPH2CELL_LIFTINGS, + **GRAPH2HYPERGRAPH_LIFTINGS, + **GRAPH2SIMPLICIAL_LIFTINGS, +} + + +locals().update(LIFTINGS) diff --git a/topobenchmark/transforms/liftings/base.py b/topobenchmark/transforms/liftings/base.py index 99bd720e..13f1f443 100644 --- a/topobenchmark/transforms/liftings/base.py +++ b/topobenchmark/transforms/liftings/base.py @@ -1,43 +1,70 @@ """Abstract class for topological liftings.""" -from abc import abstractmethod +import abc import torch_geometric -from topobenchmark.transforms.feature_liftings import FEATURE_LIFTINGS +from topobenchmark.data.utils import ( + ComplexData2Dict, + Data2NxGraph, + HypergraphData2Dict, + IdentityAdapter, + TnxComplex2ComplexData, +) -class AbstractLifting(torch_geometric.transforms.BaseTransform): - r"""Abstract class for topological liftings. +class LiftingTransform(torch_geometric.transforms.BaseTransform): + """Lifting transform. Parameters ---------- - feature_lifting : str, optional - The feature lifting method to be used. Default is 'ProjectionSum'. - **kwargs : optional - Additional arguments for the class. + lifting : LiftingMap + Lifting map. + data2domain : Converter + Conversion between ``torch_geometric.Data`` into + domain for consumption by lifting. + domain2dict : Converter + Conversion between output domain of feature lifting + and ``torch_geometric.Data``. + domain2domain : Converter + Conversion between output domain of lifting + and input domain for feature lifting. + feature_lifting : FeatureLiftingMap + Feature lifting map. """ - def __init__(self, feature_lifting=None, **kwargs): - super().__init__() - self.feature_lifting = FEATURE_LIFTINGS[feature_lifting]() - self.neighborhoods = kwargs.get("neighborhoods") + def __init__( + self, + lifting, + data2domain=None, + domain2dict=None, + domain2domain=None, + feature_lifting="ProjectionSum", + ): + if data2domain is None: + data2domain = IdentityAdapter() - @abstractmethod - def lift_topology(self, data: torch_geometric.data.Data) -> dict: - r"""Lift the topology of a graph to higher-order topological domains. + if domain2dict is None: + domain2dict = IdentityAdapter() - Parameters - ---------- - data : torch_geometric.data.Data - The input data to be lifted. + if domain2domain is None: + domain2domain = IdentityAdapter() - Returns - ------- - dict - The lifted topology. - """ - raise NotImplementedError + if isinstance(lifting, str): + from topobenchmark.transforms import TRANSFORMS + + lifting = TRANSFORMS[lifting]() + + if isinstance(feature_lifting, str): + from topobenchmark.transforms import TRANSFORMS + + feature_lifting = TRANSFORMS[feature_lifting]() + + self.data2domain = data2domain + self.domain2domain = domain2domain + self.domain2dict = domain2dict + self.lifting = lifting + self.feature_lifting = feature_lifting def forward( self, data: torch_geometric.data.Data @@ -55,6 +82,86 @@ def forward( The lifted data. """ initial_data = data.to_dict() - lifted_topology = self.lift_topology(data) + + domain = self.data2domain(data) + lifted_topology = self.lifting(domain) + lifted_topology = self.domain2domain(lifted_topology) lifted_topology = self.feature_lifting(lifted_topology) - return torch_geometric.data.Data(**initial_data, **lifted_topology) + lifted_topology_dict = self.domain2dict(lifted_topology) + + return torch_geometric.data.Data( + **initial_data, **lifted_topology_dict + ) + + +class Graph2ComplexLiftingTransform(LiftingTransform): + """Graph to complex lifting transform. + + Parameters + ---------- + lifting : LiftingMap + Lifting map. + feature_lifting : FeatureLiftingMap + Feature lifting map. + preserve_edge_attr : bool + Whether to preserve edge attributes. + neighborhoods : list, optional + List of neighborhoods of interest. + signed : bool, optional + If True, returns signed connectivity matrices. + transfer_features : bool, optional + Whether to transfer features. + """ + + def __init__( + self, + lifting, + feature_lifting="ProjectionSum", + preserve_edge_attr=False, + neighborhoods=None, + signed=False, + transfer_features=True, + ): + super().__init__( + lifting, + feature_lifting=feature_lifting, + data2domain=Data2NxGraph(preserve_edge_attr), + domain2domain=TnxComplex2ComplexData( + neighborhoods=neighborhoods, + signed=signed, + transfer_features=transfer_features, + ), + domain2dict=ComplexData2Dict(), + ) + + +Graph2SimplicialLiftingTransform = Graph2ComplexLiftingTransform +Graph2CellLiftingTransform = Graph2ComplexLiftingTransform + + +class Graph2HypergraphLiftingTransform(LiftingTransform): + def __init__( + self, + lifting, + feature_lifting="ProjectionSum", + ): + super().__init__( + lifting, + feature_lifting=feature_lifting, + domain2dict=HypergraphData2Dict(), + ) + + +class LiftingMap(abc.ABC): + """Lifting map. + + Lifts a domain into another. + """ + + def __call__(self, domain): + """Lift domain.""" + return self.lift(domain) + + @abc.abstractmethod + def lift(self, domain): + """Lift domain.""" diff --git a/topobenchmark/transforms/liftings/graph2cell/__init__.py b/topobenchmark/transforms/liftings/graph2cell/__init__.py index d0faae96..480ada64 100755 --- a/topobenchmark/transforms/liftings/graph2cell/__init__.py +++ b/topobenchmark/transforms/liftings/graph2cell/__init__.py @@ -1,96 +1,11 @@ """Graph2Cell liftings with automated exports.""" -import inspect -from importlib import util -from pathlib import Path -from typing import Any +from topobenchmark.transforms._utils import discover_objs +from topobenchmark.transforms.liftings.base import LiftingMap -from .base import Graph2CellLifting +GRAPH2CELL_LIFTINGS = discover_objs( + __file__, + condition=lambda name, obj: issubclass(obj, LiftingMap), +) - -class ModuleExportsManager: - """Manages automatic discovery and registration of Graph2Cell lifting classes.""" - - @staticmethod - def is_lifting_class(obj: Any) -> bool: - """Check if an object is a valid Graph2Cell lifting class. - - Parameters - ---------- - obj : Any - The object to check if it's a valid lifting class. - - Returns - ------- - bool - True if the object is a valid Graph2Cell lifting class (non-private class - inheriting from Graph2CellLifting), False otherwise. - """ - return ( - inspect.isclass(obj) - and obj.__module__ == "__main__" - and not obj.__name__.startswith("_") - and issubclass(obj, Graph2CellLifting) - and obj != Graph2CellLifting - ) - - @classmethod - def discover_liftings(cls, package_path: str) -> dict[str, type]: - """Dynamically discover all Graph2Cell lifting classes in the package. - - Parameters - ---------- - package_path : str - Path to the package's __init__.py file. - - Returns - ------- - dict[str, type] - Dictionary mapping class names to their corresponding class objects. - """ - liftings = {} - - # Get the directory containing the lifting modules - package_dir = Path(package_path).parent - - # Iterate through all .py files in the directory - for file_path in package_dir.glob("*.py"): - if file_path.stem == "__init__": - continue - - # Import the module - module_name = f"{Path(package_path).stem}.{file_path.stem}" - spec = util.spec_from_file_location(module_name, file_path) - if spec and spec.loader: - module = util.module_from_spec(spec) - spec.loader.exec_module(module) - - # Find all lifting classes in the module - for name, obj in inspect.getmembers(module): - if ( - inspect.isclass(obj) - and obj.__module__ == module.__name__ - and not name.startswith("_") - and issubclass(obj, Graph2CellLifting) - and obj != Graph2CellLifting - ): - liftings[name] = obj # noqa: PERF403 - - return liftings - - -# Create the exports manager -manager = ModuleExportsManager() - -# Automatically discover and populate GRAPH2CELL_LIFTINGS -GRAPH2CELL_LIFTINGS = manager.discover_liftings(__file__) - -# Automatically generate __all__ -__all__ = [ - *GRAPH2CELL_LIFTINGS.keys(), - "Graph2CellLifting", - "GRAPH2CELL_LIFTINGS", -] - -# For backwards compatibility, create individual imports -locals().update(**GRAPH2CELL_LIFTINGS) +locals().update(GRAPH2CELL_LIFTINGS) diff --git a/topobenchmark/transforms/liftings/graph2cell/base.py b/topobenchmark/transforms/liftings/graph2cell/base.py deleted file mode 100755 index aeff3646..00000000 --- a/topobenchmark/transforms/liftings/graph2cell/base.py +++ /dev/null @@ -1,57 +0,0 @@ -"""Abstract class for lifting graphs to cell complexes.""" - -import networkx as nx -import torch -from toponetx.classes import CellComplex - -from topobenchmark.data.utils.utils import get_complex_connectivity -from topobenchmark.transforms.liftings import GraphLifting - - -class Graph2CellLifting(GraphLifting): - r"""Abstract class for lifting graphs to cell complexes. - - Parameters - ---------- - complex_dim : int, optional - The dimension of the cell complex to be generated. Default is 2. - **kwargs : optional - Additional arguments for the class. - """ - - def __init__(self, complex_dim=2, **kwargs): - super().__init__(**kwargs) - self.complex_dim = complex_dim - self.type = "graph2cell" - - def _get_lifted_topology( - self, cell_complex: CellComplex, graph: nx.Graph - ) -> dict: - r"""Return the lifted topology. - - Parameters - ---------- - cell_complex : CellComplex - The cell complex. - graph : nx.Graph - The input graph. - - Returns - ------- - dict - The lifted topology. - """ - lifted_topology = get_complex_connectivity( - cell_complex, self.complex_dim, neighborhoods=self.neighborhoods - ) - lifted_topology["x_0"] = torch.stack( - list(cell_complex.get_cell_attributes("features", 0).values()) - ) - # If new edges have been added during the lifting process, we discard the edge attributes - if self.contains_edge_attr and cell_complex.shape[1] == ( - graph.number_of_edges() - ): - lifted_topology["x_1"] = torch.stack( - list(cell_complex.get_cell_attributes("features", 1).values()) - ) - return lifted_topology diff --git a/topobenchmark/transforms/liftings/graph2cell/cycle.py b/topobenchmark/transforms/liftings/graph2cell/cycle.py index 31e94d8b..63160701 100755 --- a/topobenchmark/transforms/liftings/graph2cell/cycle.py +++ b/topobenchmark/transforms/liftings/graph2cell/cycle.py @@ -1,15 +1,12 @@ """This module implements the cycle lifting for graphs to cell complexes.""" import networkx as nx -import torch_geometric from toponetx.classes import CellComplex -from topobenchmark.transforms.liftings.graph2cell.base import ( - Graph2CellLifting, -) +from topobenchmark.transforms.liftings.base import LiftingMap -class CellCycleLifting(Graph2CellLifting): +class CellCycleLifting(LiftingMap): r"""Lift graphs to cell complexes. The algorithm creates 2-cells by identifying the cycles and considering them as 2-cells. @@ -18,39 +15,40 @@ class CellCycleLifting(Graph2CellLifting): ---------- max_cell_length : int, optional The maximum length of the cycles to be lifted. Default is None. - **kwargs : optional - Additional arguments for the class. """ - def __init__(self, max_cell_length=None, **kwargs): - super().__init__(**kwargs) - self.complex_dim = 2 + def __init__(self, max_cell_length=None): + super().__init__() + self._complex_dim = 2 self.max_cell_length = max_cell_length - def lift_topology(self, data: torch_geometric.data.Data) -> dict: + def lift(self, domain): r"""Find the cycles of a graph and lifts them to 2-cells. Parameters ---------- - data : torch_geometric.data.Data - The input data to be lifted. + domain : nx.Graph + Graph to be lifted. Returns ------- - dict - The lifted topology. + CellComplex + The cell complex. """ - G = self._generate_graph_from_data(data) - cycles = nx.cycle_basis(G) - cell_complex = CellComplex(G) + graph = domain + + cycles = nx.cycle_basis(graph) + cell_complex = CellComplex(graph) # Eliminate self-loop cycles cycles = [cycle for cycle in cycles if len(cycle) != 1] - # Eliminate cycles that are greater than the max_cell_lenght + + # Eliminate cycles that are greater than the max_cell_length if self.max_cell_length is not None: cycles = [ cycle for cycle in cycles if len(cycle) <= self.max_cell_length ] if len(cycles) != 0: - cell_complex.add_cells_from(cycles, rank=self.complex_dim) - return self._get_lifted_topology(cell_complex, G) + cell_complex.add_cells_from(cycles, rank=self._complex_dim) + + return cell_complex diff --git a/topobenchmark/transforms/liftings/graph2hypergraph/__init__.py b/topobenchmark/transforms/liftings/graph2hypergraph/__init__.py index acb89e0c..e7a5a815 100755 --- a/topobenchmark/transforms/liftings/graph2hypergraph/__init__.py +++ b/topobenchmark/transforms/liftings/graph2hypergraph/__init__.py @@ -1,96 +1,11 @@ """Graph2HypergraphLifting module with automated exports.""" -import inspect -from importlib import util -from pathlib import Path -from typing import Any +from topobenchmark.transforms._utils import discover_objs +from topobenchmark.transforms.liftings.base import LiftingMap -from .base import Graph2HypergraphLifting +GRAPH2HYPERGRAPH_LIFTINGS = discover_objs( + __file__, + condition=lambda name, obj: issubclass(obj, LiftingMap), +) - -class ModuleExportsManager: - """Manages automatic discovery and registration of Graph2Hypergraph lifting classes.""" - - @staticmethod - def is_lifting_class(obj: Any) -> bool: - """Check if an object is a valid Graph2Hypergraph lifting class. - - Parameters - ---------- - obj : Any - The object to check if it's a valid lifting class. - - Returns - ------- - bool - True if the object is a valid Graph2Hypergraph lifting class (non-private class - inheriting from Graph2HypergraphLifting), False otherwise. - """ - return ( - inspect.isclass(obj) - and obj.__module__ == "__main__" - and not obj.__name__.startswith("_") - and issubclass(obj, Graph2HypergraphLifting) - and obj != Graph2HypergraphLifting - ) - - @classmethod - def discover_liftings(cls, package_path: str) -> dict[str, type]: - """Dynamically discover all Graph2Hypergraph lifting classes in the package. - - Parameters - ---------- - package_path : str - Path to the package's __init__.py file. - - Returns - ------- - dict[str, type] - Dictionary mapping class names to their corresponding class objects. - """ - liftings = {} - - # Get the directory containing the lifting modules - package_dir = Path(package_path).parent - - # Iterate through all .py files in the directory - for file_path in package_dir.glob("*.py"): - if file_path.stem == "__init__": - continue - - # Import the module - module_name = f"{Path(package_path).stem}.{file_path.stem}" - spec = util.spec_from_file_location(module_name, file_path) - if spec and spec.loader: - module = util.module_from_spec(spec) - spec.loader.exec_module(module) - - # Find all lifting classes in the module - for name, obj in inspect.getmembers(module): - if ( - inspect.isclass(obj) - and obj.__module__ == module.__name__ - and not name.startswith("_") - and issubclass(obj, Graph2HypergraphLifting) - and obj != Graph2HypergraphLifting - ): - liftings[name] = obj # noqa: PERF403 - - return liftings - - -# Create the exports manager -manager = ModuleExportsManager() - -# Automatically discover and populate GRAPH2HYPERGRAPH_LIFTINGS -GRAPH2HYPERGRAPH_LIFTINGS = manager.discover_liftings(__file__) - -# Automatically generate __all__ -__all__ = [ - *GRAPH2HYPERGRAPH_LIFTINGS.keys(), - "Graph2HypergraphLifting", - "GRAPH2HYPERGRAPH_LIFTINGS", -] - -# For backwards compatibility, create individual imports -locals().update(**GRAPH2HYPERGRAPH_LIFTINGS) +locals().update(GRAPH2HYPERGRAPH_LIFTINGS) diff --git a/topobenchmark/transforms/liftings/graph2hypergraph/base.py b/topobenchmark/transforms/liftings/graph2hypergraph/base.py deleted file mode 100755 index e060e30e..00000000 --- a/topobenchmark/transforms/liftings/graph2hypergraph/base.py +++ /dev/null @@ -1,17 +0,0 @@ -"""Abstract class for lifting graphs to hypergraphs.""" - -from topobenchmark.transforms.liftings import GraphLifting - - -class Graph2HypergraphLifting(GraphLifting): - r"""Abstract class for lifting graphs to hypergraphs. - - Parameters - ---------- - **kwargs : optional - Additional arguments for the class. - """ - - def __init__(self, **kwargs): - super().__init__(**kwargs) - self.type = "graph2hypergraph" diff --git a/topobenchmark/transforms/liftings/graph2hypergraph/khop.py b/topobenchmark/transforms/liftings/graph2hypergraph/khop.py index 298fa135..7c56006c 100755 --- a/topobenchmark/transforms/liftings/graph2hypergraph/khop.py +++ b/topobenchmark/transforms/liftings/graph2hypergraph/khop.py @@ -3,12 +3,11 @@ import torch import torch_geometric -from topobenchmark.transforms.liftings.graph2hypergraph import ( - Graph2HypergraphLifting, -) +from topobenchmark.data.utils import HypergraphData +from topobenchmark.transforms.liftings.base import LiftingMap -class HypergraphKHopLifting(Graph2HypergraphLifting): +class HypergraphKHopLifting(LiftingMap): r"""Lift graph to hypergraphs by considering k-hop neighborhoods. The class transforms graphs to hypergraph domain by considering k-hop neighborhoods of @@ -19,18 +18,16 @@ class HypergraphKHopLifting(Graph2HypergraphLifting): ---------- k_value : int, optional The number of hops to consider. Default is 1. - **kwargs : optional - Additional arguments for the class. """ - def __init__(self, k_value=1, **kwargs): - super().__init__(**kwargs) - self.k = k_value + def __init__(self, k_value=1): + super().__init__() + self.n_hops = k_value def __repr__(self) -> str: - return f"{self.__class__.__name__}(k={self.k!r})" + return f"{self.__class__.__name__}(k={self.n_hops!r})" - def lift_topology(self, data: torch_geometric.data.Data) -> dict: + def lift(self, data: torch_geometric.data.Data) -> dict: r"""Lift a graphs to hypergraphs by considering k-hop neighborhoods. Parameters @@ -40,7 +37,7 @@ def lift_topology(self, data: torch_geometric.data.Data) -> dict: Returns ------- - dict + HypergraphData The lifted topology. """ # Check if data has instance x: @@ -70,14 +67,14 @@ def lift_topology(self, data: torch_geometric.data.Data) -> dict: for n in range(num_nodes): neighbors, _, _, _ = torch_geometric.utils.k_hop_subgraph( - n, self.k, edge_index + n, self.n_hops, edge_index ) incidence_1[n, neighbors] = 1 num_hyperedges = incidence_1.shape[1] incidence_1 = torch.Tensor(incidence_1).to_sparse_coo() - return { - "incidence_hyperedges": incidence_1, - "num_hyperedges": num_hyperedges, - "x_0": data.x, - } + return HypergraphData( + incidence_hyperedges=incidence_1, + num_hyperedges=num_hyperedges, + x_0=data.x, + ) diff --git a/topobenchmark/transforms/liftings/graph2hypergraph/knn.py b/topobenchmark/transforms/liftings/graph2hypergraph/knn.py index 03d0a13a..5b0de672 100755 --- a/topobenchmark/transforms/liftings/graph2hypergraph/knn.py +++ b/topobenchmark/transforms/liftings/graph2hypergraph/knn.py @@ -3,12 +3,10 @@ import torch import torch_geometric -from topobenchmark.transforms.liftings.graph2hypergraph import ( - Graph2HypergraphLifting, -) +from topobenchmark.transforms.liftings.base import LiftingMap -class HypergraphKNNLifting(Graph2HypergraphLifting): +class HypergraphKNNLifting(LiftingMap): r"""Lift graphs to hypergraph domain by considering k-nearest neighbors. Parameters @@ -17,8 +15,6 @@ class HypergraphKNNLifting(Graph2HypergraphLifting): The number of nearest neighbors to consider. Must be positive. Default is 1. loop : bool, optional If True the hyperedges will contain the node they were created from. - **kwargs : optional - Additional arguments for the class. Raises ------ @@ -28,8 +24,8 @@ class HypergraphKNNLifting(Graph2HypergraphLifting): If k_value is not an integer or if loop is not a boolean. """ - def __init__(self, k_value=1, loop=True, **kwargs): - super().__init__(**kwargs) + def __init__(self, k_value=1, loop=True): + super().__init__() # Validate k_value if not isinstance(k_value, int): @@ -41,11 +37,9 @@ def __init__(self, k_value=1, loop=True, **kwargs): if not isinstance(loop, bool): raise TypeError("loop must be a boolean") - self.k = k_value - self.loop = loop - self.transform = torch_geometric.transforms.KNNGraph(self.k, self.loop) + self.transform = torch_geometric.transforms.KNNGraph(k_value, loop) - def lift_topology(self, data: torch_geometric.data.Data) -> dict: + def lift(self, data: torch_geometric.data.Data) -> dict: r"""Lift a graph to hypergraph by considering k-nearest neighbors. Parameters @@ -64,7 +58,7 @@ def lift_topology(self, data: torch_geometric.data.Data) -> dict: incidence_1 = torch.zeros(num_nodes, num_nodes) data_lifted = self.transform(data) # check for loops, since KNNGraph is inconsistent with nodes with equal features - if self.loop: + if self.transform.loop: for i in range(num_nodes): if not torch.any( torch.all( diff --git a/topobenchmark/transforms/liftings/graph2simplicial/__init__.py b/topobenchmark/transforms/liftings/graph2simplicial/__init__.py index 238691cd..9e77797b 100755 --- a/topobenchmark/transforms/liftings/graph2simplicial/__init__.py +++ b/topobenchmark/transforms/liftings/graph2simplicial/__init__.py @@ -1,96 +1,11 @@ """Graph2SimplicialLifting module with automated exports.""" -import inspect -from importlib import util -from pathlib import Path -from typing import Any +from topobenchmark.transforms._utils import discover_objs +from topobenchmark.transforms.liftings.base import LiftingMap -from .base import Graph2SimplicialLifting +GRAPH2SIMPLICIAL_LIFTINGS = discover_objs( + __file__, + condition=lambda name, obj: issubclass(obj, LiftingMap), +) - -class ModuleExportsManager: - """Manages automatic discovery and registration of Graph2Simplicial lifting classes.""" - - @staticmethod - def is_lifting_class(obj: Any) -> bool: - """Check if an object is a valid Graph2Simplicial lifting class. - - Parameters - ---------- - obj : Any - The object to check if it's a valid lifting class. - - Returns - ------- - bool - True if the object is a valid Graph2Simplicial lifting class (non-private class - inheriting from Graph2SimplicialLifting), False otherwise. - """ - return ( - inspect.isclass(obj) - and obj.__module__ == "__main__" - and not obj.__name__.startswith("_") - and issubclass(obj, Graph2SimplicialLifting) - and obj != Graph2SimplicialLifting - ) - - @classmethod - def discover_liftings(cls, package_path: str) -> dict[str, type]: - """Dynamically discover all Graph2Simplicial lifting classes in the package. - - Parameters - ---------- - package_path : str - Path to the package's __init__.py file. - - Returns - ------- - dict[str, type] - Dictionary mapping class names to their corresponding class objects. - """ - liftings = {} - - # Get the directory containing the lifting modules - package_dir = Path(package_path).parent - - # Iterate through all .py files in the directory - for file_path in package_dir.glob("*.py"): - if file_path.stem == "__init__": - continue - - # Import the module - module_name = f"{Path(package_path).stem}.{file_path.stem}" - spec = util.spec_from_file_location(module_name, file_path) - if spec and spec.loader: - module = util.module_from_spec(spec) - spec.loader.exec_module(module) - - # Find all lifting classes in the module - for name, obj in inspect.getmembers(module): - if ( - inspect.isclass(obj) - and obj.__module__ == module.__name__ - and not name.startswith("_") - and issubclass(obj, Graph2SimplicialLifting) - and obj != Graph2SimplicialLifting - ): - liftings[name] = obj # noqa: PERF403 - - return liftings - - -# Create the exports manager -manager = ModuleExportsManager() - -# Automatically discover and populate GRAPH2SIMPLICIAL_LIFTINGS -GRAPH2SIMPLICIAL_LIFTINGS = manager.discover_liftings(__file__) - -# Automatically generate __all__ -__all__ = [ - *GRAPH2SIMPLICIAL_LIFTINGS.keys(), - "Graph2SimplicialLifting", - "GRAPH2SIMPLICIAL_LIFTINGS", -] - -# For backwards compatibility, create individual imports -locals().update(**GRAPH2SIMPLICIAL_LIFTINGS) +locals().update(GRAPH2SIMPLICIAL_LIFTINGS) diff --git a/topobenchmark/transforms/liftings/graph2simplicial/base.py b/topobenchmark/transforms/liftings/graph2simplicial/base.py deleted file mode 100755 index e52449dc..00000000 --- a/topobenchmark/transforms/liftings/graph2simplicial/base.py +++ /dev/null @@ -1,69 +0,0 @@ -"""Abstract class for lifting graphs to simplicial complexes.""" - -import networkx as nx -import torch -from toponetx.classes import SimplicialComplex - -from topobenchmark.data.utils.utils import get_complex_connectivity -from topobenchmark.transforms.liftings import GraphLifting - - -class Graph2SimplicialLifting(GraphLifting): - r"""Abstract class for lifting graphs to simplicial complexes. - - Parameters - ---------- - complex_dim : int, optional - The maximum dimension of the simplicial complex to be generated. Default is 2. - **kwargs : optional - Additional arguments for the class. - """ - - def __init__(self, complex_dim=2, **kwargs): - super().__init__(**kwargs) - self.complex_dim = complex_dim - self.type = "graph2simplicial" - self.signed = kwargs.get("signed", False) - - def _get_lifted_topology( - self, simplicial_complex: SimplicialComplex, graph: nx.Graph - ) -> dict: - r"""Return the lifted topology. - - Parameters - ---------- - simplicial_complex : SimplicialComplex - The simplicial complex. - graph : nx.Graph - The input graph. - - Returns - ------- - dict - The lifted topology. - """ - lifted_topology = get_complex_connectivity( - simplicial_complex, - self.complex_dim, - neighborhoods=self.neighborhoods, - signed=self.signed, - ) - lifted_topology["x_0"] = torch.stack( - list( - simplicial_complex.get_simplex_attributes( - "features", 0 - ).values() - ) - ) - # If new edges have been added during the lifting process, we discard the edge attributes - if self.contains_edge_attr and simplicial_complex.shape[1] == ( - graph.number_of_edges() - ): - lifted_topology["x_1"] = torch.stack( - list( - simplicial_complex.get_simplex_attributes( - "features", 1 - ).values() - ) - ) - return lifted_topology diff --git a/topobenchmark/transforms/liftings/graph2simplicial/clique.py b/topobenchmark/transforms/liftings/graph2simplicial/clique.py index 502144fa..41047a62 100755 --- a/topobenchmark/transforms/liftings/graph2simplicial/clique.py +++ b/topobenchmark/transforms/liftings/graph2simplicial/clique.py @@ -1,50 +1,47 @@ """This module implements the CliqueLifting class, which lifts graphs to simplicial complexes.""" from itertools import combinations -from typing import Any import networkx as nx -import torch_geometric from toponetx.classes import SimplicialComplex -from topobenchmark.transforms.liftings.graph2simplicial import ( - Graph2SimplicialLifting, -) +from topobenchmark.transforms.liftings.base import LiftingMap -class SimplicialCliqueLifting(Graph2SimplicialLifting): +class SimplicialCliqueLifting(LiftingMap): r"""Lift graphs to simplicial complex domain. - The algorithm creates simplices by identifying the cliques and considering them as simplices of the same dimension. + The algorithm creates simplices by identifying the cliques + and considering them as simplices of the same dimension. Parameters ---------- - **kwargs : optional - Additional arguments for the class. + complex_dim : int + Dimension of the subcomplex. """ - def __init__(self, **kwargs): - super().__init__(**kwargs) + def __init__(self, complex_dim=2): + super().__init__() + self.complex_dim = complex_dim - def lift_topology(self, data: torch_geometric.data.Data) -> dict: + def lift(self, domain): r"""Lift the topology of a graph to a simplicial complex. Parameters ---------- - data : torch_geometric.data.Data - The input data to be lifted. + domain : nx.Graph + Graph to be lifted. Returns ------- - dict - The lifted topology. + toponetx.Complex + Lifted simplicial complex. """ - graph = self._generate_graph_from_data(data) + graph = domain + simplicial_complex = SimplicialComplex(graph) cliques = nx.find_cliques(graph) - simplices: list[set[tuple[Any, ...]]] = [ - set() for _ in range(2, self.complex_dim + 1) - ] + simplices = [set() for _ in range(2, self.complex_dim + 1)] for clique in cliques: for i in range(2, self.complex_dim + 1): for c in combinations(clique, i + 1): @@ -53,4 +50,7 @@ def lift_topology(self, data: torch_geometric.data.Data) -> dict: for set_k_simplices in simplices: simplicial_complex.add_simplices_from(list(set_k_simplices)) - return self._get_lifted_topology(simplicial_complex, graph) + # because ComplexData pads unexisting dimensions with empty matrices + simplicial_complex.practical_dim = self.complex_dim + + return simplicial_complex diff --git a/topobenchmark/transforms/liftings/graph2simplicial/khop.py b/topobenchmark/transforms/liftings/graph2simplicial/khop.py index 50239f18..dc9e13e2 100755 --- a/topobenchmark/transforms/liftings/graph2simplicial/khop.py +++ b/topobenchmark/transforms/liftings/graph2simplicial/khop.py @@ -4,15 +4,14 @@ from itertools import combinations from typing import Any +import torch import torch_geometric from toponetx.classes import SimplicialComplex -from topobenchmark.transforms.liftings.graph2simplicial.base import ( - Graph2SimplicialLifting, -) +from topobenchmark.transforms.liftings.base import LiftingMap -class SimplicialKHopLifting(Graph2SimplicialLifting): +class SimplicialKHopLifting(LiftingMap): r"""Lift graphs to simplicial complex domain. The function lifts a graph to a simplicial complex by considering k-hop @@ -23,38 +22,43 @@ class SimplicialKHopLifting(Graph2SimplicialLifting): Parameters ---------- + complex_dim : int + Dimension of the desired complex. max_k_simplices : int, optional The maximum number of k-simplices to consider. Default is 5000. - **kwargs : optional - Additional arguments for the class. """ - def __init__(self, max_k_simplices=5000, **kwargs): - super().__init__(**kwargs) + def __init__(self, complex_dim=3, max_k_simplices=5000): + super().__init__() + self.complex_dim = complex_dim self.max_k_simplices = max_k_simplices def __repr__(self) -> str: return f"{self.__class__.__name__}(max_k_simplices={self.max_k_simplices!r})" - def lift_topology(self, data: torch_geometric.data.Data) -> dict: + def lift(self, domain): r"""Lift the topology to simplicial complex domain. Parameters ---------- - data : torch_geometric.data.Data - The input data to be lifted. + domain : nx.Graph + Graph to be lifted. Returns ------- - dict - The lifted topology. + toponetx.Complex + Lifted simplicial complex. """ - graph = self._generate_graph_from_data(data) + graph = domain + simplicial_complex = SimplicialComplex(graph) - edge_index = torch_geometric.utils.to_undirected(data.edge_index) + edge_index = torch_geometric.utils.to_undirected( + torch.tensor(list(zip(*graph.edges, strict=False))) + ) simplices: list[set[tuple[Any, ...]]] = [ set() for _ in range(2, self.complex_dim + 1) ] + for n in range(graph.number_of_nodes()): # Find 1-hop node n neighbors neighbors, _, _, _ = torch_geometric.utils.k_hop_subgraph( @@ -67,10 +71,12 @@ def lift_topology(self, data: torch_geometric.data.Data) -> dict: for i in range(1, self.complex_dim): for c in combinations(neighbors, i + 1): simplices[i - 1].add(tuple(c)) + for set_k_simplices in simplices: list_k_simplices = list(set_k_simplices) if len(set_k_simplices) > self.max_k_simplices: random.shuffle(list_k_simplices) list_k_simplices = list_k_simplices[: self.max_k_simplices] simplicial_complex.add_simplices_from(list_k_simplices) - return self._get_lifted_topology(simplicial_complex, graph) + + return simplicial_complex diff --git a/topobenchmark/transforms/liftings/liftings.py b/topobenchmark/transforms/liftings/liftings.py deleted file mode 100644 index 9453eaa3..00000000 --- a/topobenchmark/transforms/liftings/liftings.py +++ /dev/null @@ -1,172 +0,0 @@ -"""This module implements the abstract classes for lifting graphs.""" - -import networkx as nx -import torch_geometric -from torch_geometric.utils.undirected import is_undirected, to_undirected - -from topobenchmark.transforms.liftings import AbstractLifting - - -class GraphLifting(AbstractLifting): - r"""Abstract class for lifting graph topologies to other domains. - - Parameters - ---------- - feature_lifting : str, optional - The feature lifting method to be used. Default is 'ProjectionSum'. - preserve_edge_attr : bool, optional - Whether to preserve edge attributes. Default is False. - **kwargs : optional - Additional arguments for the class. - """ - - def __init__( - self, - feature_lifting="ProjectionSum", - preserve_edge_attr=False, - **kwargs, - ): - super().__init__(feature_lifting=feature_lifting, **kwargs) - self.preserve_edge_attr = preserve_edge_attr - - def _data_has_edge_attr(self, data: torch_geometric.data.Data) -> bool: - r"""Check if the input data object has edge attributes. - - Parameters - ---------- - data : torch_geometric.data.Data - The input data. - - Returns - ------- - bool - Whether the data object has edge attributes. - """ - return hasattr(data, "edge_attr") and data.edge_attr is not None - - def _generate_graph_from_data( - self, data: torch_geometric.data.Data - ) -> nx.Graph: - r"""Generate a NetworkX graph from the input data object. - - Parameters - ---------- - data : torch_geometric.data.Data - The input data. - - Returns - ------- - nx.Graph - The generated NetworkX graph. - """ - # Check if data object have edge_attr, return list of tuples as [(node_id, {'features':data}, 'dim':1)] or ?? - nodes = [ - (n, dict(features=data.x[n], dim=0)) - for n in range(data.x.shape[0]) - ] - - if self.preserve_edge_attr and self._data_has_edge_attr(data): - # In case edge features are given, assign features to every edge - edge_index, edge_attr = ( - data.edge_index, - ( - data.edge_attr - if is_undirected(data.edge_index, data.edge_attr) - else to_undirected(data.edge_index, data.edge_attr) - ), - ) - edges = [ - (i.item(), j.item(), dict(features=edge_attr[edge_idx], dim=1)) - for edge_idx, (i, j) in enumerate( - zip(edge_index[0], edge_index[1], strict=False) - ) - ] - self.contains_edge_attr = True - else: - # If edge_attr is not present, return list list of edges - edges = [ - (i.item(), j.item(), {}) - for i, j in zip( - data.edge_index[0], data.edge_index[1], strict=False - ) - ] - self.contains_edge_attr = False - graph = nx.Graph() - graph.add_nodes_from(nodes) - graph.add_edges_from(edges) - return graph - - -class PointCloudLifting(AbstractLifting): - r"""Abstract class for lifting point clouds to other topological domains. - - Parameters - ---------- - feature_lifting : str, optional - The feature lifting method to be used. Default is 'ProjectionSum'. - **kwargs : optional - Additional arguments for the class. - """ - - def __init__(self, feature_lifting="ProjectionSum", **kwargs): - super().__init__(feature_lifting=feature_lifting, **kwargs) - - -class CellComplexLifting(AbstractLifting): - r"""Abstract class for lifting cell complexes to other domains. - - Parameters - ---------- - feature_lifting : str, optional - The feature lifting method to be used. Default is 'ProjectionSum'. - **kwargs : optional - Additional arguments for the class. - """ - - def __init__(self, feature_lifting="ProjectionSum", **kwargs): - super().__init__(feature_lifting=feature_lifting, **kwargs) - - -class SimplicialLifting(AbstractLifting): - r"""Abstract class for lifting simplicial complexes to other domains. - - Parameters - ---------- - feature_lifting : str, optional - The feature lifting method to be used. Default is 'ProjectionSum'. - **kwargs : optional - Additional arguments for the class. - """ - - def __init__(self, feature_lifting="ProjectionSum", **kwargs): - super().__init__(feature_lifting=feature_lifting, **kwargs) - - -class HypergraphLifting(AbstractLifting): - r"""Abstract class for lifting hypergraphs to other domains. - - Parameters - ---------- - feature_lifting : str, optional - The feature lifting method to be used. Default is 'ProjectionSum'. - **kwargs : optional - Additional arguments for the class. - """ - - def __init__(self, feature_lifting="ProjectionSum", **kwargs): - super().__init__(feature_lifting=feature_lifting, **kwargs) - - -class CombinatorialLifting(AbstractLifting): - r"""Abstract class for lifting combinatorial complexes to other domains. - - Parameters - ---------- - feature_lifting : str, optional - The feature lifting method to be used. Default is 'ProjectionSum'. - **kwargs : optional - Additional arguments for the class. - """ - - def __init__(self, feature_lifting="ProjectionSum", **kwargs): - super().__init__(feature_lifting=feature_lifting, **kwargs) diff --git a/tutorials/tutorial_lifting.ipynb b/tutorials/tutorial_lifting.ipynb index d1a77003..af533a1b 100644 --- a/tutorials/tutorial_lifting.ipynb +++ b/tutorials/tutorial_lifting.ipynb @@ -56,8 +56,6 @@ "\n", "import lightning as pl\n", "import networkx as nx\n", - "import hydra\n", - "import torch_geometric\n", "from omegaconf import OmegaConf\n", "from topomodelx.nn.simplicial.scn2 import SCN2\n", "from toponetx.classes import SimplicialComplex\n", @@ -72,8 +70,8 @@ "from topobenchmark.nn.readouts import PropagateSignalDown\n", "from topobenchmark.nn.wrappers.simplicial import SCNWrapper\n", "from topobenchmark.optimizer import TBOptimizer\n", - "from topobenchmark.transforms.liftings.graph2simplicial import (\n", - " Graph2SimplicialLifting,\n", + "from topobenchmark.transforms.liftings import (\n", + " LiftingMap,\n", ")" ] }, @@ -101,14 +99,17 @@ " \"data_domain\": \"graph\",\n", " \"data_type\": \"TUDataset\",\n", " \"data_name\": \"MUTAG\",\n", - " \"data_dir\": \"./data/MUTAG/\"}\n", + " \"data_dir\": \"./data/MUTAG/\",\n", + "}\n", "\n", "\n", - "transform_config = { \"clique_lifting\":\n", - " {\"_target_\": \"__main__.SimplicialCliquesLEQLifting\",\n", - " \"transform_name\": \"SimplicialCliquesLEQLifting\",\n", - " \"transform_type\": \"lifting\",\n", - " \"complex_dim\": 3,}\n", + "transform_config = {\n", + " \"clique_lifting\": {\n", + " \"_target_\": \"topobenchmark.transforms.data_transform.DataTransform\",\n", + " \"transform_name\": \"SimplicialCliquesLEQLifting\",\n", + " \"transform_type\": \"lifting\",\n", + " \"complex_dim\": 3,\n", + " }\n", "}\n", "\n", "split_config = {\n", @@ -138,21 +139,19 @@ "}\n", "\n", "loss_config = {\n", - " \"dataset_loss\": \n", - " {\n", - " \"task\": \"classification\", \n", - " \"loss_type\": \"cross_entropy\"\n", - " }\n", + " \"dataset_loss\": {\"task\": \"classification\", \"loss_type\": \"cross_entropy\"}\n", "}\n", "\n", - "evaluator_config = {\"task\": \"classification\",\n", - " \"num_classes\": out_channels,\n", - " \"metrics\": [\"accuracy\", \"precision\", \"recall\"]}\n", + "evaluator_config = {\n", + " \"task\": \"classification\",\n", + " \"num_classes\": out_channels,\n", + " \"metrics\": [\"accuracy\", \"precision\", \"recall\"],\n", + "}\n", "\n", - "optimizer_config = {\"optimizer_id\": \"Adam\",\n", - " \"parameters\":\n", - " {\"lr\": 0.001,\"weight_decay\": 0.0005}\n", - " }\n", + "optimizer_config = {\n", + " \"optimizer_id\": \"Adam\",\n", + " \"parameters\": {\"lr\": 0.001, \"weight_decay\": 0.0005},\n", + "}\n", "\n", "\n", "loader_config = OmegaConf.create(loader_config)\n", @@ -174,6 +173,7 @@ "def wrapper(**factory_kwargs):\n", " def factory(backbone):\n", " return SCNWrapper(backbone, **factory_kwargs)\n", + "\n", " return factory" ] }, @@ -197,16 +197,15 @@ "metadata": {}, "outputs": [], "source": [ - "class SimplicialCliquesLEQLifting(Graph2SimplicialLifting):\n", + "class SimplicialCliquesLEQLifting(LiftingMap):\n", " r\"\"\"Lifts graphs to simplicial complex domain by identifying the cliques as k-simplices. Only the cliques with size smaller or equal to the max complex dimension are considered.\n", - " \n", - " Args:\n", - " kwargs (optional): Additional arguments for the class.\n", " \"\"\"\n", - " def __init__(self, **kwargs):\n", - " super().__init__(**kwargs)\n", + " def __init__(self, complex_dim=2):\n", + " super().__init__()\n", + " self.complex_dim = complex_dim\n", + "\n", "\n", - " def lift_topology(self, data: torch_geometric.data.Data) -> dict:\n", + " def lift(self, domain) -> dict:\n", " r\"\"\"Lifts the topology of a graph to a simplicial complex by identifying the cliques as k-simplices. Only the cliques with size smaller or equal to the max complex dimension are considered.\n", "\n", " Args:\n", @@ -214,11 +213,14 @@ " Returns:\n", " dict: The lifted topology.\n", " \"\"\"\n", - " graph = self._generate_graph_from_data(data)\n", + " graph = domain\n", + "\n", " simplicial_complex = SimplicialComplex(graph)\n", " cliques = nx.find_cliques(graph)\n", - " \n", - " simplices: list[set[tuple[Any, ...]]] = [set() for _ in range(2, self.complex_dim + 1)]\n", + "\n", + " simplices: list[set[tuple[Any, ...]]] = [\n", + " set() for _ in range(2, self.complex_dim + 1)\n", + " ]\n", " for clique in cliques:\n", " if len(clique) <= self.complex_dim + 1:\n", " for i in range(2, self.complex_dim + 1):\n", @@ -227,8 +229,11 @@ "\n", " for set_k_simplices in simplices:\n", " simplicial_complex.add_simplices_from(list(set_k_simplices))\n", + " \n", + " # because ComplexData pads unexisting dimensions with empty matrices\n", + " simplicial_complex.practical_dim = self.complex_dim\n", "\n", - " return self._get_lifted_topology(simplicial_complex, graph)\n" + " return simplicial_complex" ] }, { @@ -251,9 +256,9 @@ "metadata": {}, "outputs": [], "source": [ - "from topobenchmark.transforms import TRANSFORMS\n", + "from topobenchmark.transforms import add_lifting_map\n", "\n", - "TRANSFORMS[\"SimplicialCliquesLEQLifting\"] = SimplicialCliquesLEQLifting" + "add_lifting_map(SimplicialCliquesLEQLifting, \"graph2simplicial\")" ] }, { @@ -275,8 +280,12 @@ "dataset, dataset_dir = graph_loader.load()\n", "\n", "preprocessor = PreProcessor(dataset, dataset_dir, transform_config)\n", - "dataset_train, dataset_val, dataset_test = preprocessor.load_dataset_splits(split_config)\n", - "datamodule = TBDataloader(dataset_train, dataset_val, dataset_test, batch_size=32)" + "dataset_train, dataset_val, dataset_test = preprocessor.load_dataset_splits(\n", + " split_config\n", + ")\n", + "datamodule = TBDataloader(\n", + " dataset_train, dataset_val, dataset_test, batch_size=32\n", + ")" ] }, { @@ -299,12 +308,19 @@ "metadata": {}, "outputs": [], "source": [ - "backbone = SCN2(in_channels_0=dim_hidden,in_channels_1=dim_hidden,in_channels_2=dim_hidden)\n", + "backbone = SCN2(\n", + " in_channels_0=dim_hidden,\n", + " in_channels_1=dim_hidden,\n", + " in_channels_2=dim_hidden,\n", + ")\n", "backbone_wrapper = wrapper(**wrapper_config)\n", "\n", "readout = PropagateSignalDown(**readout_config)\n", "loss = TBLoss(**loss_config)\n", - "feature_encoder = AllCellFeatureEncoder(in_channels=[in_channels, in_channels, in_channels], out_channels=dim_hidden)\n", + "feature_encoder = AllCellFeatureEncoder(\n", + " in_channels=[in_channels, in_channels, in_channels],\n", + " out_channels=dim_hidden,\n", + ")\n", "\n", "evaluator = TBEvaluator(**evaluator_config)\n", "optimizer = TBOptimizer(**optimizer_config)" @@ -316,14 +332,16 @@ "metadata": {}, "outputs": [], "source": [ - "model = TBModel(backbone=backbone,\n", - " backbone_wrapper=backbone_wrapper,\n", - " readout=readout,\n", - " loss=loss,\n", - " feature_encoder=feature_encoder,\n", - " evaluator=evaluator,\n", - " optimizer=optimizer,\n", - " compile=False,)" + "model = TBModel(\n", + " backbone=backbone,\n", + " backbone_wrapper=backbone_wrapper,\n", + " readout=readout,\n", + " loss=loss,\n", + " feature_encoder=feature_encoder,\n", + " evaluator=evaluator,\n", + " optimizer=optimizer,\n", + " compile=False,\n", + ")" ] }, { @@ -386,7 +404,9 @@ ], "source": [ "# Increase the number of epochs to get better results\n", - "trainer = pl.Trainer(max_epochs=50, accelerator=\"cpu\", enable_progress_bar=False)\n", + "trainer = pl.Trainer(\n", + " max_epochs=50, accelerator=\"cpu\", enable_progress_bar=False\n", + ")\n", "\n", "trainer.fit(model, datamodule)\n", "train_metrics = trainer.callback_metrics" @@ -415,9 +435,9 @@ } ], "source": [ - "print(' Training metrics\\n', '-'*26)\n", + "print(\" Training metrics\\n\", \"-\" * 26)\n", "for key in train_metrics:\n", - " print('{:<21s} {:>5.4f}'.format(key+':', train_metrics[key].item()))" + " print(\"{:<21s} {:>5.4f}\".format(key + \":\", train_metrics[key].item()))" ] }, { @@ -505,9 +525,9 @@ } ], "source": [ - "print(' Testing metrics\\n', '-'*25)\n", + "print(\" Testing metrics\\n\", \"-\" * 25)\n", "for key in test_metrics:\n", - " print('{:<20s} {:>5.4f}'.format(key+':', test_metrics[key].item()))" + " print(\"{:<20s} {:>5.4f}\".format(key + \":\", test_metrics[key].item()))" ] }, {