Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check hypergraphs #256

Merged
merged 26 commits into from
Mar 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
c2e4132
changed torch.mm for topomodelx.base.Conv for consistency
Coerulatus Nov 8, 2023
97c62e4
changes splits 5 to 3
levtelyatnikov Nov 9, 2023
fefc88b
Revert "changed torch.mm for topomodelx.base.Conv for consistency"
Coerulatus Nov 10, 2023
04953be
changed torch.mm to topomodelx.base.Conv for consistency
Coerulatus Nov 10, 2023
62b7c5a
changed torch.mm for conv and introduced linear transformation in uni…
Coerulatus Nov 10, 2023
1b7b01d
Merge branch 'hypergraph_checks' of https://github.com/pyt-team/TopoM…
Coerulatus Nov 10, 2023
80ed75f
Merge remote-tracking branch 'origin/main' into add_conv
Coerulatus Nov 10, 2023
3d5c6b7
added batch norm to be consistent with implementation
Coerulatus Nov 10, 2023
76f0bdb
moved batch norm to the end of the layer, unisage now passes more par…
Coerulatus Nov 10, 2023
937ca45
removed linear transformation in vertex2edge to be consistent with fo…
Coerulatus Nov 10, 2023
02c731d
added the ability to pass different graphs instead of only one fixed …
Coerulatus Nov 10, 2023
c1d54f9
added alpha to choose max number of nodes to sample in neighborhood
Coerulatus Nov 13, 2023
a89d1f1
changed batch normalization to row normalization
Coerulatus Nov 16, 2023
31d297c
added dropout
Coerulatus Nov 16, 2023
712cd85
added dropout
Coerulatus Nov 16, 2023
c941ee5
Merge branch 'main' of https://github.com/pyt-team/TopoModelX into ad…
Coerulatus Nov 16, 2023
73df4e3
removed batchnorm test since batchnorm was removed from model
Coerulatus Nov 16, 2023
9426857
added default value for use_norm
Coerulatus Nov 16, 2023
02e16f7
fixed reset parameters
Coerulatus Nov 16, 2023
2e8c285
Merge branch 'main' of https://github.com/pyt-team/TopoModelX into ad…
Coerulatus Nov 27, 2023
b0c36f5
formatting changes
Coerulatus Nov 27, 2023
c0c0e5d
fixed bug when not passing incidence_1 matrix
Coerulatus Dec 1, 2023
cfbde34
updated tests to improve coverage
Coerulatus Dec 1, 2023
985934a
fix format
Coerulatus Dec 13, 2023
5b8a423
Merge branch 'main' of https://github.com/pyt-team/TopoModelX into ad…
Coerulatus Dec 13, 2023
c95265f
Merge branch 'main' into add_conv
Coerulatus Mar 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion test/nn/hypergraph/test_hnhn_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,20 @@ def template_layer(self):
incidence_1=incidence_1,
)

def test_forward(self, template_layer):
@pytest.fixture
def template_layer2(self):
"""Initialize and return an HNHN layer."""
self.in_channels = 5
self.hidden_channels = 8

return HNHNLayer(
in_channels=self.in_channels,
hidden_channels=self.hidden_channels,
incidence_1=None,
bias_init="xavier_normal",
)

def test_forward(self, template_layer, template_layer2):
"""Test the forward pass of the HNHN layer."""
n_nodes, n_edges = template_layer.incidence_1.shape

Expand All @@ -33,6 +46,15 @@ def test_forward(self, template_layer):
assert x_0_out.shape == (n_nodes, self.hidden_channels)
assert x_1_out.shape == (n_edges, self.hidden_channels)

n_nodes = 10
n_edges = 20
incidence_1 = torch.randint(0, 2, (n_nodes, n_edges)).float()

x_0_out, x_1_out = template_layer2.forward(x_0, incidence_1)

assert x_0_out.shape == (n_nodes, self.hidden_channels)
assert x_1_out.shape == (n_edges, self.hidden_channels)

return

def test_compute_normalization_matrices(self, template_layer):
Expand Down
19 changes: 18 additions & 1 deletion test/nn/hypergraph/test_hypersage_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,24 @@ def hypersage_layer(self):
out_channels = 30
return HyperSAGELayer(in_channels, out_channels)

def test_forward(self, hypersage_layer):
@pytest.fixture
def hypersage_layer_alpha(self):
"""Return a HyperSAGE layer."""
in_channels = 10
out_channels = 30
return HyperSAGELayer(in_channels, out_channels, alpha=1)

