Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

End to end #64

Merged
merged 11 commits into from
Jan 24, 2024
9 changes: 5 additions & 4 deletions src/server/dcp_server/config.cfg
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
{
"setup": {
"segmentation": "GeneralSegmentation",
"model_to_use": "UNet",
"accepted_types": [".jpg", ".jpeg", ".png", ".tiff", ".tif"],
"seg_name_string": "_seg"
},

"service": {
"model_to_use": "CellposePatchCNN",
"save_model_path": "mito",
"runner_name": "cellpose_runner",
"runner_name": "bento_runner",
"bento_model_path": "unetN",
"service_name": "data-centric-platform",
"port": 7010
},
Expand All @@ -20,6 +20,7 @@
"classifier":{
"in_channels": 1,
"num_classes": 3,
"features":[64,128,256,512],
"black_bg": "False",
"include_mask": "False"
}
Expand All @@ -41,7 +42,7 @@
"noise_intensity": 5,
"num_classes": 3
},
"n_epochs": 10,
"n_epochs": 2,
"lr": 0.001,
"batch_size": 1,
"optimizer": "Adam"
Expand Down
73 changes: 53 additions & 20 deletions src/server/dcp_server/fsimagestorage.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import os
import numpy as np
from skimage.io import imread, imsave
from skimage.transform import resize, rescale
import os
from skimage.color import rgb2gray

from dcp_server import utils

# Import configuration
Expand All @@ -10,8 +11,9 @@
class FilesystemImageStorage():
"""Class used to deal with everything related to image storing and processing - loading, saving, transforming...
"""
def __init__(self, data_root):
def __init__(self, data_root, model_used):
self.root_dir = data_root
self.model_used = model_used

