diff --git a/test/nn/hypergraph/test_hnhn_layer.py b/test/nn/hypergraph/test_hnhn_layer.py index 4ddb7998..4003f5e7 100644 --- a/test/nn/hypergraph/test_hnhn_layer.py +++ b/test/nn/hypergraph/test_hnhn_layer.py @@ -23,7 +23,20 @@ def template_layer(self): incidence_1=incidence_1, ) - def test_forward(self, template_layer): + @pytest.fixture + def template_layer2(self): + """Initialize and return an HNHN layer.""" + self.in_channels = 5 + self.hidden_channels = 8 + + return HNHNLayer( + in_channels=self.in_channels, + hidden_channels=self.hidden_channels, + incidence_1=None, + bias_init="xavier_normal", + ) + + def test_forward(self, template_layer, template_layer2): """Test the forward pass of the HNHN layer.""" n_nodes, n_edges = template_layer.incidence_1.shape @@ -33,6 +46,15 @@ def test_forward(self, template_layer): assert x_0_out.shape == (n_nodes, self.hidden_channels) assert x_1_out.shape == (n_edges, self.hidden_channels) + n_nodes = 10 + n_edges = 20 + incidence_1 = torch.randint(0, 2, (n_nodes, n_edges)).float() + + x_0_out, x_1_out = template_layer2.forward(x_0, incidence_1) + + assert x_0_out.shape == (n_nodes, self.hidden_channels) + assert x_1_out.shape == (n_edges, self.hidden_channels) + return def test_compute_normalization_matrices(self, template_layer): diff --git a/test/nn/hypergraph/test_hypersage_layer.py b/test/nn/hypergraph/test_hypersage_layer.py index 7adcdf51..9679b070 100644 --- a/test/nn/hypergraph/test_hypersage_layer.py +++ b/test/nn/hypergraph/test_hypersage_layer.py @@ -15,14 +15,24 @@ def hypersage_layer(self): out_channels = 30 return HyperSAGELayer(in_channels, out_channels) - def test_forward(self, hypersage_layer): + @pytest.fixture + def hypersage_layer_alpha(self): + """Return a HyperSAGE layer.""" + in_channels = 10 + out_channels = 30 + return HyperSAGELayer(in_channels, out_channels, alpha=1) + + def test_forward(self, hypersage_layer, hypersage_layer_alpha): """Test the forward pass of the HyperSAGE layer.""" x_2 = torch.randn(3, 10) incidence_2 = torch.tensor( [[1, 0], [0, 1], [1, 1]], dtype=torch.float32 ).to_sparse() output = hypersage_layer.forward(x_2, incidence_2) + output2 = hypersage_layer_alpha.forward(x_2, incidence_2) + assert output.shape == (3, 30) + assert output2.shape == (3, 30) def test_forward_with_invalid_input(self, hypersage_layer): """Test the forward pass of the HyperSAGE layer with invalid input.""" @@ -65,6 +75,13 @@ def test_update_sigmoid(self, hypersage_layer): assert torch.is_tensor(updated) assert updated.shape == (10, 20) + def test_update_invalid(self, hypersage_layer): + """Test the update function with update_func = "invalid".""" + hypersage_layer.update_func = "invalid" + inputs = torch.randn(10, 20) + with pytest.raises(RuntimeError): + hypersage_layer.update(inputs) + def test_aggregation_invald(self, hypersage_layer): """Test the aggregation function with invalid mode.""" x_messages = torch.zeros(3, 10) diff --git a/test/nn/hypergraph/test_unigcnii_layer.py b/test/nn/hypergraph/test_unigcnii_layer.py index 1b3fd68e..62af86c3 100644 --- a/test/nn/hypergraph/test_unigcnii_layer.py +++ b/test/nn/hypergraph/test_unigcnii_layer.py @@ -18,14 +18,36 @@ def unigcnii_layer(self): in_channels=in_channels, hidden_channels=in_channels, alpha=alpha, beta=beta ) - def test_forward(self, unigcnii_layer): + @pytest.fixture + def unigcnii_layer2(self): + """Return a uniGCNII layer.""" + in_channels = 10 + alpha = 0.1 + beta = 0.1 + return UniGCNIILayer( + in_channels=in_channels, + hidden_channels=in_channels, + alpha=alpha, + beta=beta, + use_norm=True, + ) + + def test_forward(self, unigcnii_layer, unigcnii_layer2): """Test the forward pass.""" n_nodes, in_channels = 3, 10 x_0 = torch.randn(n_nodes, in_channels) - incidence_1 = torch.tensor([[1, 0], [1, 1], [0, 1]], dtype=torch.float32) - x_0, _ = unigcnii_layer.forward(x_0, incidence_1) + incidence_1 = torch.tensor( + [[1, 1, 0], [1, 1, 1], [0, 1, 1]], dtype=torch.float32 + ) + x_0, x_1 = unigcnii_layer.forward(x_0, incidence_1) assert x_0.shape == torch.Size([n_nodes, in_channels]) + assert x_1.shape == torch.Size([3, in_channels]) + + x_0, x_1 = unigcnii_layer2.forward(x_0, incidence_1) + + assert x_0.shape == torch.Size([n_nodes, in_channels]) + assert x_1.shape == torch.Size([3, in_channels]) def test_forward_with_skip(self): """Test the forward pass where alpha=1 and beta=0. @@ -45,3 +67,10 @@ def test_forward_with_skip(self): x_0, _ = layer(x_0, incidence_1, x_skip) torch.testing.assert_close(x_0, x_skip, rtol=1e-4, atol=1e-4) + + def test_reset_params(self, unigcnii_layer): + """Test reset parameters.""" + unigcnii_layer.linear.weight.requires_grad = False + unigcnii_layer.linear.weight.fill_(0) + unigcnii_layer.reset_parameters() + assert torch.max(unigcnii_layer.linear.weight) > 0 diff --git a/test/nn/hypergraph/test_unigin_layer.py b/test/nn/hypergraph/test_unigin_layer.py index bd1d8217..5fc6981e 100644 --- a/test/nn/hypergraph/test_unigin_layer.py +++ b/test/nn/hypergraph/test_unigin_layer.py @@ -10,12 +10,18 @@ class TestUniGINLayer: """Test the UniGIN layer.""" @pytest.fixture - def UniGIN_layer(self): + def unigin_layer(self): """Return a UniGIN layer.""" self.in_channels = 10 return UniGINLayer(in_channels=self.in_channels) - def test_forward(self, UniGIN_layer): + @pytest.fixture + def unigin_layer2(self): + """Return a UniGIN layer.""" + self.in_channels = 10 + return UniGINLayer(in_channels=self.in_channels, use_norm=True) + + def test_forward(self, unigin_layer, unigin_layer2): """Test the forward pass of the UniGIN layer.""" n_nodes, n_edges = 2, 3 incidence = torch.from_numpy( @@ -23,7 +29,12 @@ def test_forward(self, UniGIN_layer): ).to_sparse() incidence = incidence.float() x_0 = torch.rand(n_nodes, self.in_channels).float() - x_0, x_1 = UniGIN_layer.forward(x_0, incidence) + x_0, x_1 = unigin_layer.forward(x_0, incidence) + + assert x_0.shape == torch.Size([n_nodes, self.in_channels]) + assert x_1.shape == torch.Size([n_edges, self.in_channels]) + + x_0, x_1 = unigin_layer2.forward(x_0, incidence) assert x_0.shape == torch.Size([n_nodes, self.in_channels]) assert x_1.shape == torch.Size([n_edges, self.in_channels]) diff --git a/test/nn/hypergraph/test_unisage_layer.py b/test/nn/hypergraph/test_unisage_layer.py index 40c7b31b..de965ad2 100644 --- a/test/nn/hypergraph/test_unisage_layer.py +++ b/test/nn/hypergraph/test_unisage_layer.py @@ -9,17 +9,27 @@ class TestUniSAGELayer: """Tests for UniSAGE Layer.""" @pytest.fixture - def uniSAGE_layer(self): + def unisage_layer(self): """Fixture for uniSAGE layer.""" in_channels = 10 out_channels = 30 return UniSAGELayer(in_channels, out_channels) - def test_forward(self, uniSAGE_layer): + @pytest.fixture + def unisage_layer2(self): + """Fixture for uniSAGE layer.""" + in_channels = 10 + out_channels = 30 + return UniSAGELayer(in_channels, out_channels, use_norm=True) + + def test_forward(self, unisage_layer, unisage_layer2): """Test forward pass.""" x = torch.randn(3, 10) incidence = torch.tensor([[1, 1, 0], [1, 1, 1], [0, 1, 1]], dtype=torch.float32) - x_0, x_1 = uniSAGE_layer.forward(x, incidence) + x_0, x_1 = unisage_layer.forward(x, incidence) + assert x_0.shape == torch.Size([3, 30]) + assert x_1.shape == torch.Size([3, 30]) + x_0, x_1 = unisage_layer2.forward(x, incidence) assert x_0.shape == torch.Size([3, 30]) assert x_1.shape == torch.Size([3, 30]) @@ -33,7 +43,7 @@ def test_sum_aggregator(self): assert x_0.shape == torch.Size([3, 30]) assert x_1.shape == torch.Size([3, 30]) - def test_aggregator_validation(self, uniSAGE_layer): + def test_aggregator_validation(self, unisage_layer): """Test aggregator validation.""" with pytest.raises(Exception) as exc_info: _ = UniSAGELayer(10, 30, e_aggr="invalid_aggregator") @@ -42,18 +52,9 @@ def test_aggregator_validation(self, uniSAGE_layer): == "Unsupported aggregator: invalid_aggregator, should be 'sum', 'mean'," ) - def test_reset_params(self, uniSAGE_layer): + def test_reset_params(self, unisage_layer): """Test reset parameters.""" - uniSAGE_layer.linear.weight.requires_grad = False - uniSAGE_layer.linear.weight.fill_(0) - uniSAGE_layer.reset_parameters() - assert torch.max(uniSAGE_layer.linear.weight) > 0 - - def test_batchnorm(self, uniSAGE_layer): - """Test batchnorm.""" - x = torch.randn(3, 10) - incidence = torch.tensor([[1, 1, 0], [1, 1, 1], [0, 1, 1]], dtype=torch.float32) - layer = UniSAGELayer(10, 30, e_aggr="sum", use_bn=True) - layer(x, incidence) - assert layer.bn is not None - assert layer.bn.num_batches_tracked.item() == 1 + unisage_layer.linear.weight.requires_grad = False + unisage_layer.linear.weight.fill_(0) + unisage_layer.reset_parameters() + assert torch.max(unisage_layer.linear.weight) > 0 diff --git a/topomodelx/nn/hypergraph/allset_transformer.py b/topomodelx/nn/hypergraph/allset_transformer.py index 2f6d3e64..ffaa4279 100644 --- a/topomodelx/nn/hypergraph/allset_transformer.py +++ b/topomodelx/nn/hypergraph/allset_transformer.py @@ -24,7 +24,7 @@ class AllSetTransformer(torch.nn.Module): Dropout probability. mlp_num_layers : int, default: 2 Number of layers in the MLP. - mlp_dropout: float, default: 0.2 + mlp_dropout : float, default: 0.2 Dropout probability in the MLP. References diff --git a/topomodelx/nn/hypergraph/hmpnn_layer.py b/topomodelx/nn/hypergraph/hmpnn_layer.py index 00aba147..c144acc9 100644 --- a/topomodelx/nn/hypergraph/hmpnn_layer.py +++ b/topomodelx/nn/hypergraph/hmpnn_layer.py @@ -122,10 +122,10 @@ class HMPNNLayer(nn.Module): ---------- in_channels : int Dimension of input features. - node_to_hyperedge_messaging_func: None + node_to_hyperedge_messaging_func : None Node messaging function as a callable or nn.Module object. If not given, a linear plus sigmoid function is used, according to the paper. - hyperedge_to_node_messaging_func: None + hyperedge_to_node_messaging_func : None Hyperedge messaging function as a callable or nn.Module object. It gets hyperedge input features and aggregated messages of nodes as input and returns hyperedge messages. If not given, two inputs are concatenated and a linear layer reducing back to in_channels plus sigmoid is applied, according diff --git a/topomodelx/nn/hypergraph/hnhn.py b/topomodelx/nn/hypergraph/hnhn.py index bda9c2af..6d8c0dc6 100644 --- a/topomodelx/nn/hypergraph/hnhn.py +++ b/topomodelx/nn/hypergraph/hnhn.py @@ -18,6 +18,8 @@ class HNHN(torch.nn.Module): Incidence matrix mapping edges to nodes (B_1). n_layers : int, default = 2 Number of HNHN message passing layers. + layer_drop : float, default = 0.2 + Dropout rate for the hidden features. References ---------- @@ -27,7 +29,9 @@ class HNHN(torch.nn.Module): https://grlplus.github.io/papers/40.pdf """ - def __init__(self, in_channels, hidden_channels, incidence_1, n_layers=2): + def __init__( + self, in_channels, hidden_channels, incidence_1, n_layers=2, layer_drop=0.2 + ): super().__init__() self.layers = torch.nn.ModuleList( @@ -38,8 +42,9 @@ def __init__(self, in_channels, hidden_channels, incidence_1, n_layers=2): ) for i in range(n_layers) ) + self.layer_drop = torch.nn.Dropout(layer_drop) - def forward(self, x_0): + def forward(self, x_0, incidence_1=None): """Forward computation. Parameters @@ -58,6 +63,7 @@ def forward(self, x_0): Output hyperedge features. """ for layer in self.layers: - x_0, x_1 = layer(x_0) + x_0, x_1 = layer(x_0, incidence_1) + x_0 = self.layer_drop(x_0) return x_0, x_1 diff --git a/topomodelx/nn/hypergraph/hnhn_layer.py b/topomodelx/nn/hypergraph/hnhn_layer.py index 7b162ca0..f8008434 100644 --- a/topomodelx/nn/hypergraph/hnhn_layer.py +++ b/topomodelx/nn/hypergraph/hnhn_layer.py @@ -62,7 +62,7 @@ def __init__( self, in_channels, hidden_channels, - incidence_1, + incidence_1=None, use_bias: bool = True, use_normalized_incidence: bool = True, alpha: float = -1.5, @@ -76,7 +76,8 @@ def __init__( self.bias_gain = bias_gain self.use_normalized_incidence = use_normalized_incidence self.incidence_1 = incidence_1 - self.incidence_1_transpose = incidence_1.transpose(1, 0) + if incidence_1 is not None: + self.incidence_1_transpose = incidence_1.transpose(1, 0) self.conv_0_to_1 = Conv( in_channels=in_channels, @@ -98,9 +99,10 @@ def __init__( if self.use_normalized_incidence: self.alpha = alpha self.beta = beta - self.n_nodes, self.n_edges = self.incidence_1.shape - self.compute_normalization_matrices() - self.normalize_incidence_matrices() + if incidence_1 is not None: + self.n_nodes, self.n_edges = self.incidence_1.shape + self.compute_normalization_matrices() + self.normalize_incidence_matrices() def compute_normalization_matrices(self) -> None: """Compute the normalization matrices for the incidence matrices.""" @@ -158,7 +160,7 @@ def reset_parameters(self) -> None: if self.use_bias: self.init_biases() - def forward(self, x_0): + def forward(self, x_0, incidence_1=None): r"""Forward computation. The forward pass was initially proposed in [1]_. @@ -182,8 +184,8 @@ def forward(self, x_0): ---------- x_0 : torch.Tensor, shape = (n_nodes, channels_node) Input features on the hypernodes. - x_1 : torch.Tensor, shape = (n_edges, channels_edge) - Input features on the hyperedges. + incidence_1: torch.Tensor, shape = (n_nodes, n_edges) + Incidence matrix mapping edges to nodes (B_1). Returns ------- @@ -192,6 +194,13 @@ def forward(self, x_0): x_1 : torch.Tensor, shape = (n_edges, channels_edge) Output features on the hyperedges. """ + if incidence_1 is not None: + self.incidence_1 = incidence_1 + self.incidence_1_transpose = incidence_1.transpose(1, 0) + if self.use_normalized_incidence: + self.n_nodes, self.n_edges = incidence_1.shape + self.compute_normalization_matrices() + self.normalize_incidence_matrices() # Move incidence matrices to device self.incidence_1 = self.incidence_1.to(x_0.device) self.incidence_1_transpose = self.incidence_1_transpose.to(x_0.device) diff --git a/topomodelx/nn/hypergraph/hypergat.py b/topomodelx/nn/hypergraph/hypergat.py index 7af59760..e8449cc7 100644 --- a/topomodelx/nn/hypergraph/hypergat.py +++ b/topomodelx/nn/hypergraph/hypergat.py @@ -16,6 +16,8 @@ class HyperGAT(torch.nn.Module): Dimension of the hidden features. n_layers : int, default = 2 Amount of message passing layers. + layer_drop: float, default = 0.2 + Dropout rate for the hidden features. References ---------- @@ -29,6 +31,7 @@ def __init__( in_channels, hidden_channels, n_layers=2, + layer_drop=0.2, ): super().__init__() @@ -39,6 +42,7 @@ def __init__( ) for i in range(n_layers) ) + self.layer_drop = torch.nn.Dropout(layer_drop) def forward(self, x_0, incidence_1): """Forward computation through layers, then linear layer, then global max pooling. @@ -59,5 +63,6 @@ def forward(self, x_0, incidence_1): """ for layer in self.layers: x_0, x_1 = layer.forward(x_0, incidence_1) + x_0 = self.layer_drop(x_0) return x_0, x_1 diff --git a/topomodelx/nn/hypergraph/hypersage.py b/topomodelx/nn/hypergraph/hypersage.py index 738c48d0..a9744225 100644 --- a/topomodelx/nn/hypergraph/hypersage.py +++ b/topomodelx/nn/hypergraph/hypersage.py @@ -16,6 +16,8 @@ class HyperSAGE(torch.nn.Module): Dimension of the hidden features. n_layer : int, default = 2 Amount of message passing layers. + alpha : int, default = -1 + Max number of nodes in a neighborhood to consider. If -1 it considers all the nodes.รน References ---------- @@ -24,13 +26,14 @@ class HyperSAGE(torch.nn.Module): https://arxiv.org/abs/2010.04558 """ - def __init__(self, in_channels, hidden_channels, n_layers=2, **kwargs): + def __init__(self, in_channels, hidden_channels, n_layers=2, alpha=-1, **kwargs): super().__init__() self.layers = torch.nn.ModuleList( HyperSAGELayer( in_channels=in_channels if i == 0 else hidden_channels, out_channels=hidden_channels, + alpha=alpha, **kwargs, ) for i in range(n_layers) diff --git a/topomodelx/nn/hypergraph/hypersage_layer.py b/topomodelx/nn/hypergraph/hypersage_layer.py index df6acf4b..d296569b 100644 --- a/topomodelx/nn/hypergraph/hypersage_layer.py +++ b/topomodelx/nn/hypergraph/hypersage_layer.py @@ -49,6 +49,8 @@ class HyperSAGELayer(MessagePassing): Dimension of the input features. out_channels : int Dimension of the output features. + alpha : int, default=-1 + Max number of nodes in a neighborhood to consider. If -1 it considers all the nodes. aggr_func_intra : callable, default=GeneralizedMean(p=2) Aggregation function. Default is GeneralizedMean(p=2). aggr_func_inter : callable, default=GeneralizedMean(p=2) @@ -77,6 +79,7 @@ def __init__( self, in_channels: int, out_channels: int, + alpha: int = -1, aggr_func_intra: Aggregation | None = None, aggr_func_inter: Aggregation | None = None, update_func: Literal["relu", "sigmoid"] = "relu", @@ -96,6 +99,7 @@ def __init__( self.in_channels = in_channels self.out_channels = out_channels + self.alpha = alpha self.aggr_func_intra = aggr_func_intra self.aggr_func_inter = aggr_func_inter self.update_func = update_func @@ -187,7 +191,7 @@ def forward(self, x: torch.Tensor, incidence: torch.Tensor): # type: ignore[ove """ def nodes_per_edge(e): - return ( + messages = ( torch.index_select( input=incidence.to("cpu"), dim=1, index=torch.LongTensor([e]) ) @@ -195,6 +199,9 @@ def nodes_per_edge(e): .indices()[0] .to(self.device) ) + if len(messages) <= self.alpha or self.alpha == -1: + return messages + return messages[torch.randperm(len(messages))[: self.alpha]] def edges_per_node(v): return ( diff --git a/topomodelx/nn/hypergraph/unigcnii.py b/topomodelx/nn/hypergraph/unigcnii.py index 210f8e61..47d3673c 100644 --- a/topomodelx/nn/hypergraph/unigcnii.py +++ b/topomodelx/nn/hypergraph/unigcnii.py @@ -26,6 +26,8 @@ class UniGCNII(torch.nn.Module): Dropout rate for the input features. layer_drop : float, default=0.2 Dropout rate for the hidden features. + use_norm : bool, default=False + Whether to apply row normalization after every layer. References ---------- @@ -44,6 +46,7 @@ def __init__( beta=0.5, input_drop=0.2, layer_drop=0.2, + use_norm=False, ): super().__init__() layers = [] @@ -61,6 +64,7 @@ def __init__( hidden_channels=hidden_channels, alpha=alpha, beta=beta, + use_norm=use_norm, ) ) diff --git a/topomodelx/nn/hypergraph/unigcnii_layer.py b/topomodelx/nn/hypergraph/unigcnii_layer.py index e1351977..a0b06700 100644 --- a/topomodelx/nn/hypergraph/unigcnii_layer.py +++ b/topomodelx/nn/hypergraph/unigcnii_layer.py @@ -1,6 +1,8 @@ """UniGCNII layer implementation.""" import torch +from topomodelx.base.conv import Conv + class UniGCNIILayer(torch.nn.Module): r""" @@ -16,6 +18,8 @@ class UniGCNIILayer(torch.nn.Module): The alpha parameter determining the importance of the self-loop (\theta_2). beta : float The beta parameter determining the importance of the learned matrix (\theta_1). + use_norm : bool, default=False + Whether to apply row normalization after the layer. References ---------- @@ -25,12 +29,20 @@ class UniGCNIILayer(torch.nn.Module): https://arxiv.org/pdf/2105.00956.pdf """ - def __init__(self, in_channels, hidden_channels, alpha: float, beta: float) -> None: + def __init__( + self, in_channels, hidden_channels, alpha: float, beta: float, use_norm=False + ) -> None: super().__init__() self.alpha = alpha self.beta = beta self.linear = torch.nn.Linear(in_channels, hidden_channels, bias=False) + self.conv = Conv( + in_channels=in_channels, + out_channels=in_channels, + with_linear_transform=False, + ) + self.use_norm = use_norm def reset_parameters(self) -> None: """Reset the parameters of the layer.""" @@ -90,7 +102,7 @@ def forward(self, x_0, incidence_1, x_skip=None): incidence_1_transpose = incidence_1.transpose(0, 1) # First message without any learning or parameters - x_1 = torch.sparse.mm(incidence_1_transpose, x_0) + x_1 = self.conv(x_0, incidence_1_transpose) # Compute node and edge degrees for normalization. node_degree = torch.sum(incidence_1.to_dense(), dim=1) @@ -108,11 +120,18 @@ def forward(self, x_0, incidence_1, x_skip=None): edge_degree = edge_degree / torch.sum(incidence_1.to_dense(), dim=0) # Second message normalized with node and edge degrees (using broadcasting) - x_0 = (1 / torch.sqrt(node_degree).unsqueeze(-1)) * torch.sparse.mm( - incidence_1 @ torch.diag(1 / torch.sqrt(edge_degree)), x_1 + x_0 = (1 / torch.sqrt(node_degree).unsqueeze(-1)) * self.conv( + x_1, incidence_1 @ torch.diag(1 / torch.sqrt(edge_degree)) ) # Introduce skip connections with hyperparameter alpha and beta x_combined = ((1 - self.alpha) * x_0) + (self.alpha * x_skip) x_0 = ((1 - self.beta) * x_combined) + self.beta * self.linear(x_combined) + + if self.use_norm: + rownorm = x_0.detach().norm(dim=1, keepdim=True) + scale = rownorm.pow(-1) + scale[torch.isinf(scale)] = 0.0 + x_0 = x_0 * scale + return x_0, x_1 diff --git a/topomodelx/nn/hypergraph/unigin.py b/topomodelx/nn/hypergraph/unigin.py index 78a45344..d9549ada 100644 --- a/topomodelx/nn/hypergraph/unigin.py +++ b/topomodelx/nn/hypergraph/unigin.py @@ -20,6 +20,12 @@ class UniGIN(torch.nn.Module): Dropout rate for the input features. layer_drop : float, default=0.2 Dropout rate for the hidden features. + eps : float, default=0 + Constant in GIN Update equation. + train_eps : bool, default=False + Whether to make eps a trainable parameter. + use_norm : bool, default=False + Whether to apply row normalization after every layer. References @@ -37,6 +43,9 @@ def __init__( n_layers=2, input_drop=0.2, layer_drop=0.2, + eps=0, + train_eps=False, + use_norm=False, ): super().__init__() @@ -48,6 +57,9 @@ def __init__( self.layers = torch.nn.ModuleList( UniGINLayer( in_channels=hidden_channels, + eps=eps, + train_eps=train_eps, + use_norm=use_norm, ) for _ in range(n_layers) ) diff --git a/topomodelx/nn/hypergraph/unigin_layer.py b/topomodelx/nn/hypergraph/unigin_layer.py index 5add4e29..e92202ac 100644 --- a/topomodelx/nn/hypergraph/unigin_layer.py +++ b/topomodelx/nn/hypergraph/unigin_layer.py @@ -1,6 +1,8 @@ """Implementation of UniGIN layer from Huang et. al.: UniGNN: a Unified Framework for Graph and Hypergraph Neural Networks.""" import torch +from topomodelx.base.conv import Conv + class UniGINLayer(torch.nn.Module): """Layer of UniGIN. @@ -13,8 +15,11 @@ class UniGINLayer(torch.nn.Module): Dimension of input features. eps : float, default=0.0 Constant in GIN Update equation. - train_eps : boolm, default=False + train_eps : bool, default=False Whether to make eps a trainable parameter. + use_norm : bool, default=False + Whether to apply row normalization after the layer. + References ---------- @@ -35,6 +40,7 @@ def __init__( in_channels, eps: float = 0.0, train_eps: bool = False, + use_norm: bool = False, ) -> None: super().__init__() @@ -46,6 +52,19 @@ def __init__( self.linear = torch.nn.Linear(in_channels, in_channels) + self.use_norm = use_norm + + self.vertex2edge = Conv( + in_channels=in_channels, + out_channels=in_channels, + with_linear_transform=False, + ) + self.edge2vertex = Conv( + in_channels=in_channels, + out_channels=in_channels, + with_linear_transform=False, + ) + def forward(self, x_0, incidence_1): r"""[1]_ initially proposed the forward pass. @@ -93,9 +112,16 @@ def forward(self, x_0, incidence_1): """ incidence_1_transpose = incidence_1.to_dense().T.to_sparse() # First pass fills in features of edges by adding features of constituent nodes - x_1 = torch.sparse.mm(incidence_1_transpose.float(), x_0) + x_1 = self.vertex2edge(x_0, incidence_1_transpose) # Second pass fills in features of nodes by adding features of the incident edges - m_1_0 = torch.sparse.mm(incidence_1.float(), x_1) + m_1_0 = self.edge2vertex(x_1, incidence_1) # Update node features using GIN update equation x_0 = self.linear((1 + self.eps) * x_0 + m_1_0) + + if self.use_norm: + rownorm = x_0.detach().norm(dim=1, keepdim=True) + scale = rownorm.pow(-1) + scale[torch.isinf(scale)] = 0.0 + x_0 = x_0 * scale + return x_0, x_1 diff --git a/topomodelx/nn/hypergraph/unisage.py b/topomodelx/nn/hypergraph/unisage.py index 228a440c..a43b89dc 100644 --- a/topomodelx/nn/hypergraph/unisage.py +++ b/topomodelx/nn/hypergraph/unisage.py @@ -1,5 +1,7 @@ """UniSAGE class.""" +from typing import Literal + import torch from topomodelx.nn.hypergraph.unisage_layer import UniSAGELayer @@ -20,6 +22,13 @@ class UniSAGE(torch.nn.Module): Dropout rate for the hidden features. n_layers : int, default = 2 Amount of message passing layers. + e_aggr : Literal["sum", "mean",], default="sum" + Aggregator function for hyperedges. + v_aggr : Literal["sum", "mean",], default="mean" + Aggregator function for nodes. + use_norm : boolean + Whether to apply row normalization after every layer. + References ---------- @@ -36,6 +45,15 @@ def __init__( input_drop=0.2, layer_drop=0.2, n_layers=2, + e_aggr: Literal[ + "sum", + "mean", + ] = "sum", + v_aggr: Literal[ + "sum", + "mean", + ] = "mean", + use_norm: bool = False, ): super().__init__() @@ -46,6 +64,9 @@ def __init__( UniSAGELayer( in_channels=in_channels if i == 0 else hidden_channels, hidden_channels=hidden_channels, + e_aggr=e_aggr, + v_aggr=v_aggr, + use_norm=use_norm, ) for i in range(n_layers) ) diff --git a/topomodelx/nn/hypergraph/unisage_layer.py b/topomodelx/nn/hypergraph/unisage_layer.py index 989e783f..58e56545 100644 --- a/topomodelx/nn/hypergraph/unisage_layer.py +++ b/topomodelx/nn/hypergraph/unisage_layer.py @@ -20,8 +20,8 @@ class UniSAGELayer(torch.nn.Module): Aggregator function for hyperedges. v_aggr : Literal["sum", "mean",], default="mean" Aggregator function for nodes. - use_bn : boolean - Whether to use bathnorm after the linear transformation. + use_norm : boolean + Whether to apply row normalization after every layer. References ---------- @@ -58,12 +58,12 @@ def __init__( "sum", "mean", ] = "mean", - use_bn: bool = False, + use_norm: bool = False, ) -> None: super().__init__() self.in_channels = in_channels self.hidden_channels = hidden_channels - self.bn = torch.nn.BatchNorm1d(hidden_channels) if use_bn else None + self.use_norm = use_norm self.linear = torch.nn.Linear(in_channels, hidden_channels) @@ -90,8 +90,6 @@ def __init__( def reset_parameters(self) -> None: r"""Reset learnable parameters.""" self.linear.reset_parameters() - if self.bn is not None: - self.bn.reset_parameters() def forward(self, x_0, incidence_1): r"""[1]_ initially proposed the forward pass. @@ -139,11 +137,15 @@ def forward(self, x_0, incidence_1): Output hyperedge features. """ x_0 = self.linear(x_0) - if self.bn is not None: - x_0 = self.bn(x_0) x_1 = self.vertex2edge(x_0, incidence_1.transpose(1, 0)) m_1_0 = self.edge2vertex(x_1, incidence_1) x_0 = x_0 + m_1_0 + if self.use_norm: + rownorm = x_0.detach().norm(dim=1, keepdim=True) + scale = rownorm.pow(-1) + scale[torch.isinf(scale)] = 0.0 + x_0 = x_0 * scale + return x_0, x_1 diff --git a/tutorials/hypergraph/allset_transformer_train.ipynb b/tutorials/hypergraph/allset_transformer_train.ipynb index 41b3048c..a7ed18bc 100644 --- a/tutorials/hypergraph/allset_transformer_train.ipynb +++ b/tutorials/hypergraph/allset_transformer_train.ipynb @@ -456,8 +456,8 @@ } ], "source": [ - "test_interval = 0\n", - "num_epochs = 0\n", + "test_interval = 1\n", + "num_epochs = 1\n", "\n", "epoch_loss = []\n", "for epoch_i in range(1, num_epochs + 1):\n",