def test_forward(self, hypersage_layer, hypersage_layer_alpha):
"""Test the forward pass of the HyperSAGE layer."""
x_2 = torch.randn(3, 10)
incidence_2 = torch.tensor(
[[1, 0], [0, 1], [1, 1]], dtype=torch.float32
).to_sparse()
output = hypersage_layer.forward(x_2, incidence_2)
output2 = hypersage_layer_alpha.forward(x_2, incidence_2)

assert output.shape == (3, 30)
assert output2.shape == (3, 30)

def test_forward_with_invalid_input(self, hypersage_layer):
"""Test the forward pass of the HyperSAGE layer with invalid input."""
Expand Down Expand Up @@ -65,6 +75,13 @@ def test_update_sigmoid(self, hypersage_layer):
assert torch.is_tensor(updated)
assert updated.shape == (10, 20)

def test_update_invalid(self, hypersage_layer):
"""Test the update function with update_func = "invalid"."""
hypersage_layer.update_func = "invalid"
inputs = torch.randn(10, 20)
with pytest.raises(RuntimeError):
hypersage_layer.update(inputs)

def test_aggregation_invald(self, hypersage_layer):
"""Test the aggregation function with invalid mode."""
x_messages = torch.zeros(3, 10)
Expand Down
35 changes: 32 additions & 3 deletions test/nn/hypergraph/test_unigcnii_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,36 @@ def unigcnii_layer(self):
in_channels=in_channels, hidden_channels=in_channels, alpha=alpha, beta=beta
)

def test_forward(self, unigcnii_layer):
@pytest.fixture
def unigcnii_layer2(self):
"""Return a uniGCNII layer."""
in_channels = 10
alpha = 0.1
beta = 0.1
return UniGCNIILayer(
in_channels=in_channels,
hidden_channels=in_channels,
alpha=alpha,
beta=beta,
use_norm=True,
)

def test_forward(self, unigcnii_layer, unigcnii_layer2):
"""Test the forward pass."""
n_nodes, in_channels = 3, 10
x_0 = torch.randn(n_nodes, in_channels)
incidence_1 = torch.tensor([[1, 0], [1, 1], [0, 1]], dtype=torch.float32)
x_0, _ = unigcnii_layer.forward(x_0, incidence_1)
incidence_1 = torch.tensor(
[[1, 1, 0], [1, 1, 1], [0, 1, 1]], dtype=torch.float32
)
x_0, x_1 = unigcnii_layer.forward(x_0, incidence_1)

assert x_0.shape == torch.Size([n_nodes, in_channels])
assert x_1.shape == torch.Size([3, in_channels])

x_0, x_1 = unigcnii_layer2.forward(x_0, incidence_1)

assert x_0.shape == torch.Size([n_nodes, in_channels])
assert x_1.shape == torch.Size([3, in_channels])

