From 44983b8ca0d0395aef3fc1da45cdf3241b35fc7e Mon Sep 17 00:00:00 2001 From: anna-grim Date: Tue, 6 Jan 2026 19:43:42 +0000 Subject: [PATCH 1/4] refactor: filter small segments --- .../merge_proofreading/merge_datasets.py | 1 + src/neuron_proofreader/utils/img_util.py | 26 +++++++++++++++++++ 2 files changed, 27 insertions(+) diff --git a/src/neuron_proofreader/merge_proofreading/merge_datasets.py b/src/neuron_proofreader/merge_proofreading/merge_datasets.py index facd54c..c455e6d 100644 --- a/src/neuron_proofreader/merge_proofreading/merge_datasets.py +++ b/src/neuron_proofreader/merge_proofreading/merge_datasets.py @@ -501,6 +501,7 @@ def get_segment_mask(self, brain_id, center, subgraph): segment_mask = self.segmentation_readers[brain_id].read( center, self.patch_shape ) + segment_mask = img_util.remove_small_segments(segment_mask, 1000) segment_mask = 0.5 * (segment_mask > 0).astype(float) else: segment_mask = np.zeros(self.patch_shape) diff --git a/src/neuron_proofreader/utils/img_util.py b/src/neuron_proofreader/utils/img_util.py index 2e481f4..ba38678 100644 --- a/src/neuron_proofreader/utils/img_util.py +++ b/src/neuron_proofreader/utils/img_util.py @@ -9,6 +9,7 @@ """ from abc import ABC, abstractmethod +from fastremap import mask_except, renumber, unique from matplotlib.colors import ListedColormap from scipy.ndimage import zoom @@ -631,6 +632,31 @@ def pad_to_shape(img, target_shape, pad_value=0): return np.pad(img, pads, mode='constant', constant_values=pad_value) +def remove_small_segments(segmentation, min_size): + """ + Removes small segments from a segmentation. + + Parameters + ---------- + segmentation : numpy.ndarray + Integer array representing a segmentation mask. Each unique + nonzero value corresponds to a distinct segment. + min_size : int + Minimum size (in voxels) for a segment to be kept. + + Returns + ------- + segmentation : numpy.ndarray + New segmentation of the same shape as the input, with only the + retained segments renumbered contiguously. + """ + ids, cnts = unique(segmentation, return_counts=True) + ids = [i for i, cnt in zip(ids, cnts) if cnt > min_size and i != 0] + ids = mask_except(segmentation, ids) + segmentation, _ = renumber(ids, preserve_zero=True, in_place=True) + return segmentation + + def resize(img, new_shape): """ Resize a 3D image to the specified new shape using linear interpolation. From 767c7a0bd73ab6b24d1d8a8d7fb58a486e00a5da Mon Sep 17 00:00:00 2001 From: anna-grim Date: Tue, 6 Jan 2026 22:03:31 +0000 Subject: [PATCH 2/4] bug: load batches --- .../merge_proofreading/merge_datasets.py | 53 +++++++++++-------- 1 file changed, 32 insertions(+), 21 deletions(-) diff --git a/src/neuron_proofreader/merge_proofreading/merge_datasets.py b/src/neuron_proofreader/merge_proofreading/merge_datasets.py index c455e6d..f5a0b4b 100644 --- a/src/neuron_proofreader/merge_proofreading/merge_datasets.py +++ b/src/neuron_proofreader/merge_proofreading/merge_datasets.py @@ -624,7 +624,7 @@ class MergeSiteTrainDataset(MergeSiteDataset): A class for storing and retrieving training examples. """ - def __init__(self, base_dataset=None, idxs=None): + def __init__(self, base_dataset=None, idxs=None, negative_bias=0): """ Instantiates a MergeSiteTrainDataset object. @@ -634,12 +634,15 @@ def __init__(self, base_dataset=None, idxs=None): Dataset to be instantiated as a train dataset. idxs : List[int], optional Indices of examples to be kept in train dataset. + negative_bias : float, optional + Specifies percentage of additional negative examples to add. """ # Create sub-dataset subset_dataset = base_dataset.subset(self.__class__, idxs) self.__dict__.update(subset_dataset.__dict__) # Instance attributes + self.negative_bias = negative_bias self.transform = ImageTransforms() # --- Getters --- @@ -689,8 +692,10 @@ def get_site(self, idx): return self.get_indexed_positive_site(idx) elif np.random.random() < self.random_negative_example_prob: return self.get_random_negative_site() - else: + elif abs(idx) < len(self): return self.get_indexed_negative_site(abs(idx)) + else: + return self.get_random_negative_site() # --- Helpers --- def get_idxs(self): @@ -702,7 +707,8 @@ def get_idxs(self): numpy.ndarray Example indices to iterate over. """ - return np.arange(-len(self) + 1, len(self)) + n_negative_examples = int(len(self) * (1 + self.negative_bias)) + return np.arange(-n_negative_examples + 1, len(self)) class MergeSiteValDataset(MergeSiteDataset): @@ -908,6 +914,7 @@ def __init__( dataset, batch_size=32, is_multimodal=False, + modality=None, sampler=None, use_shuffle=True ): @@ -928,9 +935,11 @@ def __init__( """ # Call parent class super().__init__(dataset, batch_size=batch_size, sampler=sampler) + assert modality in [None, "graph", "pointcloud"] # Instance attributes self.is_multimodal = is_multimodal + self.modality = modality self.patches_shape = (2,) + self.dataset.patch_shape self.use_shuffle = use_shuffle @@ -952,10 +961,12 @@ def __iter__(self): # Iterate over indices for start in range(0, len(idxs), self.batch_size): end = min(start + self.batch_size, len(idxs)) - if self.is_multimodal: - yield self._load_multimodal_batch(idxs[start: end]) + if self.is_multimodal and self.modality == "graph": + yield self._load_image_graph_batch(idxs[start: end]) + elif self.is_multimodal and self.modality == "pointcloud": + yield self._load_image_pc_batch(idxs[start: end]) else: - yield self._load_batch(idxs[start: end]) + yield self._load_image_batch(idxs[start: end]) def _load_image_batch(self, batch_idxs): """ @@ -970,8 +981,8 @@ def _load_image_batch(self, batch_idxs): ------- patches : torch.Tensor Image patches for the batch. - labels : torch.Tensor - Labels corresponding to each patch. + targets : torch.Tensor + Target labels corresponding to each patch. """ with ThreadPoolExecutor() as executor: # Assign threads @@ -982,11 +993,11 @@ def _load_image_batch(self, batch_idxs): # Store results patches = np.zeros((len(batch_idxs),) + self.patches_shape) - labels = np.zeros((len(batch_idxs), 1)) + targets = np.zeros((len(batch_idxs), 1)) for thread in as_completed(pending.keys()): i = pending.pop(thread) - patches[i], _, labels[i] = thread.result() - return ml_util.to_tensor(patches), ml_util.to_tensor(labels) + patches[i], _, targets[i] = thread.result() + return ml_util.to_tensor(patches), ml_util.to_tensor(targets) def _load_image_pc_batch(self, batch_idxs): """ @@ -1001,8 +1012,8 @@ def _load_image_pc_batch(self, batch_idxs): ------- batch : Dict[str, torch.Tensor] Dictionary that maps modality names to batch features. - labels : torch.Tensor - Labels corresponding to each patch. + targets : torch.Tensor + Target labels corresponding to each patch. """ with ThreadPoolExecutor() as executor: # Assign threads @@ -1013,11 +1024,11 @@ def _load_image_pc_batch(self, batch_idxs): # Store results patches = np.zeros((len(batch_idxs),) + self.patches_shape) - labels = np.zeros((len(batch_idxs), 1)) + targets = np.zeros((len(batch_idxs), 1)) point_clouds = np.zeros((len(batch_idxs), 3, 3600)) for thread in as_completed(pending.keys()): i = pending.pop(thread) - patches[i], subgraph, labels[i] = thread.result() + patches[i], subgraph, targets[i] = thread.result() point_clouds[i] = subgraph_to_point_cloud(subgraph) # Set batch dictionary @@ -1027,7 +1038,7 @@ def _load_image_pc_batch(self, batch_idxs): "point_cloud": ml_util.to_tensor(point_clouds), } ) - return batch, ml_util.to_tensor(labels) + return batch, ml_util.to_tensor(targets) def _load_image_graph_batch(self, idxs): """ @@ -1042,8 +1053,8 @@ def _load_image_graph_batch(self, idxs): ------- batch : Dict[str, torch.Tensor] Dictionary that maps modality names to batch features. - labels : torch.Tensor - Labels corresponding to each patch. + targets : torch.Tensor + Target labels corresponding to each patch. """ with ThreadPoolExecutor() as executor: # Assign threads @@ -1052,12 +1063,12 @@ def _load_image_graph_batch(self, idxs): threads.append(executor.submit(self.dataset.__getitem__, idx)) # Store results - labels = np.zeros((len(idxs), 1)) + targets = np.zeros((len(idxs), 1)) patches = np.zeros((len(idxs),) + self.patches_shape) h, x, edge_index, batches = list(), list(), list(), list() node_offset = 0 for i, thread in enumerate(as_completed(threads)): - patches[i], subgraph, labels[i] = thread.result() + patches[i], subgraph, targets[i] = thread.result() h_i, x_i, edge_index_i = subgraph_to_data(subgraph) n_i = h_i.size(0) @@ -1084,4 +1095,4 @@ def _load_image_graph_batch(self, idxs): "graph": (h, x, edge_index, batches) } ) - return batch, ml_util.to_tensor(labels) + return batch, ml_util.to_tensor(targets) From 8dc7e91792d12e78285b69075337ff2d72d7ad54 Mon Sep 17 00:00:00 2001 From: anna-grim Date: Wed, 7 Jan 2026 23:55:06 +0000 Subject: [PATCH 3/4] refactor: model loading --- .../machine_learning/geometric_gnn_models.py | 8 +- .../machine_learning/point_cloud_models.py | 10 +-- .../machine_learning/vision_models.py | 89 ++++++++----------- src/neuron_proofreader/utils/img_util.py | 7 +- src/neuron_proofreader/utils/ml_util.py | 54 +++++++++++ 5 files changed, 100 insertions(+), 68 deletions(-) diff --git a/src/neuron_proofreader/machine_learning/geometric_gnn_models.py b/src/neuron_proofreader/machine_learning/geometric_gnn_models.py index ba8241f..503b1e7 100644 --- a/src/neuron_proofreader/machine_learning/geometric_gnn_models.py +++ b/src/neuron_proofreader/machine_learning/geometric_gnn_models.py @@ -13,10 +13,8 @@ import torch -from neuron_proofreader.machine_learning.vision_models import ( - CNN3D, - init_feedforward, -) +from neuron_proofreader.machine_learning.vision_models import CNN3D +from neuron_proofreader.utils import ml_util # --- Architectures --- @@ -328,7 +326,7 @@ def __init__(self, patch_shape, output_dim=128): output_dim=output_dim, use_double_conv=True, ) - self.output = init_feedforward(2 * output_dim + 3, 1, 3) + self.output = iml_util.nit_feedforward(2 * output_dim + 3, 1, 3) def forward(self, x): """ diff --git a/src/neuron_proofreader/machine_learning/point_cloud_models.py b/src/neuron_proofreader/machine_learning/point_cloud_models.py index a6f2a6a..d0fc19b 100644 --- a/src/neuron_proofreader/machine_learning/point_cloud_models.py +++ b/src/neuron_proofreader/machine_learning/point_cloud_models.py @@ -14,10 +14,8 @@ import torch.nn as nn import torch.nn.functional as F -from neuron_proofreader.machine_learning.vision_models import ( - CNN3D, - init_feedforward, -) +from neuron_proofreader.machine_learning.vision_models import CNN3D +from neuron_proofreader.utils import ml_util # --- Architectures --- @@ -70,7 +68,7 @@ def __init__(self, patch_shape, output_dim=128): output_dim=output_dim, use_double_conv=True, ) - self.output = init_feedforward(2 * output_dim, 1, 3) + self.output = ml_util.init_feedforward(2 * output_dim, 1, 3) def forward(self, x): """ @@ -231,7 +229,7 @@ def __init__(self, patch_shape, output_dim=128): output_dim=output_dim, use_double_conv=True, ) - self.output = init_feedforward(2 * output_dim, 1, 3) + self.output = ml_util.init_feedforward(2 * output_dim, 1, 3) def forward(self, x): """ diff --git a/src/neuron_proofreader/machine_learning/vision_models.py b/src/neuron_proofreader/machine_learning/vision_models.py index 5ef0421..44e3fef 100644 --- a/src/neuron_proofreader/machine_learning/vision_models.py +++ b/src/neuron_proofreader/machine_learning/vision_models.py @@ -10,10 +10,13 @@ """ from einops import rearrange +from neurobase.finetune import finetune_model import torch import torch.nn as nn +from neuron_proofreader.utils import ml_util + # --- CNNs --- class CNN3D(nn.Module): @@ -61,7 +64,7 @@ def __init__( # Output layer flat_size = self._get_flattened_size() - self.output = init_feedforward(flat_size, output_dim, 3) + self.output = ml_util.init_feedforward(flat_size, output_dim, 3) # Initialize weights self.apply(self.init_weights) @@ -128,6 +131,36 @@ def forward(self, x): # --- Transformers --- +class MAE3D(nn.Module): + + def __init__(self): + # Call parent closs + super().__init__() + + # Load model + full_model = finetune_model( + checkpoint_path="/home/jupyter/models/best_model-v1_mae_S.ckpt", + model_config="mae_S", + task_head_config="binary_classifier", + freeze_encoder=True + ) + + # Instance attributes + self.encoder = full_model.encoder + self.output = ml_util.init_feedforward(2 * 384, 1, 2) + + def forward(self, x): + latent0 = self.encoder(x[:, 0:1, ...]) + latent1 = self.encoder(x[:, 1:2, ...]) + + x0 = latent0["latents"][:, 0, :] + x1 = latent1["latents"][:, 0, :] + + x = torch.cat((x0, x1), dim=1) + x = self.output(x) + return x + + class ViT3D(nn.Module): """ A class that implements a 3D Vision transformer. @@ -185,7 +218,7 @@ def __init__( self.norm = nn.LayerNorm(emb_dim) # Output layer - self.output = init_feedforward(emb_dim, output_dim, 2) + self.output = ml_util.init_feedforward(emb_dim, output_dim, 2) # Initialize weights self._init_weights() @@ -486,55 +519,3 @@ def init_conv_layer(in_channels, out_channels, kernel_size, use_double_conv): # Pooling layers.append(nn.MaxPool3d(kernel_size=2)) return nn.Sequential(*layers) - - -def init_feedforward(input_dim, output_dim, n_layers): - """ - Initializes a feed forward neural network. - - Parameters - ---------- - input_dim : int - Dimension of the input. - output_dim : int - Dimension of the output of this network. - n_layers : int - Number of layers in the network. - """ - layers = list() - input_dim_i = input_dim - output_dim_i = input_dim // 2 - for i in range(n_layers): - layers.append(init_mlp(input_dim_i, input_dim_i * 2, output_dim_i)) - input_dim_i = input_dim_i // 2 - output_dim_i = output_dim_i // 2 if i < n_layers - 2 else output_dim - return nn.Sequential(*layers) - - -def init_mlp(input_dim, hidden_dim, output_dim, dropout=0.1): - """ - Initializes a multi-layer perceptron (MLP). - - Parameters - ---------- - input_dim : int - Dimension of input feature vector. - hidden_dim : int - Dimension of embedded feature vector. - output_dim : int - Dimension of output feature vector. - dropout : float, optional - Fraction of values to randomly drop during training. Default is 0.1. - - Returns - ------- - mlp : nn.Sequential - Multi-layer perception network. - """ - mlp = nn.Sequential( - nn.Linear(input_dim, hidden_dim), - nn.GELU(), - nn.Dropout(p=dropout), - nn.Linear(hidden_dim, output_dim), - ) - return mlp diff --git a/src/neuron_proofreader/utils/img_util.py b/src/neuron_proofreader/utils/img_util.py index ba38678..fc7045b 100644 --- a/src/neuron_proofreader/utils/img_util.py +++ b/src/neuron_proofreader/utils/img_util.py @@ -597,12 +597,13 @@ def normalize(img): Returns ------- - numpy.ndarray + img : numpy.ndarray Normalized image. """ try: - mn, mx = np.percentile(img, [1, 99.9]) - return np.clip((img - mn) / (mx - mn + 1e-5), 0, 1) + #mn, mx = np.percentile(img, [1, 99.9]) + #return np.clip((img - mn) / (mx - mn + 1e-5), 0, 1) + return (img - img.mean()) / (img.std() + 1e-8) except Exception as e: print("Image Normalization Failed:", e) return np.zeros(img.shape) diff --git a/src/neuron_proofreader/utils/ml_util.py b/src/neuron_proofreader/utils/ml_util.py index 6b8a2d9..ae00238 100644 --- a/src/neuron_proofreader/utils/ml_util.py +++ b/src/neuron_proofreader/utils/ml_util.py @@ -14,12 +14,66 @@ import networkx as nx import numpy as np import torch +import torch.nn as nn from neuron_proofreader.utils import util GNN_DEPTH = 2 +# --- Architectures --- +def init_feedforward(input_dim, output_dim, n_layers): + """ + Initializes a feed forward neural network. + + Parameters + ---------- + input_dim : int + Dimension of the input. + output_dim : int + Dimension of the output of this network. + n_layers : int + Number of layers in the network. + """ + layers = list() + input_dim_i = input_dim + output_dim_i = input_dim // 2 + for i in range(n_layers): + layers.append(init_mlp(input_dim_i, input_dim_i * 2, output_dim_i)) + input_dim_i = input_dim_i // 2 + output_dim_i = output_dim_i // 2 if i < n_layers - 2 else output_dim + return nn.Sequential(*layers) + + +def init_mlp(input_dim, hidden_dim, output_dim, dropout=0.1): + """ + Initializes a multi-layer perceptron (MLP). + + Parameters + ---------- + input_dim : int + Dimension of input feature vector. + hidden_dim : int + Dimension of embedded feature vector. + output_dim : int + Dimension of output feature vector. + dropout : float, optional + Fraction of values to randomly drop during training. Default is 0.1. + + Returns + ------- + mlp : nn.Sequential + Multi-layer perception network. + """ + mlp = nn.Sequential( + nn.Linear(input_dim, hidden_dim), + nn.GELU(), + nn.Dropout(p=dropout), + nn.Linear(hidden_dim, output_dim), + ) + return mlp + + # --- Batch Generation --- def get_batch(graph, proposals, batch_size, flagged_proposals=set()): """ From 736ec31befdbddbfbed4cae4083df99838687906 Mon Sep 17 00:00:00 2001 From: anna-grim Date: Tue, 27 Jan 2026 18:48:40 +0000 Subject: [PATCH 4/4] provisional branch --- .../machine_learning/augmentation.py | 10 +++++----- src/neuron_proofreader/machine_learning/train.py | 2 +- .../machine_learning/vision_models.py | 12 +++++++++--- .../merge_proofreading/merge_datasets.py | 10 +++++----- 4 files changed, 20 insertions(+), 14 deletions(-) diff --git a/src/neuron_proofreader/machine_learning/augmentation.py b/src/neuron_proofreader/machine_learning/augmentation.py index 4f5d12d..0f43a2b 100644 --- a/src/neuron_proofreader/machine_learning/augmentation.py +++ b/src/neuron_proofreader/machine_learning/augmentation.py @@ -88,7 +88,7 @@ def __call__(self, patches): for axis in self.axes: if random.random() > 0.5: patches[0, ...] = np.flip(patches[0, ...], axis=axis) - patches[1, ...] = np.flip(patches[1, ...], axis=axis) + #patches[1, ...] = np.flip(patches[1, ...], axis=axis) class RandomRotation3D: @@ -124,7 +124,7 @@ def __call__(self, patches): if random.random() < 0.5: angle = random.uniform(*self.angles) patches[0, ...] = rotate3d(patches[0, ...], angle, axes) - patches[1, ...] = rotate3d(patches[1, ...], angle, axes, True) + #patches[1, ...] = rotate3d(patches[1, ...], angle, axes, True) class RandomScale3D: @@ -174,7 +174,7 @@ def __call__(self, patches): # Rescale images patches[0, ...] = zoom(patches[0, ...], zoom_factors, order=3) - patches[1, ...] = zoom(patches[1, ...], zoom_factors, order=0) + #patches[1, ...] = zoom(patches[1, ...], zoom_factors, order=0) return patches @@ -207,7 +207,7 @@ def __call__(self, patches): the input image and "patches[1, ...]" is from the segmentation. """ factor = random.uniform(*self.factor_range) - patches[0, ...] = np.clip(patches[0, ...] * factor, 0, 1) + #patches[0, ...] = np.clip(patches[0, ...] * factor, 0, 1) class RandomNoise3D: @@ -240,7 +240,7 @@ def __call__(self, img_patch): std = self.max_std * random.random() noise = np.random.uniform(-std, std, img_patch[0, ...].shape) img_patch[0, ...] += noise - img_patch[0, ...] = np.clip(img_patch[0, ...], 0, 1) + #img_patch[0, ...] = np.clip(img_patch[0, ...], 0, 1) # --- Helpers --- diff --git a/src/neuron_proofreader/machine_learning/train.py b/src/neuron_proofreader/machine_learning/train.py index f4958ea..0d71c2c 100644 --- a/src/neuron_proofreader/machine_learning/train.py +++ b/src/neuron_proofreader/machine_learning/train.py @@ -378,7 +378,7 @@ def _save_mistake_mips(self, x, y, hat_y, idx_offset): filename = f"{mistake_type}{i + idx_offset}.png" output_path = os.path.join(self.mistakes_dir, filename) img_util.plot_image_and_segmentation_mips( - x[i, 0], 2 * x[i, 1], output_path + x[i, 0] + np.min(x[i, 0]), x[i, 0] + np.min(x[i, 0]), output_path ) def save_model(self, epoch): diff --git a/src/neuron_proofreader/machine_learning/vision_models.py b/src/neuron_proofreader/machine_learning/vision_models.py index 44e3fef..b3cf1c1 100644 --- a/src/neuron_proofreader/machine_learning/vision_models.py +++ b/src/neuron_proofreader/machine_learning/vision_models.py @@ -59,7 +59,7 @@ def __init__( # Convolutional layers self.conv_layers = init_cnn3d( - 2, n_feat_channels, n_conv_layers, use_double_conv=use_double_conv + 1, n_feat_channels, n_conv_layers, use_double_conv=use_double_conv ) # Output layer @@ -82,7 +82,7 @@ def _get_flattened_size(self): pooling. """ with torch.no_grad(): - x = torch.zeros(1, 2, *self.patch_shape) + x = torch.zeros(1, 1, *self.patch_shape) x = self.conv_layers(x) return x.view(1, -1).size(1) @@ -147,9 +147,15 @@ def __init__(self): # Instance attributes self.encoder = full_model.encoder - self.output = ml_util.init_feedforward(2 * 384, 1, 2) + self.output = ml_util.init_feedforward(384, 1, 2) def forward(self, x): + latent = self.encoder(x) + x = latent["latents"][:, 0, :] + x = self.output(x) + return x + + def forward_old(self, x): latent0 = self.encoder(x[:, 0:1, ...]) latent1 = self.encoder(x[:, 1:2, ...]) diff --git a/src/neuron_proofreader/merge_proofreading/merge_datasets.py b/src/neuron_proofreader/merge_proofreading/merge_datasets.py index f5a0b4b..0249f40 100644 --- a/src/neuron_proofreader/merge_proofreading/merge_datasets.py +++ b/src/neuron_proofreader/merge_proofreading/merge_datasets.py @@ -77,7 +77,7 @@ def __init__( self, merge_sites_df, anisotropy=(1.0, 1.0, 1.0), - brightness_clip=400, + brightness_clip=600, subgraph_radius=100, node_spacing=5, patch_shape=(128, 128, 128), @@ -324,11 +324,11 @@ def __getitem__(self, idx): # Stack image channels try: - patches = np.stack([img_patch, segment_mask], axis=0) + patches = img_patch + 2 * segment_mask except ValueError: img_patch = img_util.pad_to_shape(img_patch, self.patch_shape) - patches = np.stack([img_patch, segment_mask], axis=0) - return patches, subgraph, label + patches = img_patch + segment_mask + return patches[np.newaxis], subgraph, label def sample_brain_id(self): """ @@ -940,7 +940,7 @@ def __init__( # Instance attributes self.is_multimodal = is_multimodal self.modality = modality - self.patches_shape = (2,) + self.dataset.patch_shape + self.patches_shape = (1,) + self.dataset.patch_shape self.use_shuffle = use_shuffle # --- Core Routines ---