Skip to content

Commit

Permalink
add unisage
Browse files Browse the repository at this point in the history
  • Loading branch information
mhajij committed Sep 7, 2023
1 parent 4e61117 commit 1055295
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 0 deletions.
25 changes: 25 additions & 0 deletions test/nn/hypergraph/test_unisage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
"""Test the UniGIN class."""

import numpy as np
import torch

from topomodelx.nn.hypergraph.unisage import UniSAGE


class TestUniGIN:
"""Test the UniGIN."""

def test_fowared(self):
"""Test forward method."""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

incidence = torch.from_numpy(np.random.rand(2, 2)).to_sparse_csr()
incidence = incidence.float().to(device)
model = UniSAGE(channels_edge=2, channels_node=2, n_layers=2).to(device)
x_0 = torch.rand(2, 2)

x_0 = torch.tensor(x_0).float().to(device)

y1 = model(x_0, incidence)

assert y1.shape == torch.Size([2])
57 changes: 57 additions & 0 deletions topomodelx/nn/hypergraph/unisage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
"""UniSAGE class."""

import torch

from topomodelx.nn.hypergraph.unisage_layer import UniSAGELayer


class UniSAGE(torch.nn.Module):
"""Neural network implementation of UniSAGE for hypergraph classification.
Parameters
----------
channels_edge : int
Dimension of edge features
channels_node : int
Dimension of node features
n_layer : 2
Amount of message passing layers.
"""

def __init__(self, channels_edge, channels_node, n_layers=2):
super().__init__()
layers = []
for _ in range(n_layers):
layers.append(
UniSAGELayer(
in_channels=channels_edge,
out_channels=channels_edge,
)
)
self.layers = torch.nn.ModuleList(layers)
self.linear = torch.nn.Linear(channels_edge, 1)

def forward(self, x_1, incidence_1):
"""Forward computation through layers, then linear layer, then global max pooling.
Parameters
----------
x_1 : tensor
shape = [n_edges, channels_edge]
Edge features.
incidence_1 : tensor
shape = [n_nodes, n_edges]
Boundary matrix of rank 1.
Returns
-------
_ : tensor
shape = [1]
Label assigned to whole complex.
"""
for layer in self.layers:
x_1 = layer(x_1, incidence_1)
pooled_x = torch.max(x_1, dim=0)[0]
return torch.sigmoid(self.linear(pooled_x))

0 comments on commit 1055295

Please sign in to comment.