diff --git a/src/neuron_proofreader/machine_learning/gnn_models.py b/src/neuron_proofreader/machine_learning/gnn_models.py index ce18795..f630031 100644 --- a/src/neuron_proofreader/machine_learning/gnn_models.py +++ b/src/neuron_proofreader/machine_learning/gnn_models.py @@ -34,7 +34,14 @@ class VisionHGAT(torch.nn.Module): str(("proposal", "to", "branch")), ] - def __init__(self, patch_shape, heads=2, hidden_dim=128, n_layers=2): + def __init__( + self, + patch_shape, + disable_msg_passing=False, + heads=2, + hidden_dim=128, + n_layers=2, + ): # Call parent class super().__init__() @@ -43,9 +50,15 @@ def __init__(self, patch_shape, heads=2, hidden_dim=128, n_layers=2): self.patch_embedding = init_patch_embedding(patch_shape, hidden_dim // 2) # Message passing layers - self.gat1 = self.init_gat(hidden_dim, hidden_dim, heads) - self.gat2 = self.init_gat(hidden_dim * heads, hidden_dim, heads) - self.output = nn.Linear(hidden_dim * heads ** 2, 1) + self.disable_msg_passing = disable_msg_passing + if self.disable_msg_passing: + self.gat1 = self.init_mlp_layers(hidden_dim, n_layers) + self.gat2 = self.init_mlp_layers(hidden_dim, n_layers) + self.output = nn.Linear(hidden_dim, 1) + else: + self.gat1 = self.init_gat(hidden_dim, hidden_dim, heads) + self.gat2 = self.init_gat(hidden_dim * heads, hidden_dim, heads) + self.output = nn.Linear(hidden_dim * heads ** 2, 1) # Initialize weights self.init_weights() @@ -63,6 +76,18 @@ def init_gat(self, hidden_dim, edge_dim, heads): gat_dict[relation] = init_gat(hidden_dim, edge_dim, heads) return nn_geometric.HeteroConv(gat_dict) + def init_mlp_layers(self, hidden_dim, n_layers=2): + layers = nn.ModuleList() + for _ in range(n_layers): + layers.append( + nn_geometric.HeteroDictLinear( + hidden_dim, + hidden_dim, + types=("branch", "proposal") + ) + ) + return layers + def init_weights(self): """ Initializes linear layers. @@ -91,8 +116,14 @@ def forward(self, input_dict): x_dict["proposal"] = torch.cat((x_dict["proposal"], x_img), dim=1) # Message passing - x_dict = self.gat1(x_dict, edge_index_dict) - x_dict = self.gat2(x_dict, edge_index_dict) + if self.disable_msg_passing: + for layer in self.gat1: + x_dict = layer(x_dict) + for layer in self.gat2: + x_dict = layer(x_dict) + else: + x_dict = self.gat1(x_dict, edge_index_dict) + x_dict = self.gat2(x_dict, edge_index_dict) return self.output(x_dict["proposal"]) diff --git a/src/neuron_proofreader/machine_learning/subgraph_sampler.py b/src/neuron_proofreader/machine_learning/subgraph_sampler.py index 6453772..13d0cb7 100644 --- a/src/neuron_proofreader/machine_learning/subgraph_sampler.py +++ b/src/neuron_proofreader/machine_learning/subgraph_sampler.py @@ -202,7 +202,7 @@ class SeededSubgraphSampler(SubgraphSampler): def __init__(self, graph, gnn_depth=2, max_proposals=64): # Call parent class super(SeededSubgraphSampler, self).__init__( - graph, gnn_depth, max_proposals + graph, gnn_depth=gnn_depth, max_proposals=max_proposals ) # --- Batch Generation --- diff --git a/src/neuron_proofreader/split_proofreading/split_feature_extraction.py b/src/neuron_proofreader/split_proofreading/split_feature_extraction.py index d245032..15cf4f2 100644 --- a/src/neuron_proofreader/split_proofreading/split_feature_extraction.py +++ b/src/neuron_proofreader/split_proofreading/split_feature_extraction.py @@ -11,7 +11,6 @@ from concurrent.futures import ThreadPoolExecutor, as_completed from skimage.transform import resize -from time import time from torch_geometric.data import HeteroData import numpy as np @@ -402,7 +401,7 @@ def __init__( node1, node2 = tuple(self.proposal) self.annotate_edge(node1) self.annotate_edge(node2) - self.annotate_proposal() + self.annotate_proposal() # --- Core Routines --- def get_input_patch(self):