diff --git a/topomodelx/base/aggregation.py b/topomodelx/base/aggregation.py index 56b15a98..4f406ee7 100644 --- a/topomodelx/base/aggregation.py +++ b/topomodelx/base/aggregation.py @@ -9,10 +9,10 @@ class Aggregation(torch.nn.Module): Parameters ---------- - aggr_func : string + aggr_func : Literal["mean", "sum"], default="sum" Aggregation method. (Inter-neighborhood). - update_func : string + update_func : Literal["relu", "sigmoid", "tanh", None], default="sigmoid" Update method to apply to merged message. """ diff --git a/topomodelx/base/conv.py b/topomodelx/base/conv.py index 89467df3..d31dd711 100644 --- a/topomodelx/base/conv.py +++ b/topomodelx/base/conv.py @@ -1,4 +1,5 @@ """Convolutional layer for message passing.""" +from typing import Literal import torch from torch.nn.parameter import Parameter @@ -20,11 +21,11 @@ class Conv(MessagePassing): Dimension of output features. aggr_norm : bool, default=False Whether to normalize the aggregated message by the neighborhood size. - update_func : string, optional + update_func : Literal["relu", "sigmoid"], optional Update method to apply to message. att : bool, default=False Whether to use attention. - initialization : string + initialization : Literal["xavier_uniform", "xavier_normal"], default="xavier_uniform" Initialization method. with_linear_transform: bool Whether to apply a learnable linear transform. @@ -36,9 +37,10 @@ def __init__( in_channels, out_channels, aggr_norm: bool = False, - update_func=None, + update_func: Literal["relu", "sigmoid"] | None = None, att: bool = False, - initialization: str = "xavier_uniform", + initialization: Literal["xavier_uniform", "xavier_normal"] = "xavier_uniform", + initialization_gain: float = 1.414, with_linear_transform: bool = True, ) -> None: super().__init__( diff --git a/topomodelx/base/message_passing.py b/topomodelx/base/message_passing.py index b51b5134..81515988 100644 --- a/topomodelx/base/message_passing.py +++ b/topomodelx/base/message_passing.py @@ -1,5 +1,5 @@ """Message passing module.""" - +from typing import Literal import torch @@ -23,11 +23,11 @@ class MessagePassing(torch.nn.Module): Parameters ---------- - aggr_func : string + aggr_func : Literal["sum", "mean", "add"], default="sum" Aggregation function to use. att : bool, default=False Whether to use attention. - initialization : string + initialization : Literal["xavier_uniform", "xavier_normal"], default="xavier_uniform" Initialization method for the weights of the layer. References @@ -43,9 +43,10 @@ class MessagePassing(torch.nn.Module): def __init__( self, - aggr_func: str = "sum", + aggr_func: Literal["sum", "mean", "add"] = "sum", att: bool = False, - initialization: str = "xavier_uniform", + initialization: Literal["xavier_uniform", "xavier_normal"] = "xavier_uniform", + initialization_gain: float = 1.414, ) -> None: super().__init__() self.aggr_func = aggr_func @@ -113,7 +114,7 @@ def message(self, x_source, x_target=None): def attention(self, x_source, x_target=None): """Compute attention weights for messages. - This provides a default attention function to the message passing scheme. + This provides a default attention function to the message-passing scheme. Alternatively, users can subclass MessagePassing and overwrite the attention method in order to replace it with their own attention mechanism. diff --git a/topomodelx/nn/cell/can_layer.py b/topomodelx/nn/cell/can_layer.py index 6c2e5d16..4e0edb87 100644 --- a/topomodelx/nn/cell/can_layer.py +++ b/topomodelx/nn/cell/can_layer.py @@ -1,6 +1,6 @@ """Cell Attention Network layer.""" -from typing import Callable +from typing import Callable, Literal import torch from torch import Tensor, nn, topk @@ -396,9 +396,9 @@ class MultiHeadCellAttention(MessagePassing): Activation function to use for the attention weights. add_self_loops : bool, optional Whether to add self-loops to the adjacency matrix. - aggr_func : string, optional - Aggregation function to use. Options are "sum", "mean", "max". - initialization : string, optional + aggr_func : Literal["sum", "mean", "add"], default="sum" + Aggregation function to use. + initialization : Literal["xavier_uniform", "xavier_normal"], default="xavier_uniform" Initialization method for the weights of the layer. Notes @@ -423,14 +423,10 @@ def __init__( concat: bool, att_activation: torch.nn.Module, add_self_loops: bool = False, - aggr_func: str = "sum", - initialization: str = "xavier_uniform", + aggr_func: Literal["sum", "mean", "add"] = "sum", + initialization: Literal["xavier_uniform", "xavier_normal"] = "xavier_uniform", ) -> None: - super().__init__( - att=True, - initialization=initialization, - aggr_func=aggr_func, - ) + super().__init__(aggr_func=aggr_func, att=True, initialization=initialization) self.in_channels = in_channels self.out_channels = out_channels @@ -587,9 +583,9 @@ class MultiHeadCellAttention_v2(MessagePassing): Activation function to use for the attention weights. add_self_loops : bool, optional Whether to add self-loops to the adjacency matrix. - aggr_func : string, optional - Aggregation function to use. Options are "sum", "mean", "max". - initialization : string, optional + aggr_func : Literal["sum", "mean", "add"], default="sum" + Aggregation function to use. + initialization : Literal["xavier_uniform", "xavier_normal"], default="xavier_uniform" Initialization method for the weights of the layer. share_weights : bool, optional Whether to share the weights between the attention heads. @@ -616,14 +612,14 @@ def __init__( concat: bool, att_activation: torch.nn.Module, add_self_loops: bool = False, - aggr_func: str = "sum", - initialization: str = "xavier_uniform", + aggr_func: Literal["sum", "mean", "add"] = "sum", + initialization: Literal["xavier_uniform", "xavier_normal"] = "xavier_uniform", share_weights: bool = False, ) -> None: super().__init__( + aggr_func=aggr_func, att=True, initialization=initialization, - aggr_func=aggr_func, ) self.in_channels = in_channels @@ -797,28 +793,31 @@ class CANLayer(torch.nn.Module): Dimension of input features on n-cells. out_channels : int Dimension of output - heads : int, optional - Number of attention heads, by default 1 + heads : int, default=1 + Number of attention heads dropout : float, optional - Dropout probability of the normalized attention coefficients, by default 0.0 - concat : bool, optional - If True, the output of each head is concatenated. Otherwise, the output of each head is averaged, by default True - skip_connection : bool, optional - If True, skip connection is added, by default True + Dropout probability of the normalized attention coefficients. + concat : bool, default=True + If True, the output of each head is concatenated. Otherwise, the output of each head is averaged. + skip_connection : bool, default=True + If True, skip connection is added. add_self_loops : bool, optional - If True, self-loops are added to the neighborhood matrix, by default False - att_activation : Callable, optional - Activation function applied to the attention coefficients, by default torch.nn.LeakyReLU() - aggr_func : str, optional - Between-neighborhood aggregation function applied to the messages, by default "sum" - update_func : str, optional - Update function applied to the messages, by default "relu" - version : str, optional + If True, self-loops are added to the neighborhood matrix. + att_activation : Callable, default=torch.nn.LeakyReLU() + Activation function applied to the attention coefficients. + aggr_func : Literal["mean", "sum"], default="sum" + Between-neighborhood aggregation function applied to the messages. + update_func : Literal["relu", "sigmoid", "tanh", None], default="relu" + Update function applied to the messages. + version : Literal["v1", "v2"], default="v1" Version of the layer, by default "v1" which is the same as the original CAN layer. While "v2" has the same attetion mechanism as the GATv2 layer. - share_weights : bool, optional + share_weights : bool, default=False This option is valid only for "v2". If True, the weights of the linear transformation applied to the source and target features are shared, by default False """ + lower_att: MultiHeadCellAttention | MultiHeadCellAttention_v2 + upper_att: MultiHeadCellAttention | MultiHeadCellAttention_v2 + def __init__( self, in_channels: int, @@ -829,9 +828,9 @@ def __init__( skip_connection: bool = True, att_activation: torch.nn.Module = torch.nn.LeakyReLU(), add_self_loops: bool = False, - aggr_func: str = "sum", - update_func: str = "relu", - version: str = "v1", + aggr_func: Literal["mean", "sum"] = "sum", + update_func: Literal["relu", "sigmoid", "tanh"] | None = "relu", + version: Literal["v1", "v2"] = "v1", share_weights: bool = False, **kwargs, ) -> None: @@ -841,7 +840,7 @@ def __init__( assert out_channels > 0, ValueError("Number of output channels must be > 0") assert heads > 0, ValueError("Number of heads must be > 0") assert dropout >= 0.0 and dropout <= 1.0, ValueError("Dropout must be in [0,1]") - assert version in ["v1", "v2"], ValueError("Version must be 'v1' or 'v2'") + # assert that shared weight is True only if version is v2 assert share_weights is False or version == "v2", ValueError( "Shared weights is valid only for v2" diff --git a/topomodelx/nn/cell/can_layer_bis.py b/topomodelx/nn/cell/can_layer_bis.py index a6496495..7a0c4ee7 100644 --- a/topomodelx/nn/cell/can_layer_bis.py +++ b/topomodelx/nn/cell/can_layer_bis.py @@ -1,4 +1,6 @@ """Cellular Attention Network Layer.""" +from typing import Literal + import torch from torch.nn.parameter import Parameter @@ -50,23 +52,24 @@ class CANLayer(torch.nn.Module): ---------- channels : int Dimension of input features on edges (1-cells). - activation : string - Activation function to apply to merged message + activation : Literal["relu", "sigmoid", "tanh", None], default="sigmoid" + Activation function to apply to merged message. att : bool Whether to use attention. eps : float Epsilon used in the attention mechanism. - initialization : string + initialization : Literal["xavier_uniform", "xavier_normal"], default="xavier_uniform" Initialization method. """ def __init__( self, channels, - activation: str = "sigmoid", + activation: Literal["relu", "sigmoid", "tanh"] | None = "sigmoid", att: bool = True, eps: float = 1e-5, - initialization: str = "xavier_uniform", + initialization: Literal["xavier_uniform", "xavier_normal"] = "xavier_uniform", + initialization_gain: float = 1.414, ) -> None: super().__init__() # Do I need upper and lower convolution layers? Since I think they will have different parameters diff --git a/topomodelx/nn/hypergraph/allset_layer.py b/topomodelx/nn/hypergraph/allset_layer.py index 6b4ccad7..4f945fb8 100644 --- a/topomodelx/nn/hypergraph/allset_layer.py +++ b/topomodelx/nn/hypergraph/allset_layer.py @@ -135,6 +135,9 @@ class AllSetBlock(nn.Module): Type of layer normalization in the MLP. """ + encoder: nn.Module + decoder: nn.Module + def __init__( self, in_channels, @@ -241,7 +244,7 @@ def __init__( bias: bool = False, ) -> None: params = {} if inplace is None else {"inplace": inplace} - layers = [] + layers: list[nn.Module] = [] in_dim = in_channels for hidden_dim in hidden_channels[:-1]: layers.append(nn.Linear(in_dim, hidden_dim, bias=bias)) diff --git a/topomodelx/nn/hypergraph/allset_transformer_layer.py b/topomodelx/nn/hypergraph/allset_transformer_layer.py index adec3000..fc00cbaf 100644 --- a/topomodelx/nn/hypergraph/allset_transformer_layer.py +++ b/topomodelx/nn/hypergraph/allset_transformer_layer.py @@ -1,4 +1,6 @@ """AllSetTransformer Layer Module.""" +from typing import Literal + import torch import torch.nn.functional as F from torch import nn @@ -273,7 +275,7 @@ class MultiHeadAttention(MessagePassing): Number of attention heads. number_queries : int, default=1 Number of queries. - initialization : str, default="xavier_uniform" + initialization : Literal["xavier_uniform", "xavier_normal"], default="xavier_uniform" Initialization method. """ @@ -285,7 +287,8 @@ def __init__( update_func=None, heads: int = 4, number_queries: int = 1, - initialization: str = "xavier_uniform", + initialization: Literal["xavier_uniform", "xavier_normal"] = "xavier_uniform", + initialization_gain: float = 1.414, ) -> None: super().__init__( att=True, @@ -429,7 +432,7 @@ def __init__( bias: bool = False, ) -> None: params = {} if inplace is None else {"inplace": inplace} - layers = [] + layers: list[nn.Module] = [] in_dim = in_channels for hidden_dim in hidden_channels[:-1]: layers.append(nn.Linear(in_dim, hidden_dim, bias=bias)) diff --git a/topomodelx/nn/hypergraph/hmpnn_layer.py b/topomodelx/nn/hypergraph/hmpnn_layer.py index 34debbd4..8321ca45 100644 --- a/topomodelx/nn/hypergraph/hmpnn_layer.py +++ b/topomodelx/nn/hypergraph/hmpnn_layer.py @@ -1,4 +1,6 @@ """HMPNN (Hypergraph Message Passing Neural Network) Layer introduced in Heydari et Livi 2022.""" +from typing import Literal + import torch from torch import nn from torch.nn import functional as F @@ -21,7 +23,10 @@ def apply_dropout(self, neighborhood, dropout_rate: float): class _NodeToHyperedgeMessenger(MessagePassing, _AdjacencyDropoutMixin): def __init__( - self, messaging_func, adjacency_dropout: float = 0.7, aggr_func: str = "sum" + self, + messaging_func, + adjacency_dropout: float = 0.7, + aggr_func: Literal["sum", "mean", "add"] = "sum", ) -> None: super().__init__(aggr_func) self.messaging_func = messaging_func @@ -46,7 +51,7 @@ def __init__( self, messaging_func, adjacency_dropout: float = 0.7, - aggr_func: str = "sum", + aggr_func: Literal["sum", "mean", "add"] = "sum", ) -> None: super().__init__(aggr_func) self.messaging_func = messaging_func @@ -116,8 +121,8 @@ class HMPNNLayer(nn.Module): to the paper. adjacency_dropout: 0.7 Adjacency dropout rate. - aggr_func: "sum" - Message aggregation function. A value among "sum", "mean" and "add". + aggr_func: Literal["sum", "mean", "add"], default="sum" + Message aggregation function. updating_dropout: 0.5 Regular dropout rate applied to node and hyperedge features. updating_func: None @@ -132,7 +137,7 @@ def __init__( node_to_hyperedge_messaging_func=None, hyperedge_to_node_messaging_func=None, adjacency_dropout: float = 0.7, - aggr_func: str = "sum", + aggr_func: Literal["sum", "mean", "add"] = "sum", updating_dropout: float = 0.5, updating_func=None, ) -> None: diff --git a/topomodelx/nn/hypergraph/hnhn_layer.py b/topomodelx/nn/hypergraph/hnhn_layer.py index 4660aded..a07772c4 100644 --- a/topomodelx/nn/hypergraph/hnhn_layer.py +++ b/topomodelx/nn/hypergraph/hnhn_layer.py @@ -1,4 +1,6 @@ """Template Layer with two conv passing steps.""" +from typing import Literal + import torch from torch.nn.parameter import Parameter @@ -48,7 +50,7 @@ class HNHNLayer(torch.nn.Module): Scalar controlling the importance of node cardinality. bias_gain : float Gain for the bias initialization. - bias_init : string ["xavier_uniform"|"xavier_normal"] + bias_init : Literal["xavier_uniform", "xavier_normal"], default="xavier_uniform" Controls the bias initialization method. """ @@ -62,7 +64,7 @@ def __init__( alpha: float = -1.5, beta: float = -0.5, bias_gain: float = 1.414, - bias_init: str = "xavier_uniform", + bias_init: Literal["xavier_uniform", "xavier_normal"] = "xavier_uniform", ) -> None: super().__init__() self.use_bias = use_bias diff --git a/topomodelx/nn/hypergraph/hypergat_layer.py b/topomodelx/nn/hypergraph/hypergat_layer.py index 08992d48..e90e4aae 100644 --- a/topomodelx/nn/hypergraph/hypergat_layer.py +++ b/topomodelx/nn/hypergraph/hypergat_layer.py @@ -1,4 +1,6 @@ """HyperGAT layer.""" +from typing import Literal + import torch from topomodelx.base.message_passing import MessagePassing @@ -21,8 +23,8 @@ class HyperGATLayer(MessagePassing): Dimension of the output features. update_func : string Update method to apply to message. Default is "relu". - initialization : string - Initialization method. Default is "xavier_uniform". + initialization : Literal["xavier_uniform", "xavier_normal"], default="xavier_uniform" + Initialization method. """ def __init__( @@ -30,7 +32,8 @@ def __init__( in_channels, out_channels, update_func: str = "relu", - initialization: str = "xavier_uniform", + initialization: Literal["xavier_uniform", "xavier_normal"] = "xavier_uniform", + initialization_gain: float = 1.414, ) -> None: super().__init__(initialization=initialization) self.in_channels = in_channels @@ -67,7 +70,12 @@ def reset_parameters(self, gain: float = 1.414): "Should be either xavier_uniform or xavier_normal." ) - def attention(self, x_source, x_target=None, mechanism: str = "node-level"): + def attention( + self, + x_source, + x_target=None, + mechanism: Literal["node-level", "edge-level"] = "node-level", + ): r"""Compute attention weights for messages, as proposed in [DWLLL20]. Parameters @@ -78,7 +86,7 @@ def attention(self, x_source, x_target=None, mechanism: str = "node-level"): x_target : torch.Tensor, shape=[n_target_cells, in_channels] Input features on source cells. Assumes that all source cells have the same rank r. - mechanism: string + mechanism: Literal["node-level", "edge-level"] Attention mechanism as proposed in [DWLLL20]. If set to "node-level", will compute node-level attention, if set to "edge-level", will compute edge-level attention (see [DWLLL20]). Default is "node-level". diff --git a/topomodelx/nn/hypergraph/hypersage_layer.py b/topomodelx/nn/hypergraph/hypersage_layer.py index 76a21e10..54ac8a78 100644 --- a/topomodelx/nn/hypergraph/hypersage_layer.py +++ b/topomodelx/nn/hypergraph/hypersage_layer.py @@ -1,5 +1,6 @@ """HyperSAGE layer.""" import math +from typing import Literal import torch @@ -12,12 +13,12 @@ class GeneralizedMean(Aggregation): Parameters ---------- - power : int. - Power for the generalized mean. Default is 2. + power : int, default=2 + Power for the generalized mean. """ def __init__(self, power: int = 2, **kwargs) -> None: - super().__init__(aggr_func="generalized_mean", **kwargs) + super().__init__(**kwargs) self.power = power def forward(self, x: torch.Tensor): @@ -28,7 +29,7 @@ def forward(self, x: torch.Tensor): x : torch.Tensor """ n = x.size()[-2] - x = torch.sum(torch.pow(x, self.power), axis=-2) / n + x = torch.sum(torch.pow(x, self.power), -2) / n x = torch.pow(x, 1 / self.power) return x @@ -49,16 +50,16 @@ class HyperSAGELayer(MessagePassing): Dimension of the input features. out_channels : int Dimension of the output features. - aggr_func_intra: Aggregation + aggr_func_intra : Aggregation Aggregation function. Default is GeneralizedMean(p=2). - aggr_func_inter: Aggregation + aggr_func_inter : Aggregation Aggregation function. Default is GeneralizedMean(p=2). - update_func : string - Update method to apply to message. Default is "relu". - initialization : string - Initialization method. Default is "uniform". - device : string - Device name to train layer on. Default is "cpu". + update_func : Literal["relu", "sigmoid"], default="relu" + Update method to apply to message. + initialization : Literal["uniform", "xavier_uniform", "xavier_normal"], default="uniform" + Initialization method. + device : str, default="cpu" + Device name to train layer on. """ def __init__( @@ -67,8 +68,10 @@ def __init__( out_channels: int, aggr_func_intra: Aggregation = GeneralizedMean(power=2, update_func=None), aggr_func_inter: Aggregation = GeneralizedMean(power=2, update_func=None), - update_func: str = "relu", - initialization: str = "uniform", + update_func: Literal["relu", "sigmoid"] = "relu", + initialization: Literal[ + "uniform", "xavier_uniform", "xavier_normal" + ] = "uniform", device: str = "cpu", ) -> None: super().__init__() @@ -101,7 +104,7 @@ def reset_parameters(self): ) def update( - self, x_message_on_target: torch.Tensor, x_target: torch.Tensor = None + self, x_message_on_target: torch.Tensor, x_target: torch.Tensor | None = None ) -> torch.Tensor: r"""Update embeddings on each node (step 4). @@ -119,6 +122,7 @@ def update( return torch.nn.functional.sigmoid(x_message_on_target) if self.update_func == "relu": return torch.nn.functional.relu(x_message_on_target) + raise RuntimeError("Update function not recognized.") def aggregate(self, x_messages: torch.Tensor, mode: str = "intra"): """Aggregate messages on each target cell. diff --git a/topomodelx/nn/hypergraph/unisage_layer.py b/topomodelx/nn/hypergraph/unisage_layer.py index 9ea4053c..a45cec65 100644 --- a/topomodelx/nn/hypergraph/unisage_layer.py +++ b/topomodelx/nn/hypergraph/unisage_layer.py @@ -1,4 +1,6 @@ """Implementation of UniSAGE layer from Huang et. al.: UniGNN: a Unified Framework for Graph and Hypergraph Neural Networks.""" +from typing import Literal + import torch @@ -20,12 +22,10 @@ class UniSAGELayer(torch.nn.Module): Dimension of input features. out_channels : int Dimension of output features. - e_aggr : string - Aggregator function for hyperedges. Defaults to "sum", other options are - "sum", "mean", "amax", or "amin". - v_aggr : string - Aggregator function for nodes. Defaults to " mean", other options are - "sum", "mean", "amax", or "amin". + e_aggr : Literal["sum", "mean", "amax", "amin"], default="sum" + Aggregator function for hyperedges. + v_aggr : Literal["sum", "mean", "amax", "amin"], default="mean" + Aggregator function for nodes. use_bn : boolean Whether to use bathnorm after the linear transformation. """ @@ -40,8 +40,8 @@ def __init__( self, in_channels, out_channels, - e_aggr: str = "sum", - v_aggr: str = "mean", + e_aggr: Literal["sum", "mean", "amax", "amin"] = "sum", + v_aggr: Literal["sum", "mean", "amax", "amin"] = "mean", use_bn: bool = False, ) -> None: super().__init__() diff --git a/topomodelx/nn/simplicial/hsn_layer.py b/topomodelx/nn/simplicial/hsn_layer.py index 46931b18..51b93852 100644 --- a/topomodelx/nn/simplicial/hsn_layer.py +++ b/topomodelx/nn/simplicial/hsn_layer.py @@ -25,8 +25,6 @@ class HSNLayer(torch.nn.Module): ---------- channels : int Dimension of features on each simplicial cell. - initialization : string - Initialization method. """ def __init__( diff --git a/topomodelx/nn/simplicial/san_layer.py b/topomodelx/nn/simplicial/san_layer.py index 63cc5ac4..8ac90ad9 100644 --- a/topomodelx/nn/simplicial/san_layer.py +++ b/topomodelx/nn/simplicial/san_layer.py @@ -1,4 +1,6 @@ """Simplicial Attention Network (SAN) Layer.""" +from typing import Literal + import torch from torch.nn.parameter import Parameter @@ -16,8 +18,8 @@ class SANConv(Conv): Number of output channels. p_filters : int Number of simplicial filters. - initialization : str, optional - Weight initialization method. Defaults to "xavier_uniform". + initialization : Literal["xavier_uniform", "xavier_normal"], default="xavier_uniform" + Weight initialization method. """ def __init__( @@ -25,7 +27,7 @@ def __init__( in_channels, out_channels, p_filters, - initialization: str = "xavier_uniform", + initialization: Literal["xavier_uniform", "xavier_normal"] = "xavier_uniform", ) -> None: super(Conv, self).__init__( att=True, diff --git a/topomodelx/nn/simplicial/sccn_layer.py b/topomodelx/nn/simplicial/sccn_layer.py index e8a95e98..5fdf8fc9 100644 --- a/topomodelx/nn/simplicial/sccn_layer.py +++ b/topomodelx/nn/simplicial/sccn_layer.py @@ -1,4 +1,6 @@ """Simplicial Complex Convolutional Network (SCCN) Layer [Yang et al. LoG 2022].""" +from typing import Literal + import torch from topomodelx.base.aggregation import Aggregation @@ -37,9 +39,9 @@ class SCCNLayer(torch.nn.Module): Dimension of features on each simplicial cell. max_rank : int Maximum rank of the cells in the simplicial complex. - aggr_func : str + aggr_func : Literal["mean", "sum"], default="sum" The function to be used for aggregation. - update_func : str + update_func : Literal["relu", "sigmoid", "tanh", None], default="sigmoid" The activation function. """ @@ -47,8 +49,8 @@ def __init__( self, channels, max_rank, - aggr_func: str = "sum", - update_func: str = "sigmoid", + aggr_func: Literal["mean", "sum"] = "sum", + update_func: Literal["relu", "sigmoid", "tanh"] | None = "sigmoid", ) -> None: super().__init__() self.channels = channels diff --git a/topomodelx/nn/simplicial/scone_layer.py b/topomodelx/nn/simplicial/scone_layer.py index 30ed6a78..22afbff8 100644 --- a/topomodelx/nn/simplicial/scone_layer.py +++ b/topomodelx/nn/simplicial/scone_layer.py @@ -1,4 +1,6 @@ """Simplicial Complex Net Layer.""" +from typing import Literal + import torch from topomodelx.base.aggregation import Aggregation @@ -27,12 +29,15 @@ class SCoNeLayer(torch.nn.Module): Input dimension of features on each edge. out_channels : int Output dimension of features on each edge. - update_func : string + update_func : Literal['relu', 'sigmoid', 'tanh'] Update function to use when updating edge features. """ def __init__( - self, in_channels: int, out_channels: int, update_func: str = "tanh" + self, + in_channels: int, + out_channels: int, + update_func: Literal["relu", "sigmoid", "tanh"] = "tanh", ) -> None: super().__init__() self.in_channels = in_channels diff --git a/topomodelx/nn/simplicial/scone_layer_bis.py b/topomodelx/nn/simplicial/scone_layer_bis.py index 864b61c7..701fa36b 100644 --- a/topomodelx/nn/simplicial/scone_layer_bis.py +++ b/topomodelx/nn/simplicial/scone_layer_bis.py @@ -24,8 +24,6 @@ class SCoNeLayer(torch.nn.Module): ---------- channels : int Dimension of features on each simplicial cell. - initialization : string - Initialization method. """ def __init__(self, channels) -> None: @@ -61,7 +59,7 @@ def reset_parameters(self) -> None: def forward(self, x_0, lap_up, lap_down, iden): r"""Forward pass. - The forward pass was initially proposes in [RGS21]_. + The forward pass was initially proposed in [RGS21]_. Its equations are given in [TNN23]_ and graphically illustrated in [PSHM23]_. .. math::