From 0aece9f0c5f6a6c618a73fbf1fd9da8b56151299 Mon Sep 17 00:00:00 2001 From: anna-grim Date: Tue, 3 Feb 2026 04:57:28 +0000 Subject: [PATCH 1/2] bug: batch size --- .../machine_learning/subgraph_sampler.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/src/neuron_proofreader/machine_learning/subgraph_sampler.py b/src/neuron_proofreader/machine_learning/subgraph_sampler.py index 44b921b..6453772 100644 --- a/src/neuron_proofreader/machine_learning/subgraph_sampler.py +++ b/src/neuron_proofreader/machine_learning/subgraph_sampler.py @@ -107,7 +107,6 @@ def __iter__(self): self.populate_via_bfs(subgraph, root) # Yield batch - self.populate_attributes(subgraph) yield subgraph def populate_via_bfs(self, subgraph, root): @@ -178,10 +177,6 @@ def visit_flagged_proposal(self, subgraph, queue, visited, proposal): if not (v in visited and v in nodes_added): queue.append((v, 0)) - def populate_attributes(self, subgraph): - # TO DO - pass - # --- Helpers --- def init_subgraph(self): """ @@ -204,10 +199,10 @@ def is_subgraph_full(self, subgraph): class SeededSubgraphSampler(SubgraphSampler): - def __init__(self, graph, max_proposals=200, gnn_depth=2): + def __init__(self, graph, gnn_depth=2, max_proposals=64): # Call parent class super(SeededSubgraphSampler, self).__init__( - graph, max_proposals, gnn_depth + graph, gnn_depth, max_proposals ) # --- Batch Generation --- From b688ac621f0bb60512ddab57c01669a0f499d482 Mon Sep 17 00:00:00 2001 From: anna-grim Date: Tue, 3 Feb 2026 05:02:03 +0000 Subject: [PATCH 2/2] bug: batch size --- src/neuron_proofreader/config.py | 1 - .../machine_learning/gnn_models.py | 12 ++---------- .../machine_learning/vision_models.py | 4 ++-- 3 files changed, 4 insertions(+), 13 deletions(-) diff --git a/src/neuron_proofreader/config.py b/src/neuron_proofreader/config.py index 69e869c..e348f33 100644 --- a/src/neuron_proofreader/config.py +++ b/src/neuron_proofreader/config.py @@ -179,6 +179,5 @@ def save(self, dir_path): dir_path : str Path to directory to save JSON file. """ - self.graph.save(os.path.join(dir_path, "metadata_graph.json")) self.ml.save(os.path.join(dir_path, "metadata_ml.json")) diff --git a/src/neuron_proofreader/machine_learning/gnn_models.py b/src/neuron_proofreader/machine_learning/gnn_models.py index 7512f95..ce18795 100644 --- a/src/neuron_proofreader/machine_learning/gnn_models.py +++ b/src/neuron_proofreader/machine_learning/gnn_models.py @@ -91,16 +91,8 @@ def forward(self, input_dict): x_dict["proposal"] = torch.cat((x_dict["proposal"], x_img), dim=1) # Message passing - try: - x_dict = self.gat1(x_dict, edge_index_dict) - x_dict = self.gat2(x_dict, edge_index_dict) - except: - print("Before...") - print("\n".join(before)) - print("After...") - for key, x in x_dict.items(): - print(key, x.size()) - stop + x_dict = self.gat1(x_dict, edge_index_dict) + x_dict = self.gat2(x_dict, edge_index_dict) return self.output(x_dict["proposal"]) diff --git a/src/neuron_proofreader/machine_learning/vision_models.py b/src/neuron_proofreader/machine_learning/vision_models.py index 11f57ed..dc827c3 100644 --- a/src/neuron_proofreader/machine_learning/vision_models.py +++ b/src/neuron_proofreader/machine_learning/vision_models.py @@ -9,7 +9,7 @@ """ -#from neurobase.finetune import finetune_model +# from neurobase.finetune import finetune_model from einops import rearrange import torch @@ -147,7 +147,7 @@ def __init__(self, checkpoint_path, model_config): # Instance attributes self.encoder = full_model.encoder - self.output = ml_util.init_feedforward(384, 1, 2) + self.output = FeedForwardNet(384, 1, 3) def forward(self, x): latent0 = self.encoder(x[:, 0:1, ...])