Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

240 review nncell and nncombinatorial models + Fixed SAN bugs #269

Merged
merged 13 commits into from
Feb 20, 2024
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -121,3 +121,6 @@ checks = [
"EX01",
"SA01"
]
exclude = [
'\.undocumented_method$',
]
5 changes: 2 additions & 3 deletions test/nn/cell/test_can.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ def test_forward(self):
out_channels=2,
dropout=0.5,
heads=1,
num_classes=1,
n_layers=2,
att_lift=False,
).to(device)
Expand All @@ -36,5 +35,5 @@ def test_forward(self):
adjacency_2 = adjacency_1.float().to(device)
incidence_2 = adjacency_1.float().to(device)

y = model(x_0, x_1, adjacency_1, adjacency_2, incidence_2)
assert y.shape == torch.Size([1])
x_1 = model(x_0, x_1, adjacency_1, adjacency_2, incidence_2)
assert x_1.shape == torch.Size([1, 2])
7 changes: 4 additions & 3 deletions test/nn/cell/test_ccxn.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ def test_forward(self):
in_channels_0=2,
in_channels_1=2,
in_channels_2=2,
num_classes=1,
n_layers=2,
att=False,
).to(device)
Expand All @@ -33,5 +32,7 @@ def test_forward(self):
adjacency_1 = adjacency_1.float().to(device)
incidence_2 = incidence_2.float().to(device)

y = model(x_0, x_1, adjacency_1, incidence_2)
assert y.shape == torch.Size([1])
x_0, x_1, x_2 = model(x_0, x_1, adjacency_1, incidence_2)
assert x_0.shape == torch.Size([2, 2])
assert x_1.shape == torch.Size([2, 2])
assert x_2.shape == torch.Size([2, 2])
7 changes: 4 additions & 3 deletions test/nn/cell/test_cwn.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ def test_forward(self):
in_channels_1=2,
in_channels_2=2,
hid_channels=16,
num_classes=1,
n_layers=2,
).to(device)

Expand All @@ -36,5 +35,7 @@ def test_forward(self):
incidence_2 = incidence_2.float().to(device)
incidence_1_t = incidence_1_t.float().to(device)

y = model(x_0, x_1, x_2, adjacency_1, incidence_2, incidence_1_t)
assert y.shape == torch.Size([1])
x_0, x_1, x_2 = model(x_0, x_1, x_2, adjacency_1, incidence_2, incidence_1_t)
assert x_0.shape == torch.Size([2, 16])
assert x_1.shape == torch.Size([2, 16])
assert x_2.shape == torch.Size([2, 16])
8 changes: 5 additions & 3 deletions test/nn/combinatorial/test_hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def test_forward(self):
intermediate_channels = [2, 2, 2]
final_channels = [2, 2, 2]
channels_per_layer = [[in_channels, intermediate_channels, final_channels]]
model = HMC(channels_per_layer, negative_slope=0.2, num_classes=2).to(device)
model = HMC(channels_per_layer, negative_slope=0.2).to(device)

x_0 = torch.rand(2, 2)
x_1 = torch.rand(2, 2)
Expand All @@ -29,7 +29,7 @@ def test_forward(self):
)
adjacency_0 = adjacency_0.float().to(device)

y = model(
x_0, x_1, x_2 = model(
x_0,
x_1,
x_2,
Expand All @@ -39,4 +39,6 @@ def test_forward(self):
adjacency_0,
adjacency_0,
)
assert y.shape == torch.Size([2])
assert x_0.shape == torch.Size([2, 2])
assert x_1.shape == torch.Size([2, 2])
assert x_2.shape == torch.Size([2, 2])
4 changes: 2 additions & 2 deletions test/nn/simplicial/test_san.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,15 @@ def test_forward(self):
in_channels=in_channels,
hidden_channels=hidden_channels,
out_channels=out_channels,
n_layers=1,
n_layers=3,
)
laplacian_down_1 = from_sparse(simplicial_complex.down_laplacian_matrix(rank=1))
laplacian_up_1 = from_sparse(simplicial_complex.up_laplacian_matrix(rank=1))

assert torch.any(
torch.isclose(
model(x, laplacian_up_1, laplacian_down_1)[0],
torch.tensor([2.8254, -0.9797]),
torch.tensor([-2.5604, -3.5924]),
rtol=1e-02,
)
)
Expand Down
26 changes: 9 additions & 17 deletions topomodelx/nn/cell/can.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,10 @@ class CAN(torch.nn.Module):
Number of input channels for the edge-level input.
out_channels : int
Number of output channels.
num_classes : int
Number of output classes.
dropout : float, optional
Dropout probability. Default is 0.5.
heads : int, optional
Number of attention heads. Default is 3.
Number of attention heads. Default is 2.
concat : bool, optional
Whether to concatenate the output channels of attention heads. Default is True.
skip_connection : bool, optional
Expand All @@ -33,6 +31,8 @@ class CAN(torch.nn.Module):
Number of CAN layers.
att_lift : bool, default=True
Whether to apply a lift the signal from node-level to edge-level input.
k_pool : float, default=0.5
The pooling ratio i.e, the fraction of r-cells to keep after the pooling operation.