def load_image(self, cur_selected_img, is_gray=True):
"""Load the image (using skiimage)
Expand Down Expand Up @@ -143,16 +145,22 @@ def rescale_image(self, img, height, width, channel_ax=None, order=2):
:type channel_ax: int
:return: rescaled image
:rtype: ndarray
"""
max_dim = max(height, width)
rescale_factor = max_dim/512
return rescale(img, 1/rescale_factor, order=order, channel_axis=channel_ax)
"""
if self.model_used == "UNet":
height_pad = (height//16 + 1)*16 - height
width_pad = (width//16 + 1)*16 - width
return np.pad(img, ((0, height_pad),(0, width_pad)))
else:
# Cellpose segmentation runs best with 512 size? TODO: check
max_dim = max(height, width)
rescale_factor = max_dim/512
return rescale(img, 1/rescale_factor, order=order, channel_axis=channel_ax)

def resize_image(self, img, height, width, channel_ax=None, order=2):
"""resize image
def resize_mask(self, mask, height, width, channel_ax=None, order=2):
"""resize the mask so it matches the original image size

:param img: image
:type img: ndarray
:param mask: image
:type mask: ndarray
:param height: height of the image
:type height: int
:param width: width of the image
Expand All @@ -161,13 +169,30 @@ def resize_image(self, img, height, width, channel_ax=None, order=2):
:type order: int
:return: resized image
:rtype: ndarray
"""
if channel_ax is not None:
n_channel_dim = img.shape[channel_ax]
output_size = [height, width]
output_size.insert(channel_ax, n_channel_dim)
else: output_size = [height, width]
return resize(img, output_size, order=order)
"""

if self.model_used == "UNet":
# we assume an order C, H, W
if channel_ax is not None and channel_ax==0:
height_pad = mask.shape[1] - height
width_pad = mask.shape[2]- width
return mask[:, :-height_pad, :-width_pad]
elif channel_ax is not None and channel_ax==2:
height_pad = mask.shape[0] - height
width_pad = mask.shape[1]- width
return mask[:-height_pad, :-width_pad, :]
elif channel_ax is not None and channel_ax==1:
height_pad = mask.shape[2] - height
width_pad = mask.shape[0]- width
return mask[:-width_pad, :, :-height_pad]

else:
if channel_ax is not None:
n_channel_dim = mask.shape[channel_ax]
output_size = [height, width]
output_size.insert(channel_ax, n_channel_dim)
else: output_size = [height, width]
return resize(mask, output_size, order=order)

def prepare_images_and_masks_for_training(self, train_img_mask_pairs):
"""Image and mask processing for training.
Expand All @@ -180,6 +205,14 @@ def prepare_images_and_masks_for_training(self, train_img_mask_pairs):
imgs=[]
masks=[]
for img_file, mask_file in train_img_mask_pairs:
imgs.append(self.load_image(img_file))
masks.append(imread(mask_file))
img = self.load_image(img_file)
mask = imread(mask_file)
if self.model_used == "UNet":
# Unet only accepts image sizes divisable by 16
height_pad = (img.shape[0]//16 + 1)*16 - img.shape[0]
width_pad = (img.shape[1]//16 + 1)*16 - img.shape[1]
img = np.pad(img, ((0, height_pad),(0, width_pad)))
mask = np.pad(mask, ((0, 0), (0, height_pad),(0, width_pad)))
imgs.append(img)
masks.append(mask)
return imgs, masks
179 changes: 174 additions & 5 deletions src/server/dcp_server/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from copy import deepcopy
from tqdm import tqdm
import numpy as np
from scipy.ndimage import label

from cellpose.metrics import aggregated_jaccard_index

Expand All @@ -18,7 +19,7 @@ class CustomCellposeModel(models.CellposeModel, nn.Module):
"""Custom cellpose model inheriting the attributes and functions from the original CellposeModel and implementing
additional attributes and methods needed for this project.
"""
def __init__(self, model_config, train_config, eval_config):
def __init__(self, model_config, train_config, eval_config, model_name):
"""Constructs all the necessary attributes for the CustomCellposeModel.
The model inherits all attributes from the parent class, the init allows to pass any other argument that the parent class accepts.
Please, visit here https://cellpose.readthedocs.io/en/latest/api.html#id4 for more details on arguments accepted.
Expand All @@ -38,6 +39,7 @@ def __init__(self, model_config, train_config, eval_config):
self.mkldnn = False # otherwise we get error with saving model
self.train_config = train_config
self.eval_config = eval_config
self.model_name = model_name

def update_configs(self, train_config, eval_config):
self.train_config = train_config
Expand All @@ -64,7 +66,7 @@ def train(self, imgs, masks):
:type masks: List[np.ndarray]
"""

if not isinstance(masks, np.ndarray):
if not isinstance(masks, np.ndarray): # TODO Remove: all these should be taken care of in fsimagestorage
masks = np.array(masks)

if masks[0].shape[0] == 2:
Expand All @@ -73,7 +75,7 @@ def train(self, imgs, masks):
super().train(train_data=deepcopy(imgs), train_labels=masks, **self.train_config["segmentor"])

pred_masks = [self.eval(img) for img in masks]
self.metric = np.mean(aggregated_jaccard_index(masks, pred_masks))
self.metric = np.mean(aggregated_jaccard_index(masks, pred_masks)) # TODO move metric computation
# self.loss = self.loss_fn(masks, pred_masks)

def masks_to_outlines(self, mask):
Expand Down Expand Up @@ -214,17 +216,19 @@ class CellposePatchCNN(nn.Module):
Cellpose & patches of cells and then cnn to classify each patch
"""

def __init__(self, model_config, train_config, eval_config):
def __init__(self, model_config, train_config, eval_config, model_name):
super().__init__()

self.model_config = model_config
self.train_config = train_config
self.eval_config = eval_config
self.model_name = model_name

# Initialize the cellpose model and the classifier
self.segmentor = CustomCellposeModel(self.model_config,
self.train_config,
self.eval_config)
self.eval_config,
"Cellpose")
self.classifier = CellClassifierFCNN(self.model_config,
self.train_config,
self.eval_config)
Expand Down Expand Up @@ -287,6 +291,171 @@ def eval(self, img):
return final_mask



class UNet(nn.Module):

"""
Unet is a convolutional neural network architecture for semantic segmentation.

Args:
in_channels (int): Number of input channels (default: 3).
out_channels (int): Number of output channels (default: 4).
features (list): List of feature channels for each encoder level (default: [64,128,256,512]).
"""

class DoubleConv(nn.Module):
"""
DoubleConv module consists of two consecutive convolutional layers with
batch normalization and ReLU activation functions.
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
"""

def __init__(self, in_channels, out_channels):
super().__init__()

self.conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
)

def forward(self, x):
return self.conv(x)


def __init__(self, model_config, train_config, eval_config, model_name):

super().__init__()
self.model_config = model_config
self.train_config = train_config
self.eval_config = eval_config
self.model_name = model_name
'''
self.in_channels = self.model_config["unet"]["in_channels"]
self.out_channels = self.model_config["unet"]["out_channels"]
self.features = self.model_config["unet"]["features"]
'''
self.in_channels = self.model_config["classifier"]["in_channels"]
self.out_channels = self.model_config["classifier"]["num_classes"] + 1
self.features = self.model_config["classifier"]["features"]

self.encoder = nn.ModuleList()
self.decoder = nn.ModuleList()

self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

# Encoder
for feature in self.features:
self.encoder.append(
UNet.DoubleConv(self.in_channels, feature)
)
self.in_channels = feature

# Decoder
for feature in self.features[::-1]:
self.decoder.append(
nn.ConvTranspose2d(
feature*2, feature, kernel_size=2, stride=2
)
)
self.decoder.append(
UNet.DoubleConv(feature*2, feature)
)

self.bottle_neck = UNet.DoubleConv(self.features[-1], self.features[-1]*2)
self.output_conv = nn.Conv2d(self.features[0], self.out_channels, kernel_size=1)

def forward(self, x):
skip_connections = []
for encoder in self.encoder:
x = encoder(x)
skip_connections.append(x)
x = self.pool(x)

x = self.bottle_neck(x)
skip_connections = skip_connections[::-1]

for i in np.arange(len(self.decoder), step=2):
x = self.decoder[i](x)
skip_connection = skip_connections[i//2]
concatenate_skip = torch.cat((skip_connection, x), dim=1)
x = self.decoder[i+1](concatenate_skip)

return self.output_conv(x)

def train(self, imgs, masks):

lr = self.train_config["classifier"]['lr']
epochs = self.train_config["classifier"]['n_epochs']
batch_size = self.train_config["classifier"]['batch_size']

# Convert input images and labels to tensors
# normalize images
imgs = [(img-np.min(img))/(np.max(img)-np.min(img)) for img in imgs]
# convert to tensor
imgs = torch.stack([torch.from_numpy(img.astype(np.float32)) for img in imgs])
imgs = imgs.unsqueeze(1) if imgs.ndim == 3 else imgs

# Classification label mask
masks = np.array(masks)
masks = torch.stack([torch.from_numpy(mask[1]) for mask in masks])

# Create a training dataset and dataloader
train_dataset = TensorDataset(imgs, masks)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size)

loss_fn = nn.CrossEntropyLoss()
optimizer = Adam(params=self.parameters(), lr=lr)

for _ in tqdm(range(epochs), desc="Running UNet training"):

self.loss = 0

for imgs, masks in train_dataloader:
imgs = imgs.float()
masks = masks.long()

#forward path
preds = self.forward(imgs)
loss = loss_fn(preds, masks)

#backward path
optimizer.zero_grad()
loss.backward()
optimizer.step()

self.loss += loss.detach().mean().item()

self.loss /= len(train_dataloader)

def eval(self, img):
"""
Evaluate the model on the provided image and return the predicted label.
Input:
img: np.ndarray[np.uint8]
Output: y_hat - The predicted label
"""
with torch.no_grad():
# normalise
img = (img-np.min(img))/(np.max(img)-np.min(img))
img = torch.from_numpy(img).float().unsqueeze(0)

img = img.unsqueeze(1) if img.ndim == 3 else img

preds = self.forward(img)
class_mask = torch.argmax(preds, 1).numpy()[0]

instance_mask = label((class_mask > 0).astype(int))[0]

final_mask = np.stack((instance_mask, class_mask), axis=self.eval_config['mask_channel_axis']).astype(np.uint16)

return final_mask


# class CustomSAMModel():
# # https://github.com/facebookresearch/segment-anything/blob/main/notebooks/automatic_mask_generator_example.ipynb
Expand Down
Loading