Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add model SANN #103

Open
wants to merge 42 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
268c398
ADD: Base setup for SANN
martin-carrasco Sep 2, 2024
4cda3cc
ADD: Tutorial for SANN
martin-carrasco Sep 12, 2024
1154d45
ADD: Precompute embeddings for SANN
martin-carrasco Sep 19, 2024
21a6cff
FIX: SAN cell embeddings to encoder
martin-carrasco Sep 19, 2024
14e229e
FIX: Remove useless comments and fix indexing mixup
martin-carrasco Oct 3, 2024
80c8b6d
FIX: Add complex_dim config support and comment laplacian normalizati…
martin-carrasco Oct 3, 2024
adf8109
ADD: Batching support and a not so nice fix
martin-carrasco Oct 3, 2024
1c36e24
FIX: Changes to
martin-carrasco Oct 16, 2024
ae4cba4
FIX: Custom dataloader removal
martin-carrasco Nov 4, 2024
9d9a754
REMOVE: Testing file for SANN
martin-carrasco Nov 4, 2024
c583efc
FIX: Should be elif
martin-carrasco Nov 4, 2024
d174ede
FIX: Naming
martin-carrasco Nov 6, 2024
7552967
Merge branch 'main' of https://github.com/geometric-intelligence/Topo…
levtelyatnikov Nov 8, 2024
8af8846
Merge branch 'main' of github.com:geometric-intelligence/TopoBenchmar…
gbg141 Nov 8, 2024
50371d1
Merge branch 'main' of github.com:geometric-intelligence/TopoBenchmar…
gbg141 Nov 9, 2024
dd433e3
Enabling model-based default transform
gbg141 Nov 9, 2024
8be8641
Adapting config files
gbg141 Nov 9, 2024
50bad08
Adding tutorial
gbg141 Nov 9, 2024
3c9df01
Start model adaptation
gbg141 Nov 9, 2024
a559490
Fix bug in precompute_khop_features
gbg141 Nov 9, 2024
c9b6585
Loading and transform working
gbg141 Nov 9, 2024
bc1722b
Merge branch 'martin-sann' of https://github.com/pyt-team/TopoBenchma…
martin-carrasco Nov 9, 2024
a5b8d34
ADD: Fix for SANN and related tests
martin-carrasco Nov 9, 2024
6b60c28
FIX: Configurations and removing useless eavluation stuff
martin-carrasco Nov 9, 2024
a0ac551
ADD: Two tests for SANN
martin-carrasco Nov 9, 2024
5b4018d
SaNN running on PROTEINS
gbg141 Nov 13, 2024
991285e
Config update
gbg141 Nov 13, 2024
297ed6a
Merge branch 'martin-sann' of https://github.com/geometric-intelligen…
levtelyatnikov Nov 19, 2024
d3bbf4d
config_rrsolver
levtelyatnikov Nov 20, 2024
ffbc669
infer_in_khop_feature_dim working
gbg141 Nov 20, 2024
cee8ee3
working sann for some graph level datasets
levtelyatnikov Nov 21, 2024
1ff67c2
issue with mutag dataset
levtelyatnikov Nov 21, 2024
bd2188f
Fix bug with preserve_edge_attr
gbg141 Nov 21, 2024
02bad97
Adding new config resolvers' tests
gbg141 Nov 21, 2024
0b4a0e6
Minor
gbg141 Nov 21, 2024
b59c99f
Merge branch 'main' into martin-sann
gbg141 Nov 21, 2024
62c7c8f
Fix ruff
gbg141 Nov 21, 2024
3f2774c
Merge branch 'main' of github.com:geometric-intelligence/TopoBenchmar…
gbg141 Nov 22, 2024
5af0de2
FIX: automatic t-hop neighbourhood calculation based on config
martin-carrasco Nov 25, 2024
d4497fe
Merge branch 'martin-sann' of https://github.com/pyt-team/TopoBenchma…
martin-carrasco Nov 25, 2024
2872d0d
FIX: Sann config
martin-carrasco Nov 28, 2024
9f526a2
FIX: San configurations
martin-carrasco Nov 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion configs/dataset/graph/MUTAG.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ parameters:
task_level: graph
# Lifting parameters
max_dim_if_lifted: 3 # This is the maximum dimension of the simplicial complex in the dataset
preserve_edge_attr_if_lifted: True
preserve_edge_attr_if_lifted: ${set_preserve_edge_attr:${model.model_name},True} # Second argument is the default value
# (in case the model is compatible)

