Skip to content

Commit

Permalink
Merge pull request #269 from pyt-team/240-review-nncell-and-nncombina…
Browse files Browse the repository at this point in the history
…torial-models

240 review nncell and nncombinatorial models + Fixed SAN bugs
  • Loading branch information
ninamiolane authored Feb 20, 2024
2 parents aa07026 + cb7f074 commit a158cbf
Show file tree
Hide file tree
Showing 17 changed files with 603 additions and 497 deletions.
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -121,3 +121,6 @@ checks = [
"EX01",
"SA01"
]
exclude = [
'\.undocumented_method$',
]
5 changes: 2 additions & 3 deletions test/nn/cell/test_can.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ def test_forward(self):
out_channels=2,
dropout=0.5,
heads=1,
num_classes=1,
n_layers=2,
att_lift=False,
).to(device)
Expand All @@ -36,5 +35,5 @@ def test_forward(self):
adjacency_2 = adjacency_1.float().to(device)
incidence_2 = adjacency_1.float().to(device)

y = model(x_0, x_1, adjacency_1, adjacency_2, incidence_2)
assert y.shape == torch.Size([1])
x_1 = model(x_0, x_1, adjacency_1, adjacency_2, incidence_2)
assert x_1.shape == torch.Size([1, 2])
7 changes: 4 additions & 3 deletions test/nn/cell/test_ccxn.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ def test_forward(self):
in_channels_0=2,
in_channels_1=2,
in_channels_2=2,
num_classes=1,
n_layers=2,
att=False,
).to(device)
Expand All @@ -33,5 +32,7 @@ def test_forward(self):
adjacency_1 = adjacency_1.float().to(device)
incidence_2 = incidence_2.float().to(device)

y = model(x_0, x_1, adjacency_1, incidence_2)
assert y.shape == torch.Size([1])
x_0, x_1, x_2 = model(x_0, x_1, adjacency_1, incidence_2)
assert x_0.shape == torch.Size([2, 2])
assert x_1.shape == torch.Size([2, 2])
assert x_2.shape == torch.Size([2, 2])
7 changes: 4 additions & 3 deletions test/nn/cell/test_cwn.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ def test_forward(self):
in_channels_1=2,
in_channels_2=2,
hid_channels=16,
num_classes=1,
n_layers=2,
).to(device)

Expand All @@ -36,5 +35,7 @@ def test_forward(self):
incidence_2 = incidence_2.float().to(device)
incidence_1_t = incidence_1_t.float().to(device)

y = model(x_0, x_1, x_2, adjacency_1, incidence_2, incidence_1_t)
assert y.shape == torch.Size([1])
x_0, x_1, x_2 = model(x_0, x_1, x_2, adjacency_1, incidence_2, incidence_1_t)
assert x_0.shape == torch.Size([2, 16])
assert x_1.shape == torch.Size([2, 16])
assert x_2.shape == torch.Size([2, 16])
8 changes: 5 additions & 3 deletions test/nn/combinatorial/test_hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def test_forward(self):
intermediate_channels = [2, 2, 2]
final_channels = [2, 2, 2]
channels_per_layer = [[in_channels, intermediate_channels, final_channels]]
model = HMC(channels_per_layer, negative_slope=0.2, num_classes=2).to(device)
model = HMC(channels_per_layer, negative_slope=0.2).to(device)

x_0 = torch.rand(2, 2)
x_1 = torch.rand(2, 2)
Expand All @@ -29,7 +29,7 @@ def test_forward(self):
)
adjacency_0 = adjacency_0.float().to(device)

y = model(
x_0, x_1, x_2 = model(
x_0,
x_1,
x_2,
Expand All @@ -39,4 +39,6 @@ def test_forward(self):
adjacency_0,
adjacency_0,
)
assert y.shape == torch.Size([2])
assert x_0.shape == torch.Size([2, 2])
assert x_1.shape == torch.Size([2, 2])
assert x_2.shape == torch.Size([2, 2])
4 changes: 2 additions & 2 deletions test/nn/simplicial/test_san.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,15 @@ def test_forward(self):
in_channels=in_channels,
hidden_channels=hidden_channels,
out_channels=out_channels,
n_layers=1,
n_layers=3,
)
laplacian_down_1 = from_sparse(simplicial_complex.down_laplacian_matrix(rank=1))
laplacian_up_1 = from_sparse(simplicial_complex.up_laplacian_matrix(rank=1))

