-
Notifications
You must be signed in to change notification settings - Fork 81
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #142 from ajbrent/sca-cmps-branch
Simplicial Complex Autoencoder (CMPS)
- Loading branch information
Showing
3 changed files
with
735 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
"""Test the SCA layer.""" | ||
|
||
import torch | ||
|
||
from topomodelx.base.conv import Conv | ||
from topomodelx.nn.simplicial.sca_cmps_layer import SCACMPSLayer | ||
|
||
|
||
class TestSCALayer: | ||
"""Test the HSN layer.""" | ||
|
||
def test_cmps_forward(self): | ||
"""Test the forward pass of the SCA layer using CMPS.""" | ||
channels_list = [3, 5, 6, 8] | ||
n_chains_list = [10, 20, 15, 5] | ||
down_lap_list = [] | ||
incidence_t_list = [] | ||
for i in range(1, len(n_chains_list)): | ||
lap_down = torch.randint(0, 2, (n_chains_list[i], n_chains_list[i])).float() | ||
incidence_transpose = torch.randint( | ||
0, 2, (n_chains_list[i], n_chains_list[i - 1]) | ||
).float() | ||
down_lap_list.append(lap_down) | ||
incidence_t_list.append(incidence_transpose) | ||
|
||
x_list = [] | ||
for chan, n in zip(channels_list, n_chains_list): | ||
x = torch.randn(n, chan) | ||
x_list.append(x) | ||
|
||
sca = SCACMPSLayer( | ||
channels_list=channels_list, | ||
complex_dim=len(n_chains_list), | ||
) | ||
output = sca.forward(x_list, down_lap_list, incidence_t_list) | ||
|
||
for x, n, chan in zip(output, n_chains_list, channels_list): | ||
assert x.shape == (n, chan) | ||
|
||
def test_reset_parameters(self): | ||
"""Test the reset of the parameters.""" | ||
channels = [2, 2, 2, 2] | ||
dim = 4 | ||
|
||
sca = SCACMPSLayer(channels, dim) | ||
|
||
initial_params = [] | ||
for module in sca.modules(): | ||
if isinstance(module, torch.nn.ModuleList): | ||
for sub in module: | ||
if isinstance(sub, Conv): | ||
initial_params.append(list(sub.parameters())) | ||
with torch.no_grad(): | ||
for param in sub.parameters(): | ||
param.add_(1.0) | ||
|
||
sca.reset_parameters() | ||
reset_params = [] | ||
for module in sca.modules(): | ||
if isinstance(module, torch.nn.ModuleList): | ||
for sub in module: | ||
if isinstance(sub, Conv): | ||
reset_params.append(list(sub.parameters())) | ||
|
||
count = 0 | ||
for module, reset_param, initial_param in zip( | ||
sca.modules(), reset_params, initial_params | ||
): | ||
if isinstance(module, torch.nn.ModuleList): | ||
for sub, r_param, i_param in zip(module, reset_param, initial_param): | ||
if isinstance(sub, Conv): | ||
torch.testing.assert_close(i_param, r_param) | ||
count += 1 | ||
|
||
assert count > 0 # Ensuring if-statements were not just failed |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
"""Simplical Complex Autoencoder Layer.""" | ||
import torch | ||
|
||
from topomodelx.base.aggregation import Aggregation | ||
from topomodelx.base.conv import Conv | ||
|
||
|
||
class SCACMPSLayer(torch.nn.Module): | ||
"""Layer of a Simplicial Complex Autoencoder (SCA) using the Coadjacency Message Passing Scheme (CMPS). | ||
Implementation of the SCA layer proposed in [HZPMC22]_. | ||
Notes | ||
----- | ||
This is the architecture proposed for complex classification. | ||
References | ||
---------- | ||
.. [HZPMC22] Hajij, Zamzmi, Papamarkou, Maroulas, Cai. | ||
Simplicial Complex Autoencoder | ||
https://arxiv.org/pdf/2103.04046.pdf | ||
Parameters | ||
---------- | ||
channels_list: list[int] | ||
Dimension of features at each dimension. | ||
complex_dim: int | ||
Highest dimension of chains on the input simplicial complexes. | ||
att: bool | ||
Whether to use attention. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
channels_list, | ||
complex_dim, | ||
att=False, | ||
): | ||
super().__init__() | ||
self.att = att | ||
self.dim = complex_dim | ||
self.channels_list = channels_list | ||
lap_layers = [] | ||
inc_layers = [] | ||
for i in range(1, complex_dim): | ||
conv_layer_lap = Conv( | ||
in_channels=channels_list[i], | ||
out_channels=channels_list[i], | ||
att=att, | ||
) | ||
conv_layer_inc = Conv( | ||
in_channels=channels_list[i - 1], | ||
out_channels=channels_list[i], | ||
att=att, | ||
) | ||
lap_layers.append(conv_layer_lap) | ||
inc_layers.append(conv_layer_inc) | ||
|
||
self.lap_layers = torch.nn.ModuleList(lap_layers) | ||
self.inc_layers = torch.nn.ModuleList(inc_layers) | ||
self.aggr = Aggregation( | ||
aggr_func="sum", | ||
update_func=None, | ||
) | ||
self.inter_aggr = Aggregation( | ||
aggr_func="mean", | ||
update_func="relu", | ||
) | ||
|
||
def reset_parameters(self): | ||
r"""Reset parameters of each layer.""" | ||
for layer in self.lap_layers: | ||
if isinstance(layer, Conv): | ||
layer.reset_parameters() | ||
for layer in self.inc_layers: | ||
if isinstance(layer, Conv): | ||
layer.reset_parameters() | ||
|
||
def weight_func(self, x): | ||
r"""Weight function for intra aggregation layer according to [HZPMC22]_.""" | ||
return 1 / (1 + torch.exp(-x)) | ||
|
||
def intra_aggr(self, x): | ||
r"""Based on the use by [HZPMC22]_.""" | ||
x_list = list(torch.split(x, 1, dim=0)) | ||
x_weight = self.aggr(x_list) | ||
x_weight = torch.matmul(torch.relu(x_weight), x.transpose(1, 0)) | ||
x_weight = self.weight_func(x_weight) | ||
x = x_weight.transpose(1, 0) * x | ||
return x | ||
|
||
def forward(self, x_list, down_lap_list, incidencet_list): | ||
r"""Forward pass. | ||
The forward pass was initially proposed in [HZPMC22]_. | ||
Its equations are given in [TNN23]_ and graphically illustrated in [PSHM23]_. | ||
Coadjacency Message Passing Scheme | ||
\begin{align*} | ||
&🟥 \quad m_{y \rightarrow x}^{(r \rightarrow r'' \rightarrow r)} = M(h_{x}^{t, (r)}, h_{y}^{t, (r)},att(h_{x}^{t, (r)}, h_{y}^{t, (r)}),x,y,{\Theta^t}) \qquad \text{where } r'' < r < r' | ||
&🟥 \quad m_{y \rightarrow x}^{(r'' \rightarrow r)} = M(h_{x}^{t, (r)}, h_{y}^{t, (r'')},att(h_{x}^{t, (r)}, h_{y}^{t, (r'')}),x,y,{\Theta^t}) | ||
&🟧 \quad m_x^{(r \rightarrow r)} = AGG_{y \in \mathcal{L}\_\downarrow(x)} m_{y \rightarrow x}^{(r \rightarrow r)} | ||
&🟧 \quad m_x^{(r'' \rightarrow r)} = AGG_{y \in \mathcal{B}(x)} m_{y \rightarrow x}^{(r'' \rightarrow r)} | ||
&🟩 \quad m_x^{(r)} = \text{AGG}\_{\mathcal{N}\_k \in \mathcal{N}}(m_x^{(k)}) | ||
&🟦 \quad h_{x}^{t+1, (r)} = U(h_x^{t, (r)}, m_{x}^{(r)}) | ||
\end{align*} | ||
References | ||
---------- | ||
.. [HZPMC22] Hajij, Zamzmi, Papamarkou, Maroulas, Cai. | ||
Simplicial Complex Autoencoder | ||
https://arxiv.org/pdf/2103.04046.pdf | ||
.. [TNN23] Equations of Topological Neural Networks. | ||
https://github.com/awesome-tnns/awesome-tnns/ | ||
.. [PSHM23] Papillon, Sanborn, Hajij, Miolane. | ||
Architectures of Topological Deep Learning: A Survey on Topological Neural Networks. | ||
(2023) https://arxiv.org/abs/2304.10031. | ||
Parameters | ||
---------- | ||
x_list: list[torch.Tensor] | ||
List of tensors holding the features of each chain at each level. | ||
down_lap_list: list[torch.Tensor] | ||
List of down laplacian matrices for skeletons from 1 dimension to the dimension of the simplicial complex. | ||
incidencet_list: list[torch.Tensor] | ||
List of transpose incidence matrices for skeletons from 1 dimension to the dimension of the simplicial complex. | ||
Returns | ||
------- | ||
x_list: list[torch.Tensor] | ||
Output for skeletons of each dimension (the node features are left untouched: x_list[0]). | ||
""" | ||
for i in range(1, self.dim): | ||
x_lap = self.lap_layers[i - 1](x_list[i], down_lap_list[i - 1]) | ||
x_inc = self.inc_layers[i - 1](x_list[i - 1], incidencet_list[i - 1]) | ||
|
||
x_lap = self.intra_aggr(x_lap) | ||
x_inc = self.intra_aggr(x_inc) | ||
|
||
x_list[i] = self.inter_aggr([x_lap, x_inc]) | ||
|
||
return x_list |
Oops, something went wrong.