#splits
split_params:
Expand Down
47 changes: 47 additions & 0 deletions configs/model/simplicial/sann.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
_target_: topobenchmarkx.model.TBXModel

model_name: sann
model_domain: simplicial

feature_encoder:
_target_: topobenchmarkx.nn.encoders.${model.feature_encoder.encoder_name}
encoder_name: SANNFeatureEncoder
dataset_in_channels: ${infer_in_channels:${dataset},${oc.select:transforms,null}}
in_channels: ${transforms.sann_encoding.in_channels}
max_hop: ${transforms.sann_encoding.max_hop}
out_channels: 64
proj_dropout: 0.0
selected_dimensions:
- 0
- 1
- 2
feature_lifting: Duplicate
all_ones: true

backbone:
_target_: topobenchmarkx.nn.backbones.simplicial.sann.SANN
in_channels: ${model.backbone.hidden_channels}
n_layers: 2
max_hop: ${transforms.sann_encoding.max_hop}
hidden_channels: ${model.feature_encoder.out_channels}

backbone_wrapper:
_target_: topobenchmarkx.nn.wrappers.SANNWrapper
_partial_: true
wrapper_name: SANNWrapper
out_channels: ${model.feature_encoder.out_channels}
num_cell_dimensions: ${infere_num_cell_dimensions:${oc.select:model.feature_encoder.selected_dimensions,null},${model.feature_encoder.in_channels}}

readout:
_target_: topobenchmarkx.nn.readouts.${model.readout.readout_name}
readout_name: SANNReadout # Use <NoReadOut> in case readout is not needed Options: PropagateSignalDown
max_hop: ${transforms.sann_encoding.max_hop}
num_cell_dimensions: ${infere_num_cell_dimensions:${oc.select:model.feature_encoder.selected_dimensions,null},${model.feature_encoder.in_channels}} # The highest order of cell dimensions to consider
hidden_dim: ${model.feature_encoder.out_channels}
out_channels: ${dataset.parameters.num_classes}
task_level: ${dataset.parameters.task_level}
pooling_type: sum
complex_dim: ${oc.select:dataset.parameters.max_dim_if_lifted,3}

# compile model for faster training with pytorch 2.0
compile: false
4 changes: 2 additions & 2 deletions configs/run.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
# order of defaults determines the order in which configs override each other
defaults:
- _self_
- dataset: graph/cocitation_cora
- model: cell/topotune
- dataset: graph/MUTAG
- model: simplicial/sann
- transforms: ${get_default_transform:${dataset},${model}} #no_transform
- optimizer: default
- loss: default
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
_target_: topobenchmarkx.transforms.data_transform.DataTransform
transform_name: "PrecomputeKHopFeatures"
transform_type: "data manipulation"
max_hop: 3
complex_dim: ${oc.select:dataset.parameters.max_dim_if_lifted,3}
in_channels: ${infer_in_khop_feature_dim:${model.feature_encoder.dataset_in_channels},${.max_hop}}
# in_features: ${infer_in_sann_khop_feature_dim:${model},${3}}