def test_forward_with_skip(self):
"""Test the forward pass where alpha=1 and beta=0.
Expand All @@ -45,3 +67,10 @@ def test_forward_with_skip(self):
x_0, _ = layer(x_0, incidence_1, x_skip)

torch.testing.assert_close(x_0, x_skip, rtol=1e-4, atol=1e-4)

def test_reset_params(self, unigcnii_layer):
"""Test reset parameters."""
unigcnii_layer.linear.weight.requires_grad = False
unigcnii_layer.linear.weight.fill_(0)
unigcnii_layer.reset_parameters()
assert torch.max(unigcnii_layer.linear.weight) > 0
17 changes: 14 additions & 3 deletions test/nn/hypergraph/test_unigin_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,31 @@ class TestUniGINLayer:
"""Test the UniGIN layer."""

@pytest.fixture
def UniGIN_layer(self):
def unigin_layer(self):
"""Return a UniGIN layer."""
self.in_channels = 10
return UniGINLayer(in_channels=self.in_channels)

def test_forward(self, UniGIN_layer):
@pytest.fixture
def unigin_layer2(self):
"""Return a UniGIN layer."""
self.in_channels = 10
return UniGINLayer(in_channels=self.in_channels, use_norm=True)

def test_forward(self, unigin_layer, unigin_layer2):
"""Test the forward pass of the UniGIN layer."""
n_nodes, n_edges = 2, 3
incidence = torch.from_numpy(
np.random.default_rng().random((n_nodes, n_edges))
).to_sparse()
incidence = incidence.float()
x_0 = torch.rand(n_nodes, self.in_channels).float()
x_0, x_1 = UniGIN_layer.forward(x_0, incidence)
x_0, x_1 = unigin_layer.forward(x_0, incidence)

assert x_0.shape == torch.Size([n_nodes, self.in_channels])
assert x_1.shape == torch.Size([n_edges, self.in_channels])

x_0, x_1 = unigin_layer2.forward(x_0, incidence)

assert x_0.shape == torch.Size([n_nodes, self.in_channels])
assert x_1.shape == torch.Size([n_edges, self.in_channels])
37 changes: 19 additions & 18 deletions test/nn/hypergraph/test_unisage_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,27 @@ class TestUniSAGELayer:
"""Tests for UniSAGE Layer."""

@pytest.fixture
def uniSAGE_layer(self):
def unisage_layer(self):
"""Fixture for uniSAGE layer."""
in_channels = 10
out_channels = 30
return UniSAGELayer(in_channels, out_channels)

def test_forward(self, uniSAGE_layer):
@pytest.fixture
def unisage_layer2(self):
"""Fixture for uniSAGE layer."""
in_channels = 10
out_channels = 30
return UniSAGELayer(in_channels, out_channels, use_norm=True)

def test_forward(self, unisage_layer, unisage_layer2):
"""Test forward pass."""
x = torch.randn(3, 10)
incidence = torch.tensor([[1, 1, 0], [1, 1, 1], [0, 1, 1]], dtype=torch.float32)
x_0, x_1 = uniSAGE_layer.forward(x, incidence)
x_0, x_1 = unisage_layer.forward(x, incidence)
assert x_0.shape == torch.Size([3, 30])
assert x_1.shape == torch.Size([3, 30])
x_0, x_1 = unisage_layer2.forward(x, incidence)
assert x_0.shape == torch.Size([3, 30])
assert x_1.shape == torch.Size([3, 30])

Expand All @@ -33,7 +43,7 @@ def test_sum_aggregator(self):
assert x_0.shape == torch.Size([3, 30])
assert x_1.shape == torch.Size([3, 30])

def test_aggregator_validation(self, uniSAGE_layer):
def test_aggregator_validation(self, unisage_layer):
"""Test aggregator validation."""
with pytest.raises(Exception) as exc_info:
_ = UniSAGELayer(10, 30, e_aggr="invalid_aggregator")
Expand All @@ -42,18 +52,9 @@ def test_aggregator_validation(self, uniSAGE_layer):
== "Unsupported aggregator: invalid_aggregator, should be 'sum', 'mean',"
)

def test_reset_params(self, uniSAGE_layer):
def test_reset_params(self, unisage_layer):
"""Test reset parameters."""
uniSAGE_layer.linear.weight.requires_grad = False
uniSAGE_layer.linear.weight.fill_(0)
uniSAGE_layer.reset_parameters()
assert torch.max(uniSAGE_layer.linear.weight) > 0

def test_batchnorm(self, uniSAGE_layer):
"""Test batchnorm."""
x = torch.randn(3, 10)
incidence = torch.tensor([[1, 1, 0], [1, 1, 1], [0, 1, 1]], dtype=torch.float32)
layer = UniSAGELayer(10, 30, e_aggr="sum", use_bn=True)
layer(x, incidence)
assert layer.bn is not None
assert layer.bn.num_batches_tracked.item() == 1
unisage_layer.linear.weight.requires_grad = False
unisage_layer.linear.weight.fill_(0)
unisage_layer.reset_parameters()
assert torch.max(unisage_layer.linear.weight) > 0
2 changes: 1 addition & 1 deletion topomodelx/nn/hypergraph/allset_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class AllSetTransformer(torch.nn.Module):
Dropout probability.
mlp_num_layers : int, default: 2
Number of layers in the MLP.
mlp_dropout: float, default: 0.2
mlp_dropout : float, default: 0.2
Dropout probability in the MLP.
References
Expand Down
4 changes: 2 additions & 2 deletions topomodelx/nn/hypergraph/hmpnn_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,10 +122,10 @@ class HMPNNLayer(nn.Module):
----------
in_channels : int
Dimension of input features.
node_to_hyperedge_messaging_func: None
node_to_hyperedge_messaging_func : None
Node messaging function as a callable or nn.Module object. If not given, a linear plus sigmoid
function is used, according to the paper.
hyperedge_to_node_messaging_func: None
hyperedge_to_node_messaging_func : None
Hyperedge messaging function as a callable or nn.Module object. It gets hyperedge input features
and aggregated messages of nodes as input and returns hyperedge messages. If not given, two inputs
are concatenated and a linear layer reducing back to in_channels plus sigmoid is applied, according
Expand Down
12 changes: 9 additions & 3 deletions topomodelx/nn/hypergraph/hnhn.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ class HNHN(torch.nn.Module):
Incidence matrix mapping edges to nodes (B_1).
n_layers : int, default = 2
Number of HNHN message passing layers.
layer_drop : float, default = 0.2
Dropout rate for the hidden features.
References
----------
Expand All @@ -27,7 +29,9 @@ class HNHN(torch.nn.Module):
https://grlplus.github.io/papers/40.pdf
"""

