diff --git a/src/deep_neurographs/machine_learning/graph_models.py b/src/deep_neurographs/machine_learning/graph_models.py index 21b581c..20ae5e4 100644 --- a/src/deep_neurographs/machine_learning/graph_models.py +++ b/src/deep_neurographs/machine_learning/graph_models.py @@ -18,7 +18,7 @@ class GCN(torch.nn.Module): def __init__(self, input_channels): super().__init__() - self.input = Linear(input_channels, input_channels) + self.input = Linear(input_channels, input_channels) self.conv1 = GCNConv(input_channels, 2 * input_channels) self.conv2 = GCNConv(2 * input_channels, input_channels) self.conv3 = GCNConv(input_channels, input_channels // 2) @@ -63,6 +63,9 @@ def forward(self, x, edge_index): x = self.output(x) return x + #self.resgated = ResGatedGraphConv(CoraDataset.num_features, hidden_channels) + #self.sage = SAGEConv(hidden_channels, 2 * hidden_channels) + #self.transformer = TransformerConv(2 * hidden_channels, 2 * hidden_channels) # self.sage = SAGEConv(hidden_channels, 2 * hidden_channels) # self.transformer = TransformerConv(2 * hidden_channels, 2 * hidden_channels)