# change of the transfroms: ${new function that looks at dataset.dim (list [3,5,1])} --> list of dim
2 changes: 1 addition & 1 deletion configs/transforms/liftings/graph2simplicial/clique.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,5 @@ transform_name: "SimplicialCliqueLifting"
complex_dim: ${oc.select:dataset.parameters.max_dim_if_lifted,3}
preserve_edge_attr: ${oc.select:dataset.parameters.preserve_edge_attr_if_lifted,False}
signed: False
feature_lifting: ProjectionSum
feature_lifting: ${oc.select:model.feature_encoder.feature_lifting,ProjectionSum}
neighborhoods: ${oc.select:model.backbone.neighborhoods,null}
4 changes: 4 additions & 0 deletions configs/transforms/model_defaults/sann.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# USE python -m topobenchmarkx transforms.one_hot_node_degree_features.degrees_fields=x to run this config
defaults:
- liftings@_here_: ${get_required_lifting:graph,${model}}
- data_manipulations@sann_encoding: precompute_khop_features
20 changes: 20 additions & 0 deletions scripts/test_sann.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
python -m topobenchmarkx \
model=simplicial/sann \
dataset=graph/PROTEINS \
optimizer.parameters.lr=0.001 \
optimizer.parameters.weight_decay=0.0001 \
model.backbone.n_layers=2 \
model.feature_encoder.proj_dropout=0.0 \
dataset.dataloader_params.batch_size=64 \
dataset.split_params.data_seed=0 \
dataset.split_params.split_type=k-fold \
dataset.split_params.k=10 \
optimizer.scheduler=null \
trainer.max_epochs=50 \
trainer.min_epochs=50 \
trainer.devices=1 \
trainer.accelerator=cpu \
trainer.check_val_every_n_epoch=1 \
callbacks.early_stopping.patience=100\
logger.wandb.project=TopoBenchmarkX_main\
--multirun
106 changes: 102 additions & 4 deletions test/_utils/nn_module_auto_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,27 @@
import copy

class NNModuleAutoTest:
r"""Test the following cases:
r"""Test the following cases.

1) Assert if the module return at least one tensor
2) Reproducibility. Assert that the module return the same output when called with the same data
Additionally
3) Assert returned shape.
Important! If module returns multiple tensor. The shapes for assertion must be in list() not (!!!) tuple()
Important! If module returns multiple tensor. The shapes for assertion must be in list() not (!!!) tuple().

Parameters
----------
params : list of dict
List of dictionaries with parameters.
"""
SEED = 0

def __init__(self, params):
self.params = params

def run(self):
r"""Run the test.
"""
for param in self.params:
assert "module" in param and "init" in param and "forward" in param
module = self.exec_func(param["module"], param["init"])
Expand All @@ -40,6 +48,22 @@ def run(self):
self.assert_shape(result, param["assert_shape"])

def exec_twice(self, module, inp_1, inp_2):
""" Execute the module twice with different inputs.

Parameters
----------
module : torch.nn.Module
Module to be tested.
inp_1 : tuple or dict
Input for the module.
inp_2 : tuple or dict
Input for the module.

Returns
-------
tuple
Output of the module for the first input.
"""
torch.manual_seed(self.SEED)
result = self.exec_func(module, inp_1)

Expand All @@ -49,39 +73,113 @@ def exec_twice(self, module, inp_1, inp_2):
return result, result_2

def exec_func(self, func, args):
""" Execute the function with the arguments.

Parameters
----------
func : function
Function to be executed.
args : tuple or dict
Arguments for the function.

Returns
-------
any
Output of the function.
"""
if type(args) == tuple:
return func(*args)
elif type(args) == dict:
return func(**args)
else:
raise TypeError(f"{type(args)} is not correct type for funcntion arguments.")
raise TypeError(f"{type(args)} is not correct type for funnction arguments.")

def clone_input(self, args):
""" Clone the input arguments.

Parameters
----------
args : tuple or dict
Arguments to be cloned.

Returns
-------
tuple or dict
Cloned arguments.
"""
if type(args) == tuple:
return tuple(self.clone_object(a) for a in args)
elif type(args) == dict:
return {k: self.clone_object(v) for k, v in args.items()}

def clone_object(self, obj):
""" Clone the object.

Parameters
----------
obj : any
Object to be cloned.

Returns
-------
any
Cloned object.
"""
if hasattr(obj, "clone"):
return obj.clone()
else:
return copy.deepcopy(obj)

def assert_return_tensor(self, result):
assert any(isinstance(r, torch.Tensor) for r in result)
""" Assert if the module return at least one tensor.

Parameters
----------
result : any
Output of the module.
"""
if all(isinstance(r, tuple) for r in result):
assert any([all([isinstance(r, torch.Tensor) for r in tup]) for tup in result])
else:
assert any(isinstance(r, torch.Tensor) for r in result)

def assert_equal_output(self, module, result, result_2):
""" Assert that the module return the same output when called with the same data.