def __init__(self, in_channels, hidden_channels, incidence_1, n_layers=2):
def __init__(
self, in_channels, hidden_channels, incidence_1, n_layers=2, layer_drop=0.2
):
super().__init__()

self.layers = torch.nn.ModuleList(
Expand All @@ -38,8 +42,9 @@ def __init__(self, in_channels, hidden_channels, incidence_1, n_layers=2):
)
for i in range(n_layers)
)
self.layer_drop = torch.nn.Dropout(layer_drop)

def forward(self, x_0):
def forward(self, x_0, incidence_1=None):
"""Forward computation.
Parameters
Expand All @@ -58,6 +63,7 @@ def forward(self, x_0):
Output hyperedge features.
"""
for layer in self.layers:
x_0, x_1 = layer(x_0)
x_0, x_1 = layer(x_0, incidence_1)
x_0 = self.layer_drop(x_0)

return x_0, x_1
25 changes: 17 additions & 8 deletions topomodelx/nn/hypergraph/hnhn_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def __init__(
self,
in_channels,
hidden_channels,
incidence_1,
incidence_1=None,
use_bias: bool = True,
use_normalized_incidence: bool = True,
alpha: float = -1.5,
Expand All @@ -76,7 +76,8 @@ def __init__(
self.bias_gain = bias_gain
self.use_normalized_incidence = use_normalized_incidence
self.incidence_1 = incidence_1
self.incidence_1_transpose = incidence_1.transpose(1, 0)
if incidence_1 is not None:
self.incidence_1_transpose = incidence_1.transpose(1, 0)

self.conv_0_to_1 = Conv(
in_channels=in_channels,
Expand All @@ -98,9 +99,10 @@ def __init__(
if self.use_normalized_incidence:
self.alpha = alpha
self.beta = beta
self.n_nodes, self.n_edges = self.incidence_1.shape
self.compute_normalization_matrices()
self.normalize_incidence_matrices()
if incidence_1 is not None:
self.n_nodes, self.n_edges = self.incidence_1.shape
self.compute_normalization_matrices()
self.normalize_incidence_matrices()

def compute_normalization_matrices(self) -> None:
"""Compute the normalization matrices for the incidence matrices."""
Expand Down Expand Up @@ -158,7 +160,7 @@ def reset_parameters(self) -> None:
if self.use_bias:
self.init_biases()

def forward(self, x_0):
def forward(self, x_0, incidence_1=None):
r"""Forward computation.
The forward pass was initially proposed in [1]_.
Expand All @@ -182,8 +184,8 @@ def forward(self, x_0):
----------
x_0 : torch.Tensor, shape = (n_nodes, channels_node)
Input features on the hypernodes.
x_1 : torch.Tensor, shape = (n_edges, channels_edge)
Input features on the hyperedges.
incidence_1: torch.Tensor, shape = (n_nodes, n_edges)
Incidence matrix mapping edges to nodes (B_1).
Returns
-------
Expand All @@ -192,6 +194,13 @@ def forward(self, x_0):
x_1 : torch.Tensor, shape = (n_edges, channels_edge)
Output features on the hyperedges.
"""
if incidence_1 is not None:
self.incidence_1 = incidence_1
self.incidence_1_transpose = incidence_1.transpose(1, 0)
if self.use_normalized_incidence:
self.n_nodes, self.n_edges = incidence_1.shape
self.compute_normalization_matrices()
self.normalize_incidence_matrices()
# Move incidence matrices to device
self.incidence_1 = self.incidence_1.to(x_0.device)
self.incidence_1_transpose = self.incidence_1_transpose.to(x_0.device)
Expand Down
5 changes: 5 additions & 0 deletions topomodelx/nn/hypergraph/hypergat.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ class HyperGAT(torch.nn.Module):
Dimension of the hidden features.
n_layers : int, default = 2
Amount of message passing layers.
layer_drop: float, default = 0.2
Dropout rate for the hidden features.
References
----------
Expand All @@ -29,6 +31,7 @@ def __init__(
in_channels,
hidden_channels,
n_layers=2,
layer_drop=0.2,
):
super().__init__()

Expand All @@ -39,6 +42,7 @@ def __init__(
)
for i in range(n_layers)
)
self.layer_drop = torch.nn.Dropout(layer_drop)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍


def forward(self, x_0, incidence_1):
"""Forward computation through layers, then linear layer, then global max pooling.
Expand All @@ -59,5 +63,6 @@ def forward(self, x_0, incidence_1):
"""
for layer in self.layers:
x_0, x_1 = layer.forward(x_0, incidence_1)
x_0 = self.layer_drop(x_0)

return x_0, x_1
Loading
Loading