From 2300e62b91d2e48ac85fccd79f1de7cb14fd2800 Mon Sep 17 00:00:00 2001 From: anna-grim Date: Tue, 3 Feb 2026 02:43:21 +0000 Subject: [PATCH 1/2] refactor: feedfoward and normalization --- src/neuron_proofreader/utils/img_util.py | 11 +++++++---- src/neuron_proofreader/utils/ml_util.py | 7 ++++--- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/src/neuron_proofreader/utils/img_util.py b/src/neuron_proofreader/utils/img_util.py index c70d114..6f344c8 100644 --- a/src/neuron_proofreader/utils/img_util.py +++ b/src/neuron_proofreader/utils/img_util.py @@ -659,22 +659,25 @@ def is_precomputed(img_path): return False -def normalize(img): +def normalize(img, percentiles=(1, 99.5)): """ - Normalizes an image so that the minimum and maximum intensity values are 0 - and 1. + Normalizes an image using a percentile-based scheme and clips values to + [0, 1]. Parameters ---------- img : numpy.ndarray Image to be normalized. + percentiles : Tuple[float], optional + Upper and lower percentiles used to normalize the given image. Default + is (1, 99.5). Returns ------- img : numpy.ndarray Normalized image. """ - mn, mx = np.percentile(img, [1, 99.9]) + mn, mx = np.percentile(img, percentiles) return np.clip((img - mn) / (mx - mn + 1e-5), 0, 1) diff --git a/src/neuron_proofreader/utils/ml_util.py b/src/neuron_proofreader/utils/ml_util.py index 14cbe77..63eb008 100644 --- a/src/neuron_proofreader/utils/ml_util.py +++ b/src/neuron_proofreader/utils/ml_util.py @@ -37,12 +37,13 @@ def __init__(self, input_dim, output_dim, n_layers): super().__init__() # Instance attributes + assert n_layers > 1 self.net = self.build_network(input_dim, output_dim, n_layers) def build_network(self, input_dim, output_dim, n_layers): # Set input/output dimensions input_dim_i = input_dim - output_dim_i = input_dim // 2 + output_dim_i = max(input_dim // 2, 4) # Build architecture layers = [] @@ -50,9 +51,9 @@ def build_network(self, input_dim, output_dim, n_layers): mlp = init_mlp(input_dim_i, input_dim_i * 2, output_dim_i) layers.append(mlp) - input_dim_i = input_dim_i // 2 + input_dim_i = output_dim_i output_dim_i = ( - output_dim_i // 2 if i < n_layers - 2 else output_dim + max(output_dim_i // 2, 4) if i < n_layers - 2 else output_dim ) # Initialize weights From 9d8723c12391f985b0a800aab25d501f8f4d8894 Mon Sep 17 00:00:00 2001 From: anna-grim Date: Mon, 9 Feb 2026 17:29:22 +0000 Subject: [PATCH 2/2] gnn with mlp comp --- .../machine_learning/gnn_models.py | 43 ++++++++++++++++--- .../machine_learning/subgraph_sampler.py | 2 +- .../split_feature_extraction.py | 3 +- 3 files changed, 39 insertions(+), 9 deletions(-) diff --git a/src/neuron_proofreader/machine_learning/gnn_models.py b/src/neuron_proofreader/machine_learning/gnn_models.py index ce18795..f630031 100644 --- a/src/neuron_proofreader/machine_learning/gnn_models.py +++ b/src/neuron_proofreader/machine_learning/gnn_models.py @@ -34,7 +34,14 @@ class VisionHGAT(torch.nn.Module): str(("proposal", "to", "branch")), ] - def __init__(self, patch_shape, heads=2, hidden_dim=128, n_layers=2): + def __init__( + self, + patch_shape, + disable_msg_passing=False, + heads=2, + hidden_dim=128, + n_layers=2, + ): # Call parent class super().__init__() @@ -43,9 +50,15 @@ def __init__(self, patch_shape, heads=2, hidden_dim=128, n_layers=2): self.patch_embedding = init_patch_embedding(patch_shape, hidden_dim // 2) # Message passing layers - self.gat1 = self.init_gat(hidden_dim, hidden_dim, heads) - self.gat2 = self.init_gat(hidden_dim * heads, hidden_dim, heads) - self.output = nn.Linear(hidden_dim * heads ** 2, 1) + self.disable_msg_passing = disable_msg_passing + if self.disable_msg_passing: + self.gat1 = self.init_mlp_layers(hidden_dim, n_layers) + self.gat2 = self.init_mlp_layers(hidden_dim, n_layers) + self.output = nn.Linear(hidden_dim, 1) + else: + self.gat1 = self.init_gat(hidden_dim, hidden_dim, heads) + self.gat2 = self.init_gat(hidden_dim * heads, hidden_dim, heads) + self.output = nn.Linear(hidden_dim * heads ** 2, 1) # Initialize weights self.init_weights() @@ -63,6 +76,18 @@ def init_gat(self, hidden_dim, edge_dim, heads): gat_dict[relation] = init_gat(hidden_dim, edge_dim, heads) return nn_geometric.HeteroConv(gat_dict) + def init_mlp_layers(self, hidden_dim, n_layers=2): + layers = nn.ModuleList() + for _ in range(n_layers): + layers.append( + nn_geometric.HeteroDictLinear( + hidden_dim, + hidden_dim, + types=("branch", "proposal") + ) + ) + return layers + def init_weights(self): """ Initializes linear layers. @@ -91,8 +116,14 @@ def forward(self, input_dict): x_dict["proposal"] = torch.cat((x_dict["proposal"], x_img), dim=1) # Message passing - x_dict = self.gat1(x_dict, edge_index_dict) - x_dict = self.gat2(x_dict, edge_index_dict) + if self.disable_msg_passing: + for layer in self.gat1: + x_dict = layer(x_dict) + for layer in self.gat2: + x_dict = layer(x_dict) + else: + x_dict = self.gat1(x_dict, edge_index_dict) + x_dict = self.gat2(x_dict, edge_index_dict) return self.output(x_dict["proposal"]) diff --git a/src/neuron_proofreader/machine_learning/subgraph_sampler.py b/src/neuron_proofreader/machine_learning/subgraph_sampler.py index 6453772..13d0cb7 100644 --- a/src/neuron_proofreader/machine_learning/subgraph_sampler.py +++ b/src/neuron_proofreader/machine_learning/subgraph_sampler.py @@ -202,7 +202,7 @@ class SeededSubgraphSampler(SubgraphSampler): def __init__(self, graph, gnn_depth=2, max_proposals=64): # Call parent class super(SeededSubgraphSampler, self).__init__( - graph, gnn_depth, max_proposals + graph, gnn_depth=gnn_depth, max_proposals=max_proposals ) # --- Batch Generation --- diff --git a/src/neuron_proofreader/split_proofreading/split_feature_extraction.py b/src/neuron_proofreader/split_proofreading/split_feature_extraction.py index d245032..15cf4f2 100644 --- a/src/neuron_proofreader/split_proofreading/split_feature_extraction.py +++ b/src/neuron_proofreader/split_proofreading/split_feature_extraction.py @@ -11,7 +11,6 @@ from concurrent.futures import ThreadPoolExecutor, as_completed from skimage.transform import resize -from time import time from torch_geometric.data import HeteroData import numpy as np @@ -402,7 +401,7 @@ def __init__( node1, node2 = tuple(self.proposal) self.annotate_edge(node1) self.annotate_edge(node2) - self.annotate_proposal() + self.annotate_proposal() # --- Core Routines --- def get_input_patch(self):