assert torch.any(
torch.isclose(
model(x, laplacian_up_1, laplacian_down_1)[0],
torch.tensor([2.8254, -0.9797]),
torch.tensor([-2.5604, -3.5924]),
rtol=1e-02,
)
)
Expand Down
26 changes: 9 additions & 17 deletions topomodelx/nn/cell/can.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,10 @@ class CAN(torch.nn.Module):
Number of input channels for the edge-level input.
out_channels : int
Number of output channels.
num_classes : int
Number of output classes.
dropout : float, optional
Dropout probability. Default is 0.5.
heads : int, optional
Number of attention heads. Default is 3.
Number of attention heads. Default is 2.
concat : bool, optional
Whether to concatenate the output channels of attention heads. Default is True.
skip_connection : bool, optional
Expand All @@ -33,6 +31,8 @@ class CAN(torch.nn.Module):
Number of CAN layers.
att_lift : bool, default=True
Whether to apply a lift the signal from node-level to edge-level input.
k_pool : float, default=0.5
The pooling ratio i.e, the fraction of r-cells to keep after the pooling operation.
References
----------
Expand All @@ -47,14 +47,14 @@ def __init__(
in_channels_0,
in_channels_1,
out_channels,
num_classes,
dropout=0.5,
heads=3,
heads=2,
concat=True,
skip_connection=True,
att_activation=torch.nn.LeakyReLU(0.2),
n_layers=2,
att_lift=True,
k_pool=0.5,
):
super().__init__()

Expand Down Expand Up @@ -98,16 +98,14 @@ def __init__(

layers.append(
PoolLayer(
k_pool=0.5,
k_pool=k_pool,
in_channels_0=out_channels * heads,
signal_pool_activation=torch.nn.Sigmoid(),
readout=True,
)
)

self.layers = torch.nn.ModuleList(layers)
self.lin_0 = torch.nn.Linear(heads * out_channels, 128)
self.lin_1 = torch.nn.Linear(128, num_classes)

def forward(
self, x_0, x_1, neighborhood_0_to_0, lower_neighborhood, upper_neighborhood
Expand All @@ -129,8 +127,8 @@ def forward(
Returns
-------
torch.Tensor
Output prediction for the cell complex.
torch.Tensor, shape = (num_pooled_edges, heads * out_channels)
Final hidden representations of pooled edges.
"""
if hasattr(self, "lift_layer"):
x_1 = self.lift_layer(x_0, neighborhood_0_to_0, x_1)
Expand All @@ -144,10 +142,4 @@ def forward(
x_1 = layer(x_1, lower_neighborhood, upper_neighborhood)
x_1 = F.dropout(x_1, p=0.5, training=self.training)

# max pooling over all nodes in each graph
x = x_1.max(dim=0)[0]

# Feed-Foward Neural Network to predict the graph label
out = self.lin_1(torch.nn.functional.relu(self.lin_0(x)))

return out
return x_1
24 changes: 11 additions & 13 deletions topomodelx/nn/cell/can_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ class PoolLayer(MessagePassing):
Parameters
----------
k_pool : float in (0, 1]
The pooling ratio i.e, the fraction of edges to keep after the pooling operation.
The pooling ratio i.e, the fraction of r-cells to keep after the pooling operation.
in_channels_0 : int
Number of input channels of the input signal.
signal_pool_activation : Callable
Expand Down Expand Up @@ -323,14 +323,14 @@ def reset_parameters(self) -> None:
init.xavier_uniform_(self.att_pool.data, gain=gain)

def forward( # type: ignore[override]
self, x_0, lower_neighborhood, upper_neighborhood
self, x, lower_neighborhood, upper_neighborhood
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
r"""Forward pass.
Parameters
----------
x_0 : torch.Tensor, shape = (num_nodes, in_channels_0)
Node signal.
x : torch.Tensor, shape = (n_r_cells, in_channels_r)
Input r-cell signal.
lower_neighborhood : torch.Tensor
Lower neighborhood matrix.
upper_neighborhood : torch.Tensor
Expand All @@ -339,7 +339,7 @@ def forward( # type: ignore[override]
Returns
-------
torch.Tensor
Pooled node signal of shape (num_pooled_nodes, in_channels_0).
Pooled r_cell signal of shape (n_r_cells, in_channels_r).
Notes
-----
Expand All @@ -351,21 +351,19 @@ def forward( # type: ignore[override]
= \phi^t(h_x^t, m_{x}^{(r)}), \forall x\in \mathcal C_r^{t+1}
\end{align*}
"""
# Compute the output edge signal by applying the activation function
Zp = torch.einsum("nc,ce->ne", x_0, self.att_pool)
# Apply top-k pooling to the edge signal
# Compute the output r-cell signal by applying the activation function
Zp = torch.einsum("nc,ce->ne", x, self.att_pool)
# Apply top-k pooling to the r-cell signal
_, top_indices = topk(Zp.view(-1), int(self.k_pool * Zp.size(0)))
# Rescale the pooled signal
Zp = self.signal_pool_activation(Zp)
out = x_0[top_indices] * Zp[top_indices]
out = x[top_indices] * Zp[top_indices]

# Readout operation
if self.readout:
out = scatter_add(out, top_indices, dim=0, dim_size=x_0.size(0))[
top_indices
]
out = scatter_add(out, top_indices, dim=0, dim_size=x.size(0))[top_indices]

# Update lower and upper neighborhood matrices with the top-k pooled edges
# Update lower and upper neighborhood matrices with the top-k pooled r-cells
lower_neighborhood_modified = torch.index_select(
lower_neighborhood, 0, top_indices
)
Expand Down
33 changes: 8 additions & 25 deletions topomodelx/nn/cell/ccxn.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@ class CCXN(torch.nn.Module):
Dimension of input features on edges.
in_channels_2 : int
Dimension of input features on faces.
num_classes : int
Number of classes.
n_layers : int
Number of CCXN layers.
att : bool
Expand All @@ -36,7 +34,6 @@ def __init__(
in_channels_0,
in_channels_1,
in_channels_2,
num_classes,
n_layers=2,
att=False,
):
Expand All @@ -52,12 +49,9 @@ def __init__(
)
)
self.layers = torch.nn.ModuleList(layers)
self.lin_0 = torch.nn.Linear(in_channels_0, num_classes)
self.lin_1 = torch.nn.Linear(in_channels_1, num_classes)
self.lin_2 = torch.nn.Linear(in_channels_2, num_classes)

def forward(self, x_0, x_1, neighborhood_0_to_0, neighborhood_1_to_2):
"""Forward computation through layers, then linear layers, then avg pooling.
"""Forward computation through layers.
Parameters
----------
Expand All @@ -72,24 +66,13 @@ def forward(self, x_0, x_1, neighborhood_0_to_0, neighborhood_1_to_2):
Returns
-------
torch.Tensor, shape = (1)
Label assigned to whole complex.
x_0 : torch.Tensor, shape = (n_nodes, in_channels_0)
Final hidden states of the nodes (0-cells).
x_1 : torch.Tensor, shape = (n_edges, in_channels_1)
Final hidden states the edges (1-cells).
x_2 : torch.Tensor, shape = (n_faces, in_channels_2)
Final hidden states of the faces (2-cells).
"""
for layer in self.layers:
x_0, x_1, x_2 = layer(x_0, x_1, neighborhood_0_to_0, neighborhood_1_to_2)
x_0 = self.lin_0(x_0)
x_1 = self.lin_1(x_1)
x_2 = self.lin_2(x_2)
# Take the average of the 2D, 1D, and 0D cell features. If they are NaN, convert them to 0.
two_dimensional_cells_mean = torch.nanmean(x_2, dim=0)
two_dimensional_cells_mean[torch.isnan(two_dimensional_cells_mean)] = 0
one_dimensional_cells_mean = torch.nanmean(x_1, dim=0)
one_dimensional_cells_mean[torch.isnan(one_dimensional_cells_mean)] = 0
zero_dimensional_cells_mean = torch.nanmean(x_0, dim=0)
zero_dimensional_cells_mean[torch.isnan(zero_dimensional_cells_mean)] = 0
# Return the sum of the averages
return (
two_dimensional_cells_mean
+ one_dimensional_cells_mean
+ zero_dimensional_cells_mean
)
return (x_0, x_1, x_2)
36 changes: 7 additions & 29 deletions topomodelx/nn/cell/cwn.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@ class CWN(torch.nn.Module):
Dimension of input features on faces (2-cells).
hid_channels : int
Dimension of hidden features.
num_classes : int
Number of classes.
n_layers : int
Number of CWN layers.
Expand All @@ -38,7 +36,6 @@ def __init__(
in_channels_1,
in_channels_2,
hid_channels,
num_classes,
n_layers,
):
super().__init__()
Expand All @@ -58,10 +55,6 @@ def __init__(
)
self.layers = torch.nn.ModuleList(layers)

self.lin_0 = torch.nn.Linear(hid_channels, num_classes)
self.lin_1 = torch.nn.Linear(hid_channels, num_classes)
self.lin_2 = torch.nn.Linear(hid_channels, num_classes)

def forward(
self,
x_0,
Expand Down Expand Up @@ -90,8 +83,12 @@ def forward(
Returns
-------
torch.Tensor, shape = (1)
Label assigned to whole complex.
x_0 : torch.Tensor, shape = (n_nodes, in_channels_0)
Final hidden states of the nodes (0-cells).
x_1 : torch.Tensor, shape = (n_edges, in_channels_1)
Final hidden states the edges (1-cells).
x_2 : torch.Tensor, shape = (n_edges, in_channels_2)
Final hidden states of the faces (2-cells).
"""
x_0 = F.elu(self.proj_0(x_0))
x_1 = F.elu(self.proj_1(x_1))
Expand All @@ -107,23 +104,4 @@ def forward(
neighborhood_0_to_1,
)

x_0 = self.lin_0(x_0)
x_1 = self.lin_1(x_1)
x_2 = self.lin_2(x_2)

# Take the average of the 2D, 1D, and 0D cell features. If they are NaN, convert them to 0.
two_dimensional_cells_mean = torch.nanmean(x_2, dim=0)
two_dimensional_cells_mean[torch.isnan(two_dimensional_cells_mean)] = 0

one_dimensional_cells_mean = torch.nanmean(x_1, dim=0)
one_dimensional_cells_mean[torch.isnan(one_dimensional_cells_mean)] = 0

zero_dimensional_cells_mean = torch.nanmean(x_0, dim=0)
zero_dimensional_cells_mean[torch.isnan(zero_dimensional_cells_mean)] = 0

# Return the sum of the averages
return (
two_dimensional_cells_mean
+ one_dimensional_cells_mean
+ zero_dimensional_cells_mean
)
return x_0, x_1, x_2
Loading

0 comments on commit a158cbf

Please sign in to comment.