Skip to content

Commit

Permalink
Merge pull request #181 from pyt-team/frantzen-typing
Browse files Browse the repository at this point in the history
More concrete types throughout the project
  • Loading branch information
ninamiolane authored Sep 5, 2023
2 parents dfbf508 + 79f7760 commit 86a0519
Show file tree
Hide file tree
Showing 17 changed files with 142 additions and 107 deletions.
4 changes: 2 additions & 2 deletions topomodelx/base/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@ class Aggregation(torch.nn.Module):
Parameters
----------
aggr_func : string
aggr_func : Literal["mean", "sum"], default="sum"
Aggregation method.
(Inter-neighborhood).
update_func : string
update_func : Literal["relu", "sigmoid", "tanh", None], default="sigmoid"
Update method to apply to merged message.
"""

Expand Down
10 changes: 6 additions & 4 deletions topomodelx/base/conv.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Convolutional layer for message passing."""
from typing import Literal

import torch
from torch.nn.parameter import Parameter
Expand All @@ -20,11 +21,11 @@ class Conv(MessagePassing):
Dimension of output features.
aggr_norm : bool, default=False
Whether to normalize the aggregated message by the neighborhood size.
update_func : string, optional
update_func : Literal["relu", "sigmoid"], optional
Update method to apply to message.
att : bool, default=False
Whether to use attention.
initialization : string
initialization : Literal["xavier_uniform", "xavier_normal"], default="xavier_uniform"
Initialization method.
with_linear_transform: bool
Whether to apply a learnable linear transform.
Expand All @@ -36,9 +37,10 @@ def __init__(
in_channels,
out_channels,
aggr_norm: bool = False,
update_func=None,
update_func: Literal["relu", "sigmoid"] | None = None,
att: bool = False,
initialization: str = "xavier_uniform",
initialization: Literal["xavier_uniform", "xavier_normal"] = "xavier_uniform",
initialization_gain: float = 1.414,
with_linear_transform: bool = True,
) -> None:
super().__init__(
Expand Down
13 changes: 7 additions & 6 deletions topomodelx/base/message_passing.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Message passing module."""

from typing import Literal

import torch

Expand All @@ -23,11 +23,11 @@ class MessagePassing(torch.nn.Module):
Parameters
----------
aggr_func : string
aggr_func : Literal["sum", "mean", "add"], default="sum"
Aggregation function to use.
att : bool, default=False
Whether to use attention.
initialization : string
initialization : Literal["xavier_uniform", "xavier_normal"], default="xavier_uniform"
Initialization method for the weights of the layer.
References
Expand All @@ -43,9 +43,10 @@ class MessagePassing(torch.nn.Module):

def __init__(
self,
aggr_func: str = "sum",
aggr_func: Literal["sum", "mean", "add"] = "sum",
att: bool = False,
initialization: str = "xavier_uniform",
initialization: Literal["xavier_uniform", "xavier_normal"] = "xavier_uniform",
initialization_gain: float = 1.414,
) -> None:
super().__init__()
self.aggr_func = aggr_func
Expand Down Expand Up @@ -113,7 +114,7 @@ def message(self, x_source, x_target=None):
def attention(self, x_source, x_target=None):
"""Compute attention weights for messages.
This provides a default attention function to the message passing scheme.
This provides a default attention function to the message-passing scheme.
Alternatively, users can subclass MessagePassing and overwrite
the attention method in order to replace it with their own attention mechanism.
Expand Down
73 changes: 36 additions & 37 deletions topomodelx/nn/cell/can_layer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Cell Attention Network layer."""

from typing import Callable
from typing import Callable, Literal

import torch
from torch import Tensor, nn, topk
Expand Down Expand Up @@ -396,9 +396,9 @@ class MultiHeadCellAttention(MessagePassing):
Activation function to use for the attention weights.
add_self_loops : bool, optional
Whether to add self-loops to the adjacency matrix.
aggr_func : string, optional
Aggregation function to use. Options are "sum", "mean", "max".
initialization : string, optional
aggr_func : Literal["sum", "mean", "add"], default="sum"
Aggregation function to use.
initialization : Literal["xavier_uniform", "xavier_normal"], default="xavier_uniform"
Initialization method for the weights of the layer.
Notes
Expand All @@ -423,14 +423,10 @@ def __init__(
concat: bool,
att_activation: torch.nn.Module,
add_self_loops: bool = False,
aggr_func: str = "sum",
initialization: str = "xavier_uniform",
aggr_func: Literal["sum", "mean", "add"] = "sum",
initialization: Literal["xavier_uniform", "xavier_normal"] = "xavier_uniform",
) -> None:
super().__init__(
att=True,
initialization=initialization,
aggr_func=aggr_func,
)
super().__init__(aggr_func=aggr_func, att=True, initialization=initialization)

self.in_channels = in_channels
self.out_channels = out_channels
Expand Down Expand Up @@ -587,9 +583,9 @@ class MultiHeadCellAttention_v2(MessagePassing):
Activation function to use for the attention weights.
add_self_loops : bool, optional
Whether to add self-loops to the adjacency matrix.
aggr_func : string, optional
Aggregation function to use. Options are "sum", "mean", "max".
initialization : string, optional
aggr_func : Literal["sum", "mean", "add"], default="sum"
Aggregation function to use.
initialization : Literal["xavier_uniform", "xavier_normal"], default="xavier_uniform"
Initialization method for the weights of the layer.
share_weights : bool, optional
Whether to share the weights between the attention heads.
Expand All @@ -616,14 +612,14 @@ def __init__(
concat: bool,
att_activation: torch.nn.Module,
add_self_loops: bool = False,
aggr_func: str = "sum",
initialization: str = "xavier_uniform",
aggr_func: Literal["sum", "mean", "add"] = "sum",
initialization: Literal["xavier_uniform", "xavier_normal"] = "xavier_uniform",
share_weights: bool = False,
) -> None:
super().__init__(
aggr_func=aggr_func,
att=True,
initialization=initialization,
aggr_func=aggr_func,
)

self.in_channels = in_channels
Expand Down Expand Up @@ -797,28 +793,31 @@ class CANLayer(torch.nn.Module):
Dimension of input features on n-cells.
out_channels : int
Dimension of output
heads : int, optional
Number of attention heads, by default 1
heads : int, default=1
Number of attention heads
dropout : float, optional
Dropout probability of the normalized attention coefficients, by default 0.0
concat : bool, optional
If True, the output of each head is concatenated. Otherwise, the output of each head is averaged, by default True
skip_connection : bool, optional
If True, skip connection is added, by default True
Dropout probability of the normalized attention coefficients.
concat : bool, default=True
If True, the output of each head is concatenated. Otherwise, the output of each head is averaged.
skip_connection : bool, default=True
If True, skip connection is added.
add_self_loops : bool, optional
If True, self-loops are added to the neighborhood matrix, by default False
att_activation : Callable, optional
Activation function applied to the attention coefficients, by default torch.nn.LeakyReLU()
aggr_func : str, optional
Between-neighborhood aggregation function applied to the messages, by default "sum"
update_func : str, optional
Update function applied to the messages, by default "relu"
version : str, optional
If True, self-loops are added to the neighborhood matrix.
att_activation : Callable, default=torch.nn.LeakyReLU()
Activation function applied to the attention coefficients.
aggr_func : Literal["mean", "sum"], default="sum"
Between-neighborhood aggregation function applied to the messages.
update_func : Literal["relu", "sigmoid", "tanh", None], default="relu"
Update function applied to the messages.
version : Literal["v1", "v2"], default="v1"
Version of the layer, by default "v1" which is the same as the original CAN layer. While "v2" has the same attetion mechanism as the GATv2 layer.
share_weights : bool, optional
share_weights : bool, default=False
This option is valid only for "v2". If True, the weights of the linear transformation applied to the source and target features are shared, by default False
"""

lower_att: MultiHeadCellAttention | MultiHeadCellAttention_v2
upper_att: MultiHeadCellAttention | MultiHeadCellAttention_v2

def __init__(
self,
in_channels: int,
Expand All @@ -829,9 +828,9 @@ def __init__(
skip_connection: bool = True,
att_activation: torch.nn.Module = torch.nn.LeakyReLU(),
add_self_loops: bool = False,
aggr_func: str = "sum",
update_func: str = "relu",
version: str = "v1",
aggr_func: Literal["mean", "sum"] = "sum",
update_func: Literal["relu", "sigmoid", "tanh"] | None = "relu",
version: Literal["v1", "v2"] = "v1",
share_weights: bool = False,
**kwargs,
) -> None:
Expand All @@ -841,7 +840,7 @@ def __init__(
assert out_channels > 0, ValueError("Number of output channels must be > 0")
assert heads > 0, ValueError("Number of heads must be > 0")
assert dropout >= 0.0 and dropout <= 1.0, ValueError("Dropout must be in [0,1]")
assert version in ["v1", "v2"], ValueError("Version must be 'v1' or 'v2'")

# assert that shared weight is True only if version is v2
assert share_weights is False or version == "v2", ValueError(
"Shared weights is valid only for v2"
Expand Down
13 changes: 8 additions & 5 deletions topomodelx/nn/cell/can_layer_bis.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
"""Cellular Attention Network Layer."""
from typing import Literal

import torch
from torch.nn.parameter import Parameter

Expand Down Expand Up @@ -50,23 +52,24 @@ class CANLayer(torch.nn.Module):
----------
channels : int
Dimension of input features on edges (1-cells).
activation : string
Activation function to apply to merged message
activation : Literal["relu", "sigmoid", "tanh", None], default="sigmoid"
Activation function to apply to merged message.
att : bool
Whether to use attention.
eps : float
Epsilon used in the attention mechanism.
initialization : string
initialization : Literal["xavier_uniform", "xavier_normal"], default="xavier_uniform"
Initialization method.
"""

def __init__(
self,
channels,
activation: str = "sigmoid",
activation: Literal["relu", "sigmoid", "tanh"] | None = "sigmoid",
att: bool = True,
eps: float = 1e-5,
initialization: str = "xavier_uniform",
initialization: Literal["xavier_uniform", "xavier_normal"] = "xavier_uniform",
initialization_gain: float = 1.414,
) -> None:
super().__init__()
# Do I need upper and lower convolution layers? Since I think they will have different parameters
Expand Down
5 changes: 4 additions & 1 deletion topomodelx/nn/hypergraph/allset_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,9 @@ class AllSetBlock(nn.Module):
Type of layer normalization in the MLP.
"""

encoder: nn.Module
decoder: nn.Module

def __init__(
self,
in_channels,
Expand Down Expand Up @@ -241,7 +244,7 @@ def __init__(
bias: bool = False,
) -> None:
params = {} if inplace is None else {"inplace": inplace}
layers = []
layers: list[nn.Module] = []
in_dim = in_channels
for hidden_dim in hidden_channels[:-1]:
layers.append(nn.Linear(in_dim, hidden_dim, bias=bias))
Expand Down
9 changes: 6 additions & 3 deletions topomodelx/nn/hypergraph/allset_transformer_layer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
"""AllSetTransformer Layer Module."""
from typing import Literal

import torch
import torch.nn.functional as F
from torch import nn
Expand Down Expand Up @@ -273,7 +275,7 @@ class MultiHeadAttention(MessagePassing):
Number of attention heads.
number_queries : int, default=1
Number of queries.
initialization : str, default="xavier_uniform"
initialization : Literal["xavier_uniform", "xavier_normal"], default="xavier_uniform"
Initialization method.
"""

Expand All @@ -285,7 +287,8 @@ def __init__(
update_func=None,
heads: int = 4,
number_queries: int = 1,
initialization: str = "xavier_uniform",
initialization: Literal["xavier_uniform", "xavier_normal"] = "xavier_uniform",
initialization_gain: float = 1.414,
) -> None:
super().__init__(
att=True,
Expand Down Expand Up @@ -429,7 +432,7 @@ def __init__(
bias: bool = False,
) -> None:
params = {} if inplace is None else {"inplace": inplace}
layers = []
layers: list[nn.Module] = []
in_dim = in_channels
for hidden_dim in hidden_channels[:-1]:
layers.append(nn.Linear(in_dim, hidden_dim, bias=bias))
Expand Down
15 changes: 10 additions & 5 deletions topomodelx/nn/hypergraph/hmpnn_layer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
"""HMPNN (Hypergraph Message Passing Neural Network) Layer introduced in Heydari et Livi 2022."""
from typing import Literal

import torch
from torch import nn
from torch.nn import functional as F
Expand All @@ -21,7 +23,10 @@ def apply_dropout(self, neighborhood, dropout_rate: float):

class _NodeToHyperedgeMessenger(MessagePassing, _AdjacencyDropoutMixin):
def __init__(
self, messaging_func, adjacency_dropout: float = 0.7, aggr_func: str = "sum"
self,
messaging_func,
adjacency_dropout: float = 0.7,
aggr_func: Literal["sum", "mean", "add"] = "sum",
) -> None:
super().__init__(aggr_func)
self.messaging_func = messaging_func
Expand All @@ -46,7 +51,7 @@ def __init__(
self,
messaging_func,
adjacency_dropout: float = 0.7,
aggr_func: str = "sum",
aggr_func: Literal["sum", "mean", "add"] = "sum",
) -> None:
super().__init__(aggr_func)
self.messaging_func = messaging_func
Expand Down Expand Up @@ -116,8 +121,8 @@ class HMPNNLayer(nn.Module):
to the paper.
adjacency_dropout: 0.7
Adjacency dropout rate.
aggr_func: "sum"
Message aggregation function. A value among "sum", "mean" and "add".
aggr_func: Literal["sum", "mean", "add"], default="sum"
Message aggregation function.
updating_dropout: 0.5
Regular dropout rate applied to node and hyperedge features.
updating_func: None
Expand All @@ -132,7 +137,7 @@ def __init__(
node_to_hyperedge_messaging_func=None,
hyperedge_to_node_messaging_func=None,
adjacency_dropout: float = 0.7,
aggr_func: str = "sum",
aggr_func: Literal["sum", "mean", "add"] = "sum",
updating_dropout: float = 0.5,
updating_func=None,
) -> None:
Expand Down
6 changes: 4 additions & 2 deletions topomodelx/nn/hypergraph/hnhn_layer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
"""Template Layer with two conv passing steps."""
from typing import Literal

import torch
from torch.nn.parameter import Parameter

Expand Down Expand Up @@ -48,7 +50,7 @@ class HNHNLayer(torch.nn.Module):
Scalar controlling the importance of node cardinality.
bias_gain : float
Gain for the bias initialization.
bias_init : string ["xavier_uniform"|"xavier_normal"]
bias_init : Literal["xavier_uniform", "xavier_normal"], default="xavier_uniform"
Controls the bias initialization method.
"""

Expand All @@ -62,7 +64,7 @@ def __init__(
alpha: float = -1.5,
beta: float = -0.5,
bias_gain: float = 1.414,
bias_init: str = "xavier_uniform",
bias_init: Literal["xavier_uniform", "xavier_normal"] = "xavier_uniform",
) -> None:
super().__init__()
self.use_bias = use_bias
Expand Down
Loading

0 comments on commit 86a0519

Please sign in to comment.