Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

240 review nncell and nncombinatorial models + Fixed SAN bugs #268

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ __pycache__/
docs/build/
data/
venv_*/
.DS*

TopoNetX/
topomodelx/nn/cell/attcxn_layer.py
Expand Down
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -121,3 +121,8 @@ checks = [
"EX01",
"SA01"
]
exclude = [
'\.undocumented_method$',
'\.__repr__$',
'\.__init__$',
]
24 changes: 8 additions & 16 deletions topomodelx/nn/cell/can.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@ class CAN(torch.nn.Module):
Number of input channels for the edge-level input.
out_channels : int
Number of output channels.
num_classes : int
Number of output classes.
dropout : float, optional
Dropout probability. Default is 0.5.
heads : int, optional
Expand All @@ -33,6 +31,8 @@ class CAN(torch.nn.Module):
Number of CAN layers.
att_lift : bool, default=True
Whether to apply a lift the signal from node-level to edge-level input.
k_pool : float, default=0.5
The pooling ratio i.e, the fraction of r-cells to keep after the pooling operation.

References
----------
Expand All @@ -47,14 +47,14 @@ def __init__(
in_channels_0,
in_channels_1,
out_channels,
num_classes,
dropout=0.5,
heads=3,
heads=2,
concat=True,
skip_connection=True,
att_activation=torch.nn.LeakyReLU(0.2),
n_layers=2,
att_lift=True,
k_pool=0.5,
):
super().__init__()

Expand Down Expand Up @@ -98,16 +98,14 @@ def __init__(

layers.append(
PoolLayer(
k_pool=0.5,
k_pool=k_pool,
in_channels_0=out_channels * heads,
signal_pool_activation=torch.nn.Sigmoid(),
readout=True,
)
)

self.layers = torch.nn.ModuleList(layers)
self.lin_0 = torch.nn.Linear(heads * out_channels, 128)
self.lin_1 = torch.nn.Linear(128, num_classes)

def forward(
self, x_0, x_1, neighborhood_0_to_0, lower_neighborhood, upper_neighborhood
Expand All @@ -129,8 +127,8 @@ def forward(

Returns
-------
torch.Tensor
Output prediction for the cell complex.
x_1: torch.Tensor, shape = (num_pooled_edges, heads * out_channels)
Final hidden representations of pooled edges.
"""
if hasattr(self, "lift_layer"):
x_1 = self.lift_layer(x_0, neighborhood_0_to_0, x_1)
Expand All @@ -144,10 +142,4 @@ def forward(
x_1 = layer(x_1, lower_neighborhood, upper_neighborhood)
x_1 = F.dropout(x_1, p=0.5, training=self.training)

# max pooling over all nodes in each graph
x = x_1.max(dim=0)[0]

# Feed-Foward Neural Network to predict the graph label
out = self.lin_1(torch.nn.functional.relu(self.lin_0(x)))

return out
return x_1
24 changes: 11 additions & 13 deletions topomodelx/nn/cell/can_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ class PoolLayer(MessagePassing):
Parameters
----------
k_pool : float in (0, 1]
The pooling ratio i.e, the fraction of edges to keep after the pooling operation.
The pooling ratio i.e, the fraction of r-cells to keep after the pooling operation.
in_channels_0 : int
Number of input channels of the input signal.
signal_pool_activation : Callable
Expand Down Expand Up @@ -323,14 +323,14 @@ def reset_parameters(self) -> None:
init.xavier_uniform_(self.att_pool.data, gain=gain)

def forward( # type: ignore[override]
self, x_0, lower_neighborhood, upper_neighborhood
self, x, lower_neighborhood, upper_neighborhood
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
r"""Forward pass.

Parameters
----------
x_0 : torch.Tensor, shape = (num_nodes, in_channels_0)
Node signal.
x : torch.Tensor, shape = (n_r_cells, in_channels_r)
Input r-cell signal.
lower_neighborhood : torch.Tensor
Lower neighborhood matrix.
upper_neighborhood : torch.Tensor
Expand All @@ -339,7 +339,7 @@ def forward( # type: ignore[override]
Returns
-------
torch.Tensor
Pooled node signal of shape (num_pooled_nodes, in_channels_0).
Pooled r_cell signal of shape (n_r_cells, in_channels_r).

Notes
-----
Expand All @@ -351,21 +351,19 @@ def forward( # type: ignore[override]
= \phi^t(h_x^t, m_{x}^{(r)}), \forall x\in \mathcal C_r^{t+1}
\end{align*}
"""
# Compute the output edge signal by applying the activation function
Zp = torch.einsum("nc,ce->ne", x_0, self.att_pool)
# Apply top-k pooling to the edge signal
# Compute the output r-cell signal by applying the activation function
Zp = torch.einsum("nc,ce->ne", x, self.att_pool)
# Apply top-k pooling to the r-cell signal
_, top_indices = topk(Zp.view(-1), int(self.k_pool * Zp.size(0)))
# Rescale the pooled signal
Zp = self.signal_pool_activation(Zp)
out = x_0[top_indices] * Zp[top_indices]
out = x[top_indices] * Zp[top_indices]

# Readout operation
if self.readout:
out = scatter_add(out, top_indices, dim=0, dim_size=x_0.size(0))[
top_indices
]
out = scatter_add(out, top_indices, dim=0, dim_size=x.size(0))[top_indices]

# Update lower and upper neighborhood matrices with the top-k pooled edges
# Update lower and upper neighborhood matrices with the top-k pooled r-cells
lower_neighborhood_modified = torch.index_select(
lower_neighborhood, 0, top_indices
)
Expand Down
33 changes: 8 additions & 25 deletions topomodelx/nn/cell/ccxn.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@ class CCXN(torch.nn.Module):
Dimension of input features on edges.
in_channels_2 : int
Dimension of input features on faces.
num_classes : int
Number of classes.
n_layers : int
Number of CCXN layers.
att : bool
Expand All @@ -36,7 +34,6 @@ def __init__(
in_channels_0,
in_channels_1,
in_channels_2,
num_classes,
n_layers=2,
att=False,
):
Expand All @@ -52,12 +49,9 @@ def __init__(
)
)
self.layers = torch.nn.ModuleList(layers)
self.lin_0 = torch.nn.Linear(in_channels_0, num_classes)
self.lin_1 = torch.nn.Linear(in_channels_1, num_classes)
self.lin_2 = torch.nn.Linear(in_channels_2, num_classes)

def forward(self, x_0, x_1, neighborhood_0_to_0, neighborhood_1_to_2):
"""Forward computation through layers, then linear layers, then avg pooling.
"""Forward computation through layers.

Parameters
----------
Expand All @@ -72,24 +66,13 @@ def forward(self, x_0, x_1, neighborhood_0_to_0, neighborhood_1_to_2):

Returns
-------
torch.Tensor, shape = (1)
Label assigned to whole complex.
x_0 : torch.Tensor, shape = (n_nodes, in_channels_0)
Final hidden states of the nodes (0-cells).
x_1 : torch.Tensor, shape = (n_edges, in_channels_1)
Final hidden states the edges (1-cells).
x_2 : torch.Tensor, shape = (n_faces, in_channels_2)
Final hidden states of the faces (2-cells).
"""
for layer in self.layers:
x_0, x_1, x_2 = layer(x_0, x_1, neighborhood_0_to_0, neighborhood_1_to_2)
x_0 = self.lin_0(x_0)
x_1 = self.lin_1(x_1)
x_2 = self.lin_2(x_2)
# Take the average of the 2D, 1D, and 0D cell features. If they are NaN, convert them to 0.
two_dimensional_cells_mean = torch.nanmean(x_2, dim=0)
two_dimensional_cells_mean[torch.isnan(two_dimensional_cells_mean)] = 0
one_dimensional_cells_mean = torch.nanmean(x_1, dim=0)
one_dimensional_cells_mean[torch.isnan(one_dimensional_cells_mean)] = 0
zero_dimensional_cells_mean = torch.nanmean(x_0, dim=0)
zero_dimensional_cells_mean[torch.isnan(zero_dimensional_cells_mean)] = 0
# Return the sum of the averages
return (
two_dimensional_cells_mean
+ one_dimensional_cells_mean
+ zero_dimensional_cells_mean
)
return x_0, x_1, x_2
36 changes: 7 additions & 29 deletions topomodelx/nn/cell/cwn.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@ class CWN(torch.nn.Module):
Dimension of input features on faces (2-cells).
hid_channels : int
Dimension of hidden features.
num_classes : int
Number of classes.
n_layers : int
Number of CWN layers.

Expand All @@ -38,7 +36,6 @@ def __init__(
in_channels_1,
in_channels_2,
hid_channels,
num_classes,
n_layers,
):
super().__init__()
Expand All @@ -58,10 +55,6 @@ def __init__(
)
self.layers = torch.nn.ModuleList(layers)

self.lin_0 = torch.nn.Linear(hid_channels, num_classes)
self.lin_1 = torch.nn.Linear(hid_channels, num_classes)
self.lin_2 = torch.nn.Linear(hid_channels, num_classes)

def forward(
self,
x_0,
Expand Down Expand Up @@ -90,8 +83,12 @@ def forward(

Returns
-------
torch.Tensor, shape = (1)
Label assigned to whole complex.
x_0 : torch.Tensor, shape = (n_nodes, in_channels_0)
Final hidden states of the nodes (0-cells).
x_1 : torch.Tensor, shape = (n_edges, in_channels_1)
Final hidden states the edges (1-cells).
x_2 : torch.Tensor, shape = (n_edges, in_channels_2)
Final hidden states of the faces (2-cells).
"""
x_0 = F.elu(self.proj_0(x_0))
x_1 = F.elu(self.proj_1(x_1))
Expand All @@ -107,23 +104,4 @@ def forward(
neighborhood_0_to_1,
)

x_0 = self.lin_0(x_0)
x_1 = self.lin_1(x_1)
x_2 = self.lin_2(x_2)

# Take the average of the 2D, 1D, and 0D cell features. If they are NaN, convert them to 0.
two_dimensional_cells_mean = torch.nanmean(x_2, dim=0)
two_dimensional_cells_mean[torch.isnan(two_dimensional_cells_mean)] = 0

one_dimensional_cells_mean = torch.nanmean(x_1, dim=0)
one_dimensional_cells_mean[torch.isnan(one_dimensional_cells_mean)] = 0

zero_dimensional_cells_mean = torch.nanmean(x_0, dim=0)
zero_dimensional_cells_mean[torch.isnan(zero_dimensional_cells_mean)] = 0

# Return the sum of the averages
return (
two_dimensional_cells_mean
+ one_dimensional_cells_mean
+ zero_dimensional_cells_mean
)
return x_0, x_1, x_2
34 changes: 13 additions & 21 deletions topomodelx/nn/combinatorial/hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,30 +18,31 @@ class HMC(torch.nn.Module):
for each input signal (nodes, edges, and faces) for the k-th layer. The second list
contains the number of intermediate channels for each input signal (nodes, edges, and
faces) for the k-th layer. Finally, the third list contains the number of output channels for
each input signal (nodes, edges, and faces) for the k-th layer .
num_classes : int
Number of classes.
each input signal (nodes, edges, and faces) for the k-th layer.
negative_slope : float
Negative slope for the LeakyReLU activation.
update_func_attention : str
Update function for the attention mechanism. Default is "relu".
update_func_aggregation : str
Update function for the aggregation mechanism. Default is "relu".
"""

def __init__(
self,
channels_per_layer,
num_classes,
negative_slope=0.2,
update_func_attention="relu",
update_func_aggregation="relu",
) -> None:
def check_channels_consistency():
"""Check that the number of input, intermediate, and output channels is consistent."""
assert len(channels_per_layer) > 0
for i in range(len(channels_per_layer) - 1):
assert channels_per_layer[i][2][0] == channels_per_layer[i + 1][0][0]
assert channels_per_layer[i][2][1] == channels_per_layer[i + 1][0][1]
assert channels_per_layer[i][2][2] == channels_per_layer[i + 1][0][2]

super().__init__()
self.num_classes = num_classes
check_channels_consistency()
self.layers = torch.nn.ModuleList(
[
Expand All @@ -58,10 +59,6 @@ def check_channels_consistency():
]
)

self.l0 = torch.nn.Linear(channels_per_layer[-1][2][0], num_classes)
self.l1 = torch.nn.Linear(channels_per_layer[-1][2][1], num_classes)
self.l2 = torch.nn.Linear(channels_per_layer[-1][2][2], num_classes)

def forward(
self,
x_0,
Expand Down Expand Up @@ -96,8 +93,12 @@ def forward(

Returns
-------
y_hat : torch.Tensor, shape=[num_classes]
Vector embedding that represents the probability of the input mesh to belong to each class.
x_0 : torch.Tensor, shape = (n_nodes, out_channels_0)
Final hidden states of the nodes (0-cells).
x_1 : torch.Tensor, shape = (n_edges, out_channels_1)
Final hidden states the edges (1-cells).
x_2 : torch.Tensor, shape = (n_faces, out_channels_2)
Final hidden states of the faces (2-cells).
"""
for layer in self.layers:
x_0, x_1, x_2 = layer(
Expand All @@ -111,13 +112,4 @@ def forward(
neighborhood_1_to_2,
)

x_0 = self.l0(x_0)
x_1 = self.l1(x_1)
x_2 = self.l2(x_2)

# Sum all the elements in the dimension zero
x_0 = torch.nanmean(x_0, dim=0)
x_1 = torch.nanmean(x_1, dim=0)
x_2 = torch.nanmean(x_2, dim=0)

return x_0 + x_1 + x_2
return x_0, x_1, x_2
Loading
Loading