Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def predict(self, x_nodes):
numpy.ndarray
Predicted merge site likelihoods.
"""
with torch.no_grad():
with torch.inference_mode():
x_nodes = x_nodes.to(self.device)
y_nodes = sigmoid(self.model(x_nodes))
return np.squeeze(ml_util.to_cpu(y_nodes, to_numpy=True), axis=1)
Expand Down
11 changes: 5 additions & 6 deletions src/neuron_proofreader/proposal_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,12 +645,11 @@ def edge_attr(self, i, key="xyz", ignore=False):
return attrs

def edge_length(self, edge):
length = 0
for i in range(1, len(self.edges[edge]["xyz"])):
length += geometry.dist(
self.edges[edge]["xyz"][i], self.edges[edge]["xyz"][i - 1]
)
return length
xyz = self.edges[edge]["xyz"]
if len(xyz) < 2:
return 0.0
else:
return np.linalg.norm(xyz[1:] - xyz[:-1], axis=1).sum()

def find_fragments_near_xyz(self, query_xyz, max_dist):
hits = dict()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
"""

from concurrent.futures import ThreadPoolExecutor, as_completed
from skimage.transform import resize
from time import time
from torch_geometric.data import HeteroData

import numpy as np
Expand Down Expand Up @@ -264,6 +266,7 @@ def __call__(self, subgraph, features):
for thread in as_completed(pending.keys()):
proposal = pending.pop(thread)
extractor = thread.result()

profiles[proposal] = extractor.get_intensity_profile()
patches[proposal] = extractor.get_input_patch()

Expand Down Expand Up @@ -413,7 +416,7 @@ def get_input_patch(self):
raw image data and channel 1 contains segmentation data.
"""
img = img_util.resize(self.img, self.patch_shape)
mask = img_util.resize(self.mask, self.patch_shape, True)
mask = resize_segmentation(self.mask, self.patch_shape)
return np.stack([img, mask], axis=0)

def get_intensity_profile(self):
Expand Down Expand Up @@ -954,3 +957,29 @@ def get_feature_dict():
proposals.
"""
return {"branch": 2, "proposal": 70}


def resize_segmentation(mask, new_shape):
"""
Resizes a segmentation mask to the given new shape.

Parameters
----------
mask : numpy.ndarray
Segmentation mask to be resized.
new_shape : Tuple[int]
New shape of segmentation mask.

Returns
-------
mask : numpy.ndarray
Resized segmentation mask.
"""
mask = resize(
mask,
new_shape,
order=0,
preserve_range=True,
anti_aliasing=False,
).astype(mask.dtype)
return mask
6 changes: 3 additions & 3 deletions src/neuron_proofreader/split_proofreading/split_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
"""

from time import time
from torch.nn.functional import sigmoid
from tqdm import tqdm

import networkx as nx
Expand Down Expand Up @@ -327,10 +326,11 @@ def predict(self, data):
Dictionary that maps proposal IDs to model predictions.
"""
# Generate predictions
with torch.no_grad():
with torch.inference_mode():
device = self.config.ml.device
x = data.get_inputs().to(device)
hat_y = sigmoid(self.model(x))
with torch.cuda.amp.autocast(enabled=True):
hat_y = torch.sigmoid(self.model(x))

# Reformat predictions
idx_to_id = data.idxs_proposals.idx_to_id
Expand Down
12 changes: 2 additions & 10 deletions src/neuron_proofreader/utils/img_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,7 +730,7 @@ def remove_small_segments(segmentation, min_size):
return segmentation


def resize(img, new_shape, is_segmentation=False):
def resize(img, new_shape):
"""
Resize a 3D image to the specified new shape using linear interpolation.

Expand All @@ -740,22 +740,14 @@ def resize(img, new_shape, is_segmentation=False):
Input 3D image array with shape (depth, height, width).
new_shape : Tuple[int]
Desired output shape as (new_depth, new_height, new_width).
is_segmentation : bool, optional
Indication of whether the image represents a segmentation mask.

Returns
-------
numpy.ndarray
Resized 3D image with shape equal to "new_shape".
"""
# Set parameters
order = 0 if is_segmentation else 3
multiplier = 4 if is_segmentation else 1
zoom_factors = np.array(new_shape) / np.array(img.shape)

# Resize image
img = zoom(multiplier * img, zoom_factors, order=order)
return img / multiplier
return zoom(img, zoom_factors, order=1, prefilter=False)


def to_physical(voxel, anisotropy, offset=(0, 0, 0)):
Expand Down
Loading