From 10552958bde86cf6251881244ca42764fab2effa Mon Sep 17 00:00:00 2001 From: Mustafa Hajij Date: Wed, 6 Sep 2023 22:30:04 -0700 Subject: [PATCH] add unisage --- test/nn/hypergraph/test_unisage.py | 25 +++++++++++++ topomodelx/nn/hypergraph/unisage.py | 57 +++++++++++++++++++++++++++++ 2 files changed, 82 insertions(+) create mode 100644 test/nn/hypergraph/test_unisage.py create mode 100644 topomodelx/nn/hypergraph/unisage.py diff --git a/test/nn/hypergraph/test_unisage.py b/test/nn/hypergraph/test_unisage.py new file mode 100644 index 00000000..021b2d33 --- /dev/null +++ b/test/nn/hypergraph/test_unisage.py @@ -0,0 +1,25 @@ +"""Test the UniGIN class.""" + +import numpy as np +import torch + +from topomodelx.nn.hypergraph.unisage import UniSAGE + + +class TestUniGIN: + """Test the UniGIN.""" + + def test_fowared(self): + """Test forward method.""" + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + incidence = torch.from_numpy(np.random.rand(2, 2)).to_sparse_csr() + incidence = incidence.float().to(device) + model = UniSAGE(channels_edge=2, channels_node=2, n_layers=2).to(device) + x_0 = torch.rand(2, 2) + + x_0 = torch.tensor(x_0).float().to(device) + + y1 = model(x_0, incidence) + + assert y1.shape == torch.Size([2]) diff --git a/topomodelx/nn/hypergraph/unisage.py b/topomodelx/nn/hypergraph/unisage.py new file mode 100644 index 00000000..9380d575 --- /dev/null +++ b/topomodelx/nn/hypergraph/unisage.py @@ -0,0 +1,57 @@ +"""UniSAGE class.""" + +import torch + +from topomodelx.nn.hypergraph.unisage_layer import UniSAGELayer + + +class UniSAGE(torch.nn.Module): + """Neural network implementation of UniSAGE for hypergraph classification. + + Parameters + ---------- + channels_edge : int + Dimension of edge features + channels_node : int + Dimension of node features + n_layer : 2 + Amount of message passing layers. + + """ + + def __init__(self, channels_edge, channels_node, n_layers=2): + super().__init__() + layers = [] + for _ in range(n_layers): + layers.append( + UniSAGELayer( + in_channels=channels_edge, + out_channels=channels_edge, + ) + ) + self.layers = torch.nn.ModuleList(layers) + self.linear = torch.nn.Linear(channels_edge, 1) + + def forward(self, x_1, incidence_1): + """Forward computation through layers, then linear layer, then global max pooling. + + Parameters + ---------- + x_1 : tensor + shape = [n_edges, channels_edge] + Edge features. + + incidence_1 : tensor + shape = [n_nodes, n_edges] + Boundary matrix of rank 1. + + Returns + ------- + _ : tensor + shape = [1] + Label assigned to whole complex. + """ + for layer in self.layers: + x_1 = layer(x_1, incidence_1) + pooled_x = torch.max(x_1, dim=0)[0] + return torch.sigmoid(self.linear(pooled_x))