diff --git a/src/neuron_proofreader/machine_learning/augmentation.py b/src/neuron_proofreader/machine_learning/augmentation.py index 4f5d12d..0f43a2b 100644 --- a/src/neuron_proofreader/machine_learning/augmentation.py +++ b/src/neuron_proofreader/machine_learning/augmentation.py @@ -88,7 +88,7 @@ def __call__(self, patches): for axis in self.axes: if random.random() > 0.5: patches[0, ...] = np.flip(patches[0, ...], axis=axis) - patches[1, ...] = np.flip(patches[1, ...], axis=axis) + #patches[1, ...] = np.flip(patches[1, ...], axis=axis) class RandomRotation3D: @@ -124,7 +124,7 @@ def __call__(self, patches): if random.random() < 0.5: angle = random.uniform(*self.angles) patches[0, ...] = rotate3d(patches[0, ...], angle, axes) - patches[1, ...] = rotate3d(patches[1, ...], angle, axes, True) + #patches[1, ...] = rotate3d(patches[1, ...], angle, axes, True) class RandomScale3D: @@ -174,7 +174,7 @@ def __call__(self, patches): # Rescale images patches[0, ...] = zoom(patches[0, ...], zoom_factors, order=3) - patches[1, ...] = zoom(patches[1, ...], zoom_factors, order=0) + #patches[1, ...] = zoom(patches[1, ...], zoom_factors, order=0) return patches @@ -207,7 +207,7 @@ def __call__(self, patches): the input image and "patches[1, ...]" is from the segmentation. """ factor = random.uniform(*self.factor_range) - patches[0, ...] = np.clip(patches[0, ...] * factor, 0, 1) + #patches[0, ...] = np.clip(patches[0, ...] * factor, 0, 1) class RandomNoise3D: @@ -240,7 +240,7 @@ def __call__(self, img_patch): std = self.max_std * random.random() noise = np.random.uniform(-std, std, img_patch[0, ...].shape) img_patch[0, ...] += noise - img_patch[0, ...] = np.clip(img_patch[0, ...], 0, 1) + #img_patch[0, ...] = np.clip(img_patch[0, ...], 0, 1) # --- Helpers --- diff --git a/src/neuron_proofreader/machine_learning/geometric_gnn_models.py b/src/neuron_proofreader/machine_learning/geometric_gnn_models.py index 7338b00..cc455ab 100644 --- a/src/neuron_proofreader/machine_learning/geometric_gnn_models.py +++ b/src/neuron_proofreader/machine_learning/geometric_gnn_models.py @@ -13,10 +13,8 @@ import torch -from neuron_proofreader.machine_learning.vision_models import ( - CNN3D, - init_feedforward, -) +from neuron_proofreader.machine_learning.vision_models import CNN3D +from neuron_proofreader.utils import ml_util # --- Multimodal GNN Architectures --- diff --git a/src/neuron_proofreader/machine_learning/point_cloud_models.py b/src/neuron_proofreader/machine_learning/point_cloud_models.py index dd6d5d2..73880d4 100644 --- a/src/neuron_proofreader/machine_learning/point_cloud_models.py +++ b/src/neuron_proofreader/machine_learning/point_cloud_models.py @@ -14,10 +14,8 @@ import torch.nn as nn import torch.nn.functional as F -from neuron_proofreader.machine_learning.vision_models import ( - CNN3D, - init_feedforward, -) +from neuron_proofreader.machine_learning.vision_models import CNN3D +from neuron_proofreader.utils import ml_util # --- Architectures --- @@ -70,7 +68,7 @@ def __init__(self, patch_shape, output_dim=128): output_dim=output_dim, use_double_conv=True, ) - self.output = init_feedforward(2 * output_dim, 1, 3) + self.output = ml_util.init_feedforward(2 * output_dim, 1, 3) def forward(self, x): """ @@ -231,7 +229,7 @@ def __init__(self, patch_shape, output_dim=128): output_dim=output_dim, use_double_conv=True, ) - self.output = init_feedforward(2 * output_dim, 1, 3) + self.output = ml_util.init_feedforward(2 * output_dim, 1, 3) def forward(self, x): """ diff --git a/src/neuron_proofreader/machine_learning/train.py b/src/neuron_proofreader/machine_learning/train.py index f4958ea..0d71c2c 100644 --- a/src/neuron_proofreader/machine_learning/train.py +++ b/src/neuron_proofreader/machine_learning/train.py @@ -378,7 +378,7 @@ def _save_mistake_mips(self, x, y, hat_y, idx_offset): filename = f"{mistake_type}{i + idx_offset}.png" output_path = os.path.join(self.mistakes_dir, filename) img_util.plot_image_and_segmentation_mips( - x[i, 0], 2 * x[i, 1], output_path + x[i, 0] + np.min(x[i, 0]), x[i, 0] + np.min(x[i, 0]), output_path ) def save_model(self, epoch): diff --git a/src/neuron_proofreader/machine_learning/vision_models.py b/src/neuron_proofreader/machine_learning/vision_models.py index 5ef0421..b3cf1c1 100644 --- a/src/neuron_proofreader/machine_learning/vision_models.py +++ b/src/neuron_proofreader/machine_learning/vision_models.py @@ -10,10 +10,13 @@ """ from einops import rearrange +from neurobase.finetune import finetune_model import torch import torch.nn as nn +from neuron_proofreader.utils import ml_util + # --- CNNs --- class CNN3D(nn.Module): @@ -56,12 +59,12 @@ def __init__( # Convolutional layers self.conv_layers = init_cnn3d( - 2, n_feat_channels, n_conv_layers, use_double_conv=use_double_conv + 1, n_feat_channels, n_conv_layers, use_double_conv=use_double_conv ) # Output layer flat_size = self._get_flattened_size() - self.output = init_feedforward(flat_size, output_dim, 3) + self.output = ml_util.init_feedforward(flat_size, output_dim, 3) # Initialize weights self.apply(self.init_weights) @@ -79,7 +82,7 @@ def _get_flattened_size(self): pooling. """ with torch.no_grad(): - x = torch.zeros(1, 2, *self.patch_shape) + x = torch.zeros(1, 1, *self.patch_shape) x = self.conv_layers(x) return x.view(1, -1).size(1) @@ -128,6 +131,42 @@ def forward(self, x): # --- Transformers --- +class MAE3D(nn.Module): + + def __init__(self): + # Call parent closs + super().__init__() + + # Load model + full_model = finetune_model( + checkpoint_path="/home/jupyter/models/best_model-v1_mae_S.ckpt", + model_config="mae_S", + task_head_config="binary_classifier", + freeze_encoder=True + ) + + # Instance attributes + self.encoder = full_model.encoder + self.output = ml_util.init_feedforward(384, 1, 2) + + def forward(self, x): + latent = self.encoder(x) + x = latent["latents"][:, 0, :] + x = self.output(x) + return x + + def forward_old(self, x): + latent0 = self.encoder(x[:, 0:1, ...]) + latent1 = self.encoder(x[:, 1:2, ...]) + + x0 = latent0["latents"][:, 0, :] + x1 = latent1["latents"][:, 0, :] + + x = torch.cat((x0, x1), dim=1) + x = self.output(x) + return x + + class ViT3D(nn.Module): """ A class that implements a 3D Vision transformer. @@ -185,7 +224,7 @@ def __init__( self.norm = nn.LayerNorm(emb_dim) # Output layer - self.output = init_feedforward(emb_dim, output_dim, 2) + self.output = ml_util.init_feedforward(emb_dim, output_dim, 2) # Initialize weights self._init_weights() @@ -486,55 +525,3 @@ def init_conv_layer(in_channels, out_channels, kernel_size, use_double_conv): # Pooling layers.append(nn.MaxPool3d(kernel_size=2)) return nn.Sequential(*layers) - - -def init_feedforward(input_dim, output_dim, n_layers): - """ - Initializes a feed forward neural network. - - Parameters - ---------- - input_dim : int - Dimension of the input. - output_dim : int - Dimension of the output of this network. - n_layers : int - Number of layers in the network. - """ - layers = list() - input_dim_i = input_dim - output_dim_i = input_dim // 2 - for i in range(n_layers): - layers.append(init_mlp(input_dim_i, input_dim_i * 2, output_dim_i)) - input_dim_i = input_dim_i // 2 - output_dim_i = output_dim_i // 2 if i < n_layers - 2 else output_dim - return nn.Sequential(*layers) - - -def init_mlp(input_dim, hidden_dim, output_dim, dropout=0.1): - """ - Initializes a multi-layer perceptron (MLP). - - Parameters - ---------- - input_dim : int - Dimension of input feature vector. - hidden_dim : int - Dimension of embedded feature vector. - output_dim : int - Dimension of output feature vector. - dropout : float, optional - Fraction of values to randomly drop during training. Default is 0.1. - - Returns - ------- - mlp : nn.Sequential - Multi-layer perception network. - """ - mlp = nn.Sequential( - nn.Linear(input_dim, hidden_dim), - nn.GELU(), - nn.Dropout(p=dropout), - nn.Linear(hidden_dim, output_dim), - ) - return mlp diff --git a/src/neuron_proofreader/merge_proofreading/merge_datasets.py b/src/neuron_proofreader/merge_proofreading/merge_datasets.py index f5a0b4b..0249f40 100644 --- a/src/neuron_proofreader/merge_proofreading/merge_datasets.py +++ b/src/neuron_proofreader/merge_proofreading/merge_datasets.py @@ -77,7 +77,7 @@ def __init__( self, merge_sites_df, anisotropy=(1.0, 1.0, 1.0), - brightness_clip=400, + brightness_clip=600, subgraph_radius=100, node_spacing=5, patch_shape=(128, 128, 128), @@ -324,11 +324,11 @@ def __getitem__(self, idx): # Stack image channels try: - patches = np.stack([img_patch, segment_mask], axis=0) + patches = img_patch + 2 * segment_mask except ValueError: img_patch = img_util.pad_to_shape(img_patch, self.patch_shape) - patches = np.stack([img_patch, segment_mask], axis=0) - return patches, subgraph, label + patches = img_patch + segment_mask + return patches[np.newaxis], subgraph, label def sample_brain_id(self): """ @@ -940,7 +940,7 @@ def __init__( # Instance attributes self.is_multimodal = is_multimodal self.modality = modality - self.patches_shape = (2,) + self.dataset.patch_shape + self.patches_shape = (1,) + self.dataset.patch_shape self.use_shuffle = use_shuffle # --- Core Routines --- diff --git a/src/neuron_proofreader/utils/img_util.py b/src/neuron_proofreader/utils/img_util.py index ba38678..fc7045b 100644 --- a/src/neuron_proofreader/utils/img_util.py +++ b/src/neuron_proofreader/utils/img_util.py @@ -597,12 +597,13 @@ def normalize(img): Returns ------- - numpy.ndarray + img : numpy.ndarray Normalized image. """ try: - mn, mx = np.percentile(img, [1, 99.9]) - return np.clip((img - mn) / (mx - mn + 1e-5), 0, 1) + #mn, mx = np.percentile(img, [1, 99.9]) + #return np.clip((img - mn) / (mx - mn + 1e-5), 0, 1) + return (img - img.mean()) / (img.std() + 1e-8) except Exception as e: print("Image Normalization Failed:", e) return np.zeros(img.shape) diff --git a/src/neuron_proofreader/utils/ml_util.py b/src/neuron_proofreader/utils/ml_util.py index 08cffd7..b79b30e 100644 --- a/src/neuron_proofreader/utils/ml_util.py +++ b/src/neuron_proofreader/utils/ml_util.py @@ -14,12 +14,66 @@ import networkx as nx import numpy as np import torch +import torch.nn as nn from neuron_proofreader.utils import util GNN_DEPTH = 2 +# --- Architectures --- +def init_feedforward(input_dim, output_dim, n_layers): + """ + Initializes a feed forward neural network. + + Parameters + ---------- + input_dim : int + Dimension of the input. + output_dim : int + Dimension of the output of this network. + n_layers : int + Number of layers in the network. + """ + layers = list() + input_dim_i = input_dim + output_dim_i = input_dim // 2 + for i in range(n_layers): + layers.append(init_mlp(input_dim_i, input_dim_i * 2, output_dim_i)) + input_dim_i = input_dim_i // 2 + output_dim_i = output_dim_i // 2 if i < n_layers - 2 else output_dim + return nn.Sequential(*layers) + + +def init_mlp(input_dim, hidden_dim, output_dim, dropout=0.1): + """ + Initializes a multi-layer perceptron (MLP). + + Parameters + ---------- + input_dim : int + Dimension of input feature vector. + hidden_dim : int + Dimension of embedded feature vector. + output_dim : int + Dimension of output feature vector. + dropout : float, optional + Fraction of values to randomly drop during training. Default is 0.1. + + Returns + ------- + mlp : nn.Sequential + Multi-layer perception network. + """ + mlp = nn.Sequential( + nn.Linear(input_dim, hidden_dim), + nn.GELU(), + nn.Dropout(p=dropout), + nn.Linear(hidden_dim, output_dim), + ) + return mlp + + # --- Batch Generation --- def get_batch(graph, proposals, batch_size, flagged_proposals=set()): """