References
----------
Expand All @@ -47,14 +47,14 @@ def __init__(
in_channels_0,
in_channels_1,
out_channels,
num_classes,
dropout=0.5,
heads=3,
heads=2,
gbg141 marked this conversation as resolved.
Show resolved Hide resolved
concat=True,
skip_connection=True,
att_activation=torch.nn.LeakyReLU(0.2),
n_layers=2,
att_lift=True,
k_pool=0.5,
):
super().__init__()

Expand Down Expand Up @@ -98,16 +98,14 @@ def __init__(

layers.append(
PoolLayer(
k_pool=0.5,
k_pool=k_pool,
in_channels_0=out_channels * heads,
signal_pool_activation=torch.nn.Sigmoid(),
readout=True,
)
)

self.layers = torch.nn.ModuleList(layers)
self.lin_0 = torch.nn.Linear(heads * out_channels, 128)
self.lin_1 = torch.nn.Linear(128, num_classes)

def forward(
self, x_0, x_1, neighborhood_0_to_0, lower_neighborhood, upper_neighborhood
Expand All @@ -129,8 +127,8 @@ def forward(

Returns
-------
torch.Tensor
Output prediction for the cell complex.
torch.Tensor, shape = (num_pooled_edges, heads * out_channels)
Final hidden representations of pooled edges.
"""
if hasattr(self, "lift_layer"):
x_1 = self.lift_layer(x_0, neighborhood_0_to_0, x_1)
Expand All @@ -144,10 +142,4 @@ def forward(
x_1 = layer(x_1, lower_neighborhood, upper_neighborhood)
x_1 = F.dropout(x_1, p=0.5, training=self.training)

# max pooling over all nodes in each graph
x = x_1.max(dim=0)[0]

# Feed-Foward Neural Network to predict the graph label
out = self.lin_1(torch.nn.functional.relu(self.lin_0(x)))

return out
return x_1
24 changes: 11 additions & 13 deletions topomodelx/nn/cell/can_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ class PoolLayer(MessagePassing):
Parameters
----------
k_pool : float in (0, 1]
The pooling ratio i.e, the fraction of edges to keep after the pooling operation.
The pooling ratio i.e, the fraction of r-cells to keep after the pooling operation.
in_channels_0 : int
Number of input channels of the input signal.
signal_pool_activation : Callable
Expand Down Expand Up @@ -323,14 +323,14 @@ def reset_parameters(self) -> None:
init.xavier_uniform_(self.att_pool.data, gain=gain)

def forward( # type: ignore[override]
self, x_0, lower_neighborhood, upper_neighborhood
self, x, lower_neighborhood, upper_neighborhood
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
r"""Forward pass.

Parameters
----------
x_0 : torch.Tensor, shape = (num_nodes, in_channels_0)
Node signal.
x : torch.Tensor, shape = (n_r_cells, in_channels_r)
Input r-cell signal.
lower_neighborhood : torch.Tensor
Lower neighborhood matrix.
upper_neighborhood : torch.Tensor
Expand All @@ -339,7 +339,7 @@ def forward( # type: ignore[override]
Returns
-------
torch.Tensor
Pooled node signal of shape (num_pooled_nodes, in_channels_0).
Pooled r_cell signal of shape (n_r_cells, in_channels_r).

Notes
-----
Expand All @@ -351,21 +351,19 @@ def forward( # type: ignore[override]
= \phi^t(h_x^t, m_{x}^{(r)}), \forall x\in \mathcal C_r^{t+1}
\end{align*}
"""
# Compute the output edge signal by applying the activation function
Zp = torch.einsum("nc,ce->ne", x_0, self.att_pool)
# Apply top-k pooling to the edge signal
# Compute the output r-cell signal by applying the activation function
Zp = torch.einsum("nc,ce->ne", x, self.att_pool)
# Apply top-k pooling to the r-cell signal
_, top_indices = topk(Zp.view(-1), int(self.k_pool * Zp.size(0)))
# Rescale the pooled signal
Zp = self.signal_pool_activation(Zp)
out = x_0[top_indices] * Zp[top_indices]
out = x[top_indices] * Zp[top_indices]

# Readout operation
if self.readout:
out = scatter_add(out, top_indices, dim=0, dim_size=x_0.size(0))[
top_indices
]
out = scatter_add(out, top_indices, dim=0, dim_size=x.size(0))[top_indices]

# Update lower and upper neighborhood matrices with the top-k pooled edges
# Update lower and upper neighborhood matrices with the top-k pooled r-cells
lower_neighborhood_modified = torch.index_select(
lower_neighborhood, 0, top_indices
)
Expand Down
33 changes: 8 additions & 25 deletions topomodelx/nn/cell/ccxn.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@ class CCXN(torch.nn.Module):
Dimension of input features on edges.
in_channels_2 : int
Dimension of input features on faces.
num_classes : int
Number of classes.
n_layers : int
Number of CCXN layers.
att : bool
Expand All @@ -36,7 +34,6 @@ def __init__(
in_channels_0,
in_channels_1,
in_channels_2,
num_classes,
n_layers=2,
att=False,
):
Expand All @@ -52,12 +49,9 @@ def __init__(
)
)
self.layers = torch.nn.ModuleList(layers)
self.lin_0 = torch.nn.Linear(in_channels_0, num_classes)
self.lin_1 = torch.nn.Linear(in_channels_1, num_classes)
self.lin_2 = torch.nn.Linear(in_channels_2, num_classes)

def forward(self, x_0, x_1, neighborhood_0_to_0, neighborhood_1_to_2):
"""Forward computation through layers, then linear layers, then avg pooling.
"""Forward computation through layers.

Parameters
----------
Expand All @@ -72,24 +66,13 @@ def forward(self, x_0, x_1, neighborhood_0_to_0, neighborhood_1_to_2):

Returns
-------
torch.Tensor, shape = (1)
Label assigned to whole complex.
x_0 : torch.Tensor, shape = (n_nodes, in_channels_0)
Final hidden states of the nodes (0-cells).
x_1 : torch.Tensor, shape = (n_edges, in_channels_1)
Final hidden states the edges (1-cells).
x_2 : torch.Tensor, shape = (n_faces, in_channels_2)
Final hidden states of the faces (2-cells).
"""
for layer in self.layers:
x_0, x_1, x_2 = layer(x_0, x_1, neighborhood_0_to_0, neighborhood_1_to_2)
x_0 = self.lin_0(x_0)
x_1 = self.lin_1(x_1)
x_2 = self.lin_2(x_2)
# Take the average of the 2D, 1D, and 0D cell features. If they are NaN, convert them to 0.
two_dimensional_cells_mean = torch.nanmean(x_2, dim=0)
two_dimensional_cells_mean[torch.isnan(two_dimensional_cells_mean)] = 0
one_dimensional_cells_mean = torch.nanmean(x_1, dim=0)
one_dimensional_cells_mean[torch.isnan(one_dimensional_cells_mean)] = 0
zero_dimensional_cells_mean = torch.nanmean(x_0, dim=0)
zero_dimensional_cells_mean[torch.isnan(zero_dimensional_cells_mean)] = 0
# Return the sum of the averages
return (
two_dimensional_cells_mean
+ one_dimensional_cells_mean
+ zero_dimensional_cells_mean
)
return (x_0, x_1, x_2)
36 changes: 7 additions & 29 deletions topomodelx/nn/cell/cwn.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@ class CWN(torch.nn.Module):
Dimension of input features on faces (2-cells).
hid_channels : int
Dimension of hidden features.
num_classes : int
Number of classes.
n_layers : int
Number of CWN layers.

Expand All @@ -38,7 +36,6 @@ def __init__(
in_channels_1,
in_channels_2,
hid_channels,
num_classes,
n_layers,
):
super().__init__()
Expand All @@ -58,10 +55,6 @@ def __init__(
)
self.layers = torch.nn.ModuleList(layers)

self.lin_0 = torch.nn.Linear(hid_channels, num_classes)
self.lin_1 = torch.nn.Linear(hid_channels, num_classes)
self.lin_2 = torch.nn.Linear(hid_channels, num_classes)

def forward(
self,
x_0,
Expand Down Expand Up @@ -90,8 +83,12 @@ def forward(

Returns
-------
torch.Tensor, shape = (1)
Label assigned to whole complex.
x_0 : torch.Tensor, shape = (n_nodes, in_channels_0)
Final hidden states of the nodes (0-cells).
x_1 : torch.Tensor, shape = (n_edges, in_channels_1)
Final hidden states the edges (1-cells).
x_2 : torch.Tensor, shape = (n_edges, in_channels_2)
Final hidden states of the faces (2-cells).
"""
x_0 = F.elu(self.proj_0(x_0))
x_1 = F.elu(self.proj_1(x_1))
Expand All @@ -107,23 +104,4 @@ def forward(
neighborhood_0_to_1,
)

x_0 = self.lin_0(x_0)
x_1 = self.lin_1(x_1)
x_2 = self.lin_2(x_2)

# Take the average of the 2D, 1D, and 0D cell features. If they are NaN, convert them to 0.
two_dimensional_cells_mean = torch.nanmean(x_2, dim=0)
two_dimensional_cells_mean[torch.isnan(two_dimensional_cells_mean)] = 0

one_dimensional_cells_mean = torch.nanmean(x_1, dim=0)
one_dimensional_cells_mean[torch.isnan(one_dimensional_cells_mean)] = 0

zero_dimensional_cells_mean = torch.nanmean(x_0, dim=0)
zero_dimensional_cells_mean[torch.isnan(zero_dimensional_cells_mean)] = 0

# Return the sum of the averages
return (
two_dimensional_cells_mean
+ one_dimensional_cells_mean
+ zero_dimensional_cells_mean
)
return x_0, x_1, x_2
Loading
Loading