From e8c28479ade97b01e00c5dafcab7848f0a248333 Mon Sep 17 00:00:00 2001 From: anna-grim Date: Tue, 3 Feb 2026 20:18:34 +0000 Subject: [PATCH 1/2] refactor: optimized split corr runtime --- .../merge_proofreading/merge_inference.py | 2 +- src/neuron_proofreader/proposal_graph.py | 11 +++---- .../split_feature_extraction.py | 33 +++++++++++++++++-- .../split_proofreading/split_inference.py | 6 ++-- src/neuron_proofreader/utils/img_util.py | 12 ++----- 5 files changed, 42 insertions(+), 22 deletions(-) diff --git a/src/neuron_proofreader/merge_proofreading/merge_inference.py b/src/neuron_proofreader/merge_proofreading/merge_inference.py index 3832099..b99026f 100644 --- a/src/neuron_proofreader/merge_proofreading/merge_inference.py +++ b/src/neuron_proofreader/merge_proofreading/merge_inference.py @@ -103,7 +103,7 @@ def predict(self, x_nodes): numpy.ndarray Predicted merge site likelihoods. """ - with torch.no_grad(): + with torch.inference_mode(): x_nodes = x_nodes.to(self.device) y_nodes = sigmoid(self.model(x_nodes)) return np.squeeze(ml_util.to_cpu(y_nodes, to_numpy=True), axis=1) diff --git a/src/neuron_proofreader/proposal_graph.py b/src/neuron_proofreader/proposal_graph.py index 43e8e63..24dbf8c 100644 --- a/src/neuron_proofreader/proposal_graph.py +++ b/src/neuron_proofreader/proposal_graph.py @@ -645,12 +645,11 @@ def edge_attr(self, i, key="xyz", ignore=False): return attrs def edge_length(self, edge): - length = 0 - for i in range(1, len(self.edges[edge]["xyz"])): - length += geometry.dist( - self.edges[edge]["xyz"][i], self.edges[edge]["xyz"][i - 1] - ) - return length + xyz = self.edges[edge]["xyz"] + if len(xyz) < 2: + return 0.0 + else: + return np.linalg.norm(xyz[1:] - xyz[:-1], axis=1).sum() def find_fragments_near_xyz(self, query_xyz, max_dist): hits = dict() diff --git a/src/neuron_proofreader/split_proofreading/split_feature_extraction.py b/src/neuron_proofreader/split_proofreading/split_feature_extraction.py index 53ccd77..1ef0bdc 100644 --- a/src/neuron_proofreader/split_proofreading/split_feature_extraction.py +++ b/src/neuron_proofreader/split_proofreading/split_feature_extraction.py @@ -10,6 +10,8 @@ """ from concurrent.futures import ThreadPoolExecutor, as_completed +from skimage.transform import resize +from time import time from torch_geometric.data import HeteroData import numpy as np @@ -76,7 +78,7 @@ def __call__(self, subgraph): Subgraph of "graph" attribute to extract features for. """ features = FeatureSet(subgraph) - for extractor in self.extractors: + for name, extractor in zip(["skel", "img"], self.extractors): extractor(subgraph, features) return features @@ -264,6 +266,7 @@ def __call__(self, subgraph, features): for thread in as_completed(pending.keys()): proposal = pending.pop(thread) extractor = thread.result() + profiles[proposal] = extractor.get_intensity_profile() patches[proposal] = extractor.get_input_patch() @@ -413,7 +416,7 @@ def get_input_patch(self): raw image data and channel 1 contains segmentation data. """ img = img_util.resize(self.img, self.patch_shape) - mask = img_util.resize(self.mask, self.patch_shape, True) + mask = resize_segmentation(self.mask, self.patch_shape) return np.stack([img, mask], axis=0) def get_intensity_profile(self): @@ -954,3 +957,29 @@ def get_feature_dict(): proposals. """ return {"branch": 2, "proposal": 70} + + +def resize_segmentation(mask, new_shape): + """ + Resizes a segmentation mask to the given new shape. + + Parameters + ---------- + mask : numpy.ndarray + Segmentation mask to be resized. + new_shape : Tuple[int] + New shape of segmentation mask. + + Returns + ------- + mask : numpy.ndarray + Resized segmentation mask. + """ + mask = resize( + mask, + new_shape, + order=0, + preserve_range=True, + anti_aliasing=False, + ).astype(mask.dtype) + return mask diff --git a/src/neuron_proofreader/split_proofreading/split_inference.py b/src/neuron_proofreader/split_proofreading/split_inference.py index bc358ed..eec470d 100644 --- a/src/neuron_proofreader/split_proofreading/split_inference.py +++ b/src/neuron_proofreader/split_proofreading/split_inference.py @@ -29,7 +29,6 @@ """ from time import time -from torch.nn.functional import sigmoid from tqdm import tqdm import networkx as nx @@ -327,10 +326,11 @@ def predict(self, data): Dictionary that maps proposal IDs to model predictions. """ # Generate predictions - with torch.no_grad(): + with torch.inference_mode(): device = self.config.ml.device x = data.get_inputs().to(device) - hat_y = sigmoid(self.model(x)) + with torch.cuda.amp.autocast(enabled=True): + hat_y = torch.sigmoid(self.model(x)) # Reformat predictions idx_to_id = data.idxs_proposals.idx_to_id diff --git a/src/neuron_proofreader/utils/img_util.py b/src/neuron_proofreader/utils/img_util.py index 6f344c8..4a4f04f 100644 --- a/src/neuron_proofreader/utils/img_util.py +++ b/src/neuron_proofreader/utils/img_util.py @@ -730,7 +730,7 @@ def remove_small_segments(segmentation, min_size): return segmentation -def resize(img, new_shape, is_segmentation=False): +def resize(img, new_shape): """ Resize a 3D image to the specified new shape using linear interpolation. @@ -740,22 +740,14 @@ def resize(img, new_shape, is_segmentation=False): Input 3D image array with shape (depth, height, width). new_shape : Tuple[int] Desired output shape as (new_depth, new_height, new_width). - is_segmentation : bool, optional - Indication of whether the image represents a segmentation mask. Returns ------- numpy.ndarray Resized 3D image with shape equal to "new_shape". """ - # Set parameters - order = 0 if is_segmentation else 3 - multiplier = 4 if is_segmentation else 1 zoom_factors = np.array(new_shape) / np.array(img.shape) - - # Resize image - img = zoom(multiplier * img, zoom_factors, order=order) - return img / multiplier + return zoom(img, zoom_factors, order=1, prefilter=False) def to_physical(voxel, anisotropy, offset=(0, 0, 0)): From 64326fd5e22c6cfd32ddc145f2fae5a1d128bed1 Mon Sep 17 00:00:00 2001 From: anna-grim Date: Tue, 3 Feb 2026 20:20:12 +0000 Subject: [PATCH 2/2] remove debug line --- .../split_proofreading/split_feature_extraction.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/neuron_proofreader/split_proofreading/split_feature_extraction.py b/src/neuron_proofreader/split_proofreading/split_feature_extraction.py index 1ef0bdc..88b1e2c 100644 --- a/src/neuron_proofreader/split_proofreading/split_feature_extraction.py +++ b/src/neuron_proofreader/split_proofreading/split_feature_extraction.py @@ -78,7 +78,7 @@ def __call__(self, subgraph): Subgraph of "graph" attribute to extract features for. """ features = FeatureSet(subgraph) - for name, extractor in zip(["skel", "img"], self.extractors): + for extractor in self.extractors: extractor(subgraph, features) return features