Skip to content

Commit

Permalink
Merge branch 'main' into feat-gnn-inference
Browse files Browse the repository at this point in the history
  • Loading branch information
anna-grim authored Apr 27, 2024
2 parents 5278f36 + 649a5b9 commit 37dd6ef
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion src/deep_neurographs/machine_learning/graph_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
class GCN(torch.nn.Module):
def __init__(self, input_channels):
super().__init__()
self.input = Linear(input_channels, input_channels)
self.input = Linear(input_channels, input_channels)
self.conv1 = GCNConv(input_channels, 2 * input_channels)
self.conv2 = GCNConv(2 * input_channels, input_channels)
self.conv3 = GCNConv(input_channels, input_channels // 2)
Expand Down Expand Up @@ -63,6 +63,9 @@ def forward(self, x, edge_index):
x = self.output(x)

return x
#self.resgated = ResGatedGraphConv(CoraDataset.num_features, hidden_channels)
#self.sage = SAGEConv(hidden_channels, 2 * hidden_channels)
#self.transformer = TransformerConv(2 * hidden_channels, 2 * hidden_channels)

# self.sage = SAGEConv(hidden_channels, 2 * hidden_channels)
# self.transformer = TransformerConv(2 * hidden_channels, 2 * hidden_channels)
Expand Down

0 comments on commit 37dd6ef

Please sign in to comment.