Skip to content

Commit

Permalink
Merge pull request #128 from spindro/allsettransformer
Browse files Browse the repository at this point in the history
Implementation of AllSetTransformer
  • Loading branch information
ninamiolane authored Aug 24, 2023
2 parents 922f771 + 70f4c2b commit 664dcb9
Show file tree
Hide file tree
Showing 6 changed files with 1,203 additions and 23 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ coverage.xml
**/__pycache__/**
__pycache__/
docs/build/

venv_topo/
TopoNetX/
topomodelx/nn/cell/attcxn_layer.py
topomodelx/base/debug.ipynb
Expand Down
10 changes: 5 additions & 5 deletions docs/contributing/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -68,27 +68,27 @@ Test functions should be located in files whose filenames start with `test_`. Fo
def test_capital_case():
assert add(4, 5) == 9
Use an `assert` statement to check that the function under test returns the correct output.
Use an `assert` statement to check that the function under test returns the correct output.

Run Tests
~~~~~~~~~

Install `pytest` which is the software tools used to run tests:

.. code-block:: bash
$ pip install -e .[dev]
Then run the test using:

.. code-block:: bash
$ pytest test_add.py
Verify that the code you have added does not break `TopoModelX` by running all the tests.

.. code-block:: bash
$ pytest test/
Write Documentation
Expand All @@ -97,7 +97,7 @@ Write Documentation
Building the documentation requires installing specific requirements.

.. code-block:: bash
$ pip install -e .[doc]
Intro to Docstrings
Expand Down
147 changes: 147 additions & 0 deletions test/nn/hypergraph/test_allsettransformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
"""Test the AllSetTransformer layer."""
import pytest
import torch

from topomodelx.nn.hypergraph.allsettransformer_layer import MLP, AllSetTransformerLayer


class TestAllSetTransformerLayer:
"""Test the AllSetTransformer layer."""

@pytest.fixture
def allsettransformer_layer(self):
"""Return a allsettransformer layer."""
in_dim = 10
hid_dim = 64
heads = 4
layer = AllSetTransformerLayer(
in_channels=in_dim,
hidden_channels=hid_dim,
heads=heads,
number_queries=1,
dropout=0.0,
mlp_num_layers=1,
mlp_activation=None,
mlp_dropout=0.0,
mlp_norm=None,
)
return layer

def test_forward(self, allsettransformer_layer):
"""Test the forward pass of the allsettransformer layer."""
x_0 = torch.randn(3, 10)
incidence_1 = torch.tensor(
[[1, 0, 0], [0, 1, 1], [1, 1, 1]], dtype=torch.float32
).to_sparse()
output = allsettransformer_layer.forward(x_0, incidence_1)
assert output.shape == (3, 64)

def test_forward_with_invalid_input(self, allsettransformer_layer):
"""Test the forward pass of the allsettransformer layer with invalid input."""
x_0 = torch.randn(4, 10)
incidence_1 = torch.tensor(
[[1, 0, 0], [0, 1, 1], [1, 1, 1]], dtype=torch.float32
).to_sparse()
with pytest.raises(ValueError):
allsettransformer_layer.forward(x_0, incidence_1)

def test_reset_parameters(self, allsettransformer_layer):
"""Test the reset_parameters method."""
in_dim = 10
hid_dim = 64
heads = 4

allsettransformer_layer.reset_parameters()
assert allsettransformer_layer.vertex2edge.mlp[0].weight.requires_grad
assert allsettransformer_layer.edge2vertex.mlp[0].weight.requires_grad

# Test with attention weights & xavier_uniform
allsettransformer_layer.vertex2edge.multihead_att.initialization = (
"xavier_uniform"
)
allsettransformer_layer.edge2vertex.multihead_att.initialization = (
"xavier_uniform"
)
allsettransformer_layer.reset_parameters()
assert allsettransformer_layer.vertex2edge.multihead_att.K_weight.shape == (
heads,
in_dim,
hid_dim // heads,
)
assert allsettransformer_layer.edge2vertex.multihead_att.K_weight.shape == (
heads,
hid_dim,
hid_dim // heads,
)

def test_initialisation_heads_zero(self):
"""Test the initialisation of the allsettransformer layer with invalid input."""
with pytest.raises(ValueError):
heads = 0
_ = AllSetTransformerLayer(
in_channels=10,
hidden_channels=64,
heads=heads,
)

def test_initialisation_heads_wrong(self):
"""Test the initialisation of the allsettransformer layer with invalid input."""
with pytest.raises(ValueError):
in_channels = 10
heads = 3

_ = AllSetTransformerLayer(
in_channels=in_channels,
hidden_channels=64,
heads=heads,
)

def test_initialisation_mlp_num_layers_zero(self):
"""Test the initialisation of the allsettransformer layer with invalid input."""
with pytest.raises(ValueError):
mlp_num_layers = 0
_ = AllSetTransformerLayer(
in_channels=10,
hidden_channels=64,
heads=4,
mlp_num_layers=mlp_num_layers,
)

def test_initialisation_mlp_num_layers_negative(self):
"""Test the initialisation of the allsettransformer layer with invalid input."""
with pytest.raises(ValueError):
mlp_num_layers = -1
_ = AllSetTransformerLayer(
in_channels=10,
hidden_channels=64,
heads=4,
mlp_num_layers=mlp_num_layers,
)

def test_MLP(self):
"""Test the MLP class.
(used in AllSetTransformerLayer)
"""
in_channels_ = [10]
hidden_channels_ = [[64], [64, 64]]
norm_layers = [None, torch.nn.LayerNorm]
activation_layers = [torch.nn.ReLU, torch.nn.LeakyReLU]
dropouts = [0.0, 0.5]
bias_ = [True, False]

for in_channels in in_channels_:
for hidden_channels in hidden_channels_:
for norm_layer in norm_layers:
for activation_layer in activation_layers:
for dropout in dropouts:
for bias in bias_:
mlp = MLP(
in_channels=in_channels,
hidden_channels=hidden_channels,
norm_layer=norm_layer,
activation_layer=activation_layer,
dropout=dropout,
bias=bias,
)
assert mlp is not None
Loading

0 comments on commit 664dcb9

Please sign in to comment.