Parameters
----------
module : torch.nn.Module
Module to be tested.
result : any
Output of the module for the first input.
result_2 : any
Output of the module for the second input.
"""
assert len(result) == len(result_2)

for i, r1 in enumerate(result):
r2 = result_2[i]
if isinstance(r1, torch.Tensor):
assert torch.equal(r1, r2)
elif isinstance(r1, tuple) and isinstance(r2, tuple):
for r1_, r2_ in zip(r1, r2):
if isinstance(r1_, torch.Tensor) and isinstance(r2_, torch.Tensor):
assert torch.equal(r1_, r2_)
else:
assert r1_ == r2_
else:
assert r1 == r2

def assert_shape(self, result, shapes):
""" Assert returned shape.

Parameters
----------
result : any
Output of the module.
shapes : list
List of shapes to be asserted.
"""
i = 0
for t in result:
if isinstance(t, torch.Tensor):
Expand Down
34 changes: 34 additions & 0 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
from topobenchmarkx.transforms.liftings.graph2cell import (
CellCycleLifting
)
from topobenchmarkx.transforms.data_manipulations.precompute_khop_features import (
PrecomputeKHopFeatures
)


@pytest.fixture
Expand Down Expand Up @@ -155,6 +158,37 @@ def sg1_clique_lifted(simple_graph_1):
data.batch_0 = "null"
return data

@pytest.fixture
def sg1_clique_lifted_precompute_k_hop(simple_graph_1):
"""Return a simple graph with a clique lifting and a precomputed k-hop neighbourhood embedding.

Parameters
----------
simple_graph_1 : torch_geometric.data.Data
A simple graph data object.

Returns
-------
torch_geometric.data.Data
A simple graph data object with a clique lifting and a K-neighbourhood embedding.
"""
max_hop=2
complex_dim=3
lifting_signed = SimplicialCliqueLifting(
complex_dim=complex_dim, signed=True
)
data = lifting_signed(simple_graph_1)
precompute_k_hop = PrecomputeKHopFeatures(max_hop=max_hop, complex_dim=complex_dim)
data = precompute_k_hop(data)
# Set all k-hop dimensions to 1 to standardize testing
for i in range(max_hop+1):
for j in range(complex_dim):
data[f"x{j}_{i}"] = data[f"x{j}_{i}"][:, 0:1]
data.batch_0 = "null"
data.batch_1 = "null"
data.batch_2 = "null"
return data

@pytest.fixture
def sg1_cell_lifted(simple_graph_1):
"""Return a simple graph with a cell lifting.
Expand Down
51 changes: 51 additions & 0 deletions test/nn/backbones/simplicial/test_sann.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
"""Unit tests for SANN."""

import torch
from torch_geometric.utils import get_laplacian
from ...._utils.nn_module_auto_test import NNModuleAutoTest
from topobenchmarkx.nn.backbones.simplicial import SANN
from topobenchmarkx.transforms.liftings.graph2simplicial import (
SimplicialCliqueLifting,
)
from topobenchmarkx.transforms.data_manipulations.precompute_khop_features import (
PrecomputeKHopFeatures
)


def test_SANN(simple_graph_1):
"""Test SANN.

Test the SANN backbone module.

Parameters
----------
simple_graph_1 : torch_geometric.data.Data
A fixture of simple graph 1.
"""
max_hop = 2
complex_dim = 3
lifting_signed = SimplicialCliqueLifting(
complex_dim=complex_dim, signed=True
)
precompute_k_hop = PrecomputeKHopFeatures(max_hop=max_hop, complex_dim=complex_dim)
data = lifting_signed(simple_graph_1)
data = precompute_k_hop(data)
out_dim = 4

# Set all k-hop dimensions to 1 to standardize testing
for i in range(max_hop+1):
for j in range(complex_dim):
data[f"x{j}_{i}"] = data[f"x{j}_{i}"][:, 0:1]

x_in = tuple(tuple(data[f"x{i}_{j}"] for j in range(max_hop+1)) for i in range(complex_dim))
expected_shapes = [(data.x.shape[0], out_dim), (data.x_1.shape[0], out_dim), (data.x_2.shape[0], out_dim)]

auto_test = NNModuleAutoTest([
{
"module" : SANN,
"init": ((1, 1, 1), 1, 'lrelu', complex_dim, max_hop, 2),
"forward": (x_in, ),
"assert_shape": expected_shapes
},
])
auto_test.run()
Loading