From c5b734c56244f1ac0ac25ef233b063bcf2b3ee71 Mon Sep 17 00:00:00 2001 From: Christina Bukas Date: Mon, 26 Feb 2024 10:41:27 +0100 Subject: [PATCH 01/26] split models file, still have circular import --- .gitignore | 2 +- src/server/dcp_server/config.cfg | 1 + src/server/dcp_server/models.py | 816 ------------------ src/server/dcp_server/models/__init__.py | 11 + .../dcp_server/models/cellpose_patchCNN.py | 365 ++++++++ .../dcp_server/models/custom_cellpose.py | 121 +++ src/server/dcp_server/models/model.py | 27 + src/server/dcp_server/models/multicellpose.py | 130 +++ src/server/dcp_server/models/unet.py | 211 +++++ 9 files changed, 867 insertions(+), 817 deletions(-) delete mode 100644 src/server/dcp_server/models.py create mode 100644 src/server/dcp_server/models/__init__.py create mode 100644 src/server/dcp_server/models/cellpose_patchCNN.py create mode 100644 src/server/dcp_server/models/custom_cellpose.py create mode 100644 src/server/dcp_server/models/model.py create mode 100644 src/server/dcp_server/models/multicellpose.py create mode 100644 src/server/dcp_server/models/unet.py diff --git a/.gitignore b/.gitignore index 8d5b07d1..605cde5e 100644 --- a/.gitignore +++ b/.gitignore @@ -155,4 +155,4 @@ docs/ test-napari.pub data/ BentoML/ -models/ + diff --git a/src/server/dcp_server/config.cfg b/src/server/dcp_server/config.cfg index 24f44ea0..b2d05cff 100644 --- a/src/server/dcp_server/config.cfg +++ b/src/server/dcp_server/config.cfg @@ -15,6 +15,7 @@ "model": { "segmentor": { + "model_class": "CustomCellposeModel", "model_type": "cyto" }, diff --git a/src/server/dcp_server/models.py b/src/server/dcp_server/models.py deleted file mode 100644 index 187f6395..00000000 --- a/src/server/dcp_server/models.py +++ /dev/null @@ -1,816 +0,0 @@ -from cellpose import models, utils -import torch -from torch import nn -from torch.optim import Adam -from torch.utils.data import TensorDataset, DataLoader -from torchmetrics import F1Score -from copy import deepcopy -from tqdm import tqdm -import numpy as np -from scipy.ndimage import label -from skimage.measure import label as label_mask - - -from sklearn.ensemble import RandomForestClassifier -from sklearn.metrics import f1_score, log_loss -from sklearn.exceptions import NotFittedError - -from cellpose.metrics import aggregated_jaccard_index -from cellpose.dynamics import labels_to_flows -#from segment_anything import SamPredictor, sam_model_registry -#from segment_anything.automatic_mask_generator import SamAutomaticMaskGenerator - -from dcp_server.utils import get_centered_patches, find_max_patch_size, create_patch_dataset, create_dataset_for_rf - -class CustomCellposeModel(models.CellposeModel, nn.Module): - """Custom cellpose model inheriting the attributes and functions from the original CellposeModel and implementing - additional attributes and methods needed for this project. - """ - def __init__(self, model_config, train_config, eval_config, model_name): - """Constructs all the necessary attributes for the CustomCellposeModel. - The model inherits all attributes from the parent class, the init allows to pass any other argument that the parent class accepts. - Please, visit here https://cellpose.readthedocs.io/en/latest/api.html#id4 for more details on arguments accepted. - - :param model_config: dictionary passed from the config file with all the arguments for the __init__ function and model initialization - :type model_config: dict - :param train_config: dictionary passed from the config file with all the arguments for training function - :type train_config: dict - :param eval_config: dictionary passed from the config file with all the arguments for eval function - :type eval_config: dict - """ - - # Initialize the cellpose model - #super().__init__(**model_config["segmentor"]) - nn.Module.__init__(self) - models.CellposeModel.__init__(self, **model_config["segmentor"]) - self.mkldnn = False # otherwise we get error with saving model - self.train_config = train_config - self.eval_config = eval_config - self.loss = 1e6 - self.model_name = model_name - - def update_configs(self, train_config, eval_config): - """Update the training and evaluation configurations. - - :param train_config: Dictionary containing the training configuration. - :type train_config: dict - :param eval_config: Dictionary containing the evaluation configuration. - :type eval_config: dict - """ - self.train_config = train_config - self.eval_config = eval_config - - def eval_all_outputs(self, img): - """Get all outputs of the model when running eval. - - :param img: Input image for segmentation. - :type img: numpy.ndarray - :return: Probability mask for the input image. - :rtype: numpy.ndarray - """ - - return super().eval(x=img, **self.eval_config["segmentor"]) - - def eval(self, img): - """Evaluate the model - find mask of the given image - Calls the original eval function. - - :param img: image to evaluate on - :type img: np.ndarray - :return: mask of the image, list of 2D arrays, or single 3D array (if do_3D=True) labelled image. - :rtype: np.ndarray - """ - return super().eval(x=img, **self.eval_config["segmentor"])[0] # 0 to take only mask - - def train(self, imgs, masks): - """Trains the given model - Calls the original train function. - - :param imgs: images to train on (training data) - :type imgs: List[np.ndarray] - :param masks: masks of the given images (training labels) - :type masks: List[np.ndarray] - """ - - if not isinstance(masks, np.ndarray): # TODO Remove: all these should be taken care of in fsimagestorage - masks = np.array(masks) - - if masks[0].shape[0] == 2: - masks = list(masks[:,0,...]) - super().train(train_data=deepcopy(imgs), train_labels=masks, **self.train_config["segmentor"]) - - # compute loss and metric - true_bin_masks = [mask>0 for mask in masks] # get binary masks - true_flows = labels_to_flows(masks) # get cellpose flows - # get predicted flows and cell probability - pred_masks = [] - pred_flows = [] - true_lbl = [] - for idx, img in enumerate(imgs): - mask, flows, _ = super().eval(x=img, **self.eval_config["segmentor"]) - pred_masks.append(mask) - pred_flows.append(np.stack([flows[1][0], flows[1][1], flows[2]])) # stack cell probability map, horizontal and vertical flow - true_lbl.append(np.stack([true_bin_masks[idx], true_flows[idx][2], true_flows[idx][3]])) - - true_lbl = np.stack(true_lbl) - pred_flows=np.stack(pred_flows) - pred_flows = torch.from_numpy(pred_flows).float().to('cpu') - # compute loss, combination of mse for flows and bce for cell probability - self.loss = self.loss_fn(true_lbl, pred_flows) - self.metric = np.mean(aggregated_jaccard_index(masks, pred_masks)) - - def masks_to_outlines(self, mask): - """ get outlines of masks as a 0-1 array - Calls the original cellpose.utils.masks_to_outlines function - - :param mask: int, 2D or 3D array, mask of an image - :type mask: ndarray - :return: outlines - :rtype: ndarray - """ - return utils.masks_to_outlines(mask) #[True, False] outputs - - -class CellClassifierFCNN(nn.Module): - - """Fully convolutional classifier for cell images. NOTE -> This model cannot be used as a standalone model in DCP - - :param model_config: Model configuration. - :type model_config: dict - :param train_config: Training configuration. - :type train_config: dict - :param eval_config: Evaluation configuration. - :type eval_config: dict - - """ - - def __init__(self, model_config, train_config, eval_config): - """Initialize the fully convolutional classifier. - - :param model_config: Model configuration. - :type model_config: dict - :param train_config: Training configuration. - :type train_config: dict - :param eval_config: Evaluation configuration. - :type eval_config: dict - """ - super().__init__() - - self.in_channels = model_config["classifier"].get("in_channels",1) - self.num_classes = model_config["classifier"].get("num_classes",3) - - self.train_config = train_config["classifier"] - self.eval_config = eval_config["classifier"] - - self.include_mask = model_config["classifier"]["include_mask"] - self.in_channels = self.in_channels + 1 if self.include_mask else self.in_channels - - self.layer1 = nn.Sequential( - nn.Conv2d(self.in_channels, 16, 3, 2, 5), - nn.BatchNorm2d(16), - nn.ReLU(), - nn.Dropout2d(p=0.2), - ) - - self.layer2 = nn.Sequential( - nn.Conv2d(16, 64, 3, 1, 1), - nn.BatchNorm2d(64), - nn.ReLU(), - nn.Dropout2d(p=0.2), - ) - - self.layer3 = nn.Sequential( - nn.Conv2d(64, 128, 3, 2, 4), - nn.BatchNorm2d(128), - nn.ReLU(), - nn.Dropout2d(p=0.2), - ) - self.final_conv = nn.Conv2d(128, self.num_classes, 1) - self.pooling = nn.AdaptiveMaxPool2d(1) - - self.metric_fn = F1Score(num_classes=self.num_classes, task="multiclass") - - def update_configs(self, train_config, eval_config): - """ - Update the training and evaluation configurations. - - :param train_config: Dictionary containing the training configuration. - :type train_config: dict - :param eval_config: Dictionary containing the evaluation configuration. - :type eval_config: dict - """ - self.train_config = train_config - self.eval_config = eval_config - - def forward(self, x): - """ Performs forward pass of the CellClassifierFCNN. - - :param x: Input tensor. - :type x: torch.Tensor - :return: Output tensor after passing through the network. - :rtype: torch.Tensor - """ - - x = self.layer1(x) - x = self.layer2(x) - x = self.layer3(x) - - x = self.final_conv(x) - x = self.pooling(x) - x = x.view(x.size(0), -1) - return x - - def train (self, imgs, labels): - """Trains the given model - - :param imgs: List of input images with shape (3, dx, dy). - :type imgs: List[np.ndarray[np.uint8]] - :param labels: List of classification labels. - :type labels: List[int] - """ - - lr = self.train_config['lr'] - epochs = self.train_config['n_epochs'] - batch_size = self.train_config['batch_size'] - # optimizer_class = self.train_config['optimizer'] - - # Convert input images and labels to tensors - - # normalize images - imgs = [(img-np.min(img))/(np.max(img)-np.min(img)) for img in imgs] - # convert to tensor - imgs = torch.stack([torch.from_numpy(img.astype(np.float32)) for img in imgs]) - imgs = torch.permute(imgs, (0, 3, 1, 2)) - # Your classification label mask - labels = torch.LongTensor([label for label in labels]) - - # Create a training dataset and dataloader - train_dataset = TensorDataset(imgs, labels) - train_dataloader = DataLoader(train_dataset, batch_size=batch_size) - - loss_fn = nn.CrossEntropyLoss() - optimizer = Adam(params=self.parameters(), lr=lr) #eval(f'{optimizer_class}(params={self.parameters()}, lr={lr})') - # TODO check if we should replace self.parameters with super.parameters() - - for _ in tqdm(range(epochs), desc="Running CellClassifierFCNN training"): - self.loss, self.metric = 0, 0 - for data in train_dataloader: - imgs, labels = data - - optimizer.zero_grad() - preds = self.forward(imgs) - - l = loss_fn(preds, labels) - l.backward() - optimizer.step() - self.loss += l.item() - - self.metric += self.metric_fn(preds, labels) - - self.loss /= len(train_dataloader) - self.metric /= len(train_dataloader) - - def eval(self, img): - """Evaluates the model on the provided image and return the predicted label. - - :param img: Input image for evaluation. - :type img: np.ndarray[np.uint8] - :return: y_hat - predicted label. - :rtype: torch.Tensor - """ - # normalise - img = (img-np.min(img))/(np.max(img)-np.min(img)) - # convert to tensor - img = torch.permute(torch.tensor(img.astype(np.float32)), (2, 0, 1)).unsqueeze(0) - preds = self.forward(img) - y_hat = torch.argmax(preds, 1) - return y_hat - - -class CellposePatchCNN(nn.Module): - """ - Cellpose & patches of cells and then cnn to classify each patch - """ - - def __init__(self, model_config, train_config, eval_config, model_name): - """Constructs all the necessary attributes for the CellposePatchCNN - - :param model_config: Model configuration. - :type model_config: dict - :param train_config: Training configuration. - :type train_config: dict - :param eval_config: Evaluation configuration. - :type eval_config: dict - :param model_name: Name of the model. - :type model_name: str - """ - super().__init__() - - self.model_config = model_config - self.train_config = train_config - self.eval_config = eval_config - self.include_mask = self.model_config["classifier"]["include_mask"] - self.model_name = model_name - self.classifier_class = self.model_config.get("classifier").get("model_class", "CellClassifierFCNN") - - # Initialize the cellpose model and the classifier - self.segmentor = CustomCellposeModel(self.model_config, - self.train_config, - self.eval_config, - "Cellpose") - - if self.classifier_class == "FCNN": - self.classifier = CellClassifierFCNN(self.model_config, - self.train_config, - self.eval_config) - - elif self.classifier_class == "RandomForest": - self.classifier = CellClassifierShallowModel(self.model_config, - self.train_config, - self.eval_config) - # make sure include mask is set to False if we are using the random forest model - self.include_mask = False - - def update_configs(self, train_config, eval_config): - """Update the training and evaluation configurations. - - :param train_config: Dictionary containing the training configuration. - :type train_config: dict - :param eval_config: Dictionary containing the evaluation configuration. - :type eval_config: dict - """ - self.train_config = train_config - self.eval_config = eval_config - - def train(self, imgs, masks): - """Trains the given model. First trains the segmentor and then the clasiffier. - - :param imgs: images to train on (training data) - :type imgs: List[np.ndarray] - :param masks: masks of the given images (training labels) - :type masks: List[np.ndarray] of same shape as output of eval, i.e. one channel instances, - second channel classes, so [2, H, W] or [2, 3, H, W] for 3D - """ - # train cellpose - masks = np.array(masks) - masks_instances = list(masks[:,0,...]) #[mask.sum(-1) for mask in masks] if masks[0].ndim == 3 else masks - self.segmentor.train(deepcopy(imgs), masks_instances) - # create patch dataset to train classifier - masks_classes = list(masks[:,1,...]) #[((mask > 0) * np.arange(1, 4)).sum(-1) for mask in masks] - patches, patch_masks, labels = create_patch_dataset(imgs, - masks_classes, - masks_instances, - noise_intensity = self.train_config["classifier"]["train_data"]["noise_intensity"], - max_patch_size = self.train_config["classifier"]["train_data"]["patch_size"], - include_mask = self.include_mask) - x = patches - if self.classifier_class == "RandomForest": - x = create_dataset_for_rf(patches, patch_masks) - # train classifier - self.classifier.train(x, labels) - # and compute metric and loss - self.metric = (self.segmentor.metric + self.classifier.metric) / 2 - self.loss = (self.segmentor.loss + self.classifier.loss)/2 - - def eval(self, img): - """Evaluate the model on the provided image and return the final mask. - - :param img: Input image for evaluation. - :type img: np.ndarray[np.uint8] - :return: Final mask containing instance mask and class masks. - :rtype: np.ndarray[np.uint16] - """ - # TBD we assume image is 2D [H, W] (see fsimage storage) - # The final mask which is returned should have - # first channel the output of cellpose and the rest are the class channels - with torch.no_grad(): - # get instance mask from segmentor - instance_mask = self.segmentor.eval(img) - # find coordinates of detected objects - class_mask = np.zeros(instance_mask.shape) - - max_patch_size = self.eval_config["classifier"]["data"]["patch_size"] - if max_patch_size is None: max_patch_size = find_max_patch_size(instance_mask) - noise_intensity = self.eval_config["classifier"]["data"]["noise_intensity"] - - # get patches centered around detected objects - patches, patch_masks, instance_labels, _ = get_centered_patches(img, - instance_mask, - max_patch_size, - noise_intensity=noise_intensity, - include_mask=self.include_mask) - x = patches - if self.classifier_class == "RandomForest": - x = create_dataset_for_rf(patches, patch_masks) - # loop over patches and create classification mask - for idx in range(len(x)): - patch_class = self.classifier.eval(x[idx]) - # Assign predicted class to corresponding location in final_mask - patch_class = patch_class.item() if isinstance(patch_class, torch.Tensor) else patch_class - class_mask[instance_mask==instance_labels[idx]] = patch_class + 1 - # Apply mask to final_mask, retaining only regions where cellpose_mask is greater than 0 - final_mask = np.stack((instance_mask, class_mask), axis=self.eval_config['mask_channel_axis']).astype(np.uint16) # size 2xHxW - - return final_mask - -class CellClassifierShallowModel: - """ - This class implements a shallow model for cell classification using scikit-learn. - """ - - def __init__(self, model_config, train_config, eval_config): - """Constructs all the necessary attributes for the CellClassifierShallowModel - - :param model_config: Model configuration. - :type model_config: dict - :param train_config: Training configuration. - :type train_config: dict - :param eval_config: Evaluation configuration. - :type eval_config: dict - """ - - self.model_config = model_config - self.train_config = train_config - self.eval_config = eval_config - - self.model = RandomForestClassifier() # TODO chnage config so RandomForestClassifier accepts input params - - - def train(self, X_train, y_train): - """Trains the model using the provided training data. - - :param X_train: Features of the training data. - :type X_train: numpy.ndarray - :param y_train: Labels of the training data. - :type y_train: numpy.ndarray - """ - - self.model.fit(X_train,y_train) - - y_hat = self.model.predict(X_train) - y_hat_proba = self.model.predict_proba(X_train) - - self.metric = f1_score(y_train, y_hat, average='micro') - # Binary Cross Entrop Loss - self.loss = log_loss(y_train, y_hat_proba) - - - def eval(self, X_test): - """Evaluates the model on the provided test data. - - :param X_test: Features of the test data. - :type X_test: numpy.ndarray - :return: y_hat - predicted labels. - :rtype: numpy.ndarray - """ - - X_test = X_test.reshape(1,-1) - - try: - y_hat = self.model.predict(X_test) - except NotFittedError as e: - y_hat = np.zeros(X_test.shape[0]) - - return y_hat - -class UNet(nn.Module): - - """ - Unet is a convolutional neural network architecture for semantic segmentation. - - :param in_channels: Number of input channels (default: 3). - :type in_channels: int - :param out_channels: Number of output channels (default: 4). - :type out_channels: int - :param features: List of feature channels for each encoder level (default: [64,128,256,512]). - :type features: list - """ - - class DoubleConv(nn.Module): - """ - DoubleConv module consists of two consecutive convolutional layers with - batch normalization and ReLU activation functions. - """ - - def __init__(self, in_channels, out_channels): - """ - Initialize DoubleConv module. - - :param in_channels: Number of input channels. - :type in_channels: int - :param out_channels: Number of output channels. - :type out_channels: int - """ - - super().__init__() - - self.conv = nn.Sequential( - nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False), - nn.BatchNorm2d(out_channels), - nn.ReLU(), - nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False), - nn.BatchNorm2d(out_channels), - nn.ReLU(), - ) - - def forward(self, x): - """Forward pass through the DoubleConv module. - - :param x: Input tensor. - :type x: torch.Tensor - """ - return self.conv(x) - - - def __init__(self, model_config, train_config, eval_config, model_name): - """Constructs all the necessary attributes for the UNet model. - - - :param model_config: Model configuration. - :type model_config: dict - :param train_config: Training configuration. - :type train_config: dict - :param eval_config: Evaluation configuration. - :type eval_config: dict - :param model_name: Name of the model. - :type model_name: str - """ - - super().__init__() - self.model_config = model_config - self.train_config = train_config - self.eval_config = eval_config - self.model_name = model_name - ''' - self.in_channels = self.model_config["unet"]["in_channels"] - self.out_channels = self.model_config["unet"]["out_channels"] - self.features = self.model_config["unet"]["features"] - ''' - self.in_channels = self.model_config["classifier"]["in_channels"] - self.out_channels = self.model_config["classifier"]["num_classes"] + 1 - self.features = self.model_config["classifier"]["features"] - - self.encoder = nn.ModuleList() - self.decoder = nn.ModuleList() - - self.pool = nn.MaxPool2d(kernel_size=2, stride=2) - - # Encoder - for feature in self.features: - self.encoder.append( - UNet.DoubleConv(self.in_channels, feature) - ) - self.in_channels = feature - - # Decoder - for feature in self.features[::-1]: - self.decoder.append( - nn.ConvTranspose2d( - feature*2, feature, kernel_size=2, stride=2 - ) - ) - self.decoder.append( - UNet.DoubleConv(feature*2, feature) - ) - - self.bottle_neck = UNet.DoubleConv(self.features[-1], self.features[-1]*2) - self.output_conv = nn.Conv2d(self.features[0], self.out_channels, kernel_size=1) - - def forward(self, x): - """ - Forward pass of the UNet model. - - :param x: Input tensor. - :type x: torch.Tensor - :return: Output tensor. - :rtype: torch.Tensor - """ - skip_connections = [] - for encoder in self.encoder: - x = encoder(x) - skip_connections.append(x) - x = self.pool(x) - - x = self.bottle_neck(x) - skip_connections = skip_connections[::-1] - - for i in np.arange(len(self.decoder), step=2): - x = self.decoder[i](x) - skip_connection = skip_connections[i//2] - concatenate_skip = torch.cat((skip_connection, x), dim=1) - x = self.decoder[i+1](concatenate_skip) - - return self.output_conv(x) - - def train(self, imgs, masks): - """ - Trains the UNet model using the provided images and masks. - - :param imgs: Input images for training. - :type imgs: list[numpy.ndarray] - :param masks: Masks corresponding to the input images. - :type masks: list[numpy.ndarray] - """ - - lr = self.train_config["classifier"]['lr'] - epochs = self.train_config["classifier"]['n_epochs'] - batch_size = self.train_config["classifier"]['batch_size'] - - # Convert input images and labels to tensors - # normalize images - imgs = [(img-np.min(img))/(np.max(img)-np.min(img)) for img in imgs] - # convert to tensor - imgs = torch.stack([torch.from_numpy(img.astype(np.float32)) for img in imgs]) - imgs = imgs.unsqueeze(1) if imgs.ndim == 3 else imgs - - # Classification label mask - masks = np.array(masks) - masks = torch.stack([torch.from_numpy(mask[1].astype(np.int16)) for mask in masks]) - - # Create a training dataset and dataloader - train_dataset = TensorDataset(imgs, masks) - train_dataloader = DataLoader(train_dataset, batch_size=batch_size) - - loss_fn = nn.CrossEntropyLoss() - optimizer = Adam(params=self.parameters(), lr=lr) - - for _ in tqdm(range(epochs), desc="Running UNet training"): - - self.loss = 0 - - for imgs, masks in train_dataloader: - imgs = imgs.float() - masks = masks.long() - - #forward path - preds = self.forward(imgs) - loss = loss_fn(preds, masks) - - #backward path - optimizer.zero_grad() - loss.backward() - optimizer.step() - - self.loss += loss.detach().mean().item() - - self.loss /= len(train_dataloader) - - def eval(self, img): - """ - Evaluate the model on the provided image and return the predicted label. - - :param img: Input image for evaluation. - :type img: np.ndarray[np.uint8] - :return: predicted mask consists of instance and class masks - :rtype: numpy.ndarray - """ - with torch.no_grad(): - # normalise - img = (img-np.min(img))/(np.max(img)-np.min(img)) - img = torch.from_numpy(img).float().unsqueeze(0) - - img = img.unsqueeze(1) if img.ndim == 3 else img - - preds = self.forward(img) - class_mask = torch.argmax(preds, 1).numpy()[0] - - instance_mask = label((class_mask > 0).astype(int))[0] - - final_mask = np.stack((instance_mask, class_mask), axis=self.eval_config['mask_channel_axis']).astype(np.uint16) - - return final_mask - -class CellposeMultichannel(): - ''' - Multichannel image segmentation model. - Run the separate CustomCellposeModel models for each channel return the mask corresponding to each object type. - ''' - - def __init__(self, model_config, train_config, eval_config, model_name="Cellpose"): - """Constructs all the necessary attributes for the CellposeMultichannel model. - - :param model_config: Model configuration. - :type model_config: dict - :param train_config: Training configuration. - :type train_config: dict - :param eval_config: Evaluation configuration. - :type eval_config: dict - :param model_name: Name of the model. - :type model_name: str - """ - - self.model_config = model_config - self.train_config = train_config - self.eval_config = eval_config - self.model_name = model_name - self.num_of_channels = self.model_config["classifier"]["num_classes"] - - self.cellpose_models = [ - CustomCellposeModel(self.model_config, - self.train_config, - self.eval_config, - self.model_name - ) for _ in range(self.num_of_channels) - ] - - def train(self, imgs, masks): - """ - Train the model on the provided images and masks. - - :param imgs: Input images for training. - :type imgs: list[numpy.ndarray] - :param masks: Masks corresponding to the input images. - :type masks: list[numpy.ndarray] - """ - - for i in range(self.num_of_channels): - - masks_class = [] - - for mask in masks: - mask_class = mask.copy() - # set all instances in the instance mask not corresponding to the class in question to zero - mask_class[0][mask_class[1]!=(i+1)] = 0 - masks_class.append(mask_class) - - self.cellpose_models[i].train(imgs, masks_class) - - self.metric = np.mean([self.cellpose_models[i].metric for i in range(self.num_of_channels)]) - self.loss = np.mean([self.cellpose_models[i].loss for i in range(self.num_of_channels)]) - - - def eval(self, img): - """Evaluate the model on the provided image. The instance mask are computed as the union of the predicted model outputs, while the class of - each object is assigned based on majority voting between the models. - - :param img: Input image for evaluation. - :type img: np.ndarray[np.uint8] - :return: predicted mask consists of instance and class masks - :rtype: numpy.ndarray - """ - - instance_masks, class_masks, model_confidences = [], [], [] - - for i in range(self.num_of_channels): - # get the instance mask and pixel-wise cell probability mask - instance_mask, probs, _ = self.cellpose_models[i].eval_all_outputs(img) - confidence = probs[2] - # assign the appropriate class to all objects detected by this model - class_mask = np.zeros_like(instance_mask) - class_mask[instance_mask>0]=(i + 1) - - instance_masks.append(instance_mask) - class_masks.append(class_mask) - model_confidences.append(confidence) - # merge the outputs of the different models using the pixel-wise cell probability mask - merged_mask_instances, class_mask = self.merge_masks(instance_masks, class_masks, model_confidences) - # set all connected components to the same label in the instance mask - instance_mask = label_mask(merged_mask_instances>0) - # and set the class with the most pixels to that object - for inst_id in np.unique(instance_mask)[1:]: - where_inst_id = np.where(instance_mask==inst_id) - vals, counts = np.unique(class_mask[where_inst_id], return_counts=True) - class_mask[where_inst_id] = vals[np.argmax(counts)] - # take the final mask by stancking instance and class mask - final_mask = np.stack((instance_mask, class_mask), axis=self.eval_config['mask_channel_axis']).astype(np.uint16) - - return final_mask - - def merge_masks(self, inst_masks, class_masks, probabilities): - """Merges the instance and class masks resulting from the different models using the pixel-wise cell probability. The output of the model - with the maximum probability is selected for each pixel. - - :param inst_masks: List of predicted instance masks from each model. - :type inst_masks: List[np.array] - :param class_masks: List of corresponding class masks from each model. - :type class_masks: List[np.array] - :param probabilities: List of corresponding pixel-wise cell probability masks - :type probabilities: List[np.array] - :return: A tuple containing the following elements: - - final_mask_inst (numpy.ndarray): A single instance mask where for each pixel the output of the model with the highest probability is selected - - final_mask_class (numpy.ndarray): A single class mask where for each pixel the output of the model with the highest probability is selected - :rtype: tuple - """ - # Convert lists to numpy arrays - inst_masks = np.array(inst_masks) - class_masks = np.array(class_masks) - probabilities = np.array(probabilities) - - # Find the index of the mask with the maximum probability for each pixel - max_prob_indices = np.argmax(probabilities, axis=0) - - # Use the index to select the corresponding mask for each pixel - final_mask_inst = inst_masks[max_prob_indices, np.arange(inst_masks.shape[1])[:, None], np.arange(inst_masks.shape[2])] - final_mask_class = class_masks[max_prob_indices, np.arange(class_masks.shape[1])[:, None], np.arange(class_masks.shape[2])] - - return final_mask_inst, final_mask_class - - - - - - -# class CustomSAMModel(): -# # https://github.com/facebookresearch/segment-anything/blob/main/notebooks/automatic_mask_generator_example.ipynb -# def __init__(self): -# pass diff --git a/src/server/dcp_server/models/__init__.py b/src/server/dcp_server/models/__init__.py new file mode 100644 index 00000000..2e134350 --- /dev/null +++ b/src/server/dcp_server/models/__init__.py @@ -0,0 +1,11 @@ +# dcp_server.models/__init__.py + +from .custom_cellpose import CustomCellposeModel +from .cellpose_patchCNN import CellposePatchCNN +from .multicellpose import MultiCellpose +from .unet import UNet + +__all__ = ['CustomCellposeModel', + 'CellposePatchCNN', + 'MultiCellpose', + 'UNet'] diff --git a/src/server/dcp_server/models/cellpose_patchCNN.py b/src/server/dcp_server/models/cellpose_patchCNN.py new file mode 100644 index 00000000..c983e643 --- /dev/null +++ b/src/server/dcp_server/models/cellpose_patchCNN.py @@ -0,0 +1,365 @@ +from copy import deepcopy +from tqdm import tqdm +import numpy as np + +import torch +from torch import nn +from torch.optim import Adam +from torch.utils.data import TensorDataset, DataLoader +from torchmetrics import F1Score + +from sklearn.ensemble import RandomForestClassifier +from sklearn.metrics import f1_score, log_loss +from sklearn.exceptions import NotFittedError + +from dcp_server.models import Model, CustomCellposeModel +from dcp_server.utils import ( + get_centered_patches, + find_max_patch_size, + create_patch_dataset, + create_dataset_for_rf +) + + +class CellposePatchCNN(Model): + """ + Cellpose & patches of cells and then cnn to classify each patch + """ + + def __init__(self, model_config, train_config, eval_config, model_name): + """Constructs all the necessary attributes for the CellposePatchCNN + + :param model_config: Model configuration. + :type model_config: dict + :param train_config: Training configuration. + :type train_config: dict + :param eval_config: Evaluation configuration. + :type eval_config: dict + :param model_name: Name of the model. + :type model_name: str + """ + super().__init__() + + self.model_config = model_config + self.train_config = train_config + self.eval_config = eval_config + self.model_name = model_name + self.include_mask = self.model_config["classifier"]["include_mask"] + self.classifier_class = self.model_config.get("classifier").get("model_class", "CellClassifierFCNN") + + # Initialize the cellpose model and the classifier + self.segmentor = CustomCellposeModel(self.model_config, + self.train_config, + self.eval_config, + "Cellpose") + + if self.classifier_class == "FCNN": + self.classifier = CellClassifierFCNN(self.model_config, + self.train_config, + self.eval_config) + + elif self.classifier_class == "RandomForest": + self.classifier = CellClassifierShallowModel(self.model_config, + self.train_config, + self.eval_config) + # make sure include mask is set to False if we are using the random forest model + self.include_mask = False + + def update_configs(self, train_config, eval_config): + """Update the training and evaluation configurations. + + :param train_config: Dictionary containing the training configuration. + :type train_config: dict + :param eval_config: Dictionary containing the evaluation configuration. + :type eval_config: dict + """ + self.train_config = train_config + self.eval_config = eval_config + + def train(self, imgs, masks): + """Trains the given model. First trains the segmentor and then the clasiffier. + + :param imgs: images to train on (training data) + :type imgs: List[np.ndarray] + :param masks: masks of the given images (training labels) + :type masks: List[np.ndarray] of same shape as output of eval, i.e. one channel instances, + second channel classes, so [2, H, W] or [2, 3, H, W] for 3D + """ + # train cellpose + masks = np.array(masks) + masks_instances = list(masks[:,0,...]) #[mask.sum(-1) for mask in masks] if masks[0].ndim == 3 else masks + self.segmentor.train(deepcopy(imgs), masks_instances) + # create patch dataset to train classifier + masks_classes = list(masks[:,1,...]) #[((mask > 0) * np.arange(1, 4)).sum(-1) for mask in masks] + patches, patch_masks, labels = create_patch_dataset(imgs, + masks_classes, + masks_instances, + noise_intensity = self.train_config["classifier"]["train_data"]["noise_intensity"], + max_patch_size = self.train_config["classifier"]["train_data"]["patch_size"], + include_mask = self.include_mask) + x = patches + if self.classifier_class == "RandomForest": + x = create_dataset_for_rf(patches, patch_masks) + # train classifier + self.classifier.train(x, labels) + # and compute metric and loss + self.metric = (self.segmentor.metric + self.classifier.metric) / 2 + self.loss = (self.segmentor.loss + self.classifier.loss)/2 + + def eval(self, img): + """Evaluate the model on the provided image and return the final mask. + + :param img: Input image for evaluation. + :type img: np.ndarray[np.uint8] + :return: Final mask containing instance mask and class masks. + :rtype: np.ndarray[np.uint16] + """ + # TBD we assume image is 2D [H, W] (see fsimage storage) + # The final mask which is returned should have + # first channel the output of cellpose and the rest are the class channels + with torch.no_grad(): + # get instance mask from segmentor + instance_mask = self.segmentor.eval(img) + # find coordinates of detected objects + class_mask = np.zeros(instance_mask.shape) + + max_patch_size = self.eval_config["classifier"]["data"]["patch_size"] + if max_patch_size is None: max_patch_size = find_max_patch_size(instance_mask) + noise_intensity = self.eval_config["classifier"]["data"]["noise_intensity"] + + # get patches centered around detected objects + patches, patch_masks, instance_labels, _ = get_centered_patches(img, + instance_mask, + max_patch_size, + noise_intensity=noise_intensity, + include_mask=self.include_mask) + x = patches + if self.classifier_class == "RandomForest": + x = create_dataset_for_rf(patches, patch_masks) + # loop over patches and create classification mask + for idx in range(len(x)): + patch_class = self.classifier.eval(x[idx]) + # Assign predicted class to corresponding location in final_mask + patch_class = patch_class.item() if isinstance(patch_class, torch.Tensor) else patch_class + class_mask[instance_mask==instance_labels[idx]] = patch_class + 1 + # Apply mask to final_mask, retaining only regions where cellpose_mask is greater than 0 + final_mask = np.stack((instance_mask, class_mask), axis=self.eval_config['mask_channel_axis']).astype(np.uint16) # size 2xHxW + + return final_mask + + +class CellClassifierFCNN(nn.Module): + + """Fully convolutional classifier for cell images. NOTE -> This model cannot be used as a standalone model in DCP + + :param model_config: Model configuration. + :type model_config: dict + :param train_config: Training configuration. + :type train_config: dict + :param eval_config: Evaluation configuration. + :type eval_config: dict + + """ + + def __init__(self, model_config, train_config, eval_config): + """Initialize the fully convolutional classifier. + + :param model_config: Model configuration. + :type model_config: dict + :param train_config: Training configuration. + :type train_config: dict + :param eval_config: Evaluation configuration. + :type eval_config: dict + """ + super().__init__() + + self.in_channels = model_config["classifier"].get("in_channels",1) + self.num_classes = model_config["classifier"].get("num_classes",3) + + self.train_config = train_config["classifier"] + self.eval_config = eval_config["classifier"] + + self.include_mask = model_config["classifier"]["include_mask"] + self.in_channels = self.in_channels + 1 if self.include_mask else self.in_channels + + self.layer1 = nn.Sequential( + nn.Conv2d(self.in_channels, 16, 3, 2, 5), + nn.BatchNorm2d(16), + nn.ReLU(), + nn.Dropout2d(p=0.2), + ) + + self.layer2 = nn.Sequential( + nn.Conv2d(16, 64, 3, 1, 1), + nn.BatchNorm2d(64), + nn.ReLU(), + nn.Dropout2d(p=0.2), + ) + + self.layer3 = nn.Sequential( + nn.Conv2d(64, 128, 3, 2, 4), + nn.BatchNorm2d(128), + nn.ReLU(), + nn.Dropout2d(p=0.2), + ) + self.final_conv = nn.Conv2d(128, self.num_classes, 1) + self.pooling = nn.AdaptiveMaxPool2d(1) + + self.metric_fn = F1Score(num_classes=self.num_classes, task="multiclass") + + def update_configs(self, train_config, eval_config): + """ + Update the training and evaluation configurations. + + :param train_config: Dictionary containing the training configuration. + :type train_config: dict + :param eval_config: Dictionary containing the evaluation configuration. + :type eval_config: dict + """ + self.train_config = train_config + self.eval_config = eval_config + + def forward(self, x): + """ Performs forward pass of the CellClassifierFCNN. + + :param x: Input tensor. + :type x: torch.Tensor + :return: Output tensor after passing through the network. + :rtype: torch.Tensor + """ + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + + x = self.final_conv(x) + x = self.pooling(x) + x = x.view(x.size(0), -1) + return x + + def train (self, imgs, labels): + """Trains the given model + + :param imgs: List of input images with shape (3, dx, dy). + :type imgs: List[np.ndarray[np.uint8]] + :param labels: List of classification labels. + :type labels: List[int] + """ + + lr = self.train_config['lr'] + epochs = self.train_config['n_epochs'] + batch_size = self.train_config['batch_size'] + # optimizer_class = self.train_config['optimizer'] + + # Convert input images and labels to tensors + + # normalize images + imgs = [(img-np.min(img))/(np.max(img)-np.min(img)) for img in imgs] + # convert to tensor + imgs = torch.stack([torch.from_numpy(img.astype(np.float32)) for img in imgs]) + imgs = torch.permute(imgs, (0, 3, 1, 2)) + # Your classification label mask + labels = torch.LongTensor([label for label in labels]) + + # Create a training dataset and dataloader + train_dataset = TensorDataset(imgs, labels) + train_dataloader = DataLoader(train_dataset, batch_size=batch_size) + + loss_fn = nn.CrossEntropyLoss() + optimizer = Adam(params=self.parameters(), lr=lr) #eval(f'{optimizer_class}(params={self.parameters()}, lr={lr})') + # TODO check if we should replace self.parameters with super.parameters() + + for _ in tqdm(range(epochs), desc="Running CellClassifierFCNN training"): + self.loss, self.metric = 0, 0 + for data in train_dataloader: + imgs, labels = data + + optimizer.zero_grad() + preds = self.forward(imgs) + + l = loss_fn(preds, labels) + l.backward() + optimizer.step() + self.loss += l.item() + + self.metric += self.metric_fn(preds, labels) + + self.loss /= len(train_dataloader) + self.metric /= len(train_dataloader) + + def eval(self, img): + """Evaluates the model on the provided image and return the predicted label. + + :param img: Input image for evaluation. + :type img: np.ndarray[np.uint8] + :return: y_hat - predicted label. + :rtype: torch.Tensor + """ + # normalise + img = (img-np.min(img))/(np.max(img)-np.min(img)) + # convert to tensor + img = torch.permute(torch.tensor(img.astype(np.float32)), (2, 0, 1)).unsqueeze(0) + preds = self.forward(img) + y_hat = torch.argmax(preds, 1) + return y_hat + + +class CellClassifierShallowModel: + """ + This class implements a shallow model for cell classification using scikit-learn. + """ + + def __init__(self, model_config, train_config, eval_config): + """Constructs all the necessary attributes for the CellClassifierShallowModel + + :param model_config: Model configuration. + :type model_config: dict + :param train_config: Training configuration. + :type train_config: dict + :param eval_config: Evaluation configuration. + :type eval_config: dict + """ + + self.model_config = model_config + self.train_config = train_config + self.eval_config = eval_config + + self.model = RandomForestClassifier() # TODO chnage config so RandomForestClassifier accepts input params + + + def train(self, X_train, y_train): + """Trains the model using the provided training data. + + :param X_train: Features of the training data. + :type X_train: numpy.ndarray + :param y_train: Labels of the training data. + :type y_train: numpy.ndarray + """ + + self.model.fit(X_train,y_train) + + y_hat = self.model.predict(X_train) + y_hat_proba = self.model.predict_proba(X_train) + + self.metric = f1_score(y_train, y_hat, average='micro') + # Binary Cross Entrop Loss + self.loss = log_loss(y_train, y_hat_proba) + + + def eval(self, X_test): + """Evaluates the model on the provided test data. + + :param X_test: Features of the test data. + :type X_test: numpy.ndarray + :return: y_hat - predicted labels. + :rtype: numpy.ndarray + """ + + X_test = X_test.reshape(1,-1) + + try: + y_hat = self.model.predict(X_test) + except NotFittedError as e: + y_hat = np.zeros(X_test.shape[0]) + + return y_hat diff --git a/src/server/dcp_server/models/custom_cellpose.py b/src/server/dcp_server/models/custom_cellpose.py new file mode 100644 index 00000000..fcb8d8ab --- /dev/null +++ b/src/server/dcp_server/models/custom_cellpose.py @@ -0,0 +1,121 @@ +from copy import deepcopy +import numpy as np + +import torch +from torch import nn + +from cellpose import models, utils +from cellpose.metrics import aggregated_jaccard_index +from cellpose.dynamics import labels_to_flows + +from dcp_server.models import Model + +class CustomCellposeModel(models.CellposeModel, Model): + """Custom cellpose model inheriting the attributes and functions from the original CellposeModel and implementing + additional attributes and methods needed for this project. + """ + def __init__(self, model_config, train_config, eval_config, model_name): + """Constructs all the necessary attributes for the CustomCellposeModel. + The model inherits all attributes from the parent class, the init allows to pass any other argument that the parent class accepts. + Please, visit here https://cellpose.readthedocs.io/en/latest/api.html#id4 for more details on arguments accepted. + + :param model_config: dictionary passed from the config file with all the arguments for the __init__ function and model initialization + :type model_config: dict + :param train_config: dictionary passed from the config file with all the arguments for training function + :type train_config: dict + :param eval_config: dictionary passed from the config file with all the arguments for eval function + :type eval_config: dict + """ + + # Initialize the cellpose model + #super().__init__(**model_config["segmentor"]) + nn.Module.__init__(self) + models.CellposeModel.__init__(self, **model_config["segmentor"]) + self.model_config = model_config + self.train_config = train_config + self.eval_config = eval_config + self.model_name = model_name + self.mkldnn = False # otherwise we get error with saving model + self.loss = 1e6 + + + def update_configs(self, train_config, eval_config): + """Update the training and evaluation configurations. + + :param train_config: Dictionary containing the training configuration. + :type train_config: dict + :param eval_config: Dictionary containing the evaluation configuration. + :type eval_config: dict + """ + self.train_config = train_config + self.eval_config = eval_config + + def eval_all_outputs(self, img): + """Get all outputs of the model when running eval. + + :param img: Input image for segmentation. + :type img: numpy.ndarray + :return: Probability mask for the input image. + :rtype: numpy.ndarray + """ + + return super().eval(x=img, **self.eval_config["segmentor"]) + + def eval(self, img): + """Evaluate the model - find mask of the given image + Calls the original eval function. + + :param img: image to evaluate on + :type img: np.ndarray + :return: mask of the image, list of 2D arrays, or single 3D array (if do_3D=True) labelled image. + :rtype: np.ndarray + """ + return super().eval(x=img, **self.eval_config["segmentor"])[0] # 0 to take only mask + + def train(self, imgs, masks): + """Trains the given model + Calls the original train function. + + :param imgs: images to train on (training data) + :type imgs: List[np.ndarray] + :param masks: masks of the given images (training labels) + :type masks: List[np.ndarray] + """ + + if not isinstance(masks, np.ndarray): # TODO Remove: all these should be taken care of in fsimagestorage + masks = np.array(masks) + + if masks[0].shape[0] == 2: + masks = list(masks[:,0,...]) + super().train(train_data=deepcopy(imgs), train_labels=masks, **self.train_config["segmentor"]) + + # compute loss and metric + true_bin_masks = [mask>0 for mask in masks] # get binary masks + true_flows = labels_to_flows(masks) # get cellpose flows + # get predicted flows and cell probability + pred_masks = [] + pred_flows = [] + true_lbl = [] + for idx, img in enumerate(imgs): + mask, flows, _ = super().eval(x=img, **self.eval_config["segmentor"]) + pred_masks.append(mask) + pred_flows.append(np.stack([flows[1][0], flows[1][1], flows[2]])) # stack cell probability map, horizontal and vertical flow + true_lbl.append(np.stack([true_bin_masks[idx], true_flows[idx][2], true_flows[idx][3]])) + + true_lbl = np.stack(true_lbl) + pred_flows=np.stack(pred_flows) + pred_flows = torch.from_numpy(pred_flows).float().to('cpu') + # compute loss, combination of mse for flows and bce for cell probability + self.loss = self.loss_fn(true_lbl, pred_flows) + self.metric = np.mean(aggregated_jaccard_index(masks, pred_masks)) + + def masks_to_outlines(self, mask): + """ get outlines of masks as a 0-1 array + Calls the original cellpose.utils.masks_to_outlines function + + :param mask: int, 2D or 3D array, mask of an image + :type mask: ndarray + :return: outlines + :rtype: ndarray + """ + return utils.masks_to_outlines(mask) #[True, False] outputs diff --git a/src/server/dcp_server/models/model.py b/src/server/dcp_server/models/model.py new file mode 100644 index 00000000..47f6dc74 --- /dev/null +++ b/src/server/dcp_server/models/model.py @@ -0,0 +1,27 @@ +from abc import ABC, abstractmethod +from typing import List +import numpy as np + +class Model(ABC): + def __init__(self, model_config, train_config, eval_config, model_name): + + self.model_config = model_config + self.train_config = train_config + self.eval_config = eval_config + self.model_name = model_name + + @abstractmethod + def train(self, imgs: List[np.array], masks: List[np.array]) -> None: + pass + + @abstractmethod + def eval(self, img: np.array) -> np.array: + pass + + +#from segment_anything import SamPredictor, sam_model_registry +#from segment_anything.automatic_mask_generator import SamAutomaticMaskGenerator +# class CustomSAMModel(): +# # https://github.com/facebookresearch/segment-anything/blob/main/notebooks/automatic_mask_generator_example.ipynb +# def __init__(self): +# pass diff --git a/src/server/dcp_server/models/multicellpose.py b/src/server/dcp_server/models/multicellpose.py new file mode 100644 index 00000000..840aa210 --- /dev/null +++ b/src/server/dcp_server/models/multicellpose.py @@ -0,0 +1,130 @@ +import numpy as np +from skimage.measure import label as label_mask + +from dcp_server.models import Model, CustomCellposeModel + +class MultiCellpose(Model): + ''' + Multichannel image segmentation model. + Run the separate CustomCellposeModel models for each channel return the mask corresponding to each object type. + ''' + + def __init__(self, model_config, train_config, eval_config, model_name="Cellpose"): + """Constructs all the necessary attributes for the MultiCellpose model. + + :param model_config: Model configuration. + :type model_config: dict + :param train_config: Training configuration. + :type train_config: dict + :param eval_config: Evaluation configuration. + :type eval_config: dict + :param model_name: Name of the model. + :type model_name: str + """ + + self.model_config = model_config + self.train_config = train_config + self.eval_config = eval_config + self.model_name = model_name + self.num_of_channels = self.model_config["classifier"]["num_classes"] + + self.cellpose_models = [ + CustomCellposeModel(self.model_config, + self.train_config, + self.eval_config, + self.model_name + ) for _ in range(self.num_of_channels) + ] + + def train(self, imgs, masks): + """ + Train the model on the provided images and masks. + + :param imgs: Input images for training. + :type imgs: list[numpy.ndarray] + :param masks: Masks corresponding to the input images. + :type masks: list[numpy.ndarray] + """ + + for i in range(self.num_of_channels): + + masks_class = [] + + for mask in masks: + mask_class = mask.copy() + # set all instances in the instance mask not corresponding to the class in question to zero + mask_class[0][mask_class[1]!=(i+1)] = 0 + masks_class.append(mask_class) + + self.cellpose_models[i].train(imgs, masks_class) + + self.metric = np.mean([self.cellpose_models[i].metric for i in range(self.num_of_channels)]) + self.loss = np.mean([self.cellpose_models[i].loss for i in range(self.num_of_channels)]) + + + def eval(self, img): + """Evaluate the model on the provided image. The instance mask are computed as the union of the predicted model outputs, while the class of + each object is assigned based on majority voting between the models. + + :param img: Input image for evaluation. + :type img: np.ndarray[np.uint8] + :return: predicted mask consists of instance and class masks + :rtype: numpy.ndarray + """ + + instance_masks, class_masks, model_confidences = [], [], [] + + for i in range(self.num_of_channels): + # get the instance mask and pixel-wise cell probability mask + instance_mask, probs, _ = self.cellpose_models[i].eval_all_outputs(img) + confidence = probs[2] + # assign the appropriate class to all objects detected by this model + class_mask = np.zeros_like(instance_mask) + class_mask[instance_mask>0]=(i + 1) + + instance_masks.append(instance_mask) + class_masks.append(class_mask) + model_confidences.append(confidence) + # merge the outputs of the different models using the pixel-wise cell probability mask + merged_mask_instances, class_mask = self.merge_masks(instance_masks, class_masks, model_confidences) + # set all connected components to the same label in the instance mask + instance_mask = label_mask(merged_mask_instances>0) + # and set the class with the most pixels to that object + for inst_id in np.unique(instance_mask)[1:]: + where_inst_id = np.where(instance_mask==inst_id) + vals, counts = np.unique(class_mask[where_inst_id], return_counts=True) + class_mask[where_inst_id] = vals[np.argmax(counts)] + # take the final mask by stancking instance and class mask + final_mask = np.stack((instance_mask, class_mask), axis=self.eval_config['mask_channel_axis']).astype(np.uint16) + + return final_mask + + def merge_masks(self, inst_masks, class_masks, probabilities): + """Merges the instance and class masks resulting from the different models using the pixel-wise cell probability. The output of the model + with the maximum probability is selected for each pixel. + + :param inst_masks: List of predicted instance masks from each model. + :type inst_masks: List[np.array] + :param class_masks: List of corresponding class masks from each model. + :type class_masks: List[np.array] + :param probabilities: List of corresponding pixel-wise cell probability masks + :type probabilities: List[np.array] + :return: A tuple containing the following elements: + - final_mask_inst (numpy.ndarray): A single instance mask where for each pixel the output of the model with the highest probability is selected + - final_mask_class (numpy.ndarray): A single class mask where for each pixel the output of the model with the highest probability is selected + :rtype: tuple + """ + # Convert lists to numpy arrays + inst_masks = np.array(inst_masks) + class_masks = np.array(class_masks) + probabilities = np.array(probabilities) + + # Find the index of the mask with the maximum probability for each pixel + max_prob_indices = np.argmax(probabilities, axis=0) + + # Use the index to select the corresponding mask for each pixel + final_mask_inst = inst_masks[max_prob_indices, np.arange(inst_masks.shape[1])[:, None], np.arange(inst_masks.shape[2])] + final_mask_class = class_masks[max_prob_indices, np.arange(class_masks.shape[1])[:, None], np.arange(class_masks.shape[2])] + + return final_mask_inst, final_mask_class + diff --git a/src/server/dcp_server/models/unet.py b/src/server/dcp_server/models/unet.py new file mode 100644 index 00000000..c738f8c3 --- /dev/null +++ b/src/server/dcp_server/models/unet.py @@ -0,0 +1,211 @@ +import torch +from torch import nn +from torch.optim import Adam +from torch.utils.data import TensorDataset, DataLoader +from tqdm import tqdm +import numpy as np +from scipy.ndimage import label + +from dcp_server.models import Model + +class UNet(Model, nn.Module): + + """ + Unet is a convolutional neural network architecture for semantic segmentation. + + :param in_channels: Number of input channels (default: 3). + :type in_channels: int + :param out_channels: Number of output channels (default: 4). + :type out_channels: int + :param features: List of feature channels for each encoder level (default: [64,128,256,512]). + :type features: list + """ + + class DoubleConv(nn.Module): + """ + DoubleConv module consists of two consecutive convolutional layers with + batch normalization and ReLU activation functions. + """ + + def __init__(self, in_channels, out_channels): + """ + Initialize DoubleConv module. + + :param in_channels: Number of input channels. + :type in_channels: int + :param out_channels: Number of output channels. + :type out_channels: int + """ + + super().__init__() + + self.conv = nn.Sequential( + nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False), + nn.BatchNorm2d(out_channels), + nn.ReLU(), + nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False), + nn.BatchNorm2d(out_channels), + nn.ReLU(), + ) + + def forward(self, x): + """Forward pass through the DoubleConv module. + + :param x: Input tensor. + :type x: torch.Tensor + """ + return self.conv(x) + + + def __init__(self, model_config, train_config, eval_config, model_name): + """Constructs all the necessary attributes for the UNet model. + + + :param model_config: Model configuration. + :type model_config: dict + :param train_config: Training configuration. + :type train_config: dict + :param eval_config: Evaluation configuration. + :type eval_config: dict + :param model_name: Name of the model. + :type model_name: str + """ + + super().__init__() + self.model_config = model_config + self.train_config = train_config + self.eval_config = eval_config + self.model_name = model_name + self.in_channels = self.model_config["classifier"]["in_channels"] + self.out_channels = self.model_config["classifier"]["num_classes"] + 1 + self.features = self.model_config["classifier"]["features"] + + self.encoder = nn.ModuleList() + self.decoder = nn.ModuleList() + + self.pool = nn.MaxPool2d(kernel_size=2, stride=2) + + # Encoder + for feature in self.features: + self.encoder.append( + UNet.DoubleConv(self.in_channels, feature) + ) + self.in_channels = feature + + # Decoder + for feature in self.features[::-1]: + self.decoder.append( + nn.ConvTranspose2d( + feature*2, feature, kernel_size=2, stride=2 + ) + ) + self.decoder.append( + UNet.DoubleConv(feature*2, feature) + ) + + self.bottle_neck = UNet.DoubleConv(self.features[-1], self.features[-1]*2) + self.output_conv = nn.Conv2d(self.features[0], self.out_channels, kernel_size=1) + + def forward(self, x): + """ + Forward pass of the UNet model. + + :param x: Input tensor. + :type x: torch.Tensor + :return: Output tensor. + :rtype: torch.Tensor + """ + skip_connections = [] + for encoder in self.encoder: + x = encoder(x) + skip_connections.append(x) + x = self.pool(x) + + x = self.bottle_neck(x) + skip_connections = skip_connections[::-1] + + for i in np.arange(len(self.decoder), step=2): + x = self.decoder[i](x) + skip_connection = skip_connections[i//2] + concatenate_skip = torch.cat((skip_connection, x), dim=1) + x = self.decoder[i+1](concatenate_skip) + + return self.output_conv(x) + + def train(self, imgs, masks): + """ + Trains the UNet model using the provided images and masks. + + :param imgs: Input images for training. + :type imgs: list[numpy.ndarray] + :param masks: Masks corresponding to the input images. + :type masks: list[numpy.ndarray] + """ + + lr = self.train_config["classifier"]['lr'] + epochs = self.train_config["classifier"]['n_epochs'] + batch_size = self.train_config["classifier"]['batch_size'] + + # Convert input images and labels to tensors + # normalize images + imgs = [(img-np.min(img))/(np.max(img)-np.min(img)) for img in imgs] + # convert to tensor + imgs = torch.stack([torch.from_numpy(img.astype(np.float32)) for img in imgs]) + imgs = imgs.unsqueeze(1) if imgs.ndim == 3 else imgs + + # Classification label mask + masks = np.array(masks) + masks = torch.stack([torch.from_numpy(mask[1].astype(np.int16)) for mask in masks]) + + # Create a training dataset and dataloader + train_dataset = TensorDataset(imgs, masks) + train_dataloader = DataLoader(train_dataset, batch_size=batch_size) + + loss_fn = nn.CrossEntropyLoss() + optimizer = Adam(params=self.parameters(), lr=lr) + + for _ in tqdm(range(epochs), desc="Running UNet training"): + + self.loss = 0 + + for imgs, masks in train_dataloader: + imgs = imgs.float() + masks = masks.long() + + #forward path + preds = self.forward(imgs) + loss = loss_fn(preds, masks) + + #backward path + optimizer.zero_grad() + loss.backward() + optimizer.step() + + self.loss += loss.detach().mean().item() + + self.loss /= len(train_dataloader) + + def eval(self, img): + """ + Evaluate the model on the provided image and return the predicted label. + + :param img: Input image for evaluation. + :type img: np.ndarray[np.uint8] + :return: predicted mask consists of instance and class masks + :rtype: numpy.ndarray + """ + with torch.no_grad(): + # normalise + img = (img-np.min(img))/(np.max(img)-np.min(img)) + img = torch.from_numpy(img).float().unsqueeze(0) + + img = img.unsqueeze(1) if img.ndim == 3 else img + + preds = self.forward(img) + class_mask = torch.argmax(preds, 1).numpy()[0] + + instance_mask = label((class_mask > 0).astype(int))[0] + + final_mask = np.stack((instance_mask, class_mask), axis=self.eval_config['mask_channel_axis']).astype(np.uint16) + + return final_mask From f8e9df268d91a50ddb10291529029002f655ab26 Mon Sep 17 00:00:00 2001 From: Christina Bukas Date: Mon, 26 Feb 2024 12:37:19 +0100 Subject: [PATCH 02/26] specified types of arguments --- .../dcp_server/models/cellpose_patchCNN.py | 66 +++++++++++++------ .../dcp_server/models/custom_cellpose.py | 41 ++++++------ src/server/dcp_server/models/model.py | 30 ++++++++- src/server/dcp_server/models/multicellpose.py | 23 +++++-- src/server/dcp_server/models/unet.py | 33 +++++++--- 5 files changed, 138 insertions(+), 55 deletions(-) diff --git a/src/server/dcp_server/models/cellpose_patchCNN.py b/src/server/dcp_server/models/cellpose_patchCNN.py index c983e643..a5c5a349 100644 --- a/src/server/dcp_server/models/cellpose_patchCNN.py +++ b/src/server/dcp_server/models/cellpose_patchCNN.py @@ -1,5 +1,6 @@ from copy import deepcopy from tqdm import tqdm +from typing import List import numpy as np import torch @@ -26,7 +27,12 @@ class CellposePatchCNN(Model): Cellpose & patches of cells and then cnn to classify each patch """ - def __init__(self, model_config, train_config, eval_config, model_name): + def __init__(self, + model_config: dict, + train_config: dict, + eval_config:dict, + model_name:str + ) -> None: """Constructs all the necessary attributes for the CellposePatchCNN :param model_config: Model configuration. @@ -65,18 +71,11 @@ def __init__(self, model_config, train_config, eval_config, model_name): # make sure include mask is set to False if we are using the random forest model self.include_mask = False - def update_configs(self, train_config, eval_config): - """Update the training and evaluation configurations. - - :param train_config: Dictionary containing the training configuration. - :type train_config: dict - :param eval_config: Dictionary containing the evaluation configuration. - :type eval_config: dict - """ - self.train_config = train_config - self.eval_config = eval_config - def train(self, imgs, masks): + def train(self, + imgs: List[np.ndarray], + masks: List[np.ndarray] + ) -> None: """Trains the given model. First trains the segmentor and then the clasiffier. :param imgs: images to train on (training data) @@ -106,7 +105,9 @@ def train(self, imgs, masks): self.metric = (self.segmentor.metric + self.classifier.metric) / 2 self.loss = (self.segmentor.loss + self.classifier.loss)/2 - def eval(self, img): + def eval(self, + img: np.ndarray + ) -> np.ndarray: """Evaluate the model on the provided image and return the final mask. :param img: Input image for evaluation. @@ -161,7 +162,11 @@ class CellClassifierFCNN(nn.Module): """ - def __init__(self, model_config, train_config, eval_config): + def __init__(self, + model_config: dict, + train_config: dict, + eval_config: dict + ) -> None: """Initialize the fully convolutional classifier. :param model_config: Model configuration. @@ -207,7 +212,10 @@ def __init__(self, model_config, train_config, eval_config): self.metric_fn = F1Score(num_classes=self.num_classes, task="multiclass") - def update_configs(self, train_config, eval_config): + def update_configs(self, + train_config: dict, + eval_config: dict + ) -> None: """ Update the training and evaluation configurations. @@ -219,7 +227,9 @@ def update_configs(self, train_config, eval_config): self.train_config = train_config self.eval_config = eval_config - def forward(self, x): + def forward(self, + x: torch.Tensor + ) -> torch.Tensor: """ Performs forward pass of the CellClassifierFCNN. :param x: Input tensor. @@ -237,7 +247,10 @@ def forward(self, x): x = x.view(x.size(0), -1) return x - def train (self, imgs, labels): + def train (self, + imgs: List[np.ndarray], + labels: List[np.ndarray] + ) -> None: """Trains the given model :param imgs: List of input images with shape (3, dx, dy). @@ -287,7 +300,9 @@ def train (self, imgs, labels): self.loss /= len(train_dataloader) self.metric /= len(train_dataloader) - def eval(self, img): + def eval(self, + img: np.ndarray + ) -> torch.Tensor: """Evaluates the model on the provided image and return the predicted label. :param img: Input image for evaluation. @@ -309,7 +324,11 @@ class CellClassifierShallowModel: This class implements a shallow model for cell classification using scikit-learn. """ - def __init__(self, model_config, train_config, eval_config): + def __init__(self, + model_config: dict, + train_config: dict, + eval_config: dict + ) -> None: """Constructs all the necessary attributes for the CellClassifierShallowModel :param model_config: Model configuration. @@ -327,7 +346,10 @@ def __init__(self, model_config, train_config, eval_config): self.model = RandomForestClassifier() # TODO chnage config so RandomForestClassifier accepts input params - def train(self, X_train, y_train): + def train(self, + X_train: np.ndarray, + y_train: np.ndarray + ) -> None: """Trains the model using the provided training data. :param X_train: Features of the training data. @@ -346,7 +368,9 @@ def train(self, X_train, y_train): self.loss = log_loss(y_train, y_hat_proba) - def eval(self, X_test): + def eval(self, + X_test: np.ndarray + ) -> np.ndarray: """Evaluates the model on the provided test data. :param X_test: Features of the test data. diff --git a/src/server/dcp_server/models/custom_cellpose.py b/src/server/dcp_server/models/custom_cellpose.py index fcb8d8ab..c7aefc10 100644 --- a/src/server/dcp_server/models/custom_cellpose.py +++ b/src/server/dcp_server/models/custom_cellpose.py @@ -1,4 +1,5 @@ from copy import deepcopy +from typing import List import numpy as np import torch @@ -14,7 +15,12 @@ class CustomCellposeModel(models.CellposeModel, Model): """Custom cellpose model inheriting the attributes and functions from the original CellposeModel and implementing additional attributes and methods needed for this project. """ - def __init__(self, model_config, train_config, eval_config, model_name): + def __init__(self, + model_config: dict, + train_config: dict, + eval_config: dict, + model_name: str + ) -> None: """Constructs all the necessary attributes for the CustomCellposeModel. The model inherits all attributes from the parent class, the init allows to pass any other argument that the parent class accepts. Please, visit here https://cellpose.readthedocs.io/en/latest/api.html#id4 for more details on arguments accepted. @@ -37,31 +43,23 @@ def __init__(self, model_config, train_config, eval_config, model_name): self.model_name = model_name self.mkldnn = False # otherwise we get error with saving model self.loss = 1e6 - - - def update_configs(self, train_config, eval_config): - """Update the training and evaluation configurations. - - :param train_config: Dictionary containing the training configuration. - :type train_config: dict - :param eval_config: Dictionary containing the evaluation configuration. - :type eval_config: dict - """ - self.train_config = train_config - self.eval_config = eval_config - def eval_all_outputs(self, img): + def eval_all_outputs(self, + img: np.ndarray + ) -> tuple: """Get all outputs of the model when running eval. :param img: Input image for segmentation. :type img: numpy.ndarray - :return: Probability mask for the input image. - :rtype: numpy.ndarray + :return: mask, flows, styles etc. Returns the same as cellpose.models.CellposeModel.eval - see Cellpose API Guide for more details. + :rtype: tuple """ return super().eval(x=img, **self.eval_config["segmentor"]) - def eval(self, img): + def eval(self, + img: np.ndarray + ) -> np.ndarray: """Evaluate the model - find mask of the given image Calls the original eval function. @@ -72,7 +70,10 @@ def eval(self, img): """ return super().eval(x=img, **self.eval_config["segmentor"])[0] # 0 to take only mask - def train(self, imgs, masks): + def train(self, + imgs: List[np.ndarray], + masks: List[np.ndarray] + ) -> None: """Trains the given model Calls the original train function. @@ -109,7 +110,9 @@ def train(self, imgs, masks): self.loss = self.loss_fn(true_lbl, pred_flows) self.metric = np.mean(aggregated_jaccard_index(masks, pred_masks)) - def masks_to_outlines(self, mask): + def masks_to_outlines(self, + mask: np.ndarray + ) -> np.ndarray: """ get outlines of masks as a 0-1 array Calls the original cellpose.utils.masks_to_outlines function diff --git a/src/server/dcp_server/models/model.py b/src/server/dcp_server/models/model.py index 47f6dc74..db584214 100644 --- a/src/server/dcp_server/models/model.py +++ b/src/server/dcp_server/models/model.py @@ -3,19 +3,43 @@ import numpy as np class Model(ABC): - def __init__(self, model_config, train_config, eval_config, model_name): + def __init__(self, + model_config: dict, + train_config: dict, + eval_config: dict, + model_name: str + ) -> None: self.model_config = model_config self.train_config = train_config self.eval_config = eval_config self.model_name = model_name + def update_configs(self, + train_config: dict, + eval_config: dict + ) -> None: + """ Update the training and evaluation configurations. + + :param train_config: Dictionary containing the training configuration. + :type train_config: dict + :param eval_config: Dictionary containing the evaluation configuration. + :type eval_config: dict + """ + self.train_config = train_config + self.eval_config = eval_config + @abstractmethod - def train(self, imgs: List[np.array], masks: List[np.array]) -> None: + def train(self, + imgs: List[np.array], + masks: List[np.array] + ) -> None: pass @abstractmethod - def eval(self, img: np.array) -> np.array: + def eval(self, + img: np.array + ) -> np.array: pass diff --git a/src/server/dcp_server/models/multicellpose.py b/src/server/dcp_server/models/multicellpose.py index 840aa210..c12a1e6b 100644 --- a/src/server/dcp_server/models/multicellpose.py +++ b/src/server/dcp_server/models/multicellpose.py @@ -1,3 +1,4 @@ +from typing import List import numpy as np from skimage.measure import label as label_mask @@ -9,7 +10,12 @@ class MultiCellpose(Model): Run the separate CustomCellposeModel models for each channel return the mask corresponding to each object type. ''' - def __init__(self, model_config, train_config, eval_config, model_name="Cellpose"): + def __init__(self, + model_config: dict, + train_config: dict, + eval_config: dict, + model_name="Cellpose" + ) -> None: """Constructs all the necessary attributes for the MultiCellpose model. :param model_config: Model configuration. @@ -36,7 +42,10 @@ def __init__(self, model_config, train_config, eval_config, model_name="Cellpose ) for _ in range(self.num_of_channels) ] - def train(self, imgs, masks): + def train(self, + imgs: List[np.ndarray], + masks: List[np.ndarray] + ) -> None: """ Train the model on the provided images and masks. @@ -62,7 +71,9 @@ def train(self, imgs, masks): self.loss = np.mean([self.cellpose_models[i].loss for i in range(self.num_of_channels)]) - def eval(self, img): + def eval(self, + img: np.ndarray + ) -> np.ndarray: """Evaluate the model on the provided image. The instance mask are computed as the union of the predicted model outputs, while the class of each object is assigned based on majority voting between the models. @@ -99,7 +110,11 @@ def eval(self, img): return final_mask - def merge_masks(self, inst_masks, class_masks, probabilities): + def merge_masks(self, + inst_masks: List[np.ndarray], + class_masks: List[np.ndarray], + probabilities: List[np.ndarray] + ) -> tuple: """Merges the instance and class masks resulting from the different models using the pixel-wise cell probability. The output of the model with the maximum probability is selected for each pixel. diff --git a/src/server/dcp_server/models/unet.py b/src/server/dcp_server/models/unet.py index c738f8c3..dd88e33e 100644 --- a/src/server/dcp_server/models/unet.py +++ b/src/server/dcp_server/models/unet.py @@ -1,3 +1,4 @@ +from typing import List import torch from torch import nn from torch.optim import Adam @@ -27,7 +28,10 @@ class DoubleConv(nn.Module): batch normalization and ReLU activation functions. """ - def __init__(self, in_channels, out_channels): + def __init__(self, + in_channels: int, + out_channels: int + ) -> None: """ Initialize DoubleConv module. @@ -48,7 +52,9 @@ def __init__(self, in_channels, out_channels): nn.ReLU(), ) - def forward(self, x): + def forward(self, + x: torch.Tensor + ) -> torch.Tensor: """Forward pass through the DoubleConv module. :param x: Input tensor. @@ -57,10 +63,14 @@ def forward(self, x): return self.conv(x) - def __init__(self, model_config, train_config, eval_config, model_name): + def __init__(self, + model_config: dict, + train_config: dict, + eval_config: dict, + model_name: str + ) -> None: """Constructs all the necessary attributes for the UNet model. - - + :param model_config: Model configuration. :type model_config: dict :param train_config: Training configuration. @@ -106,7 +116,9 @@ def __init__(self, model_config, train_config, eval_config, model_name): self.bottle_neck = UNet.DoubleConv(self.features[-1], self.features[-1]*2) self.output_conv = nn.Conv2d(self.features[0], self.out_channels, kernel_size=1) - def forward(self, x): + def forward(self, + x: torch.Tensor + ) -> torch.Tensor: """ Forward pass of the UNet model. @@ -132,7 +144,10 @@ def forward(self, x): return self.output_conv(x) - def train(self, imgs, masks): + def train(self, + imgs: List[np.ndarray], + masks: List[np.ndarray] + ) -> None: """ Trains the UNet model using the provided images and masks. @@ -185,7 +200,9 @@ def train(self, imgs, masks): self.loss /= len(train_dataloader) - def eval(self, img): + def eval(self, + img: np.ndarray + ) -> np.ndarray: """ Evaluate the model on the provided image and return the predicted label. From 11ad601f8fdecf19ba21f033a7ea765f7313e118 Mon Sep 17 00:00:00 2001 From: Christina Bukas Date: Mon, 26 Feb 2024 12:56:27 +0100 Subject: [PATCH 03/26] edited format --- .../dcp_server/models/cellpose_patchCNN.py | 79 +++++++++++-------- .../dcp_server/models/custom_cellpose.py | 14 +++- src/server/dcp_server/models/multicellpose.py | 31 +++++--- src/server/dcp_server/models/unet.py | 25 ++++-- 4 files changed, 96 insertions(+), 53 deletions(-) diff --git a/src/server/dcp_server/models/cellpose_patchCNN.py b/src/server/dcp_server/models/cellpose_patchCNN.py index a5c5a349..8b134383 100644 --- a/src/server/dcp_server/models/cellpose_patchCNN.py +++ b/src/server/dcp_server/models/cellpose_patchCNN.py @@ -54,20 +54,19 @@ def __init__(self, self.classifier_class = self.model_config.get("classifier").get("model_class", "CellClassifierFCNN") # Initialize the cellpose model and the classifier - self.segmentor = CustomCellposeModel(self.model_config, - self.train_config, - self.eval_config, - "Cellpose") + self.segmentor = CustomCellposeModel( + self.model_config, self.train_config, self.eval_config, "Cellpose" + ) if self.classifier_class == "FCNN": - self.classifier = CellClassifierFCNN(self.model_config, - self.train_config, - self.eval_config) + self.classifier = CellClassifierFCNN( + self.model_config, self.train_config, self.eval_config + ) elif self.classifier_class == "RandomForest": - self.classifier = CellClassifierShallowModel(self.model_config, - self.train_config, - self.eval_config) + self.classifier = CellClassifierShallowModel( + self.model_config, self.train_config, self.eval_config + ) # make sure include mask is set to False if we are using the random forest model self.include_mask = False @@ -89,13 +88,17 @@ def train(self, masks_instances = list(masks[:,0,...]) #[mask.sum(-1) for mask in masks] if masks[0].ndim == 3 else masks self.segmentor.train(deepcopy(imgs), masks_instances) # create patch dataset to train classifier - masks_classes = list(masks[:,1,...]) #[((mask > 0) * np.arange(1, 4)).sum(-1) for mask in masks] - patches, patch_masks, labels = create_patch_dataset(imgs, - masks_classes, - masks_instances, - noise_intensity = self.train_config["classifier"]["train_data"]["noise_intensity"], - max_patch_size = self.train_config["classifier"]["train_data"]["patch_size"], - include_mask = self.include_mask) + masks_classes = list( + masks[:,1,...] + ) #[((mask > 0) * np.arange(1, 4)).sum(-1) for mask in masks] + patches, patch_masks, labels = create_patch_dataset( + imgs, + masks_classes, + masks_instances, + noise_intensity = self.train_config["classifier"]["train_data"]["noise_intensity"], + max_patch_size = self.train_config["classifier"]["train_data"]["patch_size"], + include_mask = self.include_mask + ) x = patches if self.classifier_class == "RandomForest": x = create_dataset_for_rf(patches, patch_masks) @@ -125,15 +128,18 @@ def eval(self, class_mask = np.zeros(instance_mask.shape) max_patch_size = self.eval_config["classifier"]["data"]["patch_size"] - if max_patch_size is None: max_patch_size = find_max_patch_size(instance_mask) + if max_patch_size is None: + max_patch_size = find_max_patch_size(instance_mask) noise_intensity = self.eval_config["classifier"]["data"]["noise_intensity"] # get patches centered around detected objects - patches, patch_masks, instance_labels, _ = get_centered_patches(img, - instance_mask, - max_patch_size, - noise_intensity=noise_intensity, - include_mask=self.include_mask) + patches, patch_masks, instance_labels, _ = get_centered_patches( + img, + instance_mask, + max_patch_size, + noise_intensity=noise_intensity, + include_mask=self.include_mask + ) x = patches if self.classifier_class == "RandomForest": x = create_dataset_for_rf(patches, patch_masks) @@ -142,9 +148,15 @@ def eval(self, patch_class = self.classifier.eval(x[idx]) # Assign predicted class to corresponding location in final_mask patch_class = patch_class.item() if isinstance(patch_class, torch.Tensor) else patch_class - class_mask[instance_mask==instance_labels[idx]] = patch_class + 1 + class_mask[instance_mask==instance_labels[idx]] = ( + patch_class + 1 + ) # Apply mask to final_mask, retaining only regions where cellpose_mask is greater than 0 - final_mask = np.stack((instance_mask, class_mask), axis=self.eval_config['mask_channel_axis']).astype(np.uint16) # size 2xHxW + final_mask = np.stack( + (instance_mask, class_mask), axis=self.eval_config['mask_channel_axis'] + ).astype( + np.uint16 + ) # size 2xHxW return final_mask @@ -259,15 +271,15 @@ def train (self, :type labels: List[int] """ - lr = self.train_config['lr'] - epochs = self.train_config['n_epochs'] - batch_size = self.train_config['batch_size'] - # optimizer_class = self.train_config['optimizer'] + lr = self.train_config["lr"] + epochs = self.train_config["n_epochs"] + batch_size = self.train_config["batch_size"] + # optimizer_class = self.train_config["optimizer"] # Convert input images and labels to tensors # normalize images - imgs = [(img-np.min(img))/(np.max(img)-np.min(img)) for img in imgs] + imgs = [(img - np.min(img)) / (np.max(img) - np.min(img)) for img in imgs] # convert to tensor imgs = torch.stack([torch.from_numpy(img.astype(np.float32)) for img in imgs]) imgs = torch.permute(imgs, (0, 3, 1, 2)) @@ -279,7 +291,10 @@ def train (self, train_dataloader = DataLoader(train_dataset, batch_size=batch_size) loss_fn = nn.CrossEntropyLoss() - optimizer = Adam(params=self.parameters(), lr=lr) #eval(f'{optimizer_class}(params={self.parameters()}, lr={lr})') + optimizer = Adam( + params=self.parameters(), + lr=lr + ) #eval(f'{optimizer_class}(params={self.parameters()}, lr={lr})') # TODO check if we should replace self.parameters with super.parameters() for _ in tqdm(range(epochs), desc="Running CellClassifierFCNN training"): @@ -311,7 +326,7 @@ def eval(self, :rtype: torch.Tensor """ # normalise - img = (img-np.min(img))/(np.max(img)-np.min(img)) + img = (img - np.min(img)) / (np.max(img) - np.min(img)) # convert to tensor img = torch.permute(torch.tensor(img.astype(np.float32)), (2, 0, 1)).unsqueeze(0) preds = self.forward(img) diff --git a/src/server/dcp_server/models/custom_cellpose.py b/src/server/dcp_server/models/custom_cellpose.py index c7aefc10..79757ced 100644 --- a/src/server/dcp_server/models/custom_cellpose.py +++ b/src/server/dcp_server/models/custom_cellpose.py @@ -34,7 +34,7 @@ def __init__(self, """ # Initialize the cellpose model - #super().__init__(**model_config["segmentor"]) + # super().__init__(**model_config["segmentor"]) nn.Module.__init__(self) models.CellposeModel.__init__(self, **model_config["segmentor"]) self.model_config = model_config @@ -68,7 +68,9 @@ def eval(self, :return: mask of the image, list of 2D arrays, or single 3D array (if do_3D=True) labelled image. :rtype: np.ndarray """ - return super().eval(x=img, **self.eval_config["segmentor"])[0] # 0 to take only mask + return super().eval(x=img, **self.eval_config["segmentor"])[ + 0 + ] # 0 to take only mask def train(self, imgs: List[np.ndarray], @@ -88,7 +90,11 @@ def train(self, if masks[0].shape[0] == 2: masks = list(masks[:,0,...]) - super().train(train_data=deepcopy(imgs), train_labels=masks, **self.train_config["segmentor"]) + super().train( + train_data=deepcopy(imgs), + train_labels=masks, + **self.train_config["segmentor"] + ) # compute loss and metric true_bin_masks = [mask>0 for mask in masks] # get binary masks @@ -121,4 +127,4 @@ def masks_to_outlines(self, :return: outlines :rtype: ndarray """ - return utils.masks_to_outlines(mask) #[True, False] outputs + return utils.masks_to_outlines(mask) # [True, False] outputs diff --git a/src/server/dcp_server/models/multicellpose.py b/src/server/dcp_server/models/multicellpose.py index c12a1e6b..8ec01a30 100644 --- a/src/server/dcp_server/models/multicellpose.py +++ b/src/server/dcp_server/models/multicellpose.py @@ -35,10 +35,11 @@ def __init__(self, self.num_of_channels = self.model_config["classifier"]["num_classes"] self.cellpose_models = [ - CustomCellposeModel(self.model_config, - self.train_config, - self.eval_config, - self.model_name + CustomCellposeModel( + self.model_config, + self.train_config, + self.eval_config, + self.model_name ) for _ in range(self.num_of_channels) ] @@ -62,7 +63,9 @@ def train(self, for mask in masks: mask_class = mask.copy() # set all instances in the instance mask not corresponding to the class in question to zero - mask_class[0][mask_class[1]!=(i+1)] = 0 + mask_class[0][ + mask_class[1]!=(i+1) + ] = 0 masks_class.append(mask_class) self.cellpose_models[i].train(imgs, masks_class) @@ -97,7 +100,9 @@ def eval(self, class_masks.append(class_mask) model_confidences.append(confidence) # merge the outputs of the different models using the pixel-wise cell probability mask - merged_mask_instances, class_mask = self.merge_masks(instance_masks, class_masks, model_confidences) + merged_mask_instances, class_mask = self.merge_masks( + instance_masks, class_masks, model_confidences + ) # set all connected components to the same label in the instance mask instance_mask = label_mask(merged_mask_instances>0) # and set the class with the most pixels to that object @@ -106,7 +111,11 @@ def eval(self, vals, counts = np.unique(class_mask[where_inst_id], return_counts=True) class_mask[where_inst_id] = vals[np.argmax(counts)] # take the final mask by stancking instance and class mask - final_mask = np.stack((instance_mask, class_mask), axis=self.eval_config['mask_channel_axis']).astype(np.uint16) + final_mask = np.stack( + (instance_mask, class_mask), axis=self.eval_config['mask_channel_axis'] + ).astype( + np.uint16 + ) return final_mask @@ -138,8 +147,12 @@ def merge_masks(self, max_prob_indices = np.argmax(probabilities, axis=0) # Use the index to select the corresponding mask for each pixel - final_mask_inst = inst_masks[max_prob_indices, np.arange(inst_masks.shape[1])[:, None], np.arange(inst_masks.shape[2])] - final_mask_class = class_masks[max_prob_indices, np.arange(class_masks.shape[1])[:, None], np.arange(class_masks.shape[2])] + final_mask_inst = inst_masks[ + max_prob_indices, np.arange(inst_masks.shape[1])[:, None], np.arange(inst_masks.shape[2]) + ] + final_mask_class = class_masks[ + max_prob_indices, np.arange(class_masks.shape[1])[:, None], np.arange(class_masks.shape[2]) + ] return final_mask_inst, final_mask_class diff --git a/src/server/dcp_server/models/unet.py b/src/server/dcp_server/models/unet.py index dd88e33e..58b33583 100644 --- a/src/server/dcp_server/models/unet.py +++ b/src/server/dcp_server/models/unet.py @@ -157,20 +157,24 @@ def train(self, :type masks: list[numpy.ndarray] """ - lr = self.train_config["classifier"]['lr'] - epochs = self.train_config["classifier"]['n_epochs'] - batch_size = self.train_config["classifier"]['batch_size'] + lr = self.train_config["classifier"]["lr"] + epochs = self.train_config["classifier"]["n_epochs"] + batch_size = self.train_config["classifier"]["batch_size"] # Convert input images and labels to tensors # normalize images - imgs = [(img-np.min(img))/(np.max(img)-np.min(img)) for img in imgs] + imgs = [(img - np.min(img)) / (np.max(img) - np.min(img)) for img in imgs] # convert to tensor - imgs = torch.stack([torch.from_numpy(img.astype(np.float32)) for img in imgs]) + imgs = torch.stack([ + torch.from_numpy(img.astype(np.float32)) for img in imgs + ]) imgs = imgs.unsqueeze(1) if imgs.ndim == 3 else imgs # Classification label mask masks = np.array(masks) - masks = torch.stack([torch.from_numpy(mask[1].astype(np.int16)) for mask in masks]) + masks = torch.stack([ + torch.from_numpy(mask[1].astype(np.int16)) for mask in masks + ]) # Create a training dataset and dataloader train_dataset = TensorDataset(imgs, masks) @@ -213,7 +217,7 @@ def eval(self, """ with torch.no_grad(): # normalise - img = (img-np.min(img))/(np.max(img)-np.min(img)) + img = (img - np.min(img)) / (np.max(img) - np.min(img)) img = torch.from_numpy(img).float().unsqueeze(0) img = img.unsqueeze(1) if img.ndim == 3 else img @@ -223,6 +227,11 @@ def eval(self, instance_mask = label((class_mask > 0).astype(int))[0] - final_mask = np.stack((instance_mask, class_mask), axis=self.eval_config['mask_channel_axis']).astype(np.uint16) + final_mask = np.stack( + (instance_mask, class_mask), + axis=self.eval_config['mask_channel_axis'] + ).astype( + np.uint16 + ) return final_mask From 70997b52ee188d0d3202f02df465490046b587fe Mon Sep 17 00:00:00 2001 From: Christina Bukas Date: Mon, 26 Feb 2024 17:40:09 +0100 Subject: [PATCH 04/26] moved files to new structure --- .../dcp_server/models/cellpose_patchCNN.py | 7 +-- src/server/dcp_server/models/unet.py | 10 +++-- .../dcp_server/{ => utils}/fsimagestorage.py | 0 src/server/dcp_server/utils/helpers.py | 34 ++++++++++++++ .../{utils.py => utils/processing.py} | 44 ++++--------------- 5 files changed, 53 insertions(+), 42 deletions(-) rename src/server/dcp_server/{ => utils}/fsimagestorage.py (100%) create mode 100644 src/server/dcp_server/utils/helpers.py rename src/server/dcp_server/{utils.py => utils/processing.py} (91%) diff --git a/src/server/dcp_server/models/cellpose_patchCNN.py b/src/server/dcp_server/models/cellpose_patchCNN.py index 8b134383..233e2738 100644 --- a/src/server/dcp_server/models/cellpose_patchCNN.py +++ b/src/server/dcp_server/models/cellpose_patchCNN.py @@ -14,7 +14,8 @@ from sklearn.exceptions import NotFittedError from dcp_server.models import Model, CustomCellposeModel -from dcp_server.utils import ( +from dcp_server.utils.processing import ( + normalise, get_centered_patches, find_max_patch_size, create_patch_dataset, @@ -279,7 +280,7 @@ def train (self, # Convert input images and labels to tensors # normalize images - imgs = [(img - np.min(img)) / (np.max(img) - np.min(img)) for img in imgs] + imgs = [normalise_image(img) for img in imgs] # convert to tensor imgs = torch.stack([torch.from_numpy(img.astype(np.float32)) for img in imgs]) imgs = torch.permute(imgs, (0, 3, 1, 2)) @@ -326,7 +327,7 @@ def eval(self, :rtype: torch.Tensor """ # normalise - img = (img - np.min(img)) / (np.max(img) - np.min(img)) + img = normalise(img) # convert to tensor img = torch.permute(torch.tensor(img.astype(np.float32)), (2, 0, 1)).unsqueeze(0) preds = self.forward(img) diff --git a/src/server/dcp_server/models/unet.py b/src/server/dcp_server/models/unet.py index 58b33583..4b3f7fe8 100644 --- a/src/server/dcp_server/models/unet.py +++ b/src/server/dcp_server/models/unet.py @@ -1,13 +1,15 @@ from typing import List +from tqdm import tqdm +import numpy as np +from scipy.ndimage import label + import torch from torch import nn from torch.optim import Adam from torch.utils.data import TensorDataset, DataLoader -from tqdm import tqdm -import numpy as np -from scipy.ndimage import label from dcp_server.models import Model +from dcp_server.utils.processing import normalise class UNet(Model, nn.Module): @@ -163,7 +165,7 @@ def train(self, # Convert input images and labels to tensors # normalize images - imgs = [(img - np.min(img)) / (np.max(img) - np.min(img)) for img in imgs] + imgs = [normalise(img) for img in imgs] # convert to tensor imgs = torch.stack([ torch.from_numpy(img.astype(np.float32)) for img in imgs diff --git a/src/server/dcp_server/fsimagestorage.py b/src/server/dcp_server/utils/fsimagestorage.py similarity index 100% rename from src/server/dcp_server/fsimagestorage.py rename to src/server/dcp_server/utils/fsimagestorage.py diff --git a/src/server/dcp_server/utils/helpers.py b/src/server/dcp_server/utils/helpers.py new file mode 100644 index 00000000..f706b5d0 --- /dev/null +++ b/src/server/dcp_server/utils/helpers.py @@ -0,0 +1,34 @@ +from pathlib import Path +import json + + +def read_config(name, config_path = 'config.cfg') -> dict: + """Reads the configuration file + + :param name: name of the section you want to read (e.g. 'setup','train') + :type name: string + :param config_path: path to the configuration file, defaults to 'config.cfg' + :type config_path: str, optional + :return: dictionary from the config section given by name + :rtype: dict + """ + with open(config_path) as config_file: + config_dict = json.load(config_file) + # Check if config file has main mandatory keys + assert all([i in config_dict.keys() for i in ['setup', 'service', 'model', 'train', 'eval']]) + return config_dict[name] + +def get_path_stem(filepath): return str(Path(filepath).stem) + + +def get_path_name(filepath): return str(Path(filepath).name) + + +def get_path_parent(filepath): return str(Path(filepath).parent) + + +def join_path(root_dir, filepath): return str(Path(root_dir, filepath)) + + +def get_file_extension(file): return str(Path(file).suffix) + diff --git a/src/server/dcp_server/utils.py b/src/server/dcp_server/utils/processing.py similarity index 91% rename from src/server/dcp_server/utils.py rename to src/server/dcp_server/utils/processing.py index 7e6818e8..9ca7e770 100644 --- a/src/server/dcp_server/utils.py +++ b/src/server/dcp_server/utils/processing.py @@ -1,5 +1,3 @@ -from pathlib import Path -import json from copy import deepcopy import numpy as np from scipy.ndimage import find_objects @@ -8,45 +6,21 @@ import SimpleITK as sitk from radiomics import shape2D -def read_config(name, config_path = 'config.cfg') -> dict: - """Reads the configuration file - - :param name: name of the section you want to read (e.g. 'setup','train') - :type name: string - :param config_path: path to the configuration file, defaults to 'config.cfg' - :type config_path: str, optional - :return: dictionary from the config section given by name - :rtype: dict - """ - with open(config_path) as config_file: - config_dict = json.load(config_file) - # Check if config file has main mandatory keys - assert all([i in config_dict.keys() for i in ['setup', 'service', 'model', 'train', 'eval']]) - return config_dict[name] - -def get_path_stem(filepath): return str(Path(filepath).stem) - - -def get_path_name(filepath): return str(Path(filepath).name) - - -def get_path_parent(filepath): return str(Path(filepath).parent) - - -def join_path(root_dir, filepath): return str(Path(root_dir, filepath)) - - -def get_file_extension(file): return str(Path(file).suffix) - - +def normalise(img, norm='min-max'): + """ Normalises the image based on the chosen method. Currently available methods are: + - min max normalisation + param + """ + if norm=='min-max': + return (img - np.min(img)) / (np.max(img) - np.min(img)) + def crop_centered_padded_patch(img: np.ndarray, patch_center_xy, patch_size, obj_label, mask: np.ndarray=None, noise_intensity=None) -> np.ndarray: - """ - Crop a patch from an array `x` centered at coordinates `c` with size `p`, and apply padding if necessary. + """ Crop a patch from an array `x` centered at coordinates `c` with size `p`, and apply padding if necessary. Args: img (np.ndarray): The input array from which the patch will be cropped. From 6e2ac5183fa9384d30923810ab5fce1f7c2a4ce7 Mon Sep 17 00:00:00 2001 From: Christina Bukas Date: Tue, 5 Mar 2024 11:45:54 +0100 Subject: [PATCH 05/26] changed imports and removed model from inheritance --- src/server/dcp_server/config.cfg | 7 +++---- src/server/dcp_server/main.py | 2 +- src/server/dcp_server/models/cellpose_patchCNN.py | 4 ++-- src/server/dcp_server/models/custom_cellpose.py | 6 +++--- src/server/dcp_server/models/multicellpose.py | 4 ++-- src/server/dcp_server/models/unet.py | 4 ++-- src/server/dcp_server/segmentationclasses.py | 12 ++++++------ src/server/dcp_server/service.py | 5 +++-- src/server/dcp_server/utils/fsimagestorage.py | 14 +++++++------- 9 files changed, 29 insertions(+), 29 deletions(-) diff --git a/src/server/dcp_server/config.cfg b/src/server/dcp_server/config.cfg index b2d05cff..f288e791 100644 --- a/src/server/dcp_server/config.cfg +++ b/src/server/dcp_server/config.cfg @@ -1,21 +1,20 @@ { "setup": { "segmentation": "GeneralSegmentation", - "model_to_use": "CellposeMultichannel", + "model_to_use": "MultiCellpose", "accepted_types": [".jpg", ".jpeg", ".png", ".tiff", ".tif"], "seg_name_string": "_seg" }, "service": { "runner_name": "bento_runner", - "bento_model_path": "cp-multi", + "bento_model_path": "test3", "service_name": "data-centric-platform", "port": 7010 }, "model": { "segmentor": { - "model_class": "CustomCellposeModel", "model_type": "cyto" }, @@ -35,7 +34,7 @@ "train":{ "segmentor":{ - "n_epochs": 5, + "n_epochs": 3, "channels": [0,0], "min_train_masks": 1 }, diff --git a/src/server/dcp_server/main.py b/src/server/dcp_server/main.py index 9add94b8..2e1772b2 100644 --- a/src/server/dcp_server/main.py +++ b/src/server/dcp_server/main.py @@ -1,7 +1,7 @@ import subprocess from os import path import sys -from utils import read_config +from dcp_server.utils.helpers import read_config def main(): '''entry point to bentoml diff --git a/src/server/dcp_server/models/cellpose_patchCNN.py b/src/server/dcp_server/models/cellpose_patchCNN.py index 233e2738..dc04d143 100644 --- a/src/server/dcp_server/models/cellpose_patchCNN.py +++ b/src/server/dcp_server/models/cellpose_patchCNN.py @@ -13,7 +13,7 @@ from sklearn.metrics import f1_score, log_loss from sklearn.exceptions import NotFittedError -from dcp_server.models import Model, CustomCellposeModel +from dcp_server.models import CustomCellposeModel # Model, from dcp_server.utils.processing import ( normalise, get_centered_patches, @@ -23,7 +23,7 @@ ) -class CellposePatchCNN(Model): +class CellposePatchCNN(): #Model): """ Cellpose & patches of cells and then cnn to classify each patch """ diff --git a/src/server/dcp_server/models/custom_cellpose.py b/src/server/dcp_server/models/custom_cellpose.py index 79757ced..7730624d 100644 --- a/src/server/dcp_server/models/custom_cellpose.py +++ b/src/server/dcp_server/models/custom_cellpose.py @@ -9,9 +9,9 @@ from cellpose.metrics import aggregated_jaccard_index from cellpose.dynamics import labels_to_flows -from dcp_server.models import Model +#from dcp_server.models import Model -class CustomCellposeModel(models.CellposeModel, Model): +class CustomCellposeModel(models.CellposeModel): #, Model): """Custom cellpose model inheriting the attributes and functions from the original CellposeModel and implementing additional attributes and methods needed for this project. """ @@ -35,7 +35,7 @@ def __init__(self, # Initialize the cellpose model # super().__init__(**model_config["segmentor"]) - nn.Module.__init__(self) + #nn.Module.__init__(self) models.CellposeModel.__init__(self, **model_config["segmentor"]) self.model_config = model_config self.train_config = train_config diff --git a/src/server/dcp_server/models/multicellpose.py b/src/server/dcp_server/models/multicellpose.py index 8ec01a30..352a8c45 100644 --- a/src/server/dcp_server/models/multicellpose.py +++ b/src/server/dcp_server/models/multicellpose.py @@ -2,9 +2,9 @@ import numpy as np from skimage.measure import label as label_mask -from dcp_server.models import Model, CustomCellposeModel +from dcp_server.models import CustomCellposeModel # Model, -class MultiCellpose(Model): +class MultiCellpose(): #Model): ''' Multichannel image segmentation model. Run the separate CustomCellposeModel models for each channel return the mask corresponding to each object type. diff --git a/src/server/dcp_server/models/unet.py b/src/server/dcp_server/models/unet.py index 4b3f7fe8..690632ed 100644 --- a/src/server/dcp_server/models/unet.py +++ b/src/server/dcp_server/models/unet.py @@ -8,10 +8,10 @@ from torch.optim import Adam from torch.utils.data import TensorDataset, DataLoader -from dcp_server.models import Model +#from dcp_server.models import Model from dcp_server.utils.processing import normalise -class UNet(Model, nn.Module): +class UNet(nn.Module): # Model """ Unet is a convolutional neural network architecture for semantic segmentation. diff --git a/src/server/dcp_server/segmentationclasses.py b/src/server/dcp_server/segmentationclasses.py index ccb5fff8..f6763c5d 100644 --- a/src/server/dcp_server/segmentationclasses.py +++ b/src/server/dcp_server/segmentationclasses.py @@ -1,8 +1,8 @@ -from dcp_server import utils import os +from dcp_server.utils import helpers # Import configuration -setup_config = utils.read_config('setup', config_path = 'config.cfg') +setup_config = helpers.read_config('setup', config_path = 'config.cfg') class GeneralSegmentation(): """Segmentation class. Defining the main functions needed for this project and served by service - segment image and train on images. @@ -35,7 +35,7 @@ async def segment_image(self, input_path, list_of_images): # Load the image img = self.imagestorage.load_image(img_filepath) # Get size properties - height, width, z_axis = self.imagestorage.get_image_size_properties(img, utils.get_file_extension(img_filepath)) + height, width, z_axis = self.imagestorage.get_image_size_properties(img, helpers.get_file_extension(img_filepath)) img = self.imagestorage.rescale_image(img, height, width) # Add channel ax into the model's evaluation parameters dictionary self.model.eval_config['segmentor']['z_axis'] = z_axis @@ -44,7 +44,7 @@ async def segment_image(self, input_path, list_of_images): # Resize the mask mask = self.imagestorage.resize_mask(mask, height, width, self.model.eval_config['mask_channel_axis'], order=0) # Save segmentation - seg_name = utils.get_path_stem(img_filepath) + setup_config['seg_name_string'] + '.tiff' + seg_name = helpers.get_path_stem(img_filepath) + setup_config['seg_name_string'] + '.tiff' self.imagestorage.save_image(os.path.join(input_path, seg_name), mask) async def train(self, input_path): @@ -108,7 +108,7 @@ async def segment_image(self, input_path, list_of_images): # Load the image img = self.imagestorage.load_image(img_filepath) # Get size properties - height, width, channel_ax = self.imagestorage.get_image_size_properties(img, utils.get_file_extension(img_filepath)) + height, width, channel_ax = self.imagestorage.get_image_size_properties(img, helpers.get_file_extension(img_filepath)) img = self.imagestorage.rescale_image(img, height, width, channel_ax) # Add channel ax into the model's evaluation parameters dictionary @@ -128,5 +128,5 @@ async def segment_image(self, input_path, list_of_images): new_mask[outlines==True] = 1 # Save segmentation - seg_name = utils.get_path_stem(img_filepath) + setup_config['seg_name_string'] + '.tiff' + seg_name = helpers.get_path_stem(img_filepath) + setup_config['seg_name_string'] + '.tiff' self.imagestorage.save_image(os.path.join(input_path, seg_name), new_mask) diff --git a/src/server/dcp_server/service.py b/src/server/dcp_server/service.py index 0be4fb0d..cd4ead6e 100644 --- a/src/server/dcp_server/service.py +++ b/src/server/dcp_server/service.py @@ -1,9 +1,10 @@ from __future__ import annotations import bentoml import typing as t -from dcp_server.fsimagestorage import FilesystemImageStorage from dcp_server.serviceclasses import CustomBentoService, CustomRunnable -from dcp_server.utils import read_config + +from dcp_server.utils.fsimagestorage import FilesystemImageStorage +from dcp_server.utils.helpers import read_config import sys, inspect diff --git a/src/server/dcp_server/utils/fsimagestorage.py b/src/server/dcp_server/utils/fsimagestorage.py index f4fbe8ef..66061c73 100644 --- a/src/server/dcp_server/utils/fsimagestorage.py +++ b/src/server/dcp_server/utils/fsimagestorage.py @@ -3,10 +3,10 @@ from skimage.io import imread, imsave from skimage.transform import resize, rescale -from dcp_server import utils +from dcp_server.utils import helpers # Import configuration -setup_config = utils.read_config('setup', config_path = 'config.cfg') +setup_config = helpers.read_config('setup', config_path = 'config.cfg') class FilesystemImageStorage(): """Class used to deal with everything related to image storing and processing - loading, saving, transforming... @@ -49,7 +49,7 @@ def search_images(self, directory): directory = os.path.join(self.root_dir, directory) seg_files = [file_name for file_name in os.listdir(directory) if setup_config['seg_name_string'] in file_name] # Take the image files - difference between the list of all the files in the directory and the list of seg files and only file extensions currently accepted - image_files = [os.path.join(directory, file_name) for file_name in os.listdir(directory) if (file_name not in seg_files) and (utils.get_file_extension(file_name) in setup_config['accepted_types'])] + image_files = [os.path.join(directory, file_name) for file_name in os.listdir(directory) if (file_name not in seg_files) and (helpers.get_file_extension(file_name) in setup_config['accepted_types'])] return image_files def search_segs(self, cur_selected_img): @@ -61,13 +61,13 @@ def search_segs(self, cur_selected_img): :rtype: list """ # Check the directory the image was selected from: - img_directory = utils.get_path_parent(os.path.join(self.root_dir, cur_selected_img)) + img_directory = helpers.get_path_parent(os.path.join(self.root_dir, cur_selected_img)) # Take all segmentations of the image from the current directory: - search_string = utils.get_path_stem(cur_selected_img) + setup_config['seg_name_string'] + search_string = helpers.get_path_stem(cur_selected_img) + setup_config['seg_name_string'] #seg_files = [os.path.join(img_directory, file_name) for file_name in os.listdir(img_directory) if search_string in file_name] # TODO: check where this is used - copied the command from app's search_segs function (to fix the 1_seg and 11_seg bug) - seg_files = [os.path.join(img_directory, file_name) for file_name in os.listdir(img_directory) if (search_string == utils.get_path_stem(file_name) or str(file_name).startswith(search_string))] + seg_files = [os.path.join(img_directory, file_name) for file_name in os.listdir(img_directory) if (search_string == helpers.get_path_stem(file_name) or str(file_name).startswith(search_string))] return seg_files @@ -98,7 +98,7 @@ def get_unsupported_files(self, directory): """ return [file_name for file_name in os.listdir(os.path.join(self.root_dir, directory)) - if not file_name.startswith('.') and utils.get_file_extension(file_name) not in setup_config['accepted_types']] + if not file_name.startswith('.') and helpers.get_file_extension(file_name) not in setup_config['accepted_types']] def get_image_size_properties(self, img, file_extension): """Get properties of the image size From 5ee68bf15777e03b329f670b881f63d7a3b7914c Mon Sep 17 00:00:00 2001 From: Christina Bukas Date: Tue, 5 Mar 2024 18:37:55 +0100 Subject: [PATCH 06/26] restructuring --- src/server/dcp_server/config.cfg | 23 +- src/server/dcp_server/config_instance.cfg | 2 +- src/server/dcp_server/models/__init__.py | 8 +- .../dcp_server/models/cellpose_patchCNN.py | 405 ------------------ src/server/dcp_server/models/classifiers.py | 250 +++++++++++ .../dcp_server/models/custom_cellpose.py | 93 ++-- .../dcp_server/models/inst_to_multi_seg.py | 173 ++++++++ src/server/dcp_server/models/model.py | 9 +- src/server/dcp_server/models/multicellpose.py | 29 +- src/server/dcp_server/models/unet.py | 190 ++++---- src/server/dcp_server/segmentationclasses.py | 31 +- src/server/dcp_server/service.py | 8 +- src/server/dcp_server/utils/fsimagestorage.py | 131 ++++-- src/server/dcp_server/utils/processing.py | 46 +- 14 files changed, 760 insertions(+), 638 deletions(-) delete mode 100644 src/server/dcp_server/models/cellpose_patchCNN.py create mode 100644 src/server/dcp_server/models/classifiers.py create mode 100644 src/server/dcp_server/models/inst_to_multi_seg.py diff --git a/src/server/dcp_server/config.cfg b/src/server/dcp_server/config.cfg index f288e791..8b7c1039 100644 --- a/src/server/dcp_server/config.cfg +++ b/src/server/dcp_server/config.cfg @@ -1,14 +1,14 @@ { "setup": { "segmentation": "GeneralSegmentation", - "model_to_use": "MultiCellpose", + "model_to_use": "Inst2MultiSeg", "accepted_types": [".jpg", ".jpeg", ".png", ".tiff", ".tif"], "seg_name_string": "_seg" }, "service": { "runner_name": "bento_runner", - "bento_model_path": "test3", + "bento_model_path": "test5", "service_name": "data-centric-platform", "port": 7010 }, @@ -19,9 +19,9 @@ }, "classifier":{ - "model_class": "RandomForest", + "model_class": "PatchClassifier", "in_channels": 1, - "num_classes": 2, + "num_classses": 2, "features":[64,128,256,512], "black_bg": "False", "include_mask": "False" @@ -29,7 +29,11 @@ }, "data": { - "data_root": "data" + "data_root": "data", + "patch_size": 64, + "noise_intensity": 5, + "gray": "True", + "rescale": "True" }, "train":{ @@ -39,11 +43,6 @@ "min_train_masks": 1 }, "classifier":{ - "train_data":{ - "patch_size": 64, - "noise_intensity": 5, - "num_classes": 3 - }, "n_epochs": 2, "lr": 0.001, "batch_size": 1, @@ -59,10 +58,6 @@ "batch_size": 1 }, "classifier": { - "data":{ - "patch_size": 64, - "noise_intensity": 5 - } }, "mask_channel_axis": 0 } diff --git a/src/server/dcp_server/config_instance.cfg b/src/server/dcp_server/config_instance.cfg index da9cfd84..22b89745 100644 --- a/src/server/dcp_server/config_instance.cfg +++ b/src/server/dcp_server/config_instance.cfg @@ -6,7 +6,7 @@ }, "service": { - "model_to_use": "CustomCellposeModel", + "model_to_use": "CustomCellpose", "save_model_path": "cells", "runner_name": "cellpose_runner", "service_name": "data-centric-platform", diff --git a/src/server/dcp_server/models/__init__.py b/src/server/dcp_server/models/__init__.py index 2e134350..ba003253 100644 --- a/src/server/dcp_server/models/__init__.py +++ b/src/server/dcp_server/models/__init__.py @@ -1,11 +1,11 @@ # dcp_server.models/__init__.py -from .custom_cellpose import CustomCellposeModel -from .cellpose_patchCNN import CellposePatchCNN +from .custom_cellpose import CustomCellpose +from .inst_to_multi_seg import Inst2MultiSeg from .multicellpose import MultiCellpose from .unet import UNet -__all__ = ['CustomCellposeModel', - 'CellposePatchCNN', +__all__ = ['CustomCellpose', + 'Inst2MultiSeg', 'MultiCellpose', 'UNet'] diff --git a/src/server/dcp_server/models/cellpose_patchCNN.py b/src/server/dcp_server/models/cellpose_patchCNN.py deleted file mode 100644 index dc04d143..00000000 --- a/src/server/dcp_server/models/cellpose_patchCNN.py +++ /dev/null @@ -1,405 +0,0 @@ -from copy import deepcopy -from tqdm import tqdm -from typing import List -import numpy as np - -import torch -from torch import nn -from torch.optim import Adam -from torch.utils.data import TensorDataset, DataLoader -from torchmetrics import F1Score - -from sklearn.ensemble import RandomForestClassifier -from sklearn.metrics import f1_score, log_loss -from sklearn.exceptions import NotFittedError - -from dcp_server.models import CustomCellposeModel # Model, -from dcp_server.utils.processing import ( - normalise, - get_centered_patches, - find_max_patch_size, - create_patch_dataset, - create_dataset_for_rf -) - - -class CellposePatchCNN(): #Model): - """ - Cellpose & patches of cells and then cnn to classify each patch - """ - - def __init__(self, - model_config: dict, - train_config: dict, - eval_config:dict, - model_name:str - ) -> None: - """Constructs all the necessary attributes for the CellposePatchCNN - - :param model_config: Model configuration. - :type model_config: dict - :param train_config: Training configuration. - :type train_config: dict - :param eval_config: Evaluation configuration. - :type eval_config: dict - :param model_name: Name of the model. - :type model_name: str - """ - super().__init__() - - self.model_config = model_config - self.train_config = train_config - self.eval_config = eval_config - self.model_name = model_name - self.include_mask = self.model_config["classifier"]["include_mask"] - self.classifier_class = self.model_config.get("classifier").get("model_class", "CellClassifierFCNN") - - # Initialize the cellpose model and the classifier - self.segmentor = CustomCellposeModel( - self.model_config, self.train_config, self.eval_config, "Cellpose" - ) - - if self.classifier_class == "FCNN": - self.classifier = CellClassifierFCNN( - self.model_config, self.train_config, self.eval_config - ) - - elif self.classifier_class == "RandomForest": - self.classifier = CellClassifierShallowModel( - self.model_config, self.train_config, self.eval_config - ) - # make sure include mask is set to False if we are using the random forest model - self.include_mask = False - - - def train(self, - imgs: List[np.ndarray], - masks: List[np.ndarray] - ) -> None: - """Trains the given model. First trains the segmentor and then the clasiffier. - - :param imgs: images to train on (training data) - :type imgs: List[np.ndarray] - :param masks: masks of the given images (training labels) - :type masks: List[np.ndarray] of same shape as output of eval, i.e. one channel instances, - second channel classes, so [2, H, W] or [2, 3, H, W] for 3D - """ - # train cellpose - masks = np.array(masks) - masks_instances = list(masks[:,0,...]) #[mask.sum(-1) for mask in masks] if masks[0].ndim == 3 else masks - self.segmentor.train(deepcopy(imgs), masks_instances) - # create patch dataset to train classifier - masks_classes = list( - masks[:,1,...] - ) #[((mask > 0) * np.arange(1, 4)).sum(-1) for mask in masks] - patches, patch_masks, labels = create_patch_dataset( - imgs, - masks_classes, - masks_instances, - noise_intensity = self.train_config["classifier"]["train_data"]["noise_intensity"], - max_patch_size = self.train_config["classifier"]["train_data"]["patch_size"], - include_mask = self.include_mask - ) - x = patches - if self.classifier_class == "RandomForest": - x = create_dataset_for_rf(patches, patch_masks) - # train classifier - self.classifier.train(x, labels) - # and compute metric and loss - self.metric = (self.segmentor.metric + self.classifier.metric) / 2 - self.loss = (self.segmentor.loss + self.classifier.loss)/2 - - def eval(self, - img: np.ndarray - ) -> np.ndarray: - """Evaluate the model on the provided image and return the final mask. - - :param img: Input image for evaluation. - :type img: np.ndarray[np.uint8] - :return: Final mask containing instance mask and class masks. - :rtype: np.ndarray[np.uint16] - """ - # TBD we assume image is 2D [H, W] (see fsimage storage) - # The final mask which is returned should have - # first channel the output of cellpose and the rest are the class channels - with torch.no_grad(): - # get instance mask from segmentor - instance_mask = self.segmentor.eval(img) - # find coordinates of detected objects - class_mask = np.zeros(instance_mask.shape) - - max_patch_size = self.eval_config["classifier"]["data"]["patch_size"] - if max_patch_size is None: - max_patch_size = find_max_patch_size(instance_mask) - noise_intensity = self.eval_config["classifier"]["data"]["noise_intensity"] - - # get patches centered around detected objects - patches, patch_masks, instance_labels, _ = get_centered_patches( - img, - instance_mask, - max_patch_size, - noise_intensity=noise_intensity, - include_mask=self.include_mask - ) - x = patches - if self.classifier_class == "RandomForest": - x = create_dataset_for_rf(patches, patch_masks) - # loop over patches and create classification mask - for idx in range(len(x)): - patch_class = self.classifier.eval(x[idx]) - # Assign predicted class to corresponding location in final_mask - patch_class = patch_class.item() if isinstance(patch_class, torch.Tensor) else patch_class - class_mask[instance_mask==instance_labels[idx]] = ( - patch_class + 1 - ) - # Apply mask to final_mask, retaining only regions where cellpose_mask is greater than 0 - final_mask = np.stack( - (instance_mask, class_mask), axis=self.eval_config['mask_channel_axis'] - ).astype( - np.uint16 - ) # size 2xHxW - - return final_mask - - -class CellClassifierFCNN(nn.Module): - - """Fully convolutional classifier for cell images. NOTE -> This model cannot be used as a standalone model in DCP - - :param model_config: Model configuration. - :type model_config: dict - :param train_config: Training configuration. - :type train_config: dict - :param eval_config: Evaluation configuration. - :type eval_config: dict - - """ - - def __init__(self, - model_config: dict, - train_config: dict, - eval_config: dict - ) -> None: - """Initialize the fully convolutional classifier. - - :param model_config: Model configuration. - :type model_config: dict - :param train_config: Training configuration. - :type train_config: dict - :param eval_config: Evaluation configuration. - :type eval_config: dict - """ - super().__init__() - - self.in_channels = model_config["classifier"].get("in_channels",1) - self.num_classes = model_config["classifier"].get("num_classes",3) - - self.train_config = train_config["classifier"] - self.eval_config = eval_config["classifier"] - - self.include_mask = model_config["classifier"]["include_mask"] - self.in_channels = self.in_channels + 1 if self.include_mask else self.in_channels - - self.layer1 = nn.Sequential( - nn.Conv2d(self.in_channels, 16, 3, 2, 5), - nn.BatchNorm2d(16), - nn.ReLU(), - nn.Dropout2d(p=0.2), - ) - - self.layer2 = nn.Sequential( - nn.Conv2d(16, 64, 3, 1, 1), - nn.BatchNorm2d(64), - nn.ReLU(), - nn.Dropout2d(p=0.2), - ) - - self.layer3 = nn.Sequential( - nn.Conv2d(64, 128, 3, 2, 4), - nn.BatchNorm2d(128), - nn.ReLU(), - nn.Dropout2d(p=0.2), - ) - self.final_conv = nn.Conv2d(128, self.num_classes, 1) - self.pooling = nn.AdaptiveMaxPool2d(1) - - self.metric_fn = F1Score(num_classes=self.num_classes, task="multiclass") - - def update_configs(self, - train_config: dict, - eval_config: dict - ) -> None: - """ - Update the training and evaluation configurations. - - :param train_config: Dictionary containing the training configuration. - :type train_config: dict - :param eval_config: Dictionary containing the evaluation configuration. - :type eval_config: dict - """ - self.train_config = train_config - self.eval_config = eval_config - - def forward(self, - x: torch.Tensor - ) -> torch.Tensor: - """ Performs forward pass of the CellClassifierFCNN. - - :param x: Input tensor. - :type x: torch.Tensor - :return: Output tensor after passing through the network. - :rtype: torch.Tensor - """ - - x = self.layer1(x) - x = self.layer2(x) - x = self.layer3(x) - - x = self.final_conv(x) - x = self.pooling(x) - x = x.view(x.size(0), -1) - return x - - def train (self, - imgs: List[np.ndarray], - labels: List[np.ndarray] - ) -> None: - """Trains the given model - - :param imgs: List of input images with shape (3, dx, dy). - :type imgs: List[np.ndarray[np.uint8]] - :param labels: List of classification labels. - :type labels: List[int] - """ - - lr = self.train_config["lr"] - epochs = self.train_config["n_epochs"] - batch_size = self.train_config["batch_size"] - # optimizer_class = self.train_config["optimizer"] - - # Convert input images and labels to tensors - - # normalize images - imgs = [normalise_image(img) for img in imgs] - # convert to tensor - imgs = torch.stack([torch.from_numpy(img.astype(np.float32)) for img in imgs]) - imgs = torch.permute(imgs, (0, 3, 1, 2)) - # Your classification label mask - labels = torch.LongTensor([label for label in labels]) - - # Create a training dataset and dataloader - train_dataset = TensorDataset(imgs, labels) - train_dataloader = DataLoader(train_dataset, batch_size=batch_size) - - loss_fn = nn.CrossEntropyLoss() - optimizer = Adam( - params=self.parameters(), - lr=lr - ) #eval(f'{optimizer_class}(params={self.parameters()}, lr={lr})') - # TODO check if we should replace self.parameters with super.parameters() - - for _ in tqdm(range(epochs), desc="Running CellClassifierFCNN training"): - self.loss, self.metric = 0, 0 - for data in train_dataloader: - imgs, labels = data - - optimizer.zero_grad() - preds = self.forward(imgs) - - l = loss_fn(preds, labels) - l.backward() - optimizer.step() - self.loss += l.item() - - self.metric += self.metric_fn(preds, labels) - - self.loss /= len(train_dataloader) - self.metric /= len(train_dataloader) - - def eval(self, - img: np.ndarray - ) -> torch.Tensor: - """Evaluates the model on the provided image and return the predicted label. - - :param img: Input image for evaluation. - :type img: np.ndarray[np.uint8] - :return: y_hat - predicted label. - :rtype: torch.Tensor - """ - # normalise - img = normalise(img) - # convert to tensor - img = torch.permute(torch.tensor(img.astype(np.float32)), (2, 0, 1)).unsqueeze(0) - preds = self.forward(img) - y_hat = torch.argmax(preds, 1) - return y_hat - - -class CellClassifierShallowModel: - """ - This class implements a shallow model for cell classification using scikit-learn. - """ - - def __init__(self, - model_config: dict, - train_config: dict, - eval_config: dict - ) -> None: - """Constructs all the necessary attributes for the CellClassifierShallowModel - - :param model_config: Model configuration. - :type model_config: dict - :param train_config: Training configuration. - :type train_config: dict - :param eval_config: Evaluation configuration. - :type eval_config: dict - """ - - self.model_config = model_config - self.train_config = train_config - self.eval_config = eval_config - - self.model = RandomForestClassifier() # TODO chnage config so RandomForestClassifier accepts input params - - - def train(self, - X_train: np.ndarray, - y_train: np.ndarray - ) -> None: - """Trains the model using the provided training data. - - :param X_train: Features of the training data. - :type X_train: numpy.ndarray - :param y_train: Labels of the training data. - :type y_train: numpy.ndarray - """ - - self.model.fit(X_train,y_train) - - y_hat = self.model.predict(X_train) - y_hat_proba = self.model.predict_proba(X_train) - - self.metric = f1_score(y_train, y_hat, average='micro') - # Binary Cross Entrop Loss - self.loss = log_loss(y_train, y_hat_proba) - - - def eval(self, - X_test: np.ndarray - ) -> np.ndarray: - """Evaluates the model on the provided test data. - - :param X_test: Features of the test data. - :type X_test: numpy.ndarray - :return: y_hat - predicted labels. - :rtype: numpy.ndarray - """ - - X_test = X_test.reshape(1,-1) - - try: - y_hat = self.model.predict(X_test) - except NotFittedError as e: - y_hat = np.zeros(X_test.shape[0]) - - return y_hat diff --git a/src/server/dcp_server/models/classifiers.py b/src/server/dcp_server/models/classifiers.py new file mode 100644 index 00000000..6b4e9fa1 --- /dev/null +++ b/src/server/dcp_server/models/classifiers.py @@ -0,0 +1,250 @@ +from tqdm import tqdm +from typing import List +import numpy as np + +import torch +from torch import nn +from torch.optim import Adam +from torch.utils.data import TensorDataset, DataLoader +from torchmetrics import F1Score + +from sklearn.ensemble import RandomForestClassifier +from sklearn.metrics import f1_score, log_loss +from sklearn.exceptions import NotFittedError + + +class PatchClassifier(nn.Module): + + """ Fully convolutional classifier for cell images. NOTE -> This model cannot be used as a standalone model in DCP + """ + + def __init__(self, + model_name: str, + model_config: dict, + data_config: dict, + train_config: dict, + eval_config: dict + ) -> None: + """ Initialize the fully convolutional classifier. + + :param model_name: Name of the model. + :type model_name: str + :param model_config: Model configuration. + :type model_config: dict + :param data_config: Data configuration. + :type data_config: dict + :param train_config: Training configuration. + :type train_config: dict + :param eval_config: Evaluation configuration. + :type eval_config: dict + """ + super().__init__() + + + self.model_name = model_name + self.model_config = model_config["classifier"] + self.data_config = data_config + self.train_config = train_config["classifier"] + self.eval_config = eval_config["classifier"] + + self.build_model() + + def train (self, + imgs: List[np.ndarray], + labels: List[np.ndarray] + ) -> None: + """ Trains the given model + + :param imgs: List of input images with shape (3, dx, dy). + :type imgs: List[np.ndarray[np.uint8]] + :param labels: List of classification labels. + :type labels: List[int] + """ + + # Convert input images and labels to tensors + imgs = torch.stack([torch.from_numpy(img.astype(np.float32)) for img in imgs]) + imgs = torch.permute(imgs, (0, 3, 1, 2)) + # Your classification label mask + labels = torch.LongTensor([label for label in labels]) + + # Create a training dataset and dataloader + train_dataloader = DataLoader( + TensorDataset(imgs, labels), + batch_size=self.train_config["batch_size"]) + + loss_fn = nn.CrossEntropyLoss() + optimizer = Adam( + params=self.parameters(), + lr=self.train_config["lr"] + ) + # optimizer_class = self.train_config["optimizer"] + #eval(f'{optimizer_class}(params={self.parameters()}, lr={lr})') + + # TODO check if we should replace self.parameters with super.parameters() + + for _ in tqdm( + range(self.train_config["n_epochs"]), + desc="Running PatchClassifier training" + ): + + self.loss, self.metric = 0, 0 + for data in train_dataloader: + imgs, labels = data + + optimizer.zero_grad() + preds = self.forward(imgs) + + l = loss_fn(preds, labels) + l.backward() + optimizer.step() + self.loss += l.item() + + self.metric += self.metric_fn(preds, labels) + + self.loss /= len(train_dataloader) + self.metric /= len(train_dataloader) + + def eval(self, + img: np.ndarray + ) -> torch.Tensor: + """ Evaluates the model on the provided image and return the predicted label. + + :param img: Input image for evaluation. + :type img: np.ndarray[np.uint8] + :return: y_hat - predicted label. + :rtype: torch.Tensor + """ + # convert to tensor + img = torch.permute(torch.tensor(img.astype(np.float32)), (2, 0, 1)).unsqueeze(0) + preds = self.forward(img) + y_hat = torch.argmax(preds, 1) + return y_hat + + def build_model(self): + """ Builds the PatchClassifer. + """ + in_channels = self.model_config["in_channels"] + in_channels = in_channels + 1 if self.model_config["include_mask"] else in_channels + + self.layer1 = nn.Sequential( + nn.Conv2d(in_channels, 16, 3, 2, 5), + nn.BatchNorm2d(16), + nn.ReLU(), + nn.Dropout2d(p=0.2), + ) + + self.layer2 = nn.Sequential( + nn.Conv2d(16, 64, 3, 1, 1), + nn.BatchNorm2d(64), + nn.ReLU(), + nn.Dropout2d(p=0.2), + ) + + self.layer3 = nn.Sequential( + nn.Conv2d(64, 128, 3, 2, 4), + nn.BatchNorm2d(128), + nn.ReLU(), + nn.Dropout2d(p=0.2), + ) + self.final_conv = nn.Conv2d(128, + self.model_config["num_classes"], + 1) + self.pooling = nn.AdaptiveMaxPool2d(1) + + self.metric_fn = F1Score(num_classes=self.model_config["num_classes"], + task="multiclass") + + def forward(self, + x: torch.Tensor + ) -> torch.Tensor: + """ Performs forward pass of the PatchClassifier. + + :param x: Input tensor. + :type x: torch.Tensor + :return: Output tensor after passing through the network. + :rtype: torch.Tensor + """ + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + + x = self.final_conv(x) + x = self.pooling(x) + x = x.view(x.size(0), -1) + return x + + +class FeatureClassifier: + """ This class implements a shallow model for cell classification using scikit-learn. + """ + + def __init__(self, + model_name: str, + model_config: dict, + data_config: dict, + train_config: dict, + eval_config: dict + ) -> None: + """ Constructs all the necessary attributes for the FeatureClassifier + + :param model_config: Model configuration. + :type model_config: dict + :param data_config: Data configuration. + :type data_config: dict + :param train_config: Training configuration. + :type train_config: dict + :param eval_config: Evaluation configuration. + :type eval_config: dict + """ + + self.model_name = model_name + self.model_config = model_config # use for initialising model + self.data_config = data_config + self.train_config = train_config + self.eval_config = eval_config + + self.model = RandomForestClassifier() # TODO chnage config so RandomForestClassifier accepts input params + + + def train(self, + X_train: np.ndarray, + y_train: np.ndarray + ) -> None: + """ Trains the model using the provided training data. + + :param X_train: Features of the training data. + :type X_train: numpy.ndarray + :param y_train: Labels of the training data. + :type y_train: numpy.ndarray + """ + + self.model.fit(X_train,y_train) + + y_hat = self.model.predict(X_train) + y_hat_proba = self.model.predict_proba(X_train) + + # Binary Cross Entrop Loss + self.loss = log_loss(y_train, y_hat_proba) + self.metric = f1_score(y_train, y_hat, average='micro') + + + def eval(self, + X_test: np.ndarray + ) -> np.ndarray: + """ Evaluates the model on the provided test data. + + :param X_test: Features of the test data. + :type X_test: numpy.ndarray + :return: y_hat - predicted labels. + :rtype: numpy.ndarray + """ + + X_test = X_test.reshape(1,-1) + + try: + y_hat = self.model.predict(X_test) + except NotFittedError as e: + y_hat = np.zeros(X_test.shape[0]) + + return y_hat diff --git a/src/server/dcp_server/models/custom_cellpose.py b/src/server/dcp_server/models/custom_cellpose.py index 7730624d..8e5ad970 100644 --- a/src/server/dcp_server/models/custom_cellpose.py +++ b/src/server/dcp_server/models/custom_cellpose.py @@ -11,22 +11,27 @@ #from dcp_server.models import Model -class CustomCellposeModel(models.CellposeModel): #, Model): +class CustomCellpose(models.CellposeModel): #, Model): """Custom cellpose model inheriting the attributes and functions from the original CellposeModel and implementing additional attributes and methods needed for this project. """ def __init__(self, + model_name: str, model_config: dict, + data_config: dict, train_config: dict, eval_config: dict, - model_name: str ) -> None: - """Constructs all the necessary attributes for the CustomCellposeModel. + """Constructs all the necessary attributes for the CustomCellpose. The model inherits all attributes from the parent class, the init allows to pass any other argument that the parent class accepts. Please, visit here https://cellpose.readthedocs.io/en/latest/api.html#id4 for more details on arguments accepted. + :param model_name: The name of the current model + :type model_name: str :param model_config: dictionary passed from the config file with all the arguments for the __init__ function and model initialization :type model_config: dict + :param data_config: dictionary passed from the config file with all the data configurations + :type data_config: dict :param train_config: dictionary passed from the config file with all the arguments for training function :type train_config: dict :param eval_config: dictionary passed from the config file with all the arguments for eval function @@ -38,24 +43,35 @@ def __init__(self, #nn.Module.__init__(self) models.CellposeModel.__init__(self, **model_config["segmentor"]) self.model_config = model_config + self.data_config = data_config self.train_config = train_config self.eval_config = eval_config self.model_name = model_name self.mkldnn = False # otherwise we get error with saving model self.loss = 1e6 + self.metric = 0 - def eval_all_outputs(self, - img: np.ndarray - ) -> tuple: - """Get all outputs of the model when running eval. - - :param img: Input image for segmentation. - :type img: numpy.ndarray - :return: mask, flows, styles etc. Returns the same as cellpose.models.CellposeModel.eval - see Cellpose API Guide for more details. - :rtype: tuple - """ + def train(self, + imgs: List[np.ndarray], + masks: List[np.ndarray] + ) -> None: + """Trains the given model + Calls the original train function. - return super().eval(x=img, **self.eval_config["segmentor"]) + :param imgs: images to train on (training data) + :type imgs: List[np.ndarray] + :param masks: masks of the given images (training labels) + :type masks: List[np.ndarray] + """ + super().train( + train_data=deepcopy(imgs), + train_labels=masks, + **self.train_config["segmentor"] + ) + pred_masks, pred_flows, true_flows = self.compute_masks_flows(imgs, masks) + # get loss, combination of mse for flows and bce for cell probability + self.loss = self.loss_fn(true_flows, pred_flows) + self.metric = np.mean(aggregated_jaccard_index(masks, pred_masks)) def eval(self, img: np.ndarray @@ -68,35 +84,38 @@ def eval(self, :return: mask of the image, list of 2D arrays, or single 3D array (if do_3D=True) labelled image. :rtype: np.ndarray """ + # 0 to take only mask - inline with other models eval should always return the final mask return super().eval(x=img, **self.eval_config["segmentor"])[ 0 - ] # 0 to take only mask + ] - def train(self, - imgs: List[np.ndarray], - masks: List[np.ndarray] - ) -> None: - """Trains the given model - Calls the original train function. + def eval_all_outputs(self, + img: np.ndarray + ) -> tuple: + """Get all outputs of the model when running eval. + + :param img: Input image for segmentation. + :type img: numpy.ndarray + :return: mask, flows, styles etc. Returns the same as cellpose.models.CellposeModel.eval - see Cellpose API Guide for more details. + :rtype: tuple + """ + + return super().eval(x=img, **self.eval_config["segmentor"]) + + def compute_masks_flows(self, imgs, masks): + """ Computes instance, binary mask and flows in x and y - needed for loss and metric computations :param imgs: images to train on (training data) :type imgs: List[np.ndarray] :param masks: masks of the given images (training labels) :type masks: List[np.ndarray] + :return: A tuple containing the following elements: + - pred_masks List [np.ndarray]: A list of predicted instance masks + - pred_flows (torch.Tensor): A tensor holding the stacked predicted cell probability map, horizontal and vertical flows for all images + - true_lbl (np.ndarray): A numpy array holding the stacked true binary mask, horizontal and vertical flows for all images + :rtype: tuple """ - - if not isinstance(masks, np.ndarray): # TODO Remove: all these should be taken care of in fsimagestorage - masks = np.array(masks) - - if masks[0].shape[0] == 2: - masks = list(masks[:,0,...]) - super().train( - train_data=deepcopy(imgs), - train_labels=masks, - **self.train_config["segmentor"] - ) - - # compute loss and metric + # compute for loss and metric true_bin_masks = [mask>0 for mask in masks] # get binary masks true_flows = labels_to_flows(masks) # get cellpose flows # get predicted flows and cell probability @@ -112,10 +131,8 @@ def train(self, true_lbl = np.stack(true_lbl) pred_flows=np.stack(pred_flows) pred_flows = torch.from_numpy(pred_flows).float().to('cpu') - # compute loss, combination of mse for flows and bce for cell probability - self.loss = self.loss_fn(true_lbl, pred_flows) - self.metric = np.mean(aggregated_jaccard_index(masks, pred_masks)) - + return pred_masks, pred_flows, true_lbl + def masks_to_outlines(self, mask: np.ndarray ) -> np.ndarray: diff --git a/src/server/dcp_server/models/inst_to_multi_seg.py b/src/server/dcp_server/models/inst_to_multi_seg.py new file mode 100644 index 00000000..bad68a34 --- /dev/null +++ b/src/server/dcp_server/models/inst_to_multi_seg.py @@ -0,0 +1,173 @@ +from copy import deepcopy +from typing import List + +import numpy as np +import torch + +from dcp_server.models import CustomCellpose # Model, +from dcp_server.models.classifiers import PatchClassifier, FeatureClassifier +from dcp_server.utils.processing import ( + get_centered_patches, + find_max_patch_size, + create_patch_dataset, + create_dataset_for_rf +) + +# Dictionary mapping class names to their corresponding classes + +segmentor_mapping = { + "Cellpose": CustomCellpose +} +classifier_mapping = { + "PatchClassifier": PatchClassifier, + "RandomForest": FeatureClassifier +} + + +class Inst2MultiSeg(): #Model): + """ A two stage model for: 1. instance segmentation and 2. object wise classification + """ + + def __init__(self, + model_name:str, + model_config: dict, + data_config: dict, + train_config: dict, + eval_config:dict + ) -> None: + """ Constructs all the necessary attributes for the Inst2MultiSeg + + :param model_name: Name of the model. + :type model_name: str + :param model_config: Model configuration. + :type model_config: dict + :param data_config: Data configurations + :type data_config: dict + :param train_config: Training configuration. + :type train_config: dict + :param eval_config: Evaluation configuration. + :type eval_config: dict + """ + super().__init__() + + self.model_name = model_name + self.model_config = model_config + self.data_config = data_config + self.train_config = train_config + self.eval_config = eval_config + + self.segmentor_class = self.model_config.get("classifier").get("model_class", "Cellpose") + self.classifier_class = self.model_config.get("classifier").get("model_class", "PatchClassifier") + + # Initialize the cellpose model and the classifier + segmentor = segmentor_mapping.get(self.segmentor_class) + self.segmentor = segmentor( + self.segmentor_class, self.model_config, self.data_config, self.train_config, self.eval_config + ) + ''' + if self.classifier_class == "PatchClassifier": + self.classifier = PatchClassifier( + self.classifier_class, self.model_config, self.data_config, self.train_config, self.eval_config + ) + + elif self.classifier_class == "RandomForest": + self.classifier = FeatureClassifier( + self.classifier_class, self.model_config, self.data_config, self.train_config, self.eval_config + ) + ''' + classifier = classifier_mapping.get(self.classifier_class) + self.classifier = classifier( + self.classifier_class, self.model_config, self.data_config, self.train_config, self.eval_config + ) + # make sure include mask is set to False if we are using the random forest model + if self.model_config["classifier"]["include_mask"] == True and self.classifier_class=="RandomForest": + #print("Include mask=True was found, but for Random Forest, this parameter must be set to False. Doing this now.") + self.model_config["classifier"]["include_mask"] = False + + def train(self, + imgs: List[np.ndarray], + masks: List[np.ndarray] + ) -> None: + """ Trains the given model. First trains the segmentor and then the clasiffier. + + :param imgs: images to train on (training data) + :type imgs: List[np.ndarray] + :param masks: masks of the given images (training labels) + :type masks: List[np.ndarray] of same shape as output of eval, i.e. one channel instances, + second channel classes, so [2, H, W] or [2, 3, H, W] for 3D + """ + # train cellpose + masks_instances = [mask[0] for mask in masks] + #masks_instances = list(np.array(masks)[:,0,...]) #[mask.sum(-1) for mask in masks] if masks[0].ndim == 3 else masks + self.segmentor.train(deepcopy(imgs), masks_instances) + masks_classes = [mask[1] for mask in masks] + # create patch dataset to train classifier + #masks_classes = list( + # masks[:,1,...] + #) #[((mask > 0) * np.arange(1, 4)).sum(-1) for mask in masks] + x, patch_masks, labels = create_patch_dataset( + imgs, + masks_classes, + masks_instances, + noise_intensity = self.data_config["noise_intensity"], + max_patch_size = self.data_config["patch_size"], + include_mask = self.model_config["classifier"]["include_mask"] + ) + # additionally extract features from the patches if you are in RF model + if self.classifier_class == "RandomForest": + x = create_dataset_for_rf(x, patch_masks) + # train classifier + self.classifier.train(x, labels) + # and compute metric and loss + self.metric = (self.segmentor.metric + self.classifier.metric) / 2 + self.loss = (self.segmentor.loss + self.classifier.loss)/2 + + def eval(self, + img: np.ndarray + ) -> np.ndarray: + """ Evaluate the model on the provided image and return the final mask. + + :param img: Input image for evaluation. + :type img: np.ndarray[np.uint8] + :return: Final mask containing instance mask and class masks. + :rtype: np.ndarray[np.uint16] + """ + # TBD we assume image is 2D [H, W] (see fsimage storage) + # The final mask which is returned should have + # first channel the output of cellpose and the rest are the class channels + with torch.no_grad(): + # get instance mask from segmentor + instance_mask = self.segmentor.eval(img) + # find coordinates of detected objects + class_mask = np.zeros(instance_mask.shape) + + max_patch_size = self.data_config["patch_size"] + if max_patch_size is None: + max_patch_size = find_max_patch_size(instance_mask) + + # get patches centered around detected objects + x, patch_masks, instance_labels, _ = get_centered_patches( + img, + instance_mask, + max_patch_size, + noise_intensity=self.data_config["noise_intensity"], + include_mask=self.model_config["classifier"]["include_mask"] + ) + if self.classifier_class == "RandomForest": + x = create_dataset_for_rf(x, patch_masks) + # loop over patches and create classification mask + for idx in range(len(x)): + patch_class = self.classifier.eval(x[idx]) + # Assign predicted class to corresponding location in final_mask + patch_class = patch_class.item() if isinstance(patch_class, torch.Tensor) else patch_class + class_mask[instance_mask==instance_labels[idx]] = ( + patch_class + 1 + ) + # Apply mask to final_mask, retaining only regions where cellpose_mask is greater than 0 + final_mask = np.stack( + (instance_mask, class_mask), axis=self.eval_config['mask_channel_axis'] + ).astype( + np.uint16 + ) # size 2xHxW + + return final_mask diff --git a/src/server/dcp_server/models/model.py b/src/server/dcp_server/models/model.py index db584214..65811905 100644 --- a/src/server/dcp_server/models/model.py +++ b/src/server/dcp_server/models/model.py @@ -4,16 +4,21 @@ class Model(ABC): def __init__(self, + model_name: str, model_config: dict, + data_config: dict, train_config: dict, eval_config: dict, - model_name: str ) -> None: + self.model_name = model_name self.model_config = model_config + self.data_config = data_config self.train_config = train_config self.eval_config = eval_config - self.model_name = model_name + + self.loss = 1e6 + self.metric = 0 def update_configs(self, train_config: dict, diff --git a/src/server/dcp_server/models/multicellpose.py b/src/server/dcp_server/models/multicellpose.py index 352a8c45..5d7ef0e0 100644 --- a/src/server/dcp_server/models/multicellpose.py +++ b/src/server/dcp_server/models/multicellpose.py @@ -2,44 +2,47 @@ import numpy as np from skimage.measure import label as label_mask -from dcp_server.models import CustomCellposeModel # Model, +from dcp_server.models import CustomCellpose # Model, class MultiCellpose(): #Model): ''' Multichannel image segmentation model. - Run the separate CustomCellposeModel models for each channel return the mask corresponding to each object type. + Run the separate CustomCellpose models for each channel return the mask corresponding to each object type. ''' def __init__(self, + model_name: str, model_config: dict, + data_config: dict, train_config: dict, eval_config: dict, - model_name="Cellpose" ) -> None: """Constructs all the necessary attributes for the MultiCellpose model. - + + :param model_name: Name of the model. + :type model_name: str :param model_config: Model configuration. :type model_config: dict :param train_config: Training configuration. :type train_config: dict :param eval_config: Evaluation configuration. :type eval_config: dict - :param model_name: Name of the model. - :type model_name: str """ self.model_config = model_config + self.data_config = data_config self.train_config = train_config self.eval_config = eval_config self.model_name = model_name self.num_of_channels = self.model_config["classifier"]["num_classes"] self.cellpose_models = [ - CustomCellposeModel( + CustomCellpose( + "Cellpose", self.model_config, + self.data_config, self.train_config, self.eval_config, - self.model_name ) for _ in range(self.num_of_channels) ] @@ -55,13 +58,13 @@ def train(self, :param masks: Masks corresponding to the input images. :type masks: list[numpy.ndarray] """ - + for i in range(self.num_of_channels): - + masks_class = [] for mask in masks: - mask_class = mask.copy() + mask_class = mask.copy() # TODO - Do we need copy?? # set all instances in the instance mask not corresponding to the class in question to zero mask_class[0][ mask_class[1]!=(i+1) @@ -91,14 +94,14 @@ def eval(self, for i in range(self.num_of_channels): # get the instance mask and pixel-wise cell probability mask instance_mask, probs, _ = self.cellpose_models[i].eval_all_outputs(img) - confidence = probs[2] + confidence_map = probs[2] # assign the appropriate class to all objects detected by this model class_mask = np.zeros_like(instance_mask) class_mask[instance_mask>0]=(i + 1) instance_masks.append(instance_mask) class_masks.append(class_mask) - model_confidences.append(confidence) + model_confidences.append(confidence_map) # merge the outputs of the different models using the pixel-wise cell probability mask merged_mask_instances, class_mask = self.merge_masks( instance_masks, class_masks, model_confidences diff --git a/src/server/dcp_server/models/unet.py b/src/server/dcp_server/models/unet.py index 690632ed..47e3fdc0 100644 --- a/src/server/dcp_server/models/unet.py +++ b/src/server/dcp_server/models/unet.py @@ -9,7 +9,8 @@ from torch.utils.data import TensorDataset, DataLoader #from dcp_server.models import Model -from dcp_server.utils.processing import normalise +from dcp_server.utils.processing import convert_to_tensor + class UNet(nn.Module): # Model @@ -66,85 +67,38 @@ def forward(self, def __init__(self, + model_name: str, model_config: dict, + data_config: dict, train_config: dict, eval_config: dict, - model_name: str ) -> None: - """Constructs all the necessary attributes for the UNet model. + """ Constructs all the necessary attributes for the UNet model. + :param model_name: Name of the model. + :type model_name: str :param model_config: Model configuration. :type model_config: dict + :param data_config: Data configurations + :type data_config: dict :param train_config: Training configuration. :type train_config: dict :param eval_config: Evaluation configuration. :type eval_config: dict - :param model_name: Name of the model. - :type model_name: str """ super().__init__() + + self.model_name = model_name self.model_config = model_config + self.data_config = data_config self.train_config = train_config self.eval_config = eval_config - self.model_name = model_name - self.in_channels = self.model_config["classifier"]["in_channels"] - self.out_channels = self.model_config["classifier"]["num_classes"] + 1 - self.features = self.model_config["classifier"]["features"] - - self.encoder = nn.ModuleList() - self.decoder = nn.ModuleList() - - self.pool = nn.MaxPool2d(kernel_size=2, stride=2) - - # Encoder - for feature in self.features: - self.encoder.append( - UNet.DoubleConv(self.in_channels, feature) - ) - self.in_channels = feature - - # Decoder - for feature in self.features[::-1]: - self.decoder.append( - nn.ConvTranspose2d( - feature*2, feature, kernel_size=2, stride=2 - ) - ) - self.decoder.append( - UNet.DoubleConv(feature*2, feature) - ) - self.bottle_neck = UNet.DoubleConv(self.features[-1], self.features[-1]*2) - self.output_conv = nn.Conv2d(self.features[0], self.out_channels, kernel_size=1) - - def forward(self, - x: torch.Tensor - ) -> torch.Tensor: - """ - Forward pass of the UNet model. - - :param x: Input tensor. - :type x: torch.Tensor - :return: Output tensor. - :rtype: torch.Tensor - """ - skip_connections = [] - for encoder in self.encoder: - x = encoder(x) - skip_connections.append(x) - x = self.pool(x) - - x = self.bottle_neck(x) - skip_connections = skip_connections[::-1] - - for i in np.arange(len(self.decoder), step=2): - x = self.decoder[i](x) - skip_connection = skip_connections[i//2] - concatenate_skip = torch.cat((skip_connection, x), dim=1) - x = self.decoder[i+1](concatenate_skip) + self.loss = 1e6 + self.metric = 0 - return self.output_conv(x) + self.build_model() def train(self, imgs: List[np.ndarray], @@ -159,43 +113,32 @@ def train(self, :type masks: list[numpy.ndarray] """ - lr = self.train_config["classifier"]["lr"] - epochs = self.train_config["classifier"]["n_epochs"] - batch_size = self.train_config["classifier"]["batch_size"] - - # Convert input images and labels to tensors - # normalize images - imgs = [normalise(img) for img in imgs] - # convert to tensor - imgs = torch.stack([ - torch.from_numpy(img.astype(np.float32)) for img in imgs - ]) - imgs = imgs.unsqueeze(1) if imgs.ndim == 3 else imgs - - # Classification label mask - masks = np.array(masks) - masks = torch.stack([ - torch.from_numpy(mask[1].astype(np.int16)) for mask in masks - ]) + imgs = convert_to_tensor(imgs, np.float32) + masks = convert_to_tensor(masks, np.int16) # Create a training dataset and dataloader - train_dataset = TensorDataset(imgs, masks) - train_dataloader = DataLoader(train_dataset, batch_size=batch_size) + train_dataloader = DataLoader( + TensorDataset(imgs, masks), + batch_size=self.train_config["classifier"]["batch_size"]) loss_fn = nn.CrossEntropyLoss() - optimizer = Adam(params=self.parameters(), lr=lr) + optimizer = Adam( + params=self.parameters(), + lr=self.train_config["classifier"]["lr"] + ) - for _ in tqdm(range(epochs), desc="Running UNet training"): + for _ in tqdm( + range(self.train_config["classifier"]["n_epochs"]), + desc="Running UNet training" + ): self.loss = 0 for imgs, masks in train_dataloader: - imgs = imgs.float() - masks = masks.long() #forward path - preds = self.forward(imgs) - loss = loss_fn(preds, masks) + preds = self.forward(imgs.float()) + loss = loss_fn(preds, masks.long()) #backward path optimizer.zero_grad() @@ -209,8 +152,7 @@ def train(self, def eval(self, img: np.ndarray ) -> np.ndarray: - """ - Evaluate the model on the provided image and return the predicted label. + """ Evaluate the model on the provided image and return the predicted label. :param img: Input image for evaluation. :type img: np.ndarray[np.uint8] @@ -218,15 +160,14 @@ def eval(self, :rtype: numpy.ndarray """ with torch.no_grad(): - # normalise - img = (img - np.min(img)) / (np.max(img) - np.min(img)) - img = torch.from_numpy(img).float().unsqueeze(0) - img = img.unsqueeze(1) if img.ndim == 3 else img + #img = torch.from_numpy(img).float().unsqueeze(0) + #img = img.unsqueeze(1) if img.ndim == 3 else img + img = convert_to_tensor([img], np.float32) preds = self.forward(img) class_mask = torch.argmax(preds, 1).numpy()[0] - + # TODO - make instance mask calculation optional instance_mask = label((class_mask > 0).astype(int))[0] final_mask = np.stack( @@ -237,3 +178,64 @@ def eval(self, ) return final_mask + + def build_model(self): + """ Builds the UNet. + """ + in_channels = self.model_config["classifier"]["in_channels"] + out_channels = self.model_config["classifier"]["num_classes"] + 1 + features = self.model_config["classifier"]["features"] + + self.encoder = nn.ModuleList() + self.decoder = nn.ModuleList() + + self.pool = nn.MaxPool2d(kernel_size=2, stride=2) + + # Encoder + for feature in features: + self.encoder.append( + UNet.DoubleConv(in_channels, feature) + ) + in_channels = feature + + # Decoder + for feature in features[::-1]: + self.decoder.append( + nn.ConvTranspose2d( + feature*2, feature, kernel_size=2, stride=2 + ) + ) + self.decoder.append( + UNet.DoubleConv(feature*2, feature) + ) + + self.bottle_neck = UNet.DoubleConv(features[-1], features[-1]*2) + self.output_conv = nn.Conv2d(features[0], out_channels, kernel_size=1) + + def forward(self, + x: torch.Tensor + ) -> torch.Tensor: + """ + Forward pass of the UNet model. + + :param x: Input tensor. + :type x: torch.Tensor + :return: Output tensor. + :rtype: torch.Tensor + """ + skip_connections = [] + for encoder in self.encoder: + x = encoder(x) + skip_connections.append(x) + x = self.pool(x) + + x = self.bottle_neck(x) + skip_connections = skip_connections[::-1] + + for i in np.arange(len(self.decoder), step=2): + x = self.decoder[i](x) + skip_connection = skip_connections[i//2] + concatenate_skip = torch.cat((skip_connection, x), dim=1) + x = self.decoder[i+1](concatenate_skip) + + return self.output_conv(x) diff --git a/src/server/dcp_server/segmentationclasses.py b/src/server/dcp_server/segmentationclasses.py index f6763c5d..4c94004c 100644 --- a/src/server/dcp_server/segmentationclasses.py +++ b/src/server/dcp_server/segmentationclasses.py @@ -5,10 +5,10 @@ setup_config = helpers.read_config('setup', config_path = 'config.cfg') class GeneralSegmentation(): - """Segmentation class. Defining the main functions needed for this project and served by service - segment image and train on images. + """ Segmentation class. Defining the main functions needed for this project and served by service - segment image and train on images. """ def __init__(self, imagestorage, runner, model): - """Constructs all the necessary attributes for the GeneralSegmentation. + """ Constructs all the necessary attributes for the GeneralSegmentation. :param imagestorage: imagestorage system used (see fsimagestorage.py) :type imagestorage: FilesystemImageStorage class object @@ -23,32 +23,28 @@ def __init__(self, imagestorage, runner, model): self.no_files_msg = "No image-label pairs found in curated directory" async def segment_image(self, input_path, list_of_images): - """Segments images from the given directory + """ Segments images from the given directory - :param input_path: directory where the images are saved + :param input_path: directory where the images are saved and where segmentation results will be saved :type input_path: str :param list_of_images: list of image objects from the directory that are currently supported :type list_of_images: list """ for img_filepath in list_of_images: - # Load the image - img = self.imagestorage.load_image(img_filepath) - # Get size properties - height, width, z_axis = self.imagestorage.get_image_size_properties(img, helpers.get_file_extension(img_filepath)) - img = self.imagestorage.rescale_image(img, height, width) + img = self.imagestorage.prepare_img_for_eval(img_filepath) # Add channel ax into the model's evaluation parameters dictionary - self.model.eval_config['segmentor']['z_axis'] = z_axis + self.model.eval_config['segmentor']['channel_axis'] = self.imagestorage.channel_ax # Evaluate the model mask = await self.runner.evaluate.async_run(img = img) - # Resize the mask - mask = self.imagestorage.resize_mask(mask, height, width, self.model.eval_config['mask_channel_axis'], order=0) + # And prepare the mask for saving + mask = self.imagestorage.prepare_mask_for_save(mask, self.model.eval_config['mask_channel_axis']) # Save segmentation seg_name = helpers.get_path_stem(img_filepath) + setup_config['seg_name_string'] + '.tiff' self.imagestorage.save_image(os.path.join(input_path, seg_name), mask) async def train(self, input_path): - """train model on images and masks in the given input directory. + """ Train model on images and masks in the given input directory. Calls the runner's train function. :param input_path: directory where the images are saved @@ -59,8 +55,7 @@ async def train(self, input_path): train_img_mask_pairs = self.imagestorage.get_image_seg_pairs(input_path) - if not train_img_mask_pairs: - return self.no_files_msg + if not train_img_mask_pairs: return self.no_files_msg imgs, masks = self.imagestorage.prepare_images_and_masks_for_training(train_img_mask_pairs) model_save_path = await self.runner.train.async_run(imgs, masks) @@ -78,11 +73,11 @@ async def segment_image(self, input_path, list_of_images): class MitoProjectSegmentation(GeneralSegmentation): - """Segmentation class inheriting the attributes and functions from the original GeneralSegmentation and implementing + """ Segmentation class inheriting the attributes and functions from the original GeneralSegmentation and implementing additional attributes and methods needed for this project. """ def __init__(self, imagestorage, runner, model): - """Constructs all the necessary attributes for the MitoProjectSegmentation. Inherits all from the GeneralSegmentation + """ Constructs all the necessary attributes for the MitoProjectSegmentation. Inherits all from the GeneralSegmentation :param imagestorage: imagestorage system used (see fsimagestorage.py) :type imagestorage: FilesystemImageStorage class object @@ -95,7 +90,7 @@ def __init__(self, imagestorage, runner, model): # The only difference is in segment image async def segment_image(self, input_path, list_of_images): - """Segments images from the given directory. + """ Segments images from the given directory. The function differs from the parent class' function in obtaining the outlines of the masks. :param input_path: directory where the images are saved diff --git a/src/server/dcp_server/service.py b/src/server/dcp_server/service.py index cd4ead6e..61193a86 100644 --- a/src/server/dcp_server/service.py +++ b/src/server/dcp_server/service.py @@ -22,14 +22,18 @@ # instantiate the model model_class = getattr(models_module, setup_config['model_to_use']) -model = model_class(model_config = model_config, train_config = train_config, eval_config = eval_config, model_name=setup_config['model_to_use']) +model = model_class(model_name=setup_config['model_to_use'], + model_config = model_config, + data_config = data_config, + train_config = train_config, + eval_config = eval_config) custom_model_runner = t.cast( "CustomRunner", bentoml.Runner(CustomRunnable, name=service_config['runner_name'], runnable_init_params={"model": model, "save_model_path": service_config['bento_model_path']}) ) # instantiate the segmentation type segm_class = getattr(segmentation_module, setup_config['segmentation']) -fsimagestorage = FilesystemImageStorage(data_config['data_root'], setup_config['model_to_use']) +fsimagestorage = FilesystemImageStorage(data_config, setup_config['model_to_use']) segmentation = segm_class(imagestorage=fsimagestorage, runner = custom_model_runner, model = model) diff --git a/src/server/dcp_server/utils/fsimagestorage.py b/src/server/dcp_server/utils/fsimagestorage.py index 66061c73..bcd02c8b 100644 --- a/src/server/dcp_server/utils/fsimagestorage.py +++ b/src/server/dcp_server/utils/fsimagestorage.py @@ -4,27 +4,36 @@ from skimage.transform import resize, rescale from dcp_server.utils import helpers +from dcp_server.utils.processing import pad_image, normalise # Import configuration -setup_config = helpers.read_config('setup', config_path = 'config.cfg') +setup_config = helpers.read_config("setup", config_path = "config.cfg") class FilesystemImageStorage(): """Class used to deal with everything related to image storing and processing - loading, saving, transforming... """ - def __init__(self, data_root, model_used): - self.root_dir = data_root + def __init__(self, data_config, model_used): + self.root_dir = data_config["data_root"] + self.gray = bool(data_config["gray"]) + self.rescale = bool(data_config["rescale"]) self.model_used = model_used + self.channel_ax = None + self.img_height = None + self.img_width = None - def load_image(self, cur_selected_img, is_gray=True): + def load_image(self, cur_selected_img, gray=None): """Load the image (using skiimage) :param cur_selected_img: full path of the image that needs to be loaded :type cur_selected_img: str + :param gray: whether to load the image as a grayscale or not + :type gray: bool, default=False :return: loaded image :rtype: ndarray - """ + """ + if gray is None: gray = self.gray try: - return imread(os.path.join(self.root_dir , cur_selected_img), as_gray=is_gray) + return imread(os.path.join(self.root_dir , cur_selected_img), as_gray=gray) except ValueError: return None def save_image(self, to_save_path, img): @@ -101,40 +110,39 @@ def get_unsupported_files(self, directory): if not file_name.startswith('.') and helpers.get_file_extension(file_name) not in setup_config['accepted_types']] def get_image_size_properties(self, img, file_extension): - """Get properties of the image size + """Set properties of the image size :param img: image (numpy array) :type img: ndarray :param file_extension: file extension of the image as saved in the directory :type file_extension: str - :return: size properties: - - height - - width - - z_axis - """ - + # TODO simplify! + orig_size = img.shape # png and jpeg will be RGB by default and 2D # tif can be grayscale 2D or 3D [Z, H, W] - # image channels have already been removed in imread with is_gray=True - if file_extension in (".jpg", ".jpeg", ".png"): - height, width = orig_size[0], orig_size[1] - z_axis = None + # image channels have already been removed in imread if self.gray=True + # skimage.imread reads RGB or RGBA images in always with channel axis in dim=2 + if file_extension in (".jpg", ".jpeg", ".png") and self.gray==False: + self.img_height, self.img_width = orig_size[0], orig_size[1] + self.channel_ax = 2 + elif file_extension in (".jpg", ".jpeg", ".png") and self.gray==True: + self.img_height, self.img_width = orig_size[0], orig_size[1] + self.channel_ax = None elif file_extension in (".tiff", ".tif") and len(orig_size)==2: - height, width = orig_size[0], orig_size[1] - z_axis = None + self.img_height, self.img_width = orig_size[0], orig_size[1] + self.channel_ax = None # if we have 3 dimensions the [Z, H, W] elif file_extension in (".tiff", ".tif") and len(orig_size)==3: - print('Warning: 3D image stack found. We are assuming your first dimension is your stack dimension. Please cross check this.') - height, width = orig_size[1], orig_size[2] - z_axis = 0 + print('Warning: 3D image stack found. We are assuming your last dimension is your channel dimension. Please cross check this.') + self.img_height, self.img_width = orig_size[0], orig_size[1] + self.channel_ax = 2 else: print('File not currently supported. See documentation for accepted types') - return height, width, z_axis - def rescale_image(self, img, height, width, channel_ax=None, order=2): + def rescale_image(self, img, order=2): """rescale image :param img: image @@ -149,16 +157,14 @@ def rescale_image(self, img, height, width, channel_ax=None, order=2): :rtype: ndarray """ if self.model_used == "UNet": - height_pad = (height//16 + 1)*16 - height - width_pad = (width//16 + 1)*16 - width - return np.pad(img, ((0, height_pad),(0, width_pad))) + return pad_image(img, self.img_height, self.img_width, self.channel_ax, dividable= 16) else: # Cellpose segmentation runs best with 512 size? TODO: check - max_dim = max(height, width) + max_dim = max(self.img_height, self.img_width) rescale_factor = max_dim/512 - return rescale(img, 1/rescale_factor, order=order, channel_axis=channel_ax) + return rescale(img, 1/rescale_factor, order=order, channel_axis=self.channel_ax) - def resize_mask(self, mask, height, width, channel_ax=None, order=2): + def resize_mask(self, mask, channel_ax=None, order=0): """resize the mask so it matches the original image size :param mask: image @@ -176,24 +182,24 @@ def resize_mask(self, mask, height, width, channel_ax=None, order=2): if self.model_used == "UNet": # we assume an order C, H, W if channel_ax is not None and channel_ax==0: - height_pad = mask.shape[1] - height - width_pad = mask.shape[2]- width + height_pad = mask.shape[1] - self.img_height + width_pad = mask.shape[2]- self.img_width return mask[:, :-height_pad, :-width_pad] elif channel_ax is not None and channel_ax==2: - height_pad = mask.shape[0] - height - width_pad = mask.shape[1]- width + height_pad = mask.shape[0] - self.img_height + width_pad = mask.shape[1]-self.img_width return mask[:-height_pad, :-width_pad, :] elif channel_ax is not None and channel_ax==1: - height_pad = mask.shape[2] - height - width_pad = mask.shape[0]- width + height_pad = mask.shape[2] - self.img_height + width_pad = mask.shape[0]- self.img_width return mask[:-width_pad, :, :-height_pad] else: if channel_ax is not None: n_channel_dim = mask.shape[channel_ax] - output_size = [height, width] + output_size = [self.img_height, self.img_width] output_size.insert(channel_ax, n_channel_dim) - else: output_size = [height, width] + else: output_size = [self.img_height, self.img_width] return resize(mask, output_size, order=order) def prepare_images_and_masks_for_training(self, train_img_mask_pairs): @@ -208,13 +214,48 @@ def prepare_images_and_masks_for_training(self, train_img_mask_pairs): masks=[] for img_file, mask_file in train_img_mask_pairs: img = self.load_image(img_file) - mask = imread(mask_file) + img = normalise(img) + mask = self.load_image(mask_file, gray=False) + self.get_image_size_properties(img, helpers.get_file_extension(img_file)) + # Unet only accepts image sizes divisable by 16 if self.model_used == "UNet": - # Unet only accepts image sizes divisable by 16 - height_pad = (img.shape[0]//16 + 1)*16 - img.shape[0] - width_pad = (img.shape[1]//16 + 1)*16 - img.shape[1] - img = np.pad(img, ((0, height_pad),(0, width_pad))) - mask = np.pad(mask, ((0, 0), (0, height_pad),(0, width_pad))) + img = pad_image(img, self.img_height, self.img_width, channel_ax=self.channel_ax, dividable= 16) + mask = pad_image(mask, self.img_height, self.img_width, channel_ax=0, dividable= 16) + if self.model_used == "CustomCellpose" and len(mask.shape)==3: + # if we also have class mask drop it + mask = masks[0] #assuming mask_channel_axis=0 imgs.append(img) masks.append(mask) - return imgs, masks \ No newline at end of file + return imgs, masks + + def prepare_img_for_eval(self, img_file): + """Image processing for model inference. + + :param img_file: the path to the image + :type img_file: str + :return: the loaded and processed image + :rtype: np.ndarray + """ + # Load and normalise the image + img = self.load_image(img_file) + img = normalise(img) + # Get size properties + self.get_image_size_properties(img, helpers.get_file_extension(img_file)) + if self.rescale: + img = self.rescale_image(img) + return img + + def prepare_mask_for_save(self, mask, channel_ax): + """Prepares the mask output of the model to be saved. + + :param mask: the mask + :type mask: np.ndarray + :param channel_ax: the channel dimension of the mask + :rype channel_ax: int + :return: the ready to save mask + :rtype: np.ndarray + """ + # Resize the mask if rescaling took place before + if self.rescale: + return self.resize_mask(mask, channel_ax) + else: return mask \ No newline at end of file diff --git a/src/server/dcp_server/utils/processing.py b/src/server/dcp_server/utils/processing.py index 9ca7e770..42f50948 100644 --- a/src/server/dcp_server/utils/processing.py +++ b/src/server/dcp_server/utils/processing.py @@ -5,15 +5,57 @@ from copy import deepcopy import SimpleITK as sitk from radiomics import shape2D +import torch -def normalise(img, norm='min-max'): +def normalise(img, norm='min-max') -> np.ndarray: """ Normalises the image based on the chosen method. Currently available methods are: - min max normalisation - param + + :param img: image to be normalised + :type img: np.ndarray + :param norm: the normalisation method to apply + :type norm: str + :return: the normalised image + :rtype: np.ndarray """ if norm=='min-max': return (img - np.min(img)) / (np.max(img) - np.min(img)) + +def pad_image(img, height, width, channel_ax=None, dividable = 16) -> np.ndarray: + """ Pads the image such that it is dividable by a given number, + + :param img: image to be padded + :type img: np.ndarray + : param height: image height + : type height: int + : param width: image width + : type width: int + :param channel_ax: + :type channel_ax: int or None + :param dividable: the number with which the new image size should be perfectly dividable by + :type dividable: int + :return: the padded image + :rtype: np.ndarray + """ + height_pad = (height//dividable + 1)*dividable - height + width_pad = (width//dividable + 1)*dividable - width + if channel_ax==0: + img = np.pad(img, ((0, 0), (0, height_pad), (0, width_pad))) + elif channel_ax==2: + img = np.pad(img, ((0, height_pad), (0, width_pad), (0, 0))) + else: + img = np.pad(img, ((0, height_pad), (0, width_pad))) + return img + +def convert_to_tensor(imgs, dtype): + # Convert images tensors + imgs = torch.stack([ + torch.from_numpy(img.astype(dtype)) for img in imgs + ]) + imgs = imgs.unsqueeze(1) if imgs.ndim == 3 else imgs + return imgs + def crop_centered_padded_patch(img: np.ndarray, patch_center_xy, patch_size, From 71fc91a70759e8e60f355f0d0d56c16711890152 Mon Sep 17 00:00:00 2001 From: Christina Bukas Date: Wed, 6 Mar 2024 17:29:26 +0100 Subject: [PATCH 07/26] changes after testign and convert cfg files to yaml --- .../dcp_client/{config.cfg => config.yaml} | 0 .../{config_remote.cfg => config_remote.yaml} | 0 src/client/dcp_client/gui/napari_window.py | 4 +- src/client/dcp_client/main.py | 2 +- src/client/dcp_client/utils/utils.py | 30 ++++++------ .../dcp_server/{config.cfg => config.yaml} | 21 +++++---- ...nfig_instance.cfg => config_instance.yaml} | 12 +++-- src/server/dcp_server/config_semantic.yaml | 46 +++++++++++++++++++ src/server/dcp_server/main.py | 2 +- src/server/dcp_server/models/classifiers.py | 13 +++--- .../dcp_server/models/custom_cellpose.py | 1 + .../dcp_server/models/inst_to_multi_seg.py | 22 +++------ src/server/dcp_server/models/multicellpose.py | 4 +- src/server/dcp_server/models/unet.py | 23 +++++----- src/server/dcp_server/segmentationclasses.py | 5 +- src/server/dcp_server/service.py | 12 ++--- src/server/dcp_server/utils/fsimagestorage.py | 11 +++-- src/server/dcp_server/utils/helpers.py | 9 ++-- src/server/dcp_server/utils/processing.py | 13 ++---- 19 files changed, 137 insertions(+), 93 deletions(-) rename src/client/dcp_client/{config.cfg => config.yaml} (100%) rename src/client/dcp_client/{config_remote.cfg => config_remote.yaml} (100%) rename src/server/dcp_server/{config.cfg => config.yaml} (75%) rename src/server/dcp_server/{config_instance.cfg => config_instance.yaml} (80%) create mode 100644 src/server/dcp_server/config_semantic.yaml diff --git a/src/client/dcp_client/config.cfg b/src/client/dcp_client/config.yaml similarity index 100% rename from src/client/dcp_client/config.cfg rename to src/client/dcp_client/config.yaml diff --git a/src/client/dcp_client/config_remote.cfg b/src/client/dcp_client/config_remote.yaml similarity index 100% rename from src/client/dcp_client/config_remote.cfg rename to src/client/dcp_client/config_remote.yaml diff --git a/src/client/dcp_client/gui/napari_window.py b/src/client/dcp_client/gui/napari_window.py index a87d4b97..35dbc9af 100644 --- a/src/client/dcp_client/gui/napari_window.py +++ b/src/client/dcp_client/gui/napari_window.py @@ -39,7 +39,7 @@ def __init__(self, app: Application): layout.addWidget(main_window, 0, 0, 1, 4) # select the first seg as the currently selected layer if there are any segs - if len(self.seg_files): + if len(self.seg_files) and len(self.viewer.layers[get_path_stem(self.seg_files[0])].data.shape) > 2: self.cur_selected_seg = self.viewer.layers.selection.active.name self.layer = self.viewer.layers[self.cur_selected_seg] self.viewer.layers.selection.events.changed.connect(self.on_seg_channel_changed) @@ -63,7 +63,7 @@ def __init__(self, app: Application): self.qctrl = self.viewer.window.qt_viewer.controls.widgets[self.layer] - if self.layer.data.shape[0] >= 2: + if len(self.layer.data.shape) > 2: # User hint message_label = QLabel('Choose an active mask') message_label.setAlignment(Qt.AlignRight) diff --git a/src/client/dcp_client/main.py b/src/client/dcp_client/main.py index 0f9da389..978eed80 100644 --- a/src/client/dcp_client/main.py +++ b/src/client/dcp_client/main.py @@ -17,7 +17,7 @@ def main(): settings.init() dir_name = path.dirname(path.abspath(sys.argv[0])) - server_config = read_config('server', config_path = path.join(dir_name, 'config.cfg')) + server_config = read_config('server', config_path = path.join(dir_name, 'config.yaml')) image_storage = FilesystemImageStorage() ml_model = BentomlModel() diff --git a/src/client/dcp_client/utils/utils.py b/src/client/dcp_client/utils/utils.py index 7ba4451f..040c8fca 100644 --- a/src/client/dcp_client/utils/utils.py +++ b/src/client/dcp_client/utils/utils.py @@ -6,7 +6,7 @@ from skimage.draw import polygon_perimeter from pathlib import Path, PurePath -import json +import yaml from dcp_client.utils import settings @@ -27,18 +27,18 @@ def icon(self, type: 'QFileIconProvider.IconType'): else: return super().icon(type) -def read_config(name, config_path = 'config.cfg') -> dict: +def read_config(name, config_path = 'config.yaml') -> dict: """Reads the configuration file :param name: name of the section you want to read (e.g. 'setup','train') :type name: string - :param config_path: path to the configuration file, defaults to 'config.cfg' + :param config_path: path to the configuration file, defaults to 'config.yaml' :type config_path: str, optional :return: dictionary from the config section given by name :rtype: dict """ with open(config_path) as config_file: - config_dict = json.load(config_file) + config_dict = yaml.safe_load(config_file) # json.load(config_file) for .cfg file # Check if config file has main mandatory keys assert all([i in config_dict.keys() for i in ['server']]) return config_dict[name] @@ -77,16 +77,18 @@ def get_contours(instance_mask, contours_level=None): # get a binary mask only of object single_obj_mask = np.zeros_like(instance_mask) single_obj_mask[instance_mask==instance_id] = 1 - # compute contours for mask - contours = find_contours(single_obj_mask, contours_level) - # sometimes little dots appeas as additional contours so remove these - if len(contours)>1: - contour_sizes = [contour.shape[0] for contour in contours] - contour = contours[contour_sizes.index(max(contour_sizes))].astype(int) - else: contour = contours[0] - # and draw onto contours mask - rr, cc = polygon_perimeter(contour[:, 0], contour[:, 1], contour_mask.shape) - contour_mask[rr, cc] = instance_id + try: + # compute contours for mask + contours = find_contours(single_obj_mask, contours_level) + # sometimes little dots appeas as additional contours so remove these + if len(contours)>1: + contour_sizes = [contour.shape[0] for contour in contours] + contour = contours[contour_sizes.index(max(contour_sizes))].astype(int) + else: contour = contours[0] + # and draw onto contours mask + rr, cc = polygon_perimeter(contour[:, 0], contour[:, 1], contour_mask.shape) + contour_mask[rr, cc] = instance_id + except: print("Could not create contour for instance id", instance_id) return contour_mask @staticmethod diff --git a/src/server/dcp_server/config.cfg b/src/server/dcp_server/config.yaml similarity index 75% rename from src/server/dcp_server/config.cfg rename to src/server/dcp_server/config.yaml index 8b7c1039..7f807753 100644 --- a/src/server/dcp_server/config.cfg +++ b/src/server/dcp_server/config.yaml @@ -1,30 +1,31 @@ { "setup": { "segmentation": "GeneralSegmentation", - "model_to_use": "Inst2MultiSeg", + "model_to_use": "MultiCellpose", "accepted_types": [".jpg", ".jpeg", ".png", ".tiff", ".tif"], "seg_name_string": "_seg" }, "service": { "runner_name": "bento_runner", - "bento_model_path": "test5", + "bento_model_path": "test1w", "service_name": "data-centric-platform", "port": 7010 }, "model": { + "segmentor_name": "Cellpose", "segmentor": { "model_type": "cyto" }, + "classifier_name": "PatchClassifier", "classifier":{ - "model_class": "PatchClassifier", "in_channels": 1, - "num_classses": 2, + "num_classes": 2, "features":[64,128,256,512], - "black_bg": "False", - "include_mask": "False" + "black_bg": False, + "include_mask": True } }, @@ -32,18 +33,18 @@ "data_root": "data", "patch_size": 64, "noise_intensity": 5, - "gray": "True", - "rescale": "True" + "gray": True, + "rescale": True }, "train":{ "segmentor":{ - "n_epochs": 3, + "n_epochs": 5, "channels": [0,0], "min_train_masks": 1 }, "classifier":{ - "n_epochs": 2, + "n_epochs": 100, "lr": 0.001, "batch_size": 1, "optimizer": "Adam" diff --git a/src/server/dcp_server/config_instance.cfg b/src/server/dcp_server/config_instance.yaml similarity index 80% rename from src/server/dcp_server/config_instance.cfg rename to src/server/dcp_server/config_instance.yaml index 22b89745..1af6b1eb 100644 --- a/src/server/dcp_server/config_instance.cfg +++ b/src/server/dcp_server/config_instance.yaml @@ -1,14 +1,14 @@ { "setup": { "segmentation": "GeneralSegmentation", + "model_to_use": "CustomCellpose", "accepted_types": [".jpg", ".jpeg", ".png", ".tiff", ".tif"], "seg_name_string": "_seg" }, "service": { - "model_to_use": "CustomCellpose", - "save_model_path": "cells", - "runner_name": "cellpose_runner", + "runner_name": "bento_runner", + "bento_model_path": "cells", "service_name": "data-centric-platform", "port": 7010 }, @@ -20,12 +20,14 @@ }, "data": { - "data_root": "data" + "data_root": "data", + "gray": True, + "rescale": True }, "train":{ "segmentor":{ - "n_epochs": 10, + "n_epochs": 5, "channels": [0,0], "min_train_masks": 1 } diff --git a/src/server/dcp_server/config_semantic.yaml b/src/server/dcp_server/config_semantic.yaml new file mode 100644 index 00000000..928eb931 --- /dev/null +++ b/src/server/dcp_server/config_semantic.yaml @@ -0,0 +1,46 @@ +{ + "setup": { + "segmentation": "GeneralSegmentation", + "model_to_use": "UNet", + "accepted_types": [".jpg", ".jpeg", ".png", ".tiff", ".tif"], + "seg_name_string": "_seg" + }, + + "service": { + "runner_name": "bento_runner", + "bento_model_path": "semantic-Unet", + "service_name": "data-centric-platform", + "port": 7010 + }, + + "model": { + "classifier":{ + "in_channels": 1, + "num_classes": 2, + "features":[64,128,256,512] + } + }, + + "data": { + "data_root": "data", + "gray": True, + "rescale": True + }, + + "train":{ + "classifier":{ + "n_epochs": 2, + "lr": 0.001, + "batch_size": 1, + "optimizer": "Adam" + } + }, + + "eval":{ + "classifier": { + + }, + "compute_instance": True, + "mask_channel_axis": 0 + } +} \ No newline at end of file diff --git a/src/server/dcp_server/main.py b/src/server/dcp_server/main.py index 2e1772b2..2afa577f 100644 --- a/src/server/dcp_server/main.py +++ b/src/server/dcp_server/main.py @@ -16,7 +16,7 @@ def main(): local_path = path.join(__file__, '..') dir_name = path.dirname(path.abspath(sys.argv[0])) - service_config = read_config('service', config_path = path.join(dir_name, 'config.cfg')) + service_config = read_config('service', config_path = path.join(dir_name, 'config.yaml')) port = str(service_config['port']) subprocess.run([ diff --git a/src/server/dcp_server/models/classifiers.py b/src/server/dcp_server/models/classifiers.py index 6b4e9fa1..bd154597 100644 --- a/src/server/dcp_server/models/classifiers.py +++ b/src/server/dcp_server/models/classifiers.py @@ -81,7 +81,6 @@ def train (self, #eval(f'{optimizer_class}(params={self.parameters()}, lr={lr})') # TODO check if we should replace self.parameters with super.parameters() - for _ in tqdm( range(self.train_config["n_epochs"]), desc="Running PatchClassifier training" @@ -125,7 +124,7 @@ def build_model(self): """ in_channels = self.model_config["in_channels"] in_channels = in_channels + 1 if self.model_config["include_mask"] else in_channels - + self.layer1 = nn.Sequential( nn.Conv2d(in_channels, 16, 3, 2, 5), nn.BatchNorm2d(16), @@ -199,12 +198,12 @@ def __init__(self, """ self.model_name = model_name - self.model_config = model_config # use for initialising model - self.data_config = data_config - self.train_config = train_config - self.eval_config = eval_config + self.model_config = model_config["classifier"] # use for initialising model + # self.data_config = data_config + # self.train_config = train_config + # self.eval_config = eval_config - self.model = RandomForestClassifier() # TODO chnage config so RandomForestClassifier accepts input params + self.model = RandomForestClassifier(**self.model_config) # TODO chnage config so RandomForestClassifier accepts input params def train(self, diff --git a/src/server/dcp_server/models/custom_cellpose.py b/src/server/dcp_server/models/custom_cellpose.py index 8e5ad970..6babf12a 100644 --- a/src/server/dcp_server/models/custom_cellpose.py +++ b/src/server/dcp_server/models/custom_cellpose.py @@ -63,6 +63,7 @@ def train(self, :param masks: masks of the given images (training labels) :type masks: List[np.ndarray] """ + if self.train_config["segmentor"]["n_epochs"]==0: return super().train( train_data=deepcopy(imgs), train_labels=masks, diff --git a/src/server/dcp_server/models/inst_to_multi_seg.py b/src/server/dcp_server/models/inst_to_multi_seg.py index bad68a34..0b6ebcf2 100644 --- a/src/server/dcp_server/models/inst_to_multi_seg.py +++ b/src/server/dcp_server/models/inst_to_multi_seg.py @@ -56,33 +56,25 @@ def __init__(self, self.train_config = train_config self.eval_config = eval_config - self.segmentor_class = self.model_config.get("classifier").get("model_class", "Cellpose") - self.classifier_class = self.model_config.get("classifier").get("model_class", "PatchClassifier") + self.segmentor_class = self.model_config.get("segmentor_name", "Cellpose") + self.classifier_class = self.model_config.get("classifier_name", "PatchClassifier") # Initialize the cellpose model and the classifier segmentor = segmentor_mapping.get(self.segmentor_class) self.segmentor = segmentor( self.segmentor_class, self.model_config, self.data_config, self.train_config, self.eval_config ) - ''' - if self.classifier_class == "PatchClassifier": - self.classifier = PatchClassifier( - self.classifier_class, self.model_config, self.data_config, self.train_config, self.eval_config - ) - - elif self.classifier_class == "RandomForest": - self.classifier = FeatureClassifier( - self.classifier_class, self.model_config, self.data_config, self.train_config, self.eval_config - ) - ''' classifier = classifier_mapping.get(self.classifier_class) self.classifier = classifier( self.classifier_class, self.model_config, self.data_config, self.train_config, self.eval_config ) + # make sure include mask is set to False if we are using the random forest model - if self.model_config["classifier"]["include_mask"] == True and self.classifier_class=="RandomForest": + if self.classifier_class=="RandomForest": + if "include_mask" not in self.model_config["classifier"].keys() or self.model_config["classifier"]["include_mask"] is True: #print("Include mask=True was found, but for Random Forest, this parameter must be set to False. Doing this now.") - self.model_config["classifier"]["include_mask"] = False + self.model_config["classifier"]["include_mask"] = False + def train(self, imgs: List[np.ndarray], diff --git a/src/server/dcp_server/models/multicellpose.py b/src/server/dcp_server/models/multicellpose.py index 5d7ef0e0..f1686e18 100644 --- a/src/server/dcp_server/models/multicellpose.py +++ b/src/server/dcp_server/models/multicellpose.py @@ -64,13 +64,13 @@ def train(self, masks_class = [] for mask in masks: - mask_class = mask.copy() # TODO - Do we need copy?? + mask_class = mask[0].copy() # TODO - Do we need copy?? # set all instances in the instance mask not corresponding to the class in question to zero mask_class[0][ mask_class[1]!=(i+1) ] = 0 masks_class.append(mask_class) - + print(masks_class[0].shape) self.cellpose_models[i].train(imgs, masks_class) self.metric = np.mean([self.cellpose_models[i].metric for i in range(self.num_of_channels)]) diff --git a/src/server/dcp_server/models/unet.py b/src/server/dcp_server/models/unet.py index 47e3fdc0..e4dab2b2 100644 --- a/src/server/dcp_server/models/unet.py +++ b/src/server/dcp_server/models/unet.py @@ -114,7 +114,7 @@ def train(self, """ imgs = convert_to_tensor(imgs, np.float32) - masks = convert_to_tensor(masks, np.int16) + masks = convert_to_tensor([mask[1] for mask in masks], np.int16, unsqueeze=False) # Create a training dataset and dataloader train_dataloader = DataLoader( @@ -135,7 +135,6 @@ def train(self, self.loss = 0 for imgs, masks in train_dataloader: - #forward path preds = self.forward(imgs.float()) loss = loss_fn(preds, masks.long()) @@ -166,16 +165,16 @@ def eval(self, img = convert_to_tensor([img], np.float32) preds = self.forward(img) - class_mask = torch.argmax(preds, 1).numpy()[0] - # TODO - make instance mask calculation optional - instance_mask = label((class_mask > 0).astype(int))[0] - - final_mask = np.stack( - (instance_mask, class_mask), - axis=self.eval_config['mask_channel_axis'] - ).astype( - np.uint16 - ) + class_mask = torch.argmax(preds, 1).numpy()[0] + if self.eval_config["compute_instance"] is True: + instance_mask = label((class_mask > 0).astype(int))[0] + final_mask = np.stack( + [instance_mask, class_mask], + axis=self.eval_config['mask_channel_axis'] + ).astype( + np.uint16 + ) + else: final_mask = class_mask.astype(np.uint16) return final_mask diff --git a/src/server/dcp_server/segmentationclasses.py b/src/server/dcp_server/segmentationclasses.py index 4c94004c..326bf069 100644 --- a/src/server/dcp_server/segmentationclasses.py +++ b/src/server/dcp_server/segmentationclasses.py @@ -2,7 +2,7 @@ from dcp_server.utils import helpers # Import configuration -setup_config = helpers.read_config('setup', config_path = 'config.cfg') +setup_config = helpers.read_config('setup', config_path = 'config.yaml') class GeneralSegmentation(): """ Segmentation class. Defining the main functions needed for this project and served by service - segment image and train on images. @@ -34,7 +34,8 @@ async def segment_image(self, input_path, list_of_images): for img_filepath in list_of_images: img = self.imagestorage.prepare_img_for_eval(img_filepath) # Add channel ax into the model's evaluation parameters dictionary - self.model.eval_config['segmentor']['channel_axis'] = self.imagestorage.channel_ax + if self.imagestorage.model_used!="UNet": + self.model.eval_config['segmentor']['channel_axis'] = self.imagestorage.channel_ax # Evaluate the model mask = await self.runner.evaluate.async_run(img = img) # And prepare the mask for saving diff --git a/src/server/dcp_server/service.py b/src/server/dcp_server/service.py index 61193a86..cb560bed 100644 --- a/src/server/dcp_server/service.py +++ b/src/server/dcp_server/service.py @@ -12,12 +12,12 @@ segmentation_module = __import__("segmentationclasses") # Import configuration -service_config = read_config('service', config_path = 'config.cfg') -model_config = read_config('model', config_path = 'config.cfg') -data_config = read_config('data', config_path = 'config.cfg') -train_config = read_config('train', config_path = 'config.cfg') -eval_config = read_config('eval', config_path = 'config.cfg') -setup_config = read_config('setup', config_path = 'config.cfg') +service_config = read_config('service', config_path = 'config.yaml') +model_config = read_config('model', config_path = 'config.yaml') +data_config = read_config('data', config_path = 'config.yaml') +train_config = read_config('train', config_path = 'config.yaml') +eval_config = read_config('eval', config_path = 'config.yaml') +setup_config = read_config('setup', config_path = 'config.yaml') # instantiate the model diff --git a/src/server/dcp_server/utils/fsimagestorage.py b/src/server/dcp_server/utils/fsimagestorage.py index bcd02c8b..725c417d 100644 --- a/src/server/dcp_server/utils/fsimagestorage.py +++ b/src/server/dcp_server/utils/fsimagestorage.py @@ -7,7 +7,7 @@ from dcp_server.utils.processing import pad_image, normalise # Import configuration -setup_config = helpers.read_config("setup", config_path = "config.cfg") +setup_config = helpers.read_config("setup", config_path = "config.yaml") class FilesystemImageStorage(): """Class used to deal with everything related to image storing and processing - loading, saving, transforming... @@ -43,7 +43,7 @@ def save_image(self, to_save_path, img): :type to_save_path: str :param img: image you wish to save :type img: ndarray - """ + """ imsave(os.path.join(self.root_dir, to_save_path), img) def search_images(self, directory): @@ -193,6 +193,10 @@ def resize_mask(self, mask, channel_ax=None, order=0): height_pad = mask.shape[2] - self.img_height width_pad = mask.shape[0]- self.img_width return mask[:-width_pad, :, :-height_pad] + else: + height_pad = mask.shape[0] - self.img_height + width_pad = mask.shape[1]-self.img_width + return mask[:-height_pad,:-width_pad] else: if channel_ax is not None: @@ -256,6 +260,7 @@ def prepare_mask_for_save(self, mask, channel_ax): :rtype: np.ndarray """ # Resize the mask if rescaling took place before - if self.rescale: + if self.rescale is True: + if len(mask.shape)<3: channel_ax=None return self.resize_mask(mask, channel_ax) else: return mask \ No newline at end of file diff --git a/src/server/dcp_server/utils/helpers.py b/src/server/dcp_server/utils/helpers.py index f706b5d0..72d52a78 100644 --- a/src/server/dcp_server/utils/helpers.py +++ b/src/server/dcp_server/utils/helpers.py @@ -1,19 +1,18 @@ from pathlib import Path -import json +import yaml - -def read_config(name, config_path = 'config.cfg') -> dict: +def read_config(name, config_path = 'config.yaml') -> dict: """Reads the configuration file :param name: name of the section you want to read (e.g. 'setup','train') :type name: string - :param config_path: path to the configuration file, defaults to 'config.cfg' + :param config_path: path to the configuration file, defaults to 'config.yaml' :type config_path: str, optional :return: dictionary from the config section given by name :rtype: dict """ with open(config_path) as config_file: - config_dict = json.load(config_file) + config_dict = yaml.safe_load(config_file) # json.load(config_file) for .cfg file # Check if config file has main mandatory keys assert all([i in config_dict.keys() for i in ['setup', 'service', 'model', 'train', 'eval']]) return config_dict[name] diff --git a/src/server/dcp_server/utils/processing.py b/src/server/dcp_server/utils/processing.py index 42f50948..1892e01f 100644 --- a/src/server/dcp_server/utils/processing.py +++ b/src/server/dcp_server/utils/processing.py @@ -48,12 +48,12 @@ def pad_image(img, height, width, channel_ax=None, dividable = 16) -> np.ndarray img = np.pad(img, ((0, height_pad), (0, width_pad))) return img -def convert_to_tensor(imgs, dtype): +def convert_to_tensor(imgs, dtype, unsqueeze=True): # Convert images tensors imgs = torch.stack([ torch.from_numpy(img.astype(dtype)) for img in imgs ]) - imgs = imgs.unsqueeze(1) if imgs.ndim == 3 else imgs + imgs = imgs.unsqueeze(1) if imgs.ndim == 3 and unsqueeze is True else imgs return imgs def crop_centered_padded_patch(img: np.ndarray, @@ -77,7 +77,6 @@ def crop_centered_padded_patch(img: np.ndarray, Returns: np.ndarray: The cropped patch with applied padding. """ - height, width = patch_size # Size of the patch img_height, img_width = img.shape[0], img.shape[1] # Size of the input image @@ -99,7 +98,8 @@ def crop_centered_padded_patch(img: np.ndarray, # crop the mask mask = mask[max(top, 0):min(bottom, img_height), max(left, 0):min(right, img_width), :] - patch = img[max(top, 0):min(bottom, img_height), max(left, 0):min(right, img_width), :] + patch = img[max(top, 0):min(bottom, img_height), max(left, 0):min(right, img_width), :] + # Calculate the required padding amounts and apply padding if necessary if left < 0: patch = np.hstack(( @@ -136,7 +136,6 @@ def crop_centered_padded_patch(img: np.ndarray, mask = np.vstack(( mask, np.zeros((bottom - img_height, mask.shape[1], mask.shape[2])).astype(np.uint8))) - return patch, mask @@ -202,7 +201,7 @@ def get_centered_patches(img, obj_label, mask=deepcopy(mask), noise_intensity=noise_intensity) - if include_mask: + if include_mask is True: patch_mask = 255 * (patch_mask > 0).astype(np.uint8) patch = np.concatenate((patch, patch_mask), axis=-1) @@ -296,7 +295,6 @@ def get_shape_features(img, mask): """ mask = 255 * ((mask) > 0).astype(np.uint8) - image = sitk.GetImageFromArray(img.squeeze()) roi_mask = sitk.GetImageFromArray(mask.squeeze()) @@ -351,7 +349,6 @@ def create_dataset_for_rf(imgs, masks): """ X = [] for img, mask in zip(imgs, masks): - shape_features = get_shape_features(img, mask) intensity_features = extract_intensity_features(img, mask) features_list = np.concatenate((shape_features, intensity_features), axis=0) From 0de002f417b4a0b85dfefff88b315c8f5a83a6ba Mon Sep 17 00:00:00 2001 From: Christina Bukas Date: Wed, 6 Mar 2024 17:35:43 +0100 Subject: [PATCH 08/26] add docs update --- .github/workflows/test.yml | 46 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 3c45b11f..d6398963 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -114,4 +114,50 @@ jobs: files: src/server/coverage.xml env: CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} + + update_docs: + name: Update docs + runs-on: ubuntu-latest + steps: + - name: Update + run: make clean && make html + + deploy: + # this will run when you have tagged a commit, starting with "v*" + # and requires that you have put your twine API key in your + # github secrets (see readme for details) + needs: [test_client, test_server, update_docs] + runs-on: ubuntu-latest + if: contains(github.ref, 'tags') + steps: + - uses: actions/checkout@v3 + - name: Set up Python + uses: actions/setup-python@v3 + with: + python-version: "3.x" + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install twine + pip install build + + - name: Build and publish dcp_client + env: + TWINE_USERNAME: __token__ + TWINE_PASSWORD: ${{ secrets.TWINE_API_KEY }} + run: | + git tag + python -m build . + twine upload dist/* + working-directory: src/client + + - name: Build and publish dcp_server + env: + TWINE_USERNAME: __token__ + TWINE_PASSWORD: ${{ secrets.TWINE_API_KEY }} + run: | + git tag + python -m build . + twine upload dist/* + working-directory: src/server From 4b5315e43e6d97fd843b56abb9e22e915ef20f2d Mon Sep 17 00:00:00 2001 From: Christina Bukas Date: Thu, 7 Mar 2024 22:11:54 +0100 Subject: [PATCH 09/26] added inherant model class --- src/server/dcp_server/models/custom_cellpose.py | 4 ++-- src/server/dcp_server/models/inst_to_multi_seg.py | 5 +++-- src/server/dcp_server/models/multicellpose.py | 5 +++-- src/server/dcp_server/models/unet.py | 4 ++-- 4 files changed, 10 insertions(+), 8 deletions(-) diff --git a/src/server/dcp_server/models/custom_cellpose.py b/src/server/dcp_server/models/custom_cellpose.py index 6babf12a..71fe3a2b 100644 --- a/src/server/dcp_server/models/custom_cellpose.py +++ b/src/server/dcp_server/models/custom_cellpose.py @@ -9,9 +9,9 @@ from cellpose.metrics import aggregated_jaccard_index from cellpose.dynamics import labels_to_flows -#from dcp_server.models import Model +from .model import Model -class CustomCellpose(models.CellposeModel): #, Model): +class CustomCellpose(models.CellposeModel, Model): """Custom cellpose model inheriting the attributes and functions from the original CellposeModel and implementing additional attributes and methods needed for this project. """ diff --git a/src/server/dcp_server/models/inst_to_multi_seg.py b/src/server/dcp_server/models/inst_to_multi_seg.py index 0b6ebcf2..a270cfa3 100644 --- a/src/server/dcp_server/models/inst_to_multi_seg.py +++ b/src/server/dcp_server/models/inst_to_multi_seg.py @@ -4,7 +4,8 @@ import numpy as np import torch -from dcp_server.models import CustomCellpose # Model, +from .model import Model +from .custom_cellpose import CustomCellpose from dcp_server.models.classifiers import PatchClassifier, FeatureClassifier from dcp_server.utils.processing import ( get_centered_patches, @@ -24,7 +25,7 @@ } -class Inst2MultiSeg(): #Model): +class Inst2MultiSeg(Model): """ A two stage model for: 1. instance segmentation and 2. object wise classification """ diff --git a/src/server/dcp_server/models/multicellpose.py b/src/server/dcp_server/models/multicellpose.py index f1686e18..f24d101e 100644 --- a/src/server/dcp_server/models/multicellpose.py +++ b/src/server/dcp_server/models/multicellpose.py @@ -2,9 +2,10 @@ import numpy as np from skimage.measure import label as label_mask -from dcp_server.models import CustomCellpose # Model, +from .model import Model +from .custom_cellpose import CustomCellpose -class MultiCellpose(): #Model): +class MultiCellpose(Model): ''' Multichannel image segmentation model. Run the separate CustomCellpose models for each channel return the mask corresponding to each object type. diff --git a/src/server/dcp_server/models/unet.py b/src/server/dcp_server/models/unet.py index e4dab2b2..ef728932 100644 --- a/src/server/dcp_server/models/unet.py +++ b/src/server/dcp_server/models/unet.py @@ -8,11 +8,11 @@ from torch.optim import Adam from torch.utils.data import TensorDataset, DataLoader -#from dcp_server.models import Model +from .model import Model from dcp_server.utils.processing import convert_to_tensor -class UNet(nn.Module): # Model +class UNet(nn.Module, Model): """ Unet is a convolutional neural network architecture for semantic segmentation. From b54ab150255a1fdf61bc4cba11f668db3ddc236d Mon Sep 17 00:00:00 2001 From: Christina Bukas Date: Thu, 7 Mar 2024 22:12:22 +0100 Subject: [PATCH 10/26] update configs --- src/server/dcp_server/config.yaml | 4 +-- ...st_config_fcnn.cfg => test_config_RF.yaml} | 27 ++++++-------- ...st_config_RF.cfg => test_config_fcnn.yaml} | 36 +++++++++---------- 3 files changed, 29 insertions(+), 38 deletions(-) rename src/server/test/configs/{test_config_fcnn.cfg => test_config_RF.yaml} (67%) rename src/server/test/configs/{test_config_RF.cfg => test_config_fcnn.yaml} (60%) diff --git a/src/server/dcp_server/config.yaml b/src/server/dcp_server/config.yaml index 7f807753..70b0011f 100644 --- a/src/server/dcp_server/config.yaml +++ b/src/server/dcp_server/config.yaml @@ -1,14 +1,14 @@ { "setup": { "segmentation": "GeneralSegmentation", - "model_to_use": "MultiCellpose", + "model_to_use": "Inst2MultiSeg", "accepted_types": [".jpg", ".jpeg", ".png", ".tiff", ".tif"], "seg_name_string": "_seg" }, "service": { "runner_name": "bento_runner", - "bento_model_path": "test1w", + "bento_model_path": "cells", "service_name": "data-centric-platform", "port": 7010 }, diff --git a/src/server/test/configs/test_config_fcnn.cfg b/src/server/test/configs/test_config_RF.yaml similarity index 67% rename from src/server/test/configs/test_config_fcnn.cfg rename to src/server/test/configs/test_config_RF.yaml index 02039f68..89f60818 100644 --- a/src/server/test/configs/test_config_fcnn.cfg +++ b/src/server/test/configs/test_config_RF.yaml @@ -1,34 +1,37 @@ { "setup": { "segmentation": "GeneralSegmentation", + "model_to_use": "Inst2MultiSeg", "accepted_types": [".jpg", ".jpeg", ".png", ".tiff", ".tif"], "seg_name_string": "_seg" }, "service": { - "model_to_use": "CustomCellposeModel", - "save_model_path": "mito", - "runner_name": "cellpose_runner", + "runner_name": "bento_runner", + "bento_model_path": "test", "service_name": "data-centric-platform", "port": 7010 }, "model": { + "segmentor_name": Cellpose, "segmentor": { "model_type": "cyto" }, + "classifier_name": "RandomForest", "classifier":{ - "model_class": "FCNN", "in_channels": 1, "num_classes": 3, "features":[64,128,256,512], - "black_bg": "False", - "include_mask": "False" + "black_bg": False, + "include_mask": False } }, "data": { - "data_root": "data" + "data_root": "data", + "gray": True, + "rescale": True }, "train":{ @@ -39,11 +42,6 @@ "learning_rate":0.01 }, "classifier":{ - "train_data":{ - "patch_size": 64, - "noise_intensity": 5, - "num_classes": 3 - }, "n_epochs": 20, "lr": 0.005, "batch_size": 5, @@ -59,10 +57,7 @@ "batch_size": 1 }, "classifier": { - "data":{ - "patch_size": 64, - "noise_intensity": 5 - } + }, "mask_channel_axis": 0 } diff --git a/src/server/test/configs/test_config_RF.cfg b/src/server/test/configs/test_config_fcnn.yaml similarity index 60% rename from src/server/test/configs/test_config_RF.cfg rename to src/server/test/configs/test_config_fcnn.yaml index c09c6af5..385819b7 100644 --- a/src/server/test/configs/test_config_RF.cfg +++ b/src/server/test/configs/test_config_fcnn.yaml @@ -1,34 +1,39 @@ { "setup": { "segmentation": "GeneralSegmentation", - "model_to_use": "CustomCellposeModel", + "model_to_use": "Inst2MultiSeg", "accepted_types": [".jpg", ".jpeg", ".png", ".tiff", ".tif"], "seg_name_string": "_seg" }, "service": { - "runner_name": "cellpose_runner", - "bento_model_path": "mito", + "runner_name": "bento_runner", + "save_model_path": "test", "service_name": "data-centric-platform", "port": 7010 }, "model": { + "segmentor_name": Cellpose, "segmentor": { "model_type": "cyto" }, - "classifier":{ - "model_class": "RandomForest", + "classifier_name": "PatchClassifier", + "classifier":{ "in_channels": 1, "num_classes": 3, "features":[64,128,256,512], - "black_bg": "False", - "include_mask": "False" + "black_bg": False, + "include_mask": False } }, "data": { - "data_root": "data" + "data_root": "data", + "patch_size": 64, + "noise_intensity": 5, + "gray": True, + "rescale": True }, "train":{ @@ -39,14 +44,9 @@ "learning_rate":0.01 }, "classifier":{ - "train_data":{ - "patch_size": 64, - "noise_intensity": 5, - "num_classes": 3 - }, - "n_epochs": 10, - "lr": 0.001, - "batch_size": 1, + "n_epochs": 20, + "lr": 0.005, + "batch_size": 5, "optimizer": "Adam" } }, @@ -59,10 +59,6 @@ "batch_size": 1 }, "classifier": { - "data":{ - "patch_size": 64, - "noise_intensity": 5 - } }, "mask_channel_axis": 0 } From 97a00b4bb3be450713186fcca9d61c9300d5fa3e Mon Sep 17 00:00:00 2001 From: Christina Bukas Date: Thu, 7 Mar 2024 22:12:49 +0100 Subject: [PATCH 11/26] update documentation --- .github/workflows/test.yml | 9 +-------- docs/source/dcp_server_installation.rst | 13 +++++++------ docs/source/index.rst | 2 +- 3 files changed, 9 insertions(+), 15 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 505fbfef..22e2297e 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -114,19 +114,12 @@ jobs: files: src/server/coverage.xml env: CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} - - update_docs: - name: Update docs - runs-on: ubuntu-latest - steps: - - name: Update - run: make clean && make html deploy: # this will run when you have tagged a commit, starting with "v*" # and requires that you have put your twine API key in your # github secrets (see readme for details) - needs: [test_client, test_server, update_docs] + needs: [test_client, test_server] runs-on: ubuntu-latest if: contains(github.ref, 'tags') steps: diff --git a/docs/source/dcp_server_installation.rst b/docs/source/dcp_server_installation.rst index a34bd6a8..a005ad3e 100644 --- a/docs/source/dcp_server_installation.rst +++ b/docs/source/dcp_server_installation.rst @@ -83,17 +83,18 @@ The models are currently integrated into DCP: - **Instance** Segmentation: - - ``CustomCellposeModel``: Inherits from cellpose.models.CellposeModel, see `here `__ for more information. + - ``CustomCellpose``: Inherits from cellpose.models.CellposeModel, see `here `__ for more information. - **Semantic** Segmentation: - ``UNet``: A vanilla U-Net model, trained on the full images -- **Panoptic** Segmentation: +- **Multi Class Instance** Segmentation: - - ``CellposePatchCNN``: Includes a segmentor for instance segmentation, sequentially followed by a classifier for semantic segmentation. The segmentor can only be ``CustomCellposeModel`` model, while the classifier can be one of: + - ``Inst2MultiSeg``: Includes a segmentor for instance segmentation, sequentially followed by a classifier for semantic segmentation. The segmentor can only be ``CustomCellposeModel`` model, while the classifier can be one of: - - ``CellClassifierFCNN`` or "FCNN" (in config): A CNN model for obtaining class labels, trained on images patches of individual objects, extarcted using the instance mask from the previous step - - ``CellClassifierShallowModel`` or "RandomForest" (in config): A Random Forest model for obtaining class labels, trained on shape and intensity features of the objects, extracted using the instance mask from the previous step. - - UNet: If the post-processing argument is set, then the instance mask is deduced from the labels mask. Will not be able to handle touching objects + - ``PatchClassifier`` or "FCNN" (in config): A CNN model for obtaining class labels, trained on images patches of individual objects, extarcted using the instance mask from the previous step + - ``FeatureClassifier`` or "RandomForest" (in config): A Random Forest model for obtaining class labels, trained on shape and intensity features of the objects, extracted using the instance mask from the previous step. + - ``MultiCellpose``: Includes **n** CustomCellpose models, where n equals the number of classes, stacked such that each model predicts only the object corresponding to each class. + - ``UNet``: If the post-processing argument is set, then the instance mask is deduced from the labels mask. Will not be able to handle touching objects Running with Docker diff --git a/docs/source/index.rst b/docs/source/index.rst index f220d007..c6532633 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -31,7 +31,7 @@ DCP handles all kinds of **segmentation tasks**! Try it out if you need to do: - **Instance** segmentation - **Semantic** segmentation -- **Panoptic** segmentation +- **Multi-class instance** segmentation Toy data -------- From 726a9af4292f3e61ed41f99545bfe1b1ee8090bc Mon Sep 17 00:00:00 2001 From: Christina Bukas Date: Fri, 8 Mar 2024 00:21:27 +0100 Subject: [PATCH 12/26] update tests --- .../configs/test_config_CustomCellpose.yaml | 47 +++++ ...config_Inst2MultiSeg_PatchClassifier.yaml} | 4 +- ...yaml => test_config_Inst2MultiSeg_RF.yaml} | 14 +- .../configs/test_config_MultiCellpose.yaml | 50 ++++++ src/server/test/configs/test_config_UNet.yaml | 46 +++++ src/server/test/test_integration.py | 168 +++++++++--------- src/server/test/test_models.py | 21 ++- src/server/test/test_utils.py | 2 +- 8 files changed, 244 insertions(+), 108 deletions(-) create mode 100644 src/server/test/configs/test_config_CustomCellpose.yaml rename src/server/test/configs/{test_config_fcnn.yaml => test_config_Inst2MultiSeg_PatchClassifier.yaml} (95%) rename src/server/test/configs/{test_config_RF.yaml => test_config_Inst2MultiSeg_RF.yaml} (77%) create mode 100644 src/server/test/configs/test_config_MultiCellpose.yaml create mode 100644 src/server/test/configs/test_config_UNet.yaml diff --git a/src/server/test/configs/test_config_CustomCellpose.yaml b/src/server/test/configs/test_config_CustomCellpose.yaml new file mode 100644 index 00000000..5e3c0436 --- /dev/null +++ b/src/server/test/configs/test_config_CustomCellpose.yaml @@ -0,0 +1,47 @@ +{ + "setup": { + "segmentation": "GeneralSegmentation", + "model_to_use": "CustomCellpose", + "accepted_types": [".jpg", ".jpeg", ".png", ".tiff", ".tif"], + "seg_name_string": "_seg" + }, + + "service": { + "runner_name": "bento_runner", + "bento_model_path": "cells", + "service_name": "data-centric-platform", + "port": 7010 + }, + + "model": { + "segmentor_name": Cellpose, + "segmentor": { + "model_type": "cyto" + } + }, + + "data": { + "data_root": "data", + "gray": True, + "rescale": True + }, + + "train":{ + "segmentor":{ + "n_epochs": 20, + "channels": [0,0], + "min_train_masks": 1, + "learning_rate":0.01 + } + }, + + "eval":{ + "segmentor": { + "z_axis": null, + "channel_axis": null, + "rescale": 1, + "batch_size": 1 + }, + "mask_channel_axis": null + } +} \ No newline at end of file diff --git a/src/server/test/configs/test_config_fcnn.yaml b/src/server/test/configs/test_config_Inst2MultiSeg_PatchClassifier.yaml similarity index 95% rename from src/server/test/configs/test_config_fcnn.yaml rename to src/server/test/configs/test_config_Inst2MultiSeg_PatchClassifier.yaml index 385819b7..20e5c96a 100644 --- a/src/server/test/configs/test_config_fcnn.yaml +++ b/src/server/test/configs/test_config_Inst2MultiSeg_PatchClassifier.yaml @@ -8,7 +8,7 @@ "service": { "runner_name": "bento_runner", - "save_model_path": "test", + "bento_model_path": "cells", "service_name": "data-centric-platform", "port": 7010 }, @@ -38,7 +38,7 @@ "train":{ "segmentor":{ - "n_epochs": 20, + "n_epochs": 10, "channels": [0,0], "min_train_masks": 1, "learning_rate":0.01 diff --git a/src/server/test/configs/test_config_RF.yaml b/src/server/test/configs/test_config_Inst2MultiSeg_RF.yaml similarity index 77% rename from src/server/test/configs/test_config_RF.yaml rename to src/server/test/configs/test_config_Inst2MultiSeg_RF.yaml index 89f60818..0734bcf7 100644 --- a/src/server/test/configs/test_config_RF.yaml +++ b/src/server/test/configs/test_config_Inst2MultiSeg_RF.yaml @@ -20,32 +20,25 @@ }, "classifier_name": "RandomForest", "classifier":{ - "in_channels": 1, - "num_classes": 3, - "features":[64,128,256,512], - "black_bg": False, - "include_mask": False } }, "data": { "data_root": "data", + "patch_size": 64, + "noise_intensity": 5, "gray": True, "rescale": True }, "train":{ "segmentor":{ - "n_epochs": 20, + "n_epochs": 10, "channels": [0,0], "min_train_masks": 1, "learning_rate":0.01 }, "classifier":{ - "n_epochs": 20, - "lr": 0.005, - "batch_size": 5, - "optimizer": "Adam" } }, @@ -57,7 +50,6 @@ "batch_size": 1 }, "classifier": { - }, "mask_channel_axis": 0 } diff --git a/src/server/test/configs/test_config_MultiCellpose.yaml b/src/server/test/configs/test_config_MultiCellpose.yaml new file mode 100644 index 00000000..b74476fe --- /dev/null +++ b/src/server/test/configs/test_config_MultiCellpose.yaml @@ -0,0 +1,50 @@ +{ + "setup": { + "segmentation": "GeneralSegmentation", + "model_to_use": "MultiCellpose", + "accepted_types": [".jpg", ".jpeg", ".png", ".tiff", ".tif"], + "seg_name_string": "_seg" + }, + + "service": { + "runner_name": "bento_runner", + "bento_model_path": "cells", + "service_name": "data-centric-platform", + "port": 7010 + }, + + "model": { + "segmentor_name": Cellpose, + "segmentor": { + "model_type": "cyto" + }, + "classifier":{ + "num_classes": 3 + } + }, + + "data": { + "data_root": "data", + "gray": True, + "rescale": True + }, + + "train":{ + "segmentor":{ + "n_epochs": 20, + "channels": [0,0], + "min_train_masks": 1, + "learning_rate":0.01 + } + }, + + "eval":{ + "segmentor": { + "z_axis": null, + "channel_axis": null, + "rescale": 1, + "batch_size": 1 + }, + "mask_channel_axis": 0 + } +} \ No newline at end of file diff --git a/src/server/test/configs/test_config_UNet.yaml b/src/server/test/configs/test_config_UNet.yaml new file mode 100644 index 00000000..f6ee29bc --- /dev/null +++ b/src/server/test/configs/test_config_UNet.yaml @@ -0,0 +1,46 @@ +{ + "setup": { + "segmentation": "GeneralSegmentation", + "model_to_use": "UNet", + "accepted_types": [".jpg", ".jpeg", ".png", ".tiff", ".tif"], + "seg_name_string": "_seg" + }, + + "service": { + "runner_name": "bento_runner", + "bento_model_path": "cells", + "service_name": "data-centric-platform", + "port": 7010 + }, + + "model": { + "classifier":{ + "in_channels": 1, + "num_classes": 3, + "features":[64,128,256,512] + } + }, + + "data": { + "data_root": "data", + "gray": True, + "rescale": True + }, + + "train":{ + "classifier":{ + "n_epochs": 20, + "lr": 0.005, + "batch_size": 5, + "optimizer": "Adam" + } + }, + + "eval":{ + "classifier": { + + }, + compute_instance: True, + "mask_channel_axis": 0 + } +} \ No newline at end of file diff --git a/src/server/test/test_integration.py b/src/server/test/test_integration.py index 3a28a3e2..3ef619fc 100644 --- a/src/server/test/test_integration.py +++ b/src/server/test/test_integration.py @@ -1,55 +1,51 @@ import sys +sys.path.append(".") + from glob import glob -import inspect +import pytest +#import inspect import random import numpy as np - import torch from torchmetrics import JaccardIndex -# from importlib.machinery import SourceFileLoader - -sys.path.append(".") - -import dcp_server.models as models -from dcp_server.utils import read_config +from dcp_server.models import * +from dcp_server.utils.helpers import read_config from synthetic_dataset import get_synthetic_dataset -import pytest - seed_value = 2023 random.seed(seed_value) torch.manual_seed(seed_value) np.random.seed(seed_value) -# retrieve models names -model_classes = [ - cls_obj for cls_name, cls_obj in inspect.getmembers(models) \ - if inspect.isclass(cls_obj) \ - and cls_obj.__module__ == models.__name__ \ - and not cls_name.startswith("CellClassifier") - ] - -config_paths = glob("test/configs/*.cfg") +model_mapping = { + "CustomCellpose": CustomCellpose, + "Inst2MultiSeg": Inst2MultiSeg, + "MultiCellpose": MultiCellpose, + "UNet": UNet +} -@pytest.fixture(params=model_classes) -def model_class(request): - return request.param +config_paths = glob("test/configs/*.yaml") @pytest.fixture(params=config_paths) def config_path(request): return request.param @pytest.fixture() -def model(model_class, config_path): +#def model(model_class, config_path): +def model(config_path): + setup_config = read_config('setup', config_path=config_path) model_config = read_config('model', config_path=config_path) + data_config = read_config('data', config_path=config_path) train_config = read_config('train', config_path=config_path) eval_config = read_config('eval', config_path=config_path) - - model = model_class(model_config, train_config, eval_config, str(model_class)) - + + model_name = setup_config["model_to_use"] + model_class = model_mapping.get(model_name) + model = model_class(model_name, model_config, data_config, train_config, eval_config) + # str(model_class) return model @pytest.fixture @@ -68,6 +64,67 @@ def data_eval(): msk_ = np.stack((msk.sum(-1), ((msk > 0) * np.arange(1, 4)).sum(-1)), axis=0).transpose(1,0,2,3) return img, msk_ +def test_train_eval_run(data_train, data_eval, model): + """ + Performs testing, training, and evaluation with the provided data and model. + """ + + images, masks = data_train + if model.model_name == "CustomCellpose": masks = [mask[0] for mask in masks] + model.train(images, masks) + + imgs_test, masks_test = data_eval + if model.model_name == "CustomCellpose": masks = [mask[0] for mask in masks_test] + + jaccard_index_instances = 0 + jaccard_index_classes = 0 + + jaccard_metric_binary = JaccardIndex(task="multiclass", num_classes=2, average="macro", ignore_index=0) + jaccard_metric_multi = JaccardIndex(task="multiclass", num_classes=4, average="macro", ignore_index=0) + + for img, mask in zip(imgs_test, masks_test): + + #mask - instance segmentation mask + classes (2, 512, 512) + #pred_mask - tuple of cellpose (512, 512), patch net multiclass segmentation (512, 512, 2) + + pred_mask = model.eval(img) + + if pred_mask.ndim > 2: + pred_mask_bin = torch.tensor((pred_mask[0]>0).astype(bool).astype(int)) + else: + pred_mask_bin = torch.tensor((pred_mask > 0).astype(bool).astype(int)) + + bin_mask = torch.tensor((mask[0]>0).astype(bool).astype(int)) + + jaccard_index_instances += jaccard_metric_binary( + pred_mask_bin, + bin_mask + ) + + if pred_mask.ndim > 2: + + jaccard_index_classes += jaccard_metric_multi( + torch.tensor(pred_mask[1].astype(int)), + torch.tensor(mask[1].astype(int)) + ) + + jaccard_index_instances /= len(imgs_test) + assert(jaccard_index_instances>0.2) + + # retrieve the attribute names of the class of the current model + attrs = model.__dict__.keys() + + if "metric" in attrs: + assert(model.metric>0.1) + if "loss" in attrs: + assert(model.loss<0.83) + + # for PatchCNN model + if pred_mask.ndim > 2: + + jaccard_index_classes /= len(imgs_test) + assert(jaccard_index_classes>0.1) + # def test_train_run(data_train, model): # images, masks = data_train @@ -131,62 +188,3 @@ def data_eval(): # jaccard_index_classes /= len(imgs_test) # assert(jaccard_index_classes>0.1) - -def test_train_eval_run(data_train, data_eval, model): - """ - Performs testing, training, and evaluation with the provided data and model. - """ - - images, masks = data_train - model.train(images, masks) - - imgs_test, masks_test = data_eval - - jaccard_index_instances = 0 - jaccard_index_classes = 0 - - jaccard_metric_binary = JaccardIndex(task="multiclass", num_classes=2, average="macro", ignore_index=0) - jaccard_metric_multi = JaccardIndex(task="multiclass", num_classes=4, average="macro", ignore_index=0) - - for img, mask in zip(imgs_test, masks_test): - - #mask - instance segmentation mask + classes (2, 512, 512) - #pred_mask - tuple of cellpose (512, 512), patch net multiclass segmentation (512, 512, 2) - - pred_mask = model.eval(img) - - if pred_mask.ndim > 2: - pred_mask_bin = torch.tensor((pred_mask[0]>0).astype(bool).astype(int)) - else: - pred_mask_bin = torch.tensor((pred_mask > 0).astype(bool).astype(int)) - - bin_mask = torch.tensor((mask[0]>0).astype(bool).astype(int)) - - jaccard_index_instances += jaccard_metric_binary( - pred_mask_bin, - bin_mask - ) - - if pred_mask.ndim > 2: - - jaccard_index_classes += jaccard_metric_multi( - torch.tensor(pred_mask[1].astype(int)), - torch.tensor(mask[1].astype(int)) - ) - - jaccard_index_instances /= len(imgs_test) - assert(jaccard_index_instances>0.2) - - # retrieve the attribute names of the class of the current model - attrs = model.__dict__.keys() - - if "metric" in attrs: - assert(model.metric>0.1) - if "loss" in attrs: - assert(model.loss<0.75) - - # for PatchCNN model - if pred_mask.ndim > 2: - - jaccard_index_classes /= len(imgs_test) - assert(jaccard_index_classes>0.1) \ No newline at end of file diff --git a/src/server/test/test_models.py b/src/server/test/test_models.py index 7a91fa9a..7816f594 100644 --- a/src/server/test/test_models.py +++ b/src/server/test/test_models.py @@ -2,18 +2,20 @@ import numpy as np import dcp_server.models as models -from dcp_server.utils import read_config +from dcp_server.models.classifiers import FeatureClassifier +from dcp_server.utils.helpers import read_config def test_eval_rf_not_fitted(): """ Tests the evaluation of a random forest model that has not been fitted. """ - model_config = read_config('model', config_path='test/configs/test_config_RF.cfg') - train_config = read_config('train', config_path='test/configs/test_config_RF.cfg') - eval_config = read_config('eval', config_path='test/configs/test_config_RF.cfg') + model_config = read_config('model', config_path='test/configs/test_config_Inst2MultiSeg_RF.yaml') + data_config = read_config('data', config_path='test/configs/test_config_Inst2MultiSeg_RF.yaml') + train_config = read_config('train', config_path='test/configs/test_config_Inst2MultiSeg_RF.yaml') + eval_config = read_config('eval', config_path='test/configs/test_config_Inst2MultiSeg_RF.yaml') - model_rf = models.CellClassifierShallowModel(model_config,train_config,eval_config) + model_rf = FeatureClassifier("Random Forest", model_config, data_config, train_config, eval_config) X_test = np.array([[1, 2, 3]]) # if we don't fit the model then the model returns zeros @@ -24,11 +26,12 @@ def test_update_configs(): Tests the update of model training and evaluation configurations. """ - model_config = read_config('model', config_path='test/configs/test_config_RF.cfg') - train_config = read_config('train', config_path='test/configs/test_config_RF.cfg') - eval_config = read_config('eval', config_path='test/configs/test_config_RF.cfg') + model_config = read_config('model', config_path='test/configs/test_config_Inst2MultiSeg_RF.yaml') + data_config = read_config('data', config_path='test/configs/test_config_Inst2MultiSeg_RF.yaml') + train_config = read_config('train', config_path='test/configs/test_config_Inst2MultiSeg_RF.yaml') + eval_config = read_config('eval', config_path='test/configs/test_config_Inst2MultiSeg_RF.yaml') - model = models.CustomCellposeModel(model_config,train_config,eval_config, "Cellpose") + model = models.CustomCellpose("Cellpose", model_config, data_config, train_config, eval_config) new_train_config = {"param1": "value1"} new_eval_config = {"param2": "value2"} diff --git a/src/server/test/test_utils.py b/src/server/test/test_utils.py index 35678a22..fd02c044 100644 --- a/src/server/test/test_utils.py +++ b/src/server/test/test_utils.py @@ -1,6 +1,6 @@ import numpy as np import pytest -from dcp_server.utils import find_max_patch_size +from dcp_server.utils.processing import find_max_patch_size @pytest.fixture def sample_mask(): From 03273adf2a5c254399167a9b041800d1e956f23a Mon Sep 17 00:00:00 2001 From: Christina Bukas Date: Fri, 8 Mar 2024 00:21:38 +0100 Subject: [PATCH 13/26] update models init --- src/server/dcp_server/models/custom_cellpose.py | 2 +- src/server/dcp_server/models/inst_to_multi_seg.py | 3 ++- src/server/dcp_server/models/multicellpose.py | 4 ++-- src/server/dcp_server/models/unet.py | 5 +++-- 4 files changed, 8 insertions(+), 6 deletions(-) diff --git a/src/server/dcp_server/models/custom_cellpose.py b/src/server/dcp_server/models/custom_cellpose.py index 71fe3a2b..ed3c798a 100644 --- a/src/server/dcp_server/models/custom_cellpose.py +++ b/src/server/dcp_server/models/custom_cellpose.py @@ -40,7 +40,7 @@ def __init__(self, # Initialize the cellpose model # super().__init__(**model_config["segmentor"]) - #nn.Module.__init__(self) + Model.__init__(self, model_name, model_config, data_config, train_config, eval_config) models.CellposeModel.__init__(self, **model_config["segmentor"]) self.model_config = model_config self.data_config = data_config diff --git a/src/server/dcp_server/models/inst_to_multi_seg.py b/src/server/dcp_server/models/inst_to_multi_seg.py index a270cfa3..2f8c6e77 100644 --- a/src/server/dcp_server/models/inst_to_multi_seg.py +++ b/src/server/dcp_server/models/inst_to_multi_seg.py @@ -49,7 +49,8 @@ def __init__(self, :param eval_config: Evaluation configuration. :type eval_config: dict """ - super().__init__() + #super().__init__() + Model.__init__(self, model_name, model_config, data_config, train_config, eval_config) self.model_name = model_name self.model_config = model_config diff --git a/src/server/dcp_server/models/multicellpose.py b/src/server/dcp_server/models/multicellpose.py index f24d101e..9d1b110d 100644 --- a/src/server/dcp_server/models/multicellpose.py +++ b/src/server/dcp_server/models/multicellpose.py @@ -29,7 +29,8 @@ def __init__(self, :param eval_config: Evaluation configuration. :type eval_config: dict """ - + Model.__init__(self, model_name, model_config, data_config, train_config, eval_config) + self.model_config = model_config self.data_config = data_config self.train_config = train_config @@ -71,7 +72,6 @@ def train(self, mask_class[1]!=(i+1) ] = 0 masks_class.append(mask_class) - print(masks_class[0].shape) self.cellpose_models[i].train(imgs, masks_class) self.metric = np.mean([self.cellpose_models[i].metric for i in range(self.num_of_channels)]) diff --git a/src/server/dcp_server/models/unet.py b/src/server/dcp_server/models/unet.py index ef728932..cb2565f6 100644 --- a/src/server/dcp_server/models/unet.py +++ b/src/server/dcp_server/models/unet.py @@ -86,8 +86,9 @@ def __init__(self, :param eval_config: Evaluation configuration. :type eval_config: dict """ - - super().__init__() + Model.__init__(self, model_name, model_config, data_config, train_config, eval_config) + nn.Module.__init__(self) + #super().__init__() self.model_name = model_name self.model_config = model_config From a1facb0ae3f3e4c2af0939ffc0a283aaa25b6bef Mon Sep 17 00:00:00 2001 From: Francesco Campi Date: Fri, 8 Mar 2024 12:01:12 +0100 Subject: [PATCH 14/26] Code review --- .github/workflows/test.yml | 2 +- README.md | 2 +- src/client/pyproject.toml | 4 ++++ src/client/requirements.txt | 6 +----- src/server/dcp_server/models/custom_cellpose.py | 3 ++- src/server/dcp_server/utils/processing.py | 6 +++--- 6 files changed, 12 insertions(+), 11 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 22e2297e..7ac1f499 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -48,7 +48,7 @@ jobs: pip install pytest-qt pip install pytest-xvfb pip install coverage - pip install -e ".[testing]" + pip install -e ".[dev]" pip install matplotlib working-directory: src/client diff --git a/README.md b/README.md index 97e28581..4a906d2b 100644 --- a/README.md +++ b/README.md @@ -29,5 +29,5 @@ Our platform encourages the use of data centric practices. With the user friendl - Focus on data curation: no interaction with model parameters during training and inference #### *Get more with less!* - + diff --git a/src/client/pyproject.toml b/src/client/pyproject.toml index 2e521a11..2621ca0a 100644 --- a/src/client/pyproject.toml +++ b/src/client/pyproject.toml @@ -34,6 +34,10 @@ maintainers = [ [project.optional-dependencies] dev = [ "pytest", + pytest>=7.4.3 + pytest-qt>=4.2.0 + sphinx + sphinx-rtd-theme ] [project.urls] diff --git a/src/client/requirements.txt b/src/client/requirements.txt index 98109d47..e47ad839 100644 --- a/src/client/requirements.txt +++ b/src/client/requirements.txt @@ -1,6 +1,2 @@ napari[pyqt5]>=0.4.17 -bentoml[grpc]==1.0.16 -pytest>=7.4.3 -pytest-qt>=4.2.0 -sphinx -sphinx-rtd-theme \ No newline at end of file +bentoml[grpc]==1.0.16 \ No newline at end of file diff --git a/src/server/dcp_server/models/custom_cellpose.py b/src/server/dcp_server/models/custom_cellpose.py index ed3c798a..989fc935 100644 --- a/src/server/dcp_server/models/custom_cellpose.py +++ b/src/server/dcp_server/models/custom_cellpose.py @@ -103,7 +103,8 @@ def eval_all_outputs(self, return super().eval(x=img, **self.eval_config["segmentor"]) - def compute_masks_flows(self, imgs, masks): + # I introduced typing here as suggest by the docstring + def compute_masks_flows(self, imgs: List[np.ndarray], masks:List[np.ndarray],) -> tuple: """ Computes instance, binary mask and flows in x and y - needed for loss and metric computations :param imgs: images to train on (training data) diff --git a/src/server/dcp_server/utils/processing.py b/src/server/dcp_server/utils/processing.py index 414b5aad..508cabcd 100644 --- a/src/server/dcp_server/utils/processing.py +++ b/src/server/dcp_server/utils/processing.py @@ -316,9 +316,9 @@ def create_patch_dataset(imgs, masks_classes, masks_instances, noise_intensity, for img, mask_class, mask_instance in zip(imgs, masks_classes, masks_instances): # mask_instance has dimension WxH # mask_class has dimension WxH - patch, patch_mask, _, label = get_centered_patches(img, - mask_instance, - max_patch_size, + patch, patch_mask, _, label = get_centered_patches(img=img, + mask=mask_instance, + p_size=max_patch_size, noise_intensity=noise_intensity, mask_class=mask_class, include_mask = include_mask) From 8f4f8dc23900ae8f592fe610882fc2920a0df355 Mon Sep 17 00:00:00 2001 From: Christina Bukas Date: Mon, 11 Mar 2024 14:08:59 +0100 Subject: [PATCH 15/26] move dev packages into toml --- .github/workflows/test.yml | 7 ++----- docs/source/dcp_server_installation.rst | 2 +- src/client/pyproject.toml | 9 ++++----- src/server/pyproject.toml | 4 +++- 4 files changed, 10 insertions(+), 12 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 7ac1f499..90a92f78 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -44,8 +44,6 @@ jobs: run: | python -m pip install --upgrade pip python -m pip install setuptools - pip install pytest - pip install pytest-qt pip install pytest-xvfb pip install coverage pip install -e ".[dev]" @@ -54,7 +52,7 @@ jobs: - name: Install server dependencies (for communication tests) run: | - pip install -e ".[testing]" + pip install -e ".[dev]" working-directory: src/server - name: Test with pytest @@ -94,10 +92,9 @@ jobs: python -m pip install --upgrade pip python -m pip install --upgrade setuptools pip install numpy - pip install pytest pip install wheel pip install coverage - pip install -e ".[testing]" + pip install -e ".[dev]" working-directory: src/server - name: Test with pytest diff --git a/docs/source/dcp_server_installation.rst b/docs/source/dcp_server_installation.rst index a005ad3e..823f97af 100644 --- a/docs/source/dcp_server_installation.rst +++ b/docs/source/dcp_server_installation.rst @@ -19,7 +19,7 @@ Before starting make sure you have navigated to ``data-centric-platform/src/serv .. code-block:: bash - pip install -e . + pip install -e ".[dev]" Launch DCP Server ------------------ diff --git a/src/client/pyproject.toml b/src/client/pyproject.toml index 2621ca0a..93af7bd7 100644 --- a/src/client/pyproject.toml +++ b/src/client/pyproject.toml @@ -33,11 +33,10 @@ maintainers = [ [project.optional-dependencies] dev = [ - "pytest", - pytest>=7.4.3 - pytest-qt>=4.2.0 - sphinx - sphinx-rtd-theme + "pytest>=7.4.3", + "pytest-qt>=4.2.0", + "sphinx", + "sphinx-rtd-theme" ] [project.urls] diff --git a/src/server/pyproject.toml b/src/server/pyproject.toml index 783e0dfb..4acd006c 100644 --- a/src/server/pyproject.toml +++ b/src/server/pyproject.toml @@ -33,7 +33,9 @@ maintainers = [ [project.optional-dependencies] dev = [ - "pytest", + "pytest>=7.4.3", + "sphinx", + "sphinx-rtd-theme" ] [project.urls] From 680496eef5519d17010a6d5e45723f6d28bb1aa6 Mon Sep 17 00:00:00 2001 From: Christina Bukas Date: Wed, 13 Mar 2024 12:14:56 +0100 Subject: [PATCH 16/26] updated with code review comments, typing etc --- README.md | 1 - src/server/dcp_server/main.py | 5 +- src/server/dcp_server/models/classifiers.py | 7 +- .../dcp_server/models/custom_cellpose.py | 4 +- .../dcp_server/models/inst_to_multi_seg.py | 3 +- src/server/dcp_server/models/model.py | 30 ++++---- src/server/dcp_server/models/unet.py | 2 +- src/server/dcp_server/segmentationclasses.py | 12 ++- src/server/dcp_server/serviceclasses.py | 14 ++-- src/server/dcp_server/utils/fsimagestorage.py | 29 ++++---- src/server/dcp_server/utils/helpers.py | 12 +-- src/server/dcp_server/utils/processing.py | 74 +++++++++++-------- src/server/requirements.txt | 5 +- 13 files changed, 107 insertions(+), 91 deletions(-) diff --git a/README.md b/README.md index 4a906d2b..bf0e3499 100644 --- a/README.md +++ b/README.md @@ -29,5 +29,4 @@ Our platform encourages the use of data centric practices. With the user friendl - Focus on data curation: no interaction with model parameters during training and inference #### *Get more with less!* - diff --git a/src/server/dcp_server/main.py b/src/server/dcp_server/main.py index c54c4c59..84d8b003 100644 --- a/src/server/dcp_server/main.py +++ b/src/server/dcp_server/main.py @@ -1,9 +1,10 @@ -import subprocess from os import path import sys +import subprocess + from dcp_server.utils.helpers import read_config -def main(): +def main() -> None: """ Contains main functionality related to the server. """ diff --git a/src/server/dcp_server/models/classifiers.py b/src/server/dcp_server/models/classifiers.py index bd154597..d093ab14 100644 --- a/src/server/dcp_server/models/classifiers.py +++ b/src/server/dcp_server/models/classifiers.py @@ -119,7 +119,7 @@ def eval(self, y_hat = torch.argmax(preds, 1) return y_hat - def build_model(self): + def build_model(self) -> None: """ Builds the PatchClassifer. """ in_channels = self.model_config["in_channels"] @@ -207,8 +207,8 @@ def __init__(self, def train(self, - X_train: np.ndarray, - y_train: np.ndarray + X_train: List[np.ndarray], + y_train: List[np.ndarray] ) -> None: """ Trains the model using the provided training data. @@ -217,7 +217,6 @@ def train(self, :param y_train: Labels of the training data. :type y_train: numpy.ndarray """ - self.model.fit(X_train,y_train) y_hat = self.model.predict(X_train) diff --git a/src/server/dcp_server/models/custom_cellpose.py b/src/server/dcp_server/models/custom_cellpose.py index 989fc935..f3a85d0b 100644 --- a/src/server/dcp_server/models/custom_cellpose.py +++ b/src/server/dcp_server/models/custom_cellpose.py @@ -65,7 +65,7 @@ def train(self, """ if self.train_config["segmentor"]["n_epochs"]==0: return super().train( - train_data=deepcopy(imgs), + train_data=deepcopy(imgs), #Cellpose changes the images train_labels=masks, **self.train_config["segmentor"] ) @@ -104,7 +104,7 @@ def eval_all_outputs(self, return super().eval(x=img, **self.eval_config["segmentor"]) # I introduced typing here as suggest by the docstring - def compute_masks_flows(self, imgs: List[np.ndarray], masks:List[np.ndarray],) -> tuple: + def compute_masks_flows(self, imgs: List[np.ndarray], masks:List[np.ndarray]) -> tuple: """ Computes instance, binary mask and flows in x and y - needed for loss and metric computations :param imgs: images to train on (training data) diff --git a/src/server/dcp_server/models/inst_to_multi_seg.py b/src/server/dcp_server/models/inst_to_multi_seg.py index 2f8c6e77..df52fd46 100644 --- a/src/server/dcp_server/models/inst_to_multi_seg.py +++ b/src/server/dcp_server/models/inst_to_multi_seg.py @@ -1,4 +1,3 @@ -from copy import deepcopy from typing import List import numpy as np @@ -93,7 +92,7 @@ def train(self, # train cellpose masks_instances = [mask[0] for mask in masks] #masks_instances = list(np.array(masks)[:,0,...]) #[mask.sum(-1) for mask in masks] if masks[0].ndim == 3 else masks - self.segmentor.train(deepcopy(imgs), masks_instances) + self.segmentor.train(imgs, masks_instances) masks_classes = [mask[1] for mask in masks] # create patch dataset to train classifier #masks_classes = list( diff --git a/src/server/dcp_server/models/model.py b/src/server/dcp_server/models/model.py index 65811905..4809bdf5 100644 --- a/src/server/dcp_server/models/model.py +++ b/src/server/dcp_server/models/model.py @@ -20,20 +20,6 @@ def __init__(self, self.loss = 1e6 self.metric = 0 - def update_configs(self, - train_config: dict, - eval_config: dict - ) -> None: - """ Update the training and evaluation configurations. - - :param train_config: Dictionary containing the training configuration. - :type train_config: dict - :param eval_config: Dictionary containing the evaluation configuration. - :type eval_config: dict - """ - self.train_config = train_config - self.eval_config = eval_config - @abstractmethod def train(self, imgs: List[np.array], @@ -46,6 +32,22 @@ def eval(self, img: np.array ) -> np.array: pass + + ''' + def update_configs(self, + config: dict, + ctype: str + ) -> None: + """ Update the training or evaluation configurations. + + :param config: Dictionary containing the updated configuration. + :type config: dict + :param ctype:type of config to update, will be train or eval + :type ctype: str + """ + if ctype=='train': self.train_config = config + else: self.eval_config = config + ''' #from segment_anything import SamPredictor, sam_model_registry diff --git a/src/server/dcp_server/models/unet.py b/src/server/dcp_server/models/unet.py index cb2565f6..61d06a66 100644 --- a/src/server/dcp_server/models/unet.py +++ b/src/server/dcp_server/models/unet.py @@ -179,7 +179,7 @@ def eval(self, return final_mask - def build_model(self): + def build_model(self) -> None: """ Builds the UNet. """ in_channels = self.model_config["classifier"]["in_channels"] diff --git a/src/server/dcp_server/segmentationclasses.py b/src/server/dcp_server/segmentationclasses.py index 326bf069..d9a5f4ee 100644 --- a/src/server/dcp_server/segmentationclasses.py +++ b/src/server/dcp_server/segmentationclasses.py @@ -1,5 +1,8 @@ import os + from dcp_server.utils import helpers +from dcp_server.utils.fsimagestorage import FilesystemImageStorage +from dcp_server import models as DCPModels # Import configuration setup_config = helpers.read_config('setup', config_path = 'config.yaml') @@ -7,7 +10,7 @@ class GeneralSegmentation(): """ Segmentation class. Defining the main functions needed for this project and served by service - segment image and train on images. """ - def __init__(self, imagestorage, runner, model): + def __init__(self, imagestorage: FilesystemImageStorage, runner, model: DCPModels) -> None: """ Constructs all the necessary attributes for the GeneralSegmentation. :param imagestorage: imagestorage system used (see fsimagestorage.py) @@ -22,7 +25,7 @@ def __init__(self, imagestorage, runner, model): self.model = model self.no_files_msg = "No image-label pairs found in curated directory" - async def segment_image(self, input_path, list_of_images): + async def segment_image(self, input_path: str, list_of_images: str) -> None: """ Segments images from the given directory :param input_path: directory where the images are saved and where segmentation results will be saved @@ -44,7 +47,7 @@ async def segment_image(self, input_path, list_of_images): seg_name = helpers.get_path_stem(img_filepath) + setup_config['seg_name_string'] + '.tiff' self.imagestorage.save_image(os.path.join(input_path, seg_name), mask) - async def train(self, input_path): + async def train(self, input_path: str) -> str: """ Train model on images and masks in the given input directory. Calls the runner's train function. @@ -62,7 +65,7 @@ async def train(self, input_path): model_save_path = await self.runner.train.async_run(imgs, masks) return model_save_path - +''' class GFPProjectSegmentation(GeneralSegmentation): def __init__(self, imagestorage, runner): @@ -126,3 +129,4 @@ async def segment_image(self, input_path, list_of_images): # Save segmentation seg_name = helpers.get_path_stem(img_filepath) + setup_config['seg_name_string'] + '.tiff' self.imagestorage.save_image(os.path.join(input_path, seg_name), new_mask) +''' \ No newline at end of file diff --git a/src/server/dcp_server/serviceclasses.py b/src/server/dcp_server/serviceclasses.py index c66b81c8..3ba77ba8 100644 --- a/src/server/dcp_server/serviceclasses.py +++ b/src/server/dcp_server/serviceclasses.py @@ -5,7 +5,7 @@ from typing import List from dcp_server import models as DCPModels - +import dcp_server.segmentationclasses as DCPSegClasses class CustomRunnable(bentoml.Runnable): ''' @@ -15,7 +15,7 @@ class CustomRunnable(bentoml.Runnable): SUPPORTED_RESOURCES = ("cpu",) #TODO add here? SUPPORTS_CPU_MULTI_THREADING = False - def __init__(self, model, save_model_path): + def __init__(self, model: DCPModels, save_model_path: str) -> None: """Constructs all the necessary attributes for the CustomRunnable. :param model: model to be trained or evaluated - will be one of classes in models.py @@ -45,7 +45,7 @@ def evaluate(self, img: np.ndarray) -> np.ndarray: return mask - def check_and_load_model(self): + def check_and_load_model(self) -> None: """Checks if the specified model exists in BentoML's model repository. If the model exists, it loads the latest version of the model into memory. @@ -84,7 +84,7 @@ def train(self, imgs: List[np.ndarray], masks: List[np.ndarray]) -> str: class CustomBentoService(): """BentoML Service class. Contains all the functions necessary to serve the service with BentoML """ - def __init__(self, runner, segmentation, service_name): + def __init__(self, runner: CustomRunnable, segmentation: DCPSegClasses, service_name: str) -> None: """Constructs all the necessary attributes for the class CustomBentoService(): :param runner: runner used in the service @@ -98,7 +98,7 @@ def __init__(self, runner, segmentation, service_name): self.segmentation = segmentation self.service_name = service_name - def start_service(self): + def start_service(self) -> None: """Starts the service :return: service object needed in service.py and for the bentoml serve call. @@ -106,7 +106,7 @@ def start_service(self): svc = bentoml.Service(self.service_name, runners=[self.runner]) @svc.api(input=Text(), output=NumpyNdarray()) #input path to the image output message with success and the save path - async def segment_image(input_path: str): + async def segment_image(input_path: str) -> np.ndarray: """function served within the service, used to segment images :param input_path: directory where the images for segmentation are saved @@ -125,7 +125,7 @@ async def segment_image(input_path: str): return np.array(list_of_files_not_suported) @svc.api(input=Text(), output=Text()) - async def train(input_path): + async def train(input_path: str) -> str: """function served within the service, used to retrain the model :param input_path: directory where the images for training are saved diff --git a/src/server/dcp_server/utils/fsimagestorage.py b/src/server/dcp_server/utils/fsimagestorage.py index 492e4169..555cbce0 100644 --- a/src/server/dcp_server/utils/fsimagestorage.py +++ b/src/server/dcp_server/utils/fsimagestorage.py @@ -1,4 +1,5 @@ import os +from typing import Optional, List import numpy as np from skimage.io import imread, imsave from skimage.transform import resize, rescale @@ -13,7 +14,7 @@ class FilesystemImageStorage(): """ Class used to deal with everything related to image storing and processing - loading, saving, transforming. """ - def __init__(self, data_config, model_used): + def __init__(self, data_config: dict, model_used: str) -> None: self.root_dir = data_config["data_root"] self.gray = bool(data_config["gray"]) self.rescale = bool(data_config["rescale"]) @@ -22,13 +23,13 @@ def __init__(self, data_config, model_used): self.img_height = None self.img_width = None - def load_image(self, cur_selected_img, gray=None): + def load_image(self, cur_selected_img: str, gray: Optional[bool]=None) -> Optional[np.ndarray]: """Load the image (using skiimage) :param cur_selected_img: full path of the image that needs to be loaded :type cur_selected_img: str :param gray: whether to load the image as a grayscale or not - :type gray: bool, default=False + :type gray: bool or None, default=Nonee :return: loaded image :rtype: ndarray """ @@ -37,7 +38,7 @@ def load_image(self, cur_selected_img, gray=None): return imread(os.path.join(self.root_dir , cur_selected_img), as_gray=gray) except ValueError: return None - def save_image(self, to_save_path, img): + def save_image(self, to_save_path: str, img: np.ndarray) -> None: """ Save given image using skimage. :param to_save_path: full path to the directory that the image needs to be save into (use also image name in the path, eg. '/users/new_image.png') @@ -47,7 +48,7 @@ def save_image(self, to_save_path, img): """ imsave(os.path.join(self.root_dir, to_save_path), img) - def search_images(self, directory): + def search_images(self, directory: str) -> List[str]: """ Get a list of full paths of the images in the directory. :param directory: Path to the directory to search for images. @@ -62,7 +63,7 @@ def search_images(self, directory): image_files = [os.path.join(directory, file_name) for file_name in os.listdir(directory) if (file_name not in seg_files) and (helpers.get_file_extension(file_name) in setup_config['accepted_types'])] return image_files - def search_segs(self, cur_selected_img): + def search_segs(self, cur_selected_img: str) -> List[str]: """ Returns a list of full paths of segmentations for an image. :param cur_selected_img: Full path of the image for which segmentations are needed. @@ -82,7 +83,7 @@ def search_segs(self, cur_selected_img): return seg_files - def get_image_seg_pairs(self, directory): + def get_image_seg_pairs(self, directory:str) -> List[tuple]: """ Get pairs of (image, image_seg). Used, e.g., in training to create training data-training labels pairs. @@ -101,7 +102,7 @@ def get_image_seg_pairs(self, directory): seg_files.append(seg[0]) return list(zip(image_files, seg_files)) - def get_unsupported_files(self, directory): + def get_unsupported_files(self, directory:str) -> List[str]: """ Get unsupported files found in the given directory. :param directory: Directory path to search for files in. @@ -112,7 +113,7 @@ def get_unsupported_files(self, directory): return [file_name for file_name in os.listdir(os.path.join(self.root_dir, directory)) if not file_name.startswith('.') and helpers.get_file_extension(file_name) not in setup_config['accepted_types']] - def get_image_size_properties(self, img, file_extension): + def get_image_size_properties(self, img:np.ndarray, file_extension:str) -> None: """Set properties of the image size :param img: Image (numpy array). @@ -145,7 +146,7 @@ def get_image_size_properties(self, img, file_extension): print('File not currently supported. See documentation for accepted types') - def rescale_image(self, img, order=2): + def rescale_image(self, img: np.ndarray, order: int=2) -> np.ndarray: """rescale image :param img: Image. @@ -164,7 +165,7 @@ def rescale_image(self, img, order=2): rescale_factor = max_dim/512 return rescale(img, 1/rescale_factor, order=order, channel_axis=self.channel_ax) - def resize_mask(self, mask, channel_ax=None, order=0): + def resize_mask(self, mask: np.ndarray, channel_ax: Optional[int]=None, order: int=0) -> np.ndarray: """resize the mask so it matches the original image size :param mask: Image. @@ -206,7 +207,7 @@ def resize_mask(self, mask, channel_ax=None, order=0): else: output_size = [self.img_height, self.img_width] return resize(mask, output_size, order=order) - def prepare_images_and_masks_for_training(self, train_img_mask_pairs): + def prepare_images_and_masks_for_training(self, train_img_mask_pairs: List[tuple]) -> tuple: """ Image and mask processing for training. :param train_img_mask_pairs: List pairs of (image, image_seg) (as returned by get_image_seg_pairs() function). @@ -233,7 +234,7 @@ def prepare_images_and_masks_for_training(self, train_img_mask_pairs): masks.append(mask) return imgs, masks - def prepare_img_for_eval(self, img_file): + def prepare_img_for_eval(self, img_file:str) -> np.ndarray: """Image processing for model inference. :param img_file: the path to the image @@ -250,7 +251,7 @@ def prepare_img_for_eval(self, img_file): img = self.rescale_image(img) return img - def prepare_mask_for_save(self, mask, channel_ax): + def prepare_mask_for_save(self, mask: np.ndarray, channel_ax: int) -> np.ndarray: """Prepares the mask output of the model to be saved. :param mask: the mask diff --git a/src/server/dcp_server/utils/helpers.py b/src/server/dcp_server/utils/helpers.py index 72d52a78..6d3eb61b 100644 --- a/src/server/dcp_server/utils/helpers.py +++ b/src/server/dcp_server/utils/helpers.py @@ -1,7 +1,7 @@ from pathlib import Path import yaml -def read_config(name, config_path = 'config.yaml') -> dict: +def read_config(name:str, config_path:str = 'config.yaml') -> dict: """Reads the configuration file :param name: name of the section you want to read (e.g. 'setup','train') @@ -17,17 +17,17 @@ def read_config(name, config_path = 'config.yaml') -> dict: assert all([i in config_dict.keys() for i in ['setup', 'service', 'model', 'train', 'eval']]) return config_dict[name] -def get_path_stem(filepath): return str(Path(filepath).stem) +def get_path_stem(filepath: str) -> str: return str(Path(filepath).stem) -def get_path_name(filepath): return str(Path(filepath).name) +def get_path_name(filepath: str) -> str: return str(Path(filepath).name) -def get_path_parent(filepath): return str(Path(filepath).parent) +def get_path_parent(filepath: str) -> str: return str(Path(filepath).parent) -def join_path(root_dir, filepath): return str(Path(root_dir, filepath)) +def join_path(root_dir:str, filepath: str) -> str: return str(Path(root_dir, filepath)) -def get_file_extension(file): return str(Path(file).suffix) +def get_file_extension(file: str) -> str: return str(Path(file).suffix) diff --git a/src/server/dcp_server/utils/processing.py b/src/server/dcp_server/utils/processing.py index 508cabcd..ba44dfc8 100644 --- a/src/server/dcp_server/utils/processing.py +++ b/src/server/dcp_server/utils/processing.py @@ -1,15 +1,16 @@ from copy import deepcopy +from typing import List, Optional, Union import numpy as np + from scipy.ndimage import find_objects from skimage import measure -from copy import deepcopy import SimpleITK as sitk from radiomics import shape2D import torch -def normalise(img, norm='min-max') -> np.ndarray: +def normalise(img: np.ndarray, norm: str='min-max') -> np.ndarray: """ Normalises the image based on the chosen method. Currently available methods are: - - min max normalisation + - min max normalisation. :param img: image to be normalised :type img: np.ndarray @@ -22,8 +23,8 @@ def normalise(img, norm='min-max') -> np.ndarray: return (img - np.min(img)) / (np.max(img) - np.min(img)) -def pad_image(img, height, width, channel_ax=None, dividable = 16) -> np.ndarray: - """ Pads the image such that it is dividable by a given number, +def pad_image(img: np.ndarray, height: int, width: int, channel_ax: Optional[int]=None, dividable:int = 16) -> np.ndarray: + """ Pads the image such that it is dividable by a given number. :param img: image to be padded :type img: np.ndarray @@ -48,7 +49,18 @@ def pad_image(img, height, width, channel_ax=None, dividable = 16) -> np.ndarray img = np.pad(img, ((0, height_pad), (0, width_pad))) return img -def convert_to_tensor(imgs, dtype, unsqueeze=True): +def convert_to_tensor(imgs: List[np.ndarray], dtype: type, unsqueeze: bool=True) -> torch.Tensor: + """ Convert the imgs to tensors of type dtype and add extra dimension if input bool is true. + + :param imgs: the list of images to convert + :type img: List[np.ndarray] + :param dtype: the data type to convert the image tensor + :type dtype: type + :param unsqueeze: If True an extra dim will be added at location zero + :type unsqueeze: bool + :return: the converted image + :rtype: torch.Tensor + """ # Convert images tensors imgs = torch.stack([ torch.from_numpy(img.astype(dtype)) for img in imgs @@ -57,11 +69,11 @@ def convert_to_tensor(imgs, dtype, unsqueeze=True): return imgs def crop_centered_padded_patch(img: np.ndarray, - patch_center_xy, - patch_size, - obj_label, + patch_center_xy: tuple, + patch_size: tuple, + obj_label: int, mask: np.ndarray=None, - noise_intensity=None) -> np.ndarray: + noise_intensity: int=None) -> np.ndarray: """ Crop a patch from an array centered at coordinates patch_center_xy with size patch_size, and apply padding if necessary. @@ -74,12 +86,11 @@ def crop_centered_padded_patch(img: np.ndarray, :param obj_label: the instance label of the mask at the patch :type obj_label: int :param mask: The mask array associated with the array x. - Mask is used during training to mask out non-central elements. - For RandomForest, it is used to calculate pyradiomics features. + Mask is used during training to mask out non-central elements. + For RandomForest, it is used to calculate pyradiomics features. :type mask: np.ndarray, optional :param noise_intensity: intensity of noise to be added to the background :type noise_intensity: float, optional - :return: the cropped patch with applied padding :rtype: np.ndarray """ @@ -146,16 +157,14 @@ def crop_centered_padded_patch(img: np.ndarray, return patch, mask -def get_center_of_mass_and_label(mask: np.ndarray) -> np.ndarray: +def get_center_of_mass_and_label(mask: np.ndarray) -> tuple: """ Computes the centers of mass for each object in a mask. :param mask: the input mask containing labeled objects :type mask: np.ndarray - :return: - A list of tuples representing the coordinates (row, column) of the centers of mass for each object. - - A list of ints representing the labels for each object in the mask. - + - A list of ints representing the labels for each object in the mask. :rtype: - List [tuple] - List [int] @@ -176,12 +185,12 @@ def get_center_of_mass_and_label(mask: np.ndarray) -> np.ndarray: -def get_centered_patches(img, - mask, +def get_centered_patches(img: np.ndarray, + mask: np.ndarray, p_size: int, - noise_intensity=5, - mask_class=None, - include_mask=False): + noise_intensity: int=5, + mask_class: Optional[int]=None, + include_mask: bool=False) -> tuple: """ Extracts centered patches from the input image based on the centers of objects identified in the mask. @@ -237,7 +246,7 @@ def get_centered_patches(img, return patches, patch_masks, instance_labels, class_labels -def get_objects(mask): +def get_objects(mask: np.ndarray) -> List: """ Finds labeled connected components in a binary mask. :param mask: The binary mask representing objects. @@ -247,7 +256,7 @@ def get_objects(mask): """ return find_objects(mask) -def find_max_patch_size(mask): +def find_max_patch_size(mask: np.ndarray) -> float: """ Finds the maximum patch size in a mask. :param mask: The binary mask representing objects. @@ -285,7 +294,12 @@ def find_max_patch_size(mask): return max_patch_size_edge -def create_patch_dataset(imgs, masks_classes, masks_instances, noise_intensity, max_patch_size, include_mask): +def create_patch_dataset(imgs: List[np.ndarray], + masks_classes: Optional[Union[List[np.ndarray], torch.Tensor]], + masks_instances: Optional[Union[List[np.ndarray], torch.Tensor]], + noise_intensity: int, + max_patch_size: int, + include_mask: bool) -> tuple: """ Splits images and masks into patches of equal size centered around the cells. :param imgs: A list of input images. @@ -295,9 +309,9 @@ def create_patch_dataset(imgs, masks_classes, masks_instances, noise_intensity, :param masks_instances: A list of binary masks representing instances. :type masks_instances: list of numpy.ndarray or torch.Tensor :param noise_intensity: The intensity of noise to add to the patches. - :type noise_intensity: float + :type noise_intensity: int :param max_patch_size: The maximum size of the bounding box edge for objects in the mask. - :type max_patch_size: float + :type max_patch_size: int :param include_mask: A flag indicating whether to include the mask along with patches. :type include_mask: bool :return: A tuple containing the patches, patch masks, and labels. @@ -328,7 +342,7 @@ def create_patch_dataset(imgs, masks_classes, masks_instances, noise_intensity, return patches, patch_masks, labels -def get_shape_features(img, mask): +def get_shape_features(img: np.ndarray, mask: np.ndarray) -> np.ndarray: """ Calculate shape-based radiomic features from an image within the region defined by the mask. :param img: The input image. @@ -349,7 +363,7 @@ def get_shape_features(img, mask): return np.array(list(shape_features.values())) -def extract_intensity_features(image, mask): +def extract_intensity_features(image: np.ndarray, mask: np.ndarray) -> np.ndarray: """ Extracts intensity-based features from an image within the region defined by the mask. :param image: The input image. @@ -377,7 +391,7 @@ def extract_intensity_features(image, mask): return np.array(list(features.values())) -def create_dataset_for_rf(imgs, masks): +def create_dataset_for_rf(imgs: List[np.ndarray], masks: List[np.ndarray]) -> List[np.ndarray]: """ Extracts shape and intensity-based features from images within regions defined by masks. :param imgs: A list of input images. diff --git a/src/server/requirements.txt b/src/server/requirements.txt index b6e7f266..a42ba5eb 100644 --- a/src/server/requirements.txt +++ b/src/server/requirements.txt @@ -4,10 +4,7 @@ bentoml==1.0.16 scikit-image>=0.19.3 torchmetrics>=0.11.4 torch>=2.1.0 -pytest>=7.4.3 numpy scikit-learn>=1.2.2 SimpleITK>=2.2.1 -pyradiomics==3.0.1 -sphinx -sphinx-rtd-theme \ No newline at end of file +pyradiomics==3.0.1 \ No newline at end of file From 942f9ef68735a156707a1a1483755b1557839d25 Mon Sep 17 00:00:00 2001 From: Christina Bukas Date: Wed, 13 Mar 2024 12:43:38 +0100 Subject: [PATCH 17/26] ran formatter --- src/client/dcp_client/__init__.py | 2 +- src/client/dcp_client/app.py | 143 ++++--- src/client/dcp_client/gui/_my_widget.py | 39 +- src/client/dcp_client/gui/main_window.py | 158 ++++---- src/client/dcp_client/gui/napari_window.py | 192 ++++++---- src/client/dcp_client/gui/welcome_window.py | 121 +++--- src/client/dcp_client/main.py | 28 +- src/client/dcp_client/utils/bentoml_model.py | 30 +- src/client/dcp_client/utils/fsimagestorage.py | 19 +- src/client/dcp_client/utils/settings.py | 2 +- src/client/dcp_client/utils/sync_src_dst.py | 66 ++-- src/client/dcp_client/utils/utils.py | 153 +++++--- src/client/test/test_app.py | 93 +++-- src/client/test/test_compute4mask.py | 101 +++-- src/client/test/test_fsimagestorage.py | 32 +- src/client/test/test_main_window.py | 112 +++--- src/client/test/test_mywidget.py | 26 +- src/client/test/test_napari_window.py | 67 ++-- src/client/test/test_sync_src_dst.py | 21 +- src/client/test/test_utils.py | 43 ++- src/client/test/test_welcome_window.py | 66 ++-- src/server/dcp_server/main.py | 32 +- src/server/dcp_server/models/__init__.py | 5 +- src/server/dcp_server/models/classifiers.py | 145 ++++--- .../dcp_server/models/custom_cellpose.py | 107 +++--- .../dcp_server/models/inst_to_multi_seg.py | 139 +++---- src/server/dcp_server/models/model.py | 37 +- src/server/dcp_server/models/multicellpose.py | 111 +++--- src/server/dcp_server/models/unet.py | 123 +++--- src/server/dcp_server/segmentationclasses.py | 64 ++-- src/server/dcp_server/service.py | 55 +-- src/server/dcp_server/serviceclasses.py | 68 ++-- src/server/dcp_server/utils/fsimagestorage.py | 229 ++++++----- src/server/dcp_server/utils/helpers.py | 31 +- src/server/dcp_server/utils/processing.py | 357 +++++++++++------- src/server/test/synthetic_dataset.py | 170 ++++++--- src/server/test/test_integration.py | 108 +++--- src/server/test/test_models.py | 53 ++- src/server/test/test_utils.py | 6 +- 39 files changed, 1957 insertions(+), 1397 deletions(-) diff --git a/src/client/dcp_client/__init__.py b/src/client/dcp_client/__init__.py index 65273344..f4ffe44b 100644 --- a/src/client/dcp_client/__init__.py +++ b/src/client/dcp_client/__init__.py @@ -39,4 +39,4 @@ This package structure allows for easy management of GUI components, image storage, model interactions, and server connectivity within the dcp_client application. -""" \ No newline at end of file +""" diff --git a/src/client/dcp_client/app.py b/src/client/dcp_client/app.py index 9adbb587..8ae4e8a9 100644 --- a/src/client/dcp_client/app.py +++ b/src/client/dcp_client/app.py @@ -11,7 +11,7 @@ class Model(ABC): @abstractmethod def run_train(self, path: str) -> None: pass - + @abstractmethod def run_inference(self, path: str) -> None: pass @@ -21,7 +21,7 @@ class DataSync(ABC): @abstractmethod def sync(self, src: str, dst: str, path: str) -> None: pass - + class ImageStorage(ABC): @abstractmethod @@ -35,22 +35,29 @@ def save_image(self, to_directory, cur_selected_img, img) -> None: def search_segs(self, img_directory, cur_selected_img): """Returns a list of full paths of segmentations for an image""" # Take all segmentations of the image from the current directory: - search_string = utils.get_path_stem(cur_selected_img) + '_seg' - seg_files = [file_name for file_name in os.listdir(img_directory) if (search_string == utils.get_path_stem(file_name) or str(file_name).startswith(search_string))] + search_string = utils.get_path_stem(cur_selected_img) + "_seg" + seg_files = [ + file_name + for file_name in os.listdir(img_directory) + if ( + search_string == utils.get_path_stem(file_name) + or str(file_name).startswith(search_string) + ) + ] return seg_files class Application: def __init__( - self, + self, ml_model: Model, syncer: DataSync, image_storage: ImageStorage, server_ip: str, server_port: int, - eval_data_path: str = '', - train_data_path: str = '', - inprogr_data_path: str = '', + eval_data_path: str = "", + train_data_path: str = "", + inprogr_data_path: str = "", ): self.ml_model = ml_model self.syncer = syncer @@ -60,73 +67,90 @@ def __init__( self.eval_data_path = eval_data_path self.train_data_path = train_data_path self.inprogr_data_path = inprogr_data_path - self.cur_selected_img = '' - self.cur_selected_path = '' - self.seg_filepaths = [] + self.cur_selected_img = "" + self.cur_selected_path = "" + self.seg_filepaths = [] def upload_data_to_server(self): """ Uploads the train and eval data to the server. """ - success_f1, message1 = self.syncer.first_sync(path=self.train_data_path) - success_f2, message2 = self.syncer.first_sync(path=self.eval_data_path) + success_f1, message1 = self.syncer.first_sync(path=self.train_data_path) + success_f2, message2 = self.syncer.first_sync(path=self.eval_data_path) return success_f1, success_f2, message1, message2 def try_server_connection(self): """ Checks if the ml model is connected to server and attempts to connect if not. """ - connection_success = self.ml_model.connect(ip=self.server_ip, port=self.server_port) + connection_success = self.ml_model.connect( + ip=self.server_ip, port=self.server_port + ) return connection_success - + def run_train(self): - """ Checks if the ml model is connected to the server, connects if not (and if possible), and trains the model with all data available in train_data_path """ - if not self.ml_model.is_connected and not self.try_server_connection(): + """Checks if the ml model is connected to the server, connects if not (and if possible), and trains the model with all data available in train_data_path""" + if not self.ml_model.is_connected and not self.try_server_connection(): message_title = "Warning" message_text = "Connection could not be established. Please check if the server is running and try again." return message_text, message_title # if syncer.host name is None then local machine is used to train message_title = "Information" - if self.syncer.host_name=="local": + if self.syncer.host_name == "local": message_text = self.ml_model.run_train(self.train_data_path) else: - success_sync, srv_relative_path = self.syncer.sync(src='client', dst='server', path=self.train_data_path) + success_sync, srv_relative_path = self.syncer.sync( + src="client", dst="server", path=self.train_data_path + ) # make sure syncing of folders was successful - if success_sync=="Success": message_text = self.ml_model.run_train(srv_relative_path) - else: message_text = None - if message_text is None: + if success_sync == "Success": + message_text = self.ml_model.run_train(srv_relative_path) + else: + message_text = None + if message_text is None: message_text = "An error has occured on the server. Please check your image data and configurations. If the problem persists contact your software provider." message_title = "Error" return message_text, message_title - + def run_inference(self): - """ Checks if the ml model is connected to the server, connects if not (and if possible), and runs inference on all images in eval_data_path """ - if not self.ml_model.is_connected and not self.try_server_connection(): + """Checks if the ml model is connected to the server, connects if not (and if possible), and runs inference on all images in eval_data_path""" + if not self.ml_model.is_connected and not self.try_server_connection(): message_title = "Warning" message_text = "Connection could not be established. Please check if the server is running and try again." return message_text, message_title - - if self.syncer.host_name=="local": + + if self.syncer.host_name == "local": # model serving directly from local - list_of_files_not_suported = self.ml_model.run_inference(self.eval_data_path) - success_sync = "Success" + list_of_files_not_suported = self.ml_model.run_inference( + self.eval_data_path + ) + success_sync = "Success" else: # sync data so that server gets updated files in client - e.g. if file was moved to curated srv_relative_path = utils.get_relative_path(self.eval_data_path) - success_sync, _ = self.syncer.sync(src='client', dst='server', path=self.eval_data_path) + success_sync, _ = self.syncer.sync( + src="client", dst="server", path=self.eval_data_path + ) # model serving from server list_of_files_not_suported = self.ml_model.run_inference(srv_relative_path) - # sync data so that client gets new masks - success_sync, _ = self.syncer.sync(src='server', dst='client', path=self.eval_data_path) + # sync data so that client gets new masks + success_sync, _ = self.syncer.sync( + src="server", dst="client", path=self.eval_data_path + ) # check if serving could not be performed for some files and prepare message - if list_of_files_not_suported is None or success_sync=="Error": + if list_of_files_not_suported is None or success_sync == "Error": message_text = "An error has occured on the server. Please check your image data and configurations. If the problem persists contact your software provider." message_title = "Error" else: list_of_files_not_suported = list(list_of_files_not_suported) if len(list_of_files_not_suported) > 0: - message_text = "Image types not supported. Only 2D and 3D image shapes currently supported. 3D stacks must be of type grayscale. \ - Currently supported image file formats are: " + ", ".join(settings.accepted_types)+ ". The files that were not supported are: " + ", ".join(list_of_files_not_suported) + message_text = ( + "Image types not supported. Only 2D and 3D image shapes currently supported. 3D stacks must be of type grayscale. \ + Currently supported image file formats are: " + + ", ".join(settings.accepted_types) + + ". The files that were not supported are: " + + ", ".join(list_of_files_not_suported) + ) message_title = "Warning" else: message_text = "Success! Masks generated for all images" @@ -145,51 +169,58 @@ def load_image(self, image_name=None): """ if image_name is None: - return self.fs_image_storage.load_image(self.cur_selected_path, self.cur_selected_img) - else: return self.fs_image_storage.load_image(self.cur_selected_path, image_name) - + return self.fs_image_storage.load_image( + self.cur_selected_path, self.cur_selected_img + ) + else: + return self.fs_image_storage.load_image(self.cur_selected_path, image_name) + def search_segs(self): - """ Searches in cur_selected_path for all possible segmentation files associated to cur_selected_img. - These files should have a _seg extension to the cur_selected_img filename. """ - self.seg_filepaths = self.fs_image_storage.search_segs(self.cur_selected_path, self.cur_selected_img) - + """Searches in cur_selected_path for all possible segmentation files associated to cur_selected_img. + These files should have a _seg extension to the cur_selected_img filename.""" + self.seg_filepaths = self.fs_image_storage.search_segs( + self.cur_selected_path, self.cur_selected_img + ) + def save_image(self, dst_directory, image_name, img): - """ Saves img array image in the dst_directory with filename cur_selected_img - + """Saves img array image in the dst_directory with filename cur_selected_img + :param dst_directory: The destination directory where the image will be saved. :type dst_directory: str :param image_name: The name of the image file. :type image_name: str - :param img: The image that will be saved. + :param img: The image that will be saved. :type img: numpy.ndarray """ self.fs_image_storage.save_image(dst_directory, image_name, img) def move_images(self, dst_directory, move_segs=False): """ - Moves cur_selected_img image from the current directory to the dst_directory. - + Moves cur_selected_img image from the current directory to the dst_directory. + :param dst_directory: The destination directory where the images will be moved. :type dst_directory: str :param move_segs: If True, moves the corresponding segmentation along with the image. Default is False. :type move_segs: bool - + """ - #if image_name is None: - self.fs_image_storage.move_image(self.cur_selected_path, dst_directory, self.cur_selected_img) + # if image_name is None: + self.fs_image_storage.move_image( + self.cur_selected_path, dst_directory, self.cur_selected_img + ) if move_segs: for seg_name in self.seg_filepaths: - self.fs_image_storage.move_image(self.cur_selected_path, dst_directory, seg_name) + self.fs_image_storage.move_image( + self.cur_selected_path, dst_directory, seg_name + ) def delete_images(self, image_names): - """ If image_name in the image_names list exists in the current directory it is deleted. - + """If image_name in the image_names list exists in the current directory it is deleted. + :param image_names: A list of image names to be deleted. :type image_names: list[str] """ for image_name in image_names: - if os.path.exists(os.path.join(self.cur_selected_path, image_name)): + if os.path.exists(os.path.join(self.cur_selected_path, image_name)): self.fs_image_storage.delete_image(self.cur_selected_path, image_name) - - diff --git a/src/client/dcp_client/gui/_my_widget.py b/src/client/dcp_client/gui/_my_widget.py index 8298360e..acf54b61 100644 --- a/src/client/dcp_client/gui/_my_widget.py +++ b/src/client/dcp_client/gui/_my_widget.py @@ -1,18 +1,25 @@ from PyQt5.QtWidgets import QWidget, QMessageBox from PyQt5.QtCore import QTimer + class MyWidget(QWidget): """ This class represents a custom widget. """ msg = None - sim = False # will be used for testing to simulate user click + sim = False # will be used for testing to simulate user click - def create_warning_box(self, message_text: str=" ", message_title: str="Information", add_cancel_btn: bool=False, custom_dialog=None) -> None: + def create_warning_box( + self, + message_text: str = " ", + message_title: str = "Information", + add_cancel_btn: bool = False, + custom_dialog=None, + ) -> None: """Creates a warning box with the specified message and options. - :param message_text: The text to be displayed in the message box. + :param message_text: The text to be displayed in the message box. :type message_text: str :param message_title: The title of the message box. Default is "Information". :type message_title: str @@ -21,14 +28,16 @@ def create_warning_box(self, message_text: str=" ", message_title: str="Informat :param custom_dialog: An optional custom dialog to use instead of creating a new QMessageBox instance. Default is None. :type custom_dialog: Any :return: None - """ - #setup box - if custom_dialog is not None: self.msg = custom_dialog - else: self.msg = QMessageBox() + """ + # setup box + if custom_dialog is not None: + self.msg = custom_dialog + else: + self.msg = QMessageBox() - if message_title=="Warning": + if message_title == "Warning": message_type = QMessageBox.Warning - elif message_title=="Error": + elif message_title == "Error": message_type = QMessageBox.Critical else: message_type = QMessageBox.Information @@ -39,12 +48,16 @@ def create_warning_box(self, message_text: str=" ", message_title: str="Informat if add_cancel_btn: self.msg.setStandardButtons(QMessageBox.Ok | QMessageBox.Cancel) # simulate button click if specified - workaround used for testing - if self.sim: QTimer.singleShot(0, self.msg.button(QMessageBox.Cancel).clicked) + if self.sim: + QTimer.singleShot(0, self.msg.button(QMessageBox.Cancel).clicked) else: self.msg.setStandardButtons(QMessageBox.Ok) # simulate button click if specified - workaround used for testing - if self.sim: QTimer.singleShot(0, self.msg.button(QMessageBox.Ok).clicked) + if self.sim: + QTimer.singleShot(0, self.msg.button(QMessageBox.Ok).clicked) # return if user clicks Ok and False otherwise usr_response = self.msg.exec() - if usr_response == QMessageBox.Ok: return True - else: return False \ No newline at end of file + if usr_response == QMessageBox.Ok: + return True + else: + return False diff --git a/src/client/dcp_client/gui/main_window.py b/src/client/dcp_client/gui/main_window.py index 8407c1dc..ae4917ed 100644 --- a/src/client/dcp_client/gui/main_window.py +++ b/src/client/dcp_client/gui/main_window.py @@ -1,7 +1,16 @@ from __future__ import annotations from typing import TYPE_CHECKING -from PyQt5.QtWidgets import QPushButton, QVBoxLayout, QFileSystemModel, QHBoxLayout, QLabel, QTreeView, QProgressBar, QShortcut +from PyQt5.QtWidgets import ( + QPushButton, + QVBoxLayout, + QFileSystemModel, + QHBoxLayout, + QLabel, + QTreeView, + QProgressBar, + QShortcut, +) from PyQt5.QtCore import Qt, QThread, pyqtSignal from PyQt5.QtGui import QKeySequence @@ -14,13 +23,21 @@ if TYPE_CHECKING: from dcp_client.app import Application + class WorkerThread(QThread): """ - Worker thread for displaying Pulse ProgressBar during model serving. - + Worker thread for displaying Pulse ProgressBar during model serving. + """ + task_finished = pyqtSignal(tuple) - def __init__(self, app: Application, task: str = None, parent = None,): + + def __init__( + self, + app: Application, + task: str = None, + parent=None, + ): """ Initialize the WorkerThread. @@ -36,13 +53,13 @@ def __init__(self, app: Application, task: str = None, parent = None,): def run(self): """ - Once run_inference or run_train is executed, the tuple of + Once run_inference or run_train is executed, the tuple of (message_text, message_title) will be returned to on_finished. """ try: - if self.task == 'inference': + if self.task == "inference": message_text, message_title = self.app.run_inference() - elif self.task == 'train': + elif self.task == "train": message_text, message_title = self.app.run_train() else: message_text, message_title = "Unknown task", "Error" @@ -53,17 +70,18 @@ def run(self): self.task_finished.emit((message_text, message_title)) + class MainWindow(MyWidget): """ Main Window Widget object. - Opens the main window of the app where selected images in both directories are listed. + Opens the main window of the app where selected images in both directories are listed. User can view the images, train the model to get the labels, and visualise the result. - + :param eval_data_path: Chosen path to images without labeles, selected by the user in the WelcomeWindow :type eval_data_path: string :param train_data_path: Chosen path to images with labeles, selected by the user in the WelcomeWindow :type train_data_path: string - """ + """ def __init__(self, app: Application): """ @@ -81,21 +99,20 @@ def __init__(self, app: Application): self.title = "Data Overview" self.worker_thread = None self.main_window() - + def main_window(self): - """Sets up the GUI - """ + """Sets up the GUI""" self.setWindowTitle(self.title) - #self.resize(1000, 1500) + # self.resize(1000, 1500) main_layout = QVBoxLayout() - dir_layout = QHBoxLayout() - + dir_layout = QHBoxLayout() + self.uncurated_layout = QVBoxLayout() self.inprogress_layout = QVBoxLayout() self.curated_layout = QVBoxLayout() - self.eval_dir_layout = QVBoxLayout() - self.eval_dir_layout.setContentsMargins(0,0,0,0) + self.eval_dir_layout = QVBoxLayout() + self.eval_dir_layout.setContentsMargins(0, 0, 0, 0) self.label_eval = QLabel(self) self.label_eval.setText("Uncurated dataset") self.eval_dir_layout.addWidget(self.label_eval) @@ -104,45 +121,55 @@ def main_window(self): model_eval.setIconProvider(IconProvider()) self.list_view_eval = QTreeView(self) self.list_view_eval.setModel(model_eval) - for i in range(1,4): + for i in range(1, 4): self.list_view_eval.hideColumn(i) - #self.list_view_eval.setFixedSize(600, 600) - self.list_view_eval.setRootIndex(model_eval.setRootPath(self.app.eval_data_path)) + # self.list_view_eval.setFixedSize(600, 600) + self.list_view_eval.setRootIndex( + model_eval.setRootPath(self.app.eval_data_path) + ) self.list_view_eval.clicked.connect(self.on_item_eval_selected) - + self.eval_dir_layout.addWidget(self.list_view_eval) self.uncurated_layout.addLayout(self.eval_dir_layout) # add buttons self.inference_button = QPushButton("Generate Labels", self) - self.inference_button.clicked.connect(self.on_run_inference_button_clicked) # add selected image + self.inference_button.clicked.connect( + self.on_run_inference_button_clicked + ) # add selected image self.uncurated_layout.addWidget(self.inference_button, alignment=Qt.AlignCenter) dir_layout.addLayout(self.uncurated_layout) # In progress layout - self.inprogr_dir_layout = QVBoxLayout() - self.inprogr_dir_layout.setContentsMargins(0,0,0,0) + self.inprogr_dir_layout = QVBoxLayout() + self.inprogr_dir_layout.setContentsMargins(0, 0, 0, 0) self.label_inprogr = QLabel(self) self.label_inprogr.setText("Curation in progress") self.inprogr_dir_layout.addWidget(self.label_inprogr) # add in progress dir list model_inprogr = QFileSystemModel() - #self.list_view = QListView(self) + # self.list_view = QListView(self) self.list_view_inprogr = QTreeView(self) model_inprogr.setIconProvider(IconProvider()) self.list_view_inprogr.setModel(model_inprogr) - for i in range(1,4): + for i in range(1, 4): self.list_view_inprogr.hideColumn(i) - #self.list_view_inprogr.setFixedSize(600, 600) - self.list_view_inprogr.setRootIndex(model_inprogr.setRootPath(self.app.inprogr_data_path)) + # self.list_view_inprogr.setFixedSize(600, 600) + self.list_view_inprogr.setRootIndex( + model_inprogr.setRootPath(self.app.inprogr_data_path) + ) self.list_view_inprogr.clicked.connect(self.on_item_inprogr_selected) self.inprogr_dir_layout.addWidget(self.list_view_inprogr) self.inprogress_layout.addLayout(self.inprogr_dir_layout) self.launch_nap_button = QPushButton("View image and fix label", self) - self.launch_nap_button.clicked.connect(self.on_launch_napari_button_clicked) # add selected image - self.inprogress_layout.addWidget(self.launch_nap_button, alignment=Qt.AlignCenter) + self.launch_nap_button.clicked.connect( + self.on_launch_napari_button_clicked + ) # add selected image + self.inprogress_layout.addWidget( + self.launch_nap_button, alignment=Qt.AlignCenter + ) # Create a shortcut for the Enter key to click the button enter_shortcut = QShortcut(QKeySequence(Qt.Key_Return), self) enter_shortcut.activated.connect(self.on_launch_napari_button_clicked) @@ -150,27 +177,31 @@ def main_window(self): dir_layout.addLayout(self.inprogress_layout) # Curated layout - self.train_dir_layout = QVBoxLayout() - self.train_dir_layout.setContentsMargins(0,0,0,0) + self.train_dir_layout = QVBoxLayout() + self.train_dir_layout.setContentsMargins(0, 0, 0, 0) self.label_train = QLabel(self) self.label_train.setText("Curated dataset") self.train_dir_layout.addWidget(self.label_train) # add train dir list model_train = QFileSystemModel() - #self.list_view = QListView(self) + # self.list_view = QListView(self) self.list_view_train = QTreeView(self) model_train.setIconProvider(IconProvider()) self.list_view_train.setModel(model_train) - for i in range(1,4): + for i in range(1, 4): self.list_view_train.hideColumn(i) - #self.list_view_train.setFixedSize(600, 600) - self.list_view_train.setRootIndex(model_train.setRootPath(self.app.train_data_path)) + # self.list_view_train.setFixedSize(600, 600) + self.list_view_train.setRootIndex( + model_train.setRootPath(self.app.train_data_path) + ) self.list_view_train.clicked.connect(self.on_item_train_selected) self.train_dir_layout.addWidget(self.list_view_train) self.curated_layout.addLayout(self.train_dir_layout) - + self.train_button = QPushButton("Train Model", self) - self.train_button.clicked.connect(self.on_train_button_clicked) # add selected image + self.train_button.clicked.connect( + self.on_train_button_clicked + ) # add selected image self.curated_layout.addWidget(self.train_button, alignment=Qt.AlignCenter) dir_layout.addLayout(self.curated_layout) @@ -178,9 +209,9 @@ def main_window(self): # add progress bar progress_layout = QHBoxLayout() - progress_layout.addStretch(1) + progress_layout.addStretch(1) self.progress_bar = QProgressBar(self) - self.progress_bar.setRange(0,1) + self.progress_bar.setRange(0, 1) progress_layout.addWidget(self.progress_bar) main_layout.addLayout(progress_layout) @@ -222,9 +253,9 @@ def on_train_button_clicked(self): Is called once user clicks the "Train Model" button. """ self.train_button.setEnabled(False) - self.progress_bar.setRange(0,0) + self.progress_bar.setRange(0, 0) # initialise the worker thread - self.worker_thread = WorkerThread(app=self.app, task='train') + self.worker_thread = WorkerThread(app=self.app, task="train") self.worker_thread.task_finished.connect(self.on_finished) # start the worker thread to train self.worker_thread.start() @@ -234,23 +265,23 @@ def on_run_inference_button_clicked(self): Is called once user clicks the "Generate Labels" button. """ self.inference_button.setEnabled(False) - self.progress_bar.setRange(0,0) + self.progress_bar.setRange(0, 0) # initialise the worker thread - self.worker_thread = WorkerThread(app=self.app, task='inference') + self.worker_thread = WorkerThread(app=self.app, task="inference") self.worker_thread.task_finished.connect(self.on_finished) # start the worker thread to run inference self.worker_thread.start() - def on_launch_napari_button_clicked(self): - """ + def on_launch_napari_button_clicked(self): + """ Launches the napari window after the image is selected. """ - if not self.app.cur_selected_img or '_seg.tiff' in self.app.cur_selected_img: + if not self.app.cur_selected_img or "_seg.tiff" in self.app.cur_selected_img: message_text = "Please first select an image you wish to visualise. The selected image must be an original image, not a mask." _ = self.create_warning_box(message_text, message_title="Warning") else: self.nap_win = NapariWindow(self.app) - self.nap_win.show() + self.nap_win.show() def on_finished(self, result): """ @@ -258,9 +289,9 @@ def on_finished(self, result): :param result: The result emitted by the worker thread. See return type of WorkerThread.run :type result: tuple - """ + """ # Stop the pulsation - self.progress_bar.setRange(0,1) + self.progress_bar.setRange(0, 1) # Display message of result message_text, message_title = result _ = self.create_warning_box(message_text, message_title) @@ -282,20 +313,21 @@ def on_finished(self, result): from dcp_client.utils.fsimagestorage import FilesystemImageStorage from dcp_client.utils import settings from dcp_client.utils.sync_src_dst import DataRSync + settings.init() image_storage = FilesystemImageStorage() ml_model = BentomlModel() - data_sync = DataRSync(user_name="local", - host_name="local", - server_repo_path=None) + data_sync = DataRSync(user_name="local", host_name="local", server_repo_path=None) app = QApplication(sys.argv) - app_ = Application(ml_model=ml_model, - syncer=data_sync, - image_storage=image_storage, - server_ip='0.0.0.0', - server_port=7010, - eval_data_path='data', - train_data_path='', # set path - inprogr_data_path='') # set path + app_ = Application( + ml_model=ml_model, + syncer=data_sync, + image_storage=image_storage, + server_ip="0.0.0.0", + server_port=7010, + eval_data_path="data", + train_data_path="", # set path + inprogr_data_path="", + ) # set path window = MainWindow(app=app_) - sys.exit(app.exec()) \ No newline at end of file + sys.exit(app.exec()) diff --git a/src/client/dcp_client/gui/napari_window.py b/src/client/dcp_client/gui/napari_window.py index 99da89ef..2ca2a18f 100644 --- a/src/client/dcp_client/gui/napari_window.py +++ b/src/client/dcp_client/gui/napari_window.py @@ -11,6 +11,7 @@ from dcp_client.utils.utils import get_path_stem, check_equal_arrays, Compute4Mask from dcp_client.gui._my_widget import MyWidget + class NapariWindow(MyWidget): """Napari Window Widget object. Opens the napari image viewer to view and fix the labeles. @@ -37,17 +38,24 @@ def __init__(self, app: Application): self.viewer = napari.Viewer(show=False) self.viewer.add_image(img, name=get_path_stem(self.app.cur_selected_img)) for seg_file in self.seg_files: - self.viewer.add_labels(self.app.load_image(seg_file), name=get_path_stem(seg_file)) + self.viewer.add_labels( + self.app.load_image(seg_file), name=get_path_stem(seg_file) + ) main_window = self.viewer.window._qt_window layout = QGridLayout() layout.addWidget(main_window, 0, 0, 1, 4) # select the first seg as the currently selected layer if there are any segs - if len(self.seg_files) and len(self.viewer.layers[get_path_stem(self.seg_files[0])].data.shape) > 2: + if ( + len(self.seg_files) + and len(self.viewer.layers[get_path_stem(self.seg_files[0])].data.shape) > 2 + ): self.cur_selected_seg = self.viewer.layers.selection.active.name self.layer = self.viewer.layers[self.cur_selected_seg] - self.viewer.layers.selection.events.changed.connect(self.on_seg_channel_changed) + self.viewer.layers.selection.events.changed.connect( + self.on_seg_channel_changed + ) # set first mask as active by default self.active_mask_index = 0 self.viewer.dims.events.current_step.connect(self.axis_changed) @@ -58,27 +66,39 @@ def __init__(self, app: Application): for seg_file in self.seg_files: layer_name = get_path_stem(seg_file) # get unique instance labels for each seg - self.original_instance_mask[layer_name] = deepcopy(self.viewer.layers[layer_name].data[0]) - self.original_class_mask[layer_name] = deepcopy(self.viewer.layers[layer_name].data[1]) + self.original_instance_mask[layer_name] = deepcopy( + self.viewer.layers[layer_name].data[0] + ) + self.original_class_mask[layer_name] = deepcopy( + self.viewer.layers[layer_name].data[1] + ) # compute unique instance ids - self.instances[layer_name] = Compute4Mask.get_unique_objects(self.original_instance_mask[layer_name]) + self.instances[layer_name] = Compute4Mask.get_unique_objects( + self.original_instance_mask[layer_name] + ) # remove border from class mask - self.contours_mask[layer_name] = Compute4Mask.get_contours(self.original_instance_mask[layer_name], contours_level=0.8) - self.viewer.layers[layer_name].data[1][self.contours_mask[layer_name]!=0] = 0 - + self.contours_mask[layer_name] = Compute4Mask.get_contours( + self.original_instance_mask[layer_name], contours_level=0.8 + ) + self.viewer.layers[layer_name].data[1][ + self.contours_mask[layer_name] != 0 + ] = 0 + self.qctrl = self.viewer.window.qt_viewer.controls.widgets[self.layer] if len(self.layer.data.shape) > 2: # User hint - message_label = QLabel('Choose an active mask') + message_label = QLabel("Choose an active mask") message_label.setAlignment(Qt.AlignRight) layout.addWidget(message_label, 1, 0) - + # Drop list to choose which is an active mask self.mask_choice_dropdown = QComboBox() self.mask_choice_dropdown.setEnabled(False) - self.mask_choice_dropdown.addItem('Instance Segmentation Mask', userData=0) - self.mask_choice_dropdown.addItem('Labels Mask', userData=1) + self.mask_choice_dropdown.addItem( + "Instance Segmentation Mask", userData=0 + ) + self.mask_choice_dropdown.addItem("Labels Mask", userData=1) layout.addWidget(self.mask_choice_dropdown, 1, 1) # when user has chosen the mask, we don't want to change it anymore to avoid errors @@ -91,11 +111,15 @@ def __init__(self, app: Application): self.layer = None # add buttons for moving images to other dirs - add_to_inprogress_button = QPushButton('Move to \'Curatation in progress\' folder') + add_to_inprogress_button = QPushButton( + "Move to 'Curatation in progress' folder" + ) layout.addWidget(add_to_inprogress_button, 2, 0, 1, 2) - add_to_inprogress_button.clicked.connect(self.on_add_to_inprogress_button_clicked) + add_to_inprogress_button.clicked.connect( + self.on_add_to_inprogress_button_clicked + ) - add_to_curated_button = QPushButton('Move to \'Curated dataset\' folder') + add_to_curated_button = QPushButton("Move to 'Curated dataset' folder") layout.addWidget(add_to_curated_button, 2, 2, 1, 2) add_to_curated_button.clicked.connect(self.on_add_to_curated_button_clicked) @@ -116,33 +140,41 @@ def on_seg_channel_changed(self, event): if (act := self.viewer.layers.selection.active) is not None: # updater cur_selected_seg with the new selection from the user self.cur_selected_seg = act.name - if type(self.viewer.layers[self.cur_selected_seg]) == napari.layers.Image: pass + if type(self.viewer.layers[self.cur_selected_seg]) == napari.layers.Image: + pass # set self.layer to new selection from user - elif self.layer is not None: self.layer = self.viewer.layers[self.cur_selected_seg] - else: pass - + elif self.layer is not None: + self.layer = self.viewer.layers[self.cur_selected_seg] + else: + pass + def axis_changed(self, event): """ - Is triggered each time the user switches the viewer between the mask channels. At this point the class mask + Is triggered each time the user switches the viewer between the mask channels. At this point the class mask needs to be updated according to the changes made tot the instance segmentation mask. """ self.active_mask_index = self.viewer.dims.current_step[0] masks = deepcopy(self.layer.data) # if user has switched to the instance mask - if self.active_mask_index==0: + if self.active_mask_index == 0: class_mask_with_contours = Compute4Mask.add_contour(masks[1], masks[0]) - if not check_equal_arrays(class_mask_with_contours.astype(bool), self.original_class_mask[self.cur_selected_seg].astype(bool)): + if not check_equal_arrays( + class_mask_with_contours.astype(bool), + self.original_class_mask[self.cur_selected_seg].astype(bool), + ): self.update_instance_mask(masks[0], masks[1]) self.switch_to_instance_mask() # else if user has switched to the class mask - elif self.active_mask_index==1: - if not check_equal_arrays(masks[0], self.original_instance_mask[self.cur_selected_seg]): + elif self.active_mask_index == 1: + if not check_equal_arrays( + masks[0], self.original_instance_mask[self.cur_selected_seg] + ): self.update_labels_mask(masks[0]) self.switch_to_labels_mask() def switch_to_instance_mask(self): """ - Switch the application to the active mask mode by enabling 'paint_button', 'erase_button' + Switch the application to the active mask mode by enabling 'paint_button', 'erase_button' and 'fill_button'. """ self.switch_controls("paint_button", True) @@ -154,12 +186,16 @@ def switch_to_labels_mask(self): Switch the application to non-active mask mode by enabling 'fill_button' and disabling 'paint_button' and 'erase_button'. """ if self.cur_selected_seg in [layer.name for layer in self.viewer.layers]: - self.viewer.layers[self.cur_selected_seg].mode = 'pan_zoom' - info_message_paint = "Painting objects is only possible in the instance layer for now." - info_message_erase = "Erasing objects is only possible in the instance layer for now." + self.viewer.layers[self.cur_selected_seg].mode = "pan_zoom" + info_message_paint = ( + "Painting objects is only possible in the instance layer for now." + ) + info_message_erase = ( + "Erasing objects is only possible in the instance layer for now." + ) self.switch_controls("paint_button", False, info_message_paint) self.switch_controls("erase_button", False, info_message_erase) - self.switch_controls("fill_button", True) + self.switch_controls("fill_button", True) def update_labels_mask(self, instance_mask): """Updates the class mask based on changes in the instance mask. @@ -170,17 +206,25 @@ def update_labels_mask(self, instance_mask): :type instance_mask: numpy.ndarray :return: None """ - self.original_class_mask[self.cur_selected_seg] = Compute4Mask.compute_new_labels_mask(self.original_class_mask[self.cur_selected_seg], - instance_mask, - self.original_instance_mask[self.cur_selected_seg], - self.instances[self.cur_selected_seg]) + self.original_class_mask[self.cur_selected_seg] = ( + Compute4Mask.compute_new_labels_mask( + self.original_class_mask[self.cur_selected_seg], + instance_mask, + self.original_instance_mask[self.cur_selected_seg], + self.instances[self.cur_selected_seg], + ) + ) # update original instance mask and instances self.original_instance_mask[self.cur_selected_seg] = instance_mask - self.instances[self.cur_selected_seg] = Compute4Mask.get_unique_objects(self.original_instance_mask[self.cur_selected_seg]) + self.instances[self.cur_selected_seg] = Compute4Mask.get_unique_objects( + self.original_instance_mask[self.cur_selected_seg] + ) # compute contours to remove from class mask visualisation - self.contours_mask[self.cur_selected_seg] = Compute4Mask.get_contours(instance_mask, contours_level=0.8) + self.contours_mask[self.cur_selected_seg] = Compute4Mask.get_contours( + instance_mask, contours_level=0.8 + ) vis_labels_mask = deepcopy(self.original_class_mask[self.cur_selected_seg]) - vis_labels_mask[self.contours_mask[self.cur_selected_seg]!=0] = 0 + vis_labels_mask[self.contours_mask[self.cur_selected_seg] != 0] = 0 # update the viewer self.layer.data[1] = vis_labels_mask self.layer.refresh() @@ -198,9 +242,12 @@ def update_instance_mask(self, instance_mask, labels_mask): # add contours back to labels mask labels_mask = Compute4Mask.add_contour(labels_mask, instance_mask) # and compute the updated instance mask - self.original_instance_mask[self.cur_selected_seg] = Compute4Mask.compute_new_instance_mask(labels_mask, - instance_mask) - self.instances[self.cur_selected_seg] = Compute4Mask.get_unique_objects(self.original_instance_mask[self.cur_selected_seg]) + self.original_instance_mask[self.cur_selected_seg] = ( + Compute4Mask.compute_new_instance_mask(labels_mask, instance_mask) + ) + self.instances[self.cur_selected_seg] = Compute4Mask.get_unique_objects( + self.original_instance_mask[self.cur_selected_seg] + ) self.original_class_mask[self.cur_selected_seg] = labels_mask # update the viewer self.layer.data[0] = self.original_instance_mask[self.cur_selected_seg] @@ -224,67 +271,75 @@ def switch_controls(self, target_widget, status: bool, info_message=None): pass def on_add_to_curated_button_clicked(self): - """Defines what happens when the "Move to curated dataset folder" button is clicked. - """ - if self.app.cur_selected_path == str(self.app.train_data_path): - message_text = "Image is already in the \'Curated data\' folder and should not be changed again" + """Defines what happens when the "Move to curated dataset folder" button is clicked.""" + if self.app.cur_selected_path == str(self.app.train_data_path): + message_text = "Image is already in the 'Curated data' folder and should not be changed again" _ = self.create_warning_box(message_text, message_title="Warning") return - + # take the name of the currently selected layer (by the user) seg_name_to_save = self.viewer.layers.selection.active.name # TODO if more than one item is selected this will break! - if '_seg' not in seg_name_to_save: + if "_seg" not in seg_name_to_save: message_text = ( - "Please select the segmenation you wish to save from the layer list." - "The labels layer should have the same name as the image to which it corresponds, followed by _seg." + "Please select the segmenation you wish to save from the layer list." + "The labels layer should have the same name as the image to which it corresponds, followed by _seg." ) _ = self.create_warning_box(message_text, message_title="Warning") return - + # Save the (changed) seg seg = self.viewer.layers[seg_name_to_save].data seg[1] = Compute4Mask.add_contour(seg[1], seg[0]) - annot_error, mask_mismatch_error, faulty_ids_annot, faulty_ids_missmatch = Compute4Mask.assert_consistent_labels(seg) + annot_error, mask_mismatch_error, faulty_ids_annot, faulty_ids_missmatch = ( + Compute4Mask.assert_consistent_labels(seg) + ) if annot_error: - message_text = ("There seems to be a problem with your mask. We expect each object to be a connected component. For object(s) with ID(s) \n" - +str(faulty_ids_annot)+"\n" - "more than one connected component was found. Please go back and fix this.") + message_text = ( + "There seems to be a problem with your mask. We expect each object to be a connected component. For object(s) with ID(s) \n" + + str(faulty_ids_annot) + + "\n" + "more than one connected component was found. Please go back and fix this." + ) self.create_warning_box(message_text, "Warning") elif mask_mismatch_error: - message_text = ("There seems to be a mismatch between your class and instance masks for object(s) with ID(s) \n" - +str(faulty_ids_missmatch)+"\n" - "This should not occur and will cause a problem later during model training. Please go back and check.") + message_text = ( + "There seems to be a mismatch between your class and instance masks for object(s) with ID(s) \n" + + str(faulty_ids_missmatch) + + "\n" + "This should not occur and will cause a problem later during model training. Please go back and check." + ) self.create_warning_box(message_text, "Warning") - else: + else: # Move original image self.app.move_images(self.app.train_data_path) - self.app.save_image(self.app.train_data_path, seg_name_to_save+'.tiff', seg) + self.app.save_image( + self.app.train_data_path, seg_name_to_save + ".tiff", seg + ) # We remove seg from the current directory if it exists (both eval and inprogr allowed) self.app.delete_images(self.seg_files) - # TODO Create the Archive folder for the rest? Or move them as well? + # TODO Create the Archive folder for the rest? Or move them as well? self.viewer.close() self.close() def on_add_to_inprogress_button_clicked(self): - """Defines what happens when the "Move to curation in progress folder" button is clicked. - """ + """Defines what happens when the "Move to curation in progress folder" button is clicked.""" # TODO: Do we allow this? What if they moved it by mistake? User can always manually move from their folders?) if self.app.cur_selected_path == str(self.app.train_data_path): - message_text = "Images from '\Curated data'\ folder can not be moved back to \'Curatation in progress\' folder." + message_text = "Images from '\Curated data'\ folder can not be moved back to 'Curatation in progress' folder." _ = self.create_warning_box(message_text, message_title="Warning") return - + # take the name of the currently selected layer (by the user) seg_name_to_save = self.viewer.layers.selection.active.name # TODO if more than one item is selected this will break! - if '_seg' not in seg_name_to_save: + if "_seg" not in seg_name_to_save: message_text = ( - "Please select the segmenation you wish to save from the layer list." - "The labels layer should have the same name as the image to which it corresponds, followed by _seg." + "Please select the segmenation you wish to save from the layer list." + "The labels layer should have the same name as the image to which it corresponds, followed by _seg." ) _ = self.create_warning_box(message_text, message_title="Warning") return @@ -294,8 +349,7 @@ def on_add_to_inprogress_button_clicked(self): # Save the (changed) seg - this will overwrite existing seg if seg name hasn't been changed in viewer seg = self.viewer.layers[seg_name_to_save].data seg[1] = Compute4Mask.add_contour(seg[1], seg[0]) - self.app.save_image(self.app.inprogr_data_path, seg_name_to_save+'.tiff', seg) - + self.app.save_image(self.app.inprogr_data_path, seg_name_to_save + ".tiff", seg) + self.viewer.close() self.close() - \ No newline at end of file diff --git a/src/client/dcp_client/gui/welcome_window.py b/src/client/dcp_client/gui/welcome_window.py index 74b3c55d..0856403a 100644 --- a/src/client/dcp_client/gui/welcome_window.py +++ b/src/client/dcp_client/gui/welcome_window.py @@ -1,7 +1,14 @@ from __future__ import annotations from typing import TYPE_CHECKING -from qtpy.QtWidgets import QPushButton, QVBoxLayout, QHBoxLayout, QLabel, QFileDialog, QLineEdit +from qtpy.QtWidgets import ( + QPushButton, + QVBoxLayout, + QHBoxLayout, + QLabel, + QFileDialog, + QLineEdit, +) from qtpy.QtCore import Qt from dcp_client.gui.main_window import MainWindow @@ -10,9 +17,10 @@ if TYPE_CHECKING: from dcp_client.app import Application + class WelcomeWindow(MyWidget): """Welcome Window Widget object. - The first window of the application providing a dialog that allows users to select directories. + The first window of the application providing a dialog that allows users to select directories. Currently supported image file types that can be selected for segmentation are: .jpg, .jpeg, .png, .tiff, .tif. By clicking 'start' the MainWindow is called. """ @@ -30,7 +38,9 @@ def __init__(self, app: Application): self.main_layout = QVBoxLayout() input_layout = QHBoxLayout() label = QLabel(self) - label.setText('Welcome to Helmholtz AI data centric tool! Please select your dataset folder') + label.setText( + "Welcome to Helmholtz AI data centric tool! Please select your dataset folder" + ) self.main_layout.addWidget(label) self.text_layout = QVBoxLayout() @@ -38,35 +48,41 @@ def __init__(self, app: Application): self.button_layout = QVBoxLayout() val_label = QLabel(self) - val_label.setText('Uncurated dataset path:') + val_label.setText("Uncurated dataset path:") inprogr_label = QLabel(self) - inprogr_label.setText('Curation in progress path:') + inprogr_label.setText("Curation in progress path:") train_label = QLabel(self) - train_label.setText('Curated dataset path:') + train_label.setText("Curated dataset path:") self.text_layout.addWidget(val_label) self.text_layout.addWidget(inprogr_label) self.text_layout.addWidget(train_label) self.val_textbox = QLineEdit(self) - self.val_textbox.textEdited.connect(lambda x: self.on_text_changed(self.val_textbox, "eval", x)) + self.val_textbox.textEdited.connect( + lambda x: self.on_text_changed(self.val_textbox, "eval", x) + ) self.inprogr_textbox = QLineEdit(self) - self.inprogr_textbox.textEdited.connect(lambda x: self.on_text_changed(self.inprogr_textbox, "inprogress", x)) + self.inprogr_textbox.textEdited.connect( + lambda x: self.on_text_changed(self.inprogr_textbox, "inprogress", x) + ) self.train_textbox = QLineEdit(self) - self.train_textbox.textEdited.connect(lambda x: self.on_text_changed(self.train_textbox, "train", x)) + self.train_textbox.textEdited.connect( + lambda x: self.on_text_changed(self.train_textbox, "train", x) + ) self.path_layout.addWidget(self.val_textbox) self.path_layout.addWidget(self.inprogr_textbox) self.path_layout.addWidget(self.train_textbox) - - self.file_open_button_val = QPushButton('Browse',self) + + self.file_open_button_val = QPushButton("Browse", self) self.file_open_button_val.show() self.file_open_button_val.clicked.connect(self.browse_eval_clicked) - self.file_open_button_prog = QPushButton('Browse',self) + self.file_open_button_prog = QPushButton("Browse", self) self.file_open_button_prog.show() self.file_open_button_prog.clicked.connect(self.browse_inprogr_clicked) - self.file_open_button_train = QPushButton('Browse',self) + self.file_open_button_train = QPushButton("Browse", self) self.file_open_button_train.show() self.file_open_button_train.clicked.connect(self.browse_train_clicked) self.button_layout.addWidget(self.file_open_button_val) @@ -78,11 +94,11 @@ def __init__(self, app: Application): input_layout.addLayout(self.button_layout) self.main_layout.addLayout(input_layout) - self.start_button = QPushButton('Start', self) + self.start_button = QPushButton("Start", self) self.start_button.setFixedSize(120, 30) self.start_button.show() # check if we need to upload data to server - self.done_upload = False # we only do once + self.done_upload = False # we only do once if self.app.syncer.host_name == "local": self.start_button.clicked.connect(self.start_main) else: @@ -93,7 +109,7 @@ def __init__(self, app: Application): self.show() def browse_eval_clicked(self): - """Activates when the user clicks the button to choose the evaluation directory (QFileDialog) and + """Activates when the user clicks the button to choose the evaluation directory (QFileDialog) and displays the name of the evaluation directory chosen in the validation textbox line (QLineEdit). """ self.fd = QFileDialog() @@ -104,9 +120,9 @@ def browse_eval_clicked(self): self.val_textbox.setText(self.app.eval_data_path) finally: self.fd = None - + def browse_train_clicked(self): - """Activates when the user clicks the button to choose the train directory (QFileDialog) and + """Activates when the user clicks the button to choose the train directory (QFileDialog) and displays the name of the train directory chosen in the train textbox line (QLineEdit). """ @@ -118,7 +134,7 @@ def browse_train_clicked(self): def on_text_changed(self, field_obj, field_name, text): """ - Update data paths based on text changes in input fields. + Update data paths based on text changes in input fields. Used for copying paths in the welcome window. :param field_obj: The QLineEdit object. @@ -136,30 +152,37 @@ def on_text_changed(self, field_obj, field_name, text): elif field_name == "inprogress": self.app.inprogr_data_path = text field_obj.setText(text) - - def browse_inprogr_clicked(self): """ - Activates when the user clicks the button to choose the curation in progress directory (QFileDialog) and + Activates when the user clicks the button to choose the curation in progress directory (QFileDialog) and displays the name of the evaluation directory chosen in the validation textbox line (QLineEdit). """ fd = QFileDialog() fd.setFileMode(QFileDialog.Directory) - if fd.exec_(): # Browse clicked - self.app.inprogr_data_path = fd.selectedFiles()[0] #TODO: case when browse is clicked but nothing is specified - currently it is filled with os.getcwd() + if fd.exec_(): # Browse clicked + self.app.inprogr_data_path = fd.selectedFiles()[ + 0 + ] # TODO: case when browse is clicked but nothing is specified - currently it is filled with os.getcwd() self.inprogr_textbox.setText(self.app.inprogr_data_path) - + def start_main(self): - """Starts the main window after the user clicks 'Start' and only if both evaluation and train directories are chosen and all unique. - """ - - if len({self.app.inprogr_data_path, self.app.train_data_path, self.app.eval_data_path})<3: + """Starts the main window after the user clicks 'Start' and only if both evaluation and train directories are chosen and all unique.""" + + if ( + len( + { + self.app.inprogr_data_path, + self.app.train_data_path, + self.app.eval_data_path, + } + ) + < 3 + ): self.message_text = "All directory names must be distinct." _ = self.create_warning_box(self.message_text, message_title="Warning") - elif self.app.train_data_path and self.app.eval_data_path: self.hide() self.mw = MainWindow(self.app) @@ -173,21 +196,29 @@ def start_upload_and_main(self): to the server and the upload starts before launching the main window. """ if self.done_upload is False: - message_text = ("Your current configurations are set to run some operations on the cloud. \n" - "For this we need to upload your data to our server." - "We will now upload your data. Click ok to continue. \n" - "If you do not agree close the application and contact your software provider.") - usr_response = self.create_warning_box(message_text, message_title="Warning", add_cancel_btn=True) - if usr_response: + message_text = ( + "Your current configurations are set to run some operations on the cloud. \n" + "For this we need to upload your data to our server." + "We will now upload your data. Click ok to continue. \n" + "If you do not agree close the application and contact your software provider." + ) + usr_response = self.create_warning_box( + message_text, message_title="Warning", add_cancel_btn=True + ) + if usr_response: success_up1, success_up2, _, _ = self.app.upload_data_to_server() - if success_up1=="Error" or success_up2=="Error": - message_text = ("An error has occured during data upload to the server. \n" - "Please check your configuration file and ensure that the server connection settings are correct and you have been given access to the server. \n" - "If the problem persists contact your software provider. Exiting now.") - usr_response = self.create_warning_box(message_text, message_title="Error") - self.close() - else: + if success_up1 == "Error" or success_up2 == "Error": + message_text = ( + "An error has occured during data upload to the server. \n" + "Please check your configuration file and ensure that the server connection settings are correct and you have been given access to the server. \n" + "If the problem persists contact your software provider. Exiting now." + ) + usr_response = self.create_warning_box( + message_text, message_title="Error" + ) + self.close() + else: self.done_upload = True self.start_upload_and_main() - else: self.start_main() - \ No newline at end of file + else: + self.start_main() diff --git a/src/client/dcp_client/main.py b/src/client/dcp_client/main.py index 978eed80..51c7b13c 100644 --- a/src/client/dcp_client/main.py +++ b/src/client/dcp_client/main.py @@ -11,27 +11,35 @@ from dcp_client.gui.welcome_window import WelcomeWindow import warnings -warnings.simplefilter('ignore') + +warnings.simplefilter("ignore") def main(): settings.init() dir_name = path.dirname(path.abspath(sys.argv[0])) - server_config = read_config('server', config_path = path.join(dir_name, 'config.yaml')) + server_config = read_config( + "server", config_path=path.join(dir_name, "config.yaml") + ) image_storage = FilesystemImageStorage() ml_model = BentomlModel() - data_sync = DataRSync(user_name=server_config["user"], - host_name=server_config["host"], - server_repo_path=server_config["data-path"]) - welcome_app = Application(ml_model=ml_model, - syncer=data_sync, - image_storage=image_storage, - server_ip=server_config["ip"], - server_port=server_config["port"]) + data_sync = DataRSync( + user_name=server_config["user"], + host_name=server_config["host"], + server_repo_path=server_config["data-path"], + ) + welcome_app = Application( + ml_model=ml_model, + syncer=data_sync, + image_storage=image_storage, + server_ip=server_config["ip"], + server_port=server_config["port"], + ) app = QApplication(sys.argv) window = WelcomeWindow(welcome_app) sys.exit(app.exec()) + if __name__ == "__main__": main() diff --git a/src/client/dcp_client/utils/bentoml_model.py b/src/client/dcp_client/utils/bentoml_model.py index 25204ac5..5c2e58fa 100644 --- a/src/client/dcp_client/utils/bentoml_model.py +++ b/src/client/dcp_client/utils/bentoml_model.py @@ -5,14 +5,11 @@ from dcp_client.app import Model + class BentomlModel(Model): - """BentomlModel class for connecting to a BentoML server and running training and inference tasks. - """ + """BentomlModel class for connecting to a BentoML server and running training and inference tasks.""" - def __init__( - self, - client: Optional[BentoClient] = None - ): + def __init__(self, client: Optional[BentoClient] = None): """Initializes the BentomlModel. :param client: Optional BentoClient instance. If None, it will be initialized during connection. @@ -20,7 +17,7 @@ def __init__( """ self.client = client - def connect(self, ip: str = '0.0.0.0', port: int = 7010): + def connect(self, ip: str = "0.0.0.0", port: int = 7010): """Connects to the BentoML server. :param ip: IP address of the BentoML server. Default is '0.0.0.0'. @@ -30,12 +27,13 @@ def connect(self, ip: str = '0.0.0.0', port: int = 7010): :return: True if connection is successful, False otherwise. :rtype: bool """ - url = f"http://{ip}:{port}" #"http://0.0.0.0:7010" + url = f"http://{ip}:{port}" # "http://0.0.0.0:7010" try: - self.client = BentoClient.from_url(url) + self.client = BentoClient.from_url(url) return True - except : return False # except ConnectionRefusedError - + except: + return False # except ConnectionRefusedError + @property def is_connected(self): """Checks if the BentomlModel is connected to the BentoML server. @@ -55,7 +53,8 @@ async def _run_train(self, data_path): try: response = await self.client.async_train(data_path) return response - except BentoMLException: return None + except BentoMLException: + return None def run_train(self, data_path): """Runs the training. @@ -76,8 +75,9 @@ async def _run_inference(self, data_path): try: response = await self.client.async_segment_image(data_path) return response - except BentoMLException: return None - + except BentoMLException: + return None + def run_inference(self, data_path): """Runs the inference. @@ -86,4 +86,4 @@ def run_inference(self, data_path): :return: List of files not supported by the server if unsuccessful, otherwise returns None. """ list_of_files_not_suported = asyncio.run(self._run_inference(data_path)) - return list_of_files_not_suported \ No newline at end of file + return list_of_files_not_suported diff --git a/src/client/dcp_client/utils/fsimagestorage.py b/src/client/dcp_client/utils/fsimagestorage.py index d33371ff..52d6c006 100644 --- a/src/client/dcp_client/utils/fsimagestorage.py +++ b/src/client/dcp_client/utils/fsimagestorage.py @@ -3,9 +3,9 @@ from dcp_client.app import ImageStorage + class FilesystemImageStorage(ImageStorage): - """FilesystemImageStorage class for handling image storage operations on the local filesystem. - """ + """FilesystemImageStorage class for handling image storage operations on the local filesystem.""" def load_image(self, from_directory, cur_selected_img): """Loads an image from the specified directory. @@ -18,7 +18,7 @@ def load_image(self, from_directory, cur_selected_img): """ # Read the selected image and read the segmentation if any: return imread(os.path.join(from_directory, cur_selected_img)) - + def move_image(self, from_directory, to_directory, cur_selected_img): """Moves an image from one directory to another. @@ -29,8 +29,13 @@ def move_image(self, from_directory, to_directory, cur_selected_img): :param cur_selected_img: Name of the image file. :type cur_selected_img: str """ - print(f"from:{os.path.join(from_directory, cur_selected_img)}, to:{os.path.join(to_directory, cur_selected_img)}") - os.replace(os.path.join(from_directory, cur_selected_img), os.path.join(to_directory, cur_selected_img)) + print( + f"from:{os.path.join(from_directory, cur_selected_img)}, to:{os.path.join(to_directory, cur_selected_img)}" + ) + os.replace( + os.path.join(from_directory, cur_selected_img), + os.path.join(to_directory, cur_selected_img), + ) def save_image(self, to_directory, cur_selected_img, img): """Saves an image to the specified directory. @@ -41,9 +46,9 @@ def save_image(self, to_directory, cur_selected_img, img): :type cur_selected_img: str :param img: Image data to be saved. """ - + imsave(os.path.join(to_directory, cur_selected_img), img) - + def delete_image(self, from_directory, cur_selected_img): """Deletes an image from the specified directory. diff --git a/src/client/dcp_client/utils/settings.py b/src/client/dcp_client/utils/settings.py index 2fd6bcb2..3decd50f 100644 --- a/src/client/dcp_client/utils/settings.py +++ b/src/client/dcp_client/utils/settings.py @@ -2,4 +2,4 @@ def init(): global accepted_types accepted_types = (".jpg", ".jpeg", ".png", ".tiff", ".tif") global seg_name_string - seg_name_string = '_seg' + seg_name_string = "_seg" diff --git a/src/client/dcp_client/utils/sync_src_dst.py b/src/client/dcp_client/utils/sync_src_dst.py index 66d0a4b7..951f8cb3 100644 --- a/src/client/dcp_client/utils/sync_src_dst.py +++ b/src/client/dcp_client/utils/sync_src_dst.py @@ -6,13 +6,15 @@ class DataRSync(DataSync): - ''' + """ Class which uses rsync bash command to sync data between client and server - ''' - def __init__(self, - user_name: str, - host_name: str, - server_repo_path: str, + """ + + def __init__( + self, + user_name: str, + host_name: str, + server_repo_path: str, ): """Constructs all the necessary attributes for the CustomRunnable. @@ -22,7 +24,7 @@ def __init__(self, :type: host_name: str :param server_repo_path: the server path where we wish to sync data - if None, then it is assumed that local machine is used for the server :type server_repo_path: str - """ + """ self.user_name = user_name self.host_name = host_name self.server_repo_path = server_repo_path @@ -30,25 +32,20 @@ def __init__(self, def first_sync(self, path): """ During the first sync the folder structure should be created on the server - + :param path: Path to the local directory to synchronize. :type path: str """ - server = self.user_name + "@" + self.host_name + ":" + self.server_repo_path + server = self.user_name + "@" + self.host_name + ":" + self.server_repo_path try: # Run the subprocess command - result = subprocess.run(["rsync", - "-azP" , - path, - server], - check=True) + result = subprocess.run(["rsync", "-azP", path, server], check=True) return ("Success", result.stdout) except subprocess.CalledProcessError as e: return ("Error", e) - def sync(self, src, dst, path): - """ Syncs the data between the src and the dst. Both src and dst can be one of either + """Syncs the data between the src and the dst. Both src and dst can be one of either 'client' or 'server', whereas path is the local path we wish to sync :param src: A string specifying the source, from where the data will be sent to dst. Can be 'client' or 'server'. @@ -57,16 +54,16 @@ def sync(self, src, dst, path): :type dst: str :param path: Path to the directory we want to synchronize. :type path: str - + """ - path += '/' # otherwise it doesn't go in the directory - rel_path = get_relative_path(path) # get last folder, i.e. uncurated, curated + path += "/" # otherwise it doesn't go in the directory + rel_path = get_relative_path(path) # get last folder, i.e. uncurated, curated server_full_path = os.path.join(self.server_repo_path, rel_path) - server_full_path += '/' - server = self.user_name + "@" + self.host_name + ":" + server_full_path - print('server is: ', server) - - if src=='server': + server_full_path += "/" + server = self.user_name + "@" + self.host_name + ":" + server_full_path + print("server is: ", server) + + if src == "server": src = server dst = path else: @@ -74,19 +71,14 @@ def sync(self, src, dst, path): dst = server try: # Run the subprocess command - _ = subprocess.run(["rsync", - "-r" , - "--delete", - src, - dst], - check=True) + _ = subprocess.run(["rsync", "-r", "--delete", src, dst], check=True) return ("Success", server_full_path) except subprocess.CalledProcessError as e: return ("Error", e) - -if __name__=="__main__": - ds = DataRSync() #vm2 + +if __name__ == "__main__": + ds = DataRSync() # vm2 # These combinations work for me: # ubuntu@jusuf-vm2:/path... # jusuf-vm2:/path... @@ -94,6 +86,8 @@ def sync(self, src, dst, path): src = "client" # dst = 'client' # src = 'server' - #path = "data/" - path = "/Users/christina.bukas/Documents/AI_projects/code/data-centric-platform/data" - ds.sync(src, dst, path) \ No newline at end of file + # path = "data/" + path = ( + "/Users/christina.bukas/Documents/AI_projects/code/data-centric-platform/data" + ) + ds.sync(src, dst, path) diff --git a/src/client/dcp_client/utils/utils.py b/src/client/dcp_client/utils/utils.py index fe43004b..5b2ef133 100644 --- a/src/client/dcp_client/utils/utils.py +++ b/src/client/dcp_client/utils/utils.py @@ -1,4 +1,4 @@ -from PyQt5.QtWidgets import QFileIconProvider +from PyQt5.QtWidgets import QFileIconProvider from PyQt5.QtCore import QSize from PyQt5.QtGui import QPixmap, QIcon import numpy as np @@ -10,15 +10,15 @@ from dcp_client.utils import settings + class IconProvider(QFileIconProvider): def __init__(self) -> None: - """ Initializes the IconProvider with the default icon size. - """ + """Initializes the IconProvider with the default icon size.""" super().__init__() - self.ICON_SIZE = QSize(512,512) + self.ICON_SIZE = QSize(512, 512) - def icon(self, type: 'QFileIconProvider.IconType'): - """ Returns the icon for the specified file type. + def icon(self, type: "QFileIconProvider.IconType"): + """Returns the icon for the specified file type. :param type: The type of the file for which the icon is requested. :type type: QFileIconProvider.IconType @@ -27,7 +27,8 @@ def icon(self, type: 'QFileIconProvider.IconType'): """ try: fn = type.filePath() - except AttributeError: return super().icon(type) # TODO handle exception differently? + except AttributeError: + return super().icon(type) # TODO handle exception differently? if fn.endswith(settings.accepted_types): a = QPixmap(self.ICON_SIZE) @@ -36,7 +37,8 @@ def icon(self, type: 'QFileIconProvider.IconType'): else: return super().icon(type) -def read_config(name, config_path = 'config.yaml') -> dict: + +def read_config(name, config_path="config.yaml") -> dict: """Reads the configuration file :param name: name of the section you want to read (e.g. 'setup','train') @@ -45,15 +47,18 @@ def read_config(name, config_path = 'config.yaml') -> dict: :type config_path: str, optional :return: dictionary from the config section given by name :rtype: dict - """ + """ with open(config_path) as config_file: - config_dict = yaml.safe_load(config_file) # json.load(config_file) for .cfg file + config_dict = yaml.safe_load( + config_file + ) # json.load(config_file) for .cfg file # Check if config file has main mandatory keys - assert all([i in config_dict.keys() for i in ['server']]) + assert all([i in config_dict.keys() for i in ["server"]]) return config_dict[name] + def get_relative_path(filepath): - """ Returns the name of the file from the given filepath. + """Returns the name of the file from the given filepath. :param filepath: The path of the file. :type filepath: str @@ -62,8 +67,9 @@ def get_relative_path(filepath): """ return PurePath(filepath).name -def get_path_stem(filepath): - """ Returns the stem (filename without its extension) from the given filepath. + +def get_path_stem(filepath): + """Returns the stem (filename without its extension) from the given filepath. :param filepath: The path of the file. :type filepath: str @@ -72,8 +78,9 @@ def get_path_stem(filepath): """ return str(Path(filepath).stem) + def get_path_name(filepath): - """ Returns the name of the file from the given filepath. + """Returns the name of the file from the given filepath. :param filepath: The path of the file. :type filepath: str @@ -82,8 +89,9 @@ def get_path_name(filepath): """ return str(Path(filepath).name) + def get_path_parent(filepath): - """ Returns the parent directory of the given filepath. + """Returns the parent directory of the given filepath. :param filepath: The path of the file. :type filepath: str @@ -92,8 +100,9 @@ def get_path_parent(filepath): """ return str(Path(filepath).parent) + def join_path(root_dir, filepath): - """ Joins the root directory path with the given filepath. + """Joins the root directory path with the given filepath. :param root_dir: The root directory. :type root_dir: str @@ -104,8 +113,9 @@ def join_path(root_dir, filepath): """ return str(Path(root_dir, filepath)) + def check_equal_arrays(array1, array2): - """ Checks if two arrays are equal. + """Checks if two arrays are equal. :param array1: The first array. :type array1: numpy.ndarray @@ -116,6 +126,7 @@ def check_equal_arrays(array1, array2): """ return np.array_equal(array1, array2) + class Compute4Mask: """ Compute4Mask provides methods for manipulating masks. @@ -123,7 +134,7 @@ class Compute4Mask: @staticmethod def get_contours(instance_mask, contours_level=None): - """ Find contours of objects in the instance mask. This function is used to identify the contours of the objects to prevent the problem of the merged + """Find contours of objects in the instance mask. This function is used to identify the contours of the objects to prevent the problem of the merged objects in napari window (mask). :param instance_mask: The instance mask array. @@ -134,29 +145,37 @@ def get_contours(instance_mask, contours_level=None): :rtype: numpy.ndarray """ - instance_ids = Compute4Mask.get_unique_objects(instance_mask) # get object instance labels ignoring background - contour_mask= np.zeros_like(instance_mask) + instance_ids = Compute4Mask.get_unique_objects( + instance_mask + ) # get object instance labels ignoring background + contour_mask = np.zeros_like(instance_mask) for instance_id in instance_ids: # get a binary mask only of object single_obj_mask = np.zeros_like(instance_mask) - single_obj_mask[instance_mask==instance_id] = 1 + single_obj_mask[instance_mask == instance_id] = 1 try: # compute contours for mask contours = find_contours(single_obj_mask, contours_level) # sometimes little dots appeas as additional contours so remove these - if len(contours)>1: + if len(contours) > 1: contour_sizes = [contour.shape[0] for contour in contours] - contour = contours[contour_sizes.index(max(contour_sizes))].astype(int) - else: contour = contours[0] + contour = contours[contour_sizes.index(max(contour_sizes))].astype( + int + ) + else: + contour = contours[0] # and draw onto contours mask - rr, cc = polygon_perimeter(contour[:, 0], contour[:, 1], contour_mask.shape) + rr, cc = polygon_perimeter( + contour[:, 0], contour[:, 1], contour_mask.shape + ) contour_mask[rr, cc] = instance_id - except: print("Could not create contour for instance id", instance_id) + except: + print("Could not create contour for instance id", instance_id) return contour_mask - + @staticmethod def add_contour(labels_mask, instance_mask): - """ Add contours of objects to the labels mask. + """Add contours of objects to the labels mask. :param labels_mask: The class mask array without the contour pixels annotated. :type labels_mask: numpy.ndarray @@ -167,20 +186,21 @@ def add_contour(labels_mask, instance_mask): """ instance_ids = Compute4Mask.get_unique_objects(instance_mask) for instance_id in instance_ids: - where_instances = np.where(instance_mask==instance_id) + where_instances = np.where(instance_mask == instance_id) # get unique class ids where the object is present - class_vals, counts = np.unique(labels_mask[where_instances], return_counts=True) + class_vals, counts = np.unique( + labels_mask[where_instances], return_counts=True + ) # and take the class id which is most heavily represented class_id = class_vals[np.argmax(counts)] # make sure instance mask and class mask match - labels_mask[np.where(instance_mask==instance_id)] = class_id + labels_mask[np.where(instance_mask == instance_id)] = class_id return labels_mask - @staticmethod def compute_new_instance_mask(labels_mask, instance_mask): - """ Given an updated labels mask, update also the instance mask accordingly. - So far the user can only remove an entire object in the labels mask view by + """Given an updated labels mask, update also the instance mask accordingly. + So far the user can only remove an entire object in the labels mask view by setting the color of the object to the background. Therefore the instance mask can only change by entirely removing an object. @@ -193,15 +213,21 @@ def compute_new_instance_mask(labels_mask, instance_mask): """ instance_ids = Compute4Mask.get_unique_objects(instance_mask) for instance_id in instance_ids: - unique_items_in_class_mask = list(np.unique(labels_mask[instance_mask==instance_id])) - if len(unique_items_in_class_mask)==1 and unique_items_in_class_mask[0]==0: - instance_mask[instance_mask==instance_id] = 0 + unique_items_in_class_mask = list( + np.unique(labels_mask[instance_mask == instance_id]) + ) + if ( + len(unique_items_in_class_mask) == 1 + and unique_items_in_class_mask[0] == 0 + ): + instance_mask[instance_mask == instance_id] = 0 return instance_mask - @staticmethod - def compute_new_labels_mask(labels_mask, instance_mask, original_instance_mask, old_instances): - """ Given the existing labels mask, the updated instance mask is used to update the labels mask. + def compute_new_labels_mask( + labels_mask, instance_mask, original_instance_mask, old_instances + ): + """Given the existing labels mask, the updated instance mask is used to update the labels mask. :param labels_mask: The existing labels mask, which needs to be updated. :type labels_mask: numpy.ndarray @@ -216,34 +242,37 @@ def compute_new_labels_mask(labels_mask, instance_mask, original_instance_mask, """ new_labels_mask = np.zeros_like(labels_mask) for instance_id in np.unique(instance_mask): - where_instance = np.where(instance_mask==instance_id) + where_instance = np.where(instance_mask == instance_id) # if the label is background skip - if instance_id==0: continue + if instance_id == 0: + continue # if the label is a newly added object, add with the same id to the labels mask # this is an indication to the user that this object needs to be assigned a class elif instance_id not in old_instances: new_labels_mask[where_instance] = instance_id else: - where_instance_orig = np.where(original_instance_mask==instance_id) + where_instance_orig = np.where(original_instance_mask == instance_id) # if the locations of the instance haven't changed, means object wasn't changed, do nothing num_classes = np.unique(labels_mask[where_instance]) # if area was erased and object retains same class - if len(num_classes)==1: + if len(num_classes) == 1: new_labels_mask[where_instance] = num_classes[0] # area was added where there is background or other class else: - old_class_id, counts = np.unique(labels_mask[where_instance_orig], return_counts=True) - #assert len(old_class_id)==1 - #old_class_id = old_class_id[0] + old_class_id, counts = np.unique( + labels_mask[where_instance_orig], return_counts=True + ) + # assert len(old_class_id)==1 + # old_class_id = old_class_id[0] # and take the class id which is most heavily represented old_class_id = old_class_id[np.argmax(counts)] new_labels_mask[where_instance] = old_class_id - + return new_labels_mask - + @staticmethod def get_unique_objects(active_mask): - """ Gets unique objects from the active mask. + """Gets unique objects from the active mask. :param active_mask: The mask array. :type active_mask: numpy.ndarray @@ -251,10 +280,10 @@ def get_unique_objects(active_mask): :rtype: list """ return list(np.unique(active_mask)[1:]) - + @staticmethod def assert_consistent_labels(mask): - """ Before saving the final mask make sure the user has not mistakenly made an error during annotation, + """Before saving the final mask make sure the user has not mistakenly made an error during annotation, such that one instance id does not correspond to exactly one class id. Also checks whether for one instance id multiple classes exist. :param mask: The mask which we want to test. @@ -265,10 +294,10 @@ def assert_consistent_labels(mask): - A list with all the instance ids for which more than one connected component was found. - A list with all the instance ids for which a missmatch between class and instance masks was found. :rtype : - - bool - bool + - bool + - list[int] - list[int] - - list[int] """ user_annot_error = False mask_mismatch_error = False @@ -278,12 +307,20 @@ def assert_consistent_labels(mask): instance_ids = Compute4Mask.get_unique_objects(instance_mask) for instance_id in instance_ids: # check if there are more than one objects (connected components) with same instance_id - if np.unique(label(instance_mask==instance_id)).shape[0] > 2: + if np.unique(label(instance_mask == instance_id)).shape[0] > 2: user_annot_error = True faulty_ids_annot.append(instance_id) # and check if there is a mismatch between class mask and instance mask - should never happen! - if np.unique(class_mask[np.where(instance_mask==instance_id)]).shape[0]>1: + if ( + np.unique(class_mask[np.where(instance_mask == instance_id)]).shape[0] + > 1 + ): mask_mismatch_error = True faulty_ids_missmatch.append(instance_id) - return user_annot_error, mask_mismatch_error, faulty_ids_annot, faulty_ids_missmatch \ No newline at end of file + return ( + user_annot_error, + mask_mismatch_error, + faulty_ids_annot, + faulty_ids_missmatch, + ) diff --git a/src/client/test/test_app.py b/src/client/test/test_app.py index e4e6d1f9..ad31285a 100644 --- a/src/client/test/test_app.py +++ b/src/client/test/test_app.py @@ -1,5 +1,6 @@ import os import sys + sys.path.append("../") import pytest import subprocess @@ -13,88 +14,101 @@ from dcp_client.utils.fsimagestorage import FilesystemImageStorage from dcp_client.utils.sync_src_dst import DataRSync + @pytest.fixture def app(): img1 = data.astronaut() img2 = data.coffee() img3 = data.cat() - if not os.path.exists('in_prog'): - os.mkdir('in_prog') - imsave('in_prog/coffee.png', img2) + if not os.path.exists("in_prog"): + os.mkdir("in_prog") + imsave("in_prog/coffee.png", img2) - if not os.path.exists('eval_data_path'): - os.mkdir('eval_data_path') - imsave('eval_data_path/cat.png', img3) + if not os.path.exists("eval_data_path"): + os.mkdir("eval_data_path") + imsave("eval_data_path/cat.png", img3) - rsyncer = DataRSync(user_name="local", host_name="local", server_repo_path='.') - app = Application(BentomlModel(), - rsyncer, - FilesystemImageStorage(), - "0.0.0.0", - 7010, - os.path.join(os.getcwd(), 'eval_data_path')) + rsyncer = DataRSync(user_name="local", host_name="local", server_repo_path=".") + app = Application( + BentomlModel(), + rsyncer, + FilesystemImageStorage(), + "0.0.0.0", + 7010, + os.path.join(os.getcwd(), "eval_data_path"), + ) return app, img1, img2, img3 + def test_load_image(app): app, img, img2, _ = app # Unpack the app, img, and img2 from the fixture - - app.cur_selected_img = 'coffee.png' - app.cur_selected_path = 'in_prog' + + app.cur_selected_img = "coffee.png" + app.cur_selected_path = "in_prog" img_test = app.load_image() # if image_name is None assert img.all() == img_test.all() - app.cur_selected_path = 'eval_data_path' - img_test2 = app.load_image('cat.png') # if a filename is given + app.cur_selected_path = "eval_data_path" + img_test2 = app.load_image("cat.png") # if a filename is given assert img2.all() == img_test2.all() + def test_run_inference_no_connection(app): - app, _, _, _ = app + app, _, _, _ = app message_text, message_title = app.run_inference() - assert message_text=="Connection could not be established. Please check if the server is running and try again." - assert message_title=="Warning" + assert ( + message_text + == "Connection could not be established. Please check if the server is running and try again." + ) + assert message_title == "Warning" + def test_run_inference_run(app): - app, _, _, _ = app + app, _, _, _ = app # start the sevrer in the background locally command = [ "bentoml", - "serve", - '--working-dir', - '../server/dcp_server', + "serve", + "--working-dir", + "../server/dcp_server", "service:svc", "--reload", "--port=7010", ] process = subprocess.Popen(command, stdin=subprocess.PIPE, shell=False) # and wait until it is setup - if sys.platform == 'win32' or sys.platform == 'cygwin': time.sleep(240) - else: time.sleep(60) + if sys.platform == "win32" or sys.platform == "cygwin": + time.sleep(240) + else: + time.sleep(60) # then do model serving message_text, message_title = app.run_inference() # and assert returning message print(f"HERE: {message_text, message_title}") - assert message_text== "Success! Masks generated for all images" - assert message_title=="Information" + assert message_text == "Success! Masks generated for all images" + assert message_title == "Information" # finally clean up process process.terminate() process.wait() process.kill() + def test_search_segs(app): - app, _, _, _ = app - app.cur_selected_img = 'cat.png' - app.cur_selected_path = 'eval_data_path' + app, _, _, _ = app + app.cur_selected_img = "cat.png" + app.cur_selected_path = "eval_data_path" app.search_segs() - res = app.seg_filepaths - assert len(res)==1 - assert res[0]=='cat_seg.tiff' + res = app.seg_filepaths + assert len(res) == 1 + assert res[0] == "cat_seg.tiff" # also remove the seg as it is not needed for other scripts - os.remove('eval_data_path/cat_seg.tiff') + os.remove("eval_data_path/cat_seg.tiff") -''' + +""" def test_run_train(): pass @@ -107,7 +121,4 @@ def test_move_images(): def test_delete_images(): pass -''' - - - +""" diff --git a/src/client/test/test_compute4mask.py b/src/client/test/test_compute4mask.py index e76dfc1c..b4cc7435 100644 --- a/src/client/test/test_compute4mask.py +++ b/src/client/test/test_compute4mask.py @@ -2,82 +2,113 @@ import pytest from dcp_client.utils.utils import Compute4Mask + @pytest.fixture def sample_data(): - instance_mask = np.array([[0, 1, 1, 1, 0], - [0, 1, 1, 1, 0], - [0, 1, 1, 1, 0], - [2, 2, 0, 0, 0], - [0, 0, 3, 3, 0]]) - labels_mask = np.array([[0, 1, 1, 1, 0], - [0, 1, 1, 1, 0], - [0, 1, 1, 1, 0], - [2, 2, 0, 0, 0], - [0, 0, 1, 1, 0]]) + instance_mask = np.array( + [ + [0, 1, 1, 1, 0], + [0, 1, 1, 1, 0], + [0, 1, 1, 1, 0], + [2, 2, 0, 0, 0], + [0, 0, 3, 3, 0], + ] + ) + labels_mask = np.array( + [ + [0, 1, 1, 1, 0], + [0, 1, 1, 1, 0], + [0, 1, 1, 1, 0], + [2, 2, 0, 0, 0], + [0, 0, 1, 1, 0], + ] + ) return instance_mask, labels_mask + def test_get_unique_objects(sample_data): instance_mask, _ = sample_data unique_objects = Compute4Mask.get_unique_objects(instance_mask) assert unique_objects == [1, 2, 3] + def test_get_contours(sample_data): instance_mask, _ = sample_data contour_mask = Compute4Mask.get_contours(instance_mask) assert contour_mask.shape == instance_mask.shape - assert contour_mask[0,1] == 1 # randomly check a contour location is present + assert contour_mask[0, 1] == 1 # randomly check a contour location is present + def test_add_contour(sample_data): instance_mask, labels_mask = sample_data contours_mask = Compute4Mask.get_contours(instance_mask, contours_level=0.1) labels_mask_wo_contour = np.copy(labels_mask) - labels_mask_wo_contour[contours_mask!=0] = 0 - updated_labels_mask = Compute4Mask.add_contour(labels_mask_wo_contour, instance_mask) + labels_mask_wo_contour[contours_mask != 0] = 0 + updated_labels_mask = Compute4Mask.add_contour( + labels_mask_wo_contour, instance_mask + ) assert np.array_equal(updated_labels_mask[:3], labels_mask[:3]) + def test_compute_new_instance_mask(sample_data): instance_mask, labels_mask = sample_data - labels_mask[labels_mask==1] = 0 - updated_instance_mask = Compute4Mask.compute_new_instance_mask(labels_mask, instance_mask) - assert list(np.unique(updated_instance_mask))==[0,2] + labels_mask[labels_mask == 1] = 0 + updated_instance_mask = Compute4Mask.compute_new_instance_mask( + labels_mask, instance_mask + ) + assert list(np.unique(updated_instance_mask)) == [0, 2] + def test_compute_new_labels_mask_obj_added(sample_data): instance_mask, labels_mask = sample_data original_instance_mask = np.copy(instance_mask) instance_mask[0, 0] = 4 old_instances = Compute4Mask.get_unique_objects(original_instance_mask) - new_labels_mask = Compute4Mask.compute_new_labels_mask(labels_mask, instance_mask, original_instance_mask, old_instances) - assert new_labels_mask[0,0]==4 + new_labels_mask = Compute4Mask.compute_new_labels_mask( + labels_mask, instance_mask, original_instance_mask, old_instances + ) + assert new_labels_mask[0, 0] == 4 + def test_compute_new_labels_mask_obj_erased(sample_data): instance_mask, labels_mask = sample_data original_instance_mask = np.copy(instance_mask) instance_mask[0] = 0 old_instances = Compute4Mask.get_unique_objects(original_instance_mask) - new_labels_mask = Compute4Mask.compute_new_labels_mask(labels_mask, instance_mask, original_instance_mask, old_instances) - assert np.all(new_labels_mask[0])==0 + new_labels_mask = Compute4Mask.compute_new_labels_mask( + labels_mask, instance_mask, original_instance_mask, old_instances + ) + assert np.all(new_labels_mask[0]) == 0 assert np.array_equal(new_labels_mask[1:], labels_mask[1:]) + def test_compute_new_labels_mask_obj_added(sample_data): instance_mask, labels_mask = sample_data original_instance_mask = np.copy(instance_mask) instance_mask[:, -1] = 1 old_instances = Compute4Mask.get_unique_objects(original_instance_mask) - new_labels_mask = Compute4Mask.compute_new_labels_mask(labels_mask, instance_mask, original_instance_mask, old_instances) - assert np.all(new_labels_mask[:, -1])==1 + new_labels_mask = Compute4Mask.compute_new_labels_mask( + labels_mask, instance_mask, original_instance_mask, old_instances + ) + assert np.all(new_labels_mask[:, -1]) == 1 + def assert_consistent_labels(sample_data): instance_mask, labels_mask = sample_data - user_annot_error, mask_mismatch_error, faulty_ids_annot, faulty_ids_missmatch = Compute4Mask.assert_consistent_labels(sample_data) - assert user_annot_error==False - assert mask_mismatch_error==False - assert len(faulty_ids_annot)==len(faulty_ids_missmatch)==0 - instance_mask[instance_mask==3] = 1 - labels_mask[1,2] = 2 - user_annot_error, mask_mismatch_error, faulty_ids_annot, faulty_ids_missmatch = Compute4Mask.assert_consistent_labels(np.stack(instance_mask, labels_mask)) - assert user_annot_error==True - assert mask_mismatch_error==True - assert len(faulty_ids_annot)==1 - assert faulty_ids_annot[0]==1 - assert len(faulty_ids_missmatch)==1 - assert faulty_ids_missmatch[0]==1 + user_annot_error, mask_mismatch_error, faulty_ids_annot, faulty_ids_missmatch = ( + Compute4Mask.assert_consistent_labels(sample_data) + ) + assert user_annot_error == False + assert mask_mismatch_error == False + assert len(faulty_ids_annot) == len(faulty_ids_missmatch) == 0 + instance_mask[instance_mask == 3] = 1 + labels_mask[1, 2] = 2 + user_annot_error, mask_mismatch_error, faulty_ids_annot, faulty_ids_missmatch = ( + Compute4Mask.assert_consistent_labels(np.stack(instance_mask, labels_mask)) + ) + assert user_annot_error == True + assert mask_mismatch_error == True + assert len(faulty_ids_annot) == 1 + assert faulty_ids_annot[0] == 1 + assert len(faulty_ids_missmatch) == 1 + assert faulty_ids_missmatch[0] == 1 diff --git a/src/client/test/test_fsimagestorage.py b/src/client/test/test_fsimagestorage.py index 275e5f0b..f971fbfe 100644 --- a/src/client/test/test_fsimagestorage.py +++ b/src/client/test/test_fsimagestorage.py @@ -5,42 +5,48 @@ from dcp_client.utils.fsimagestorage import FilesystemImageStorage + @pytest.fixture def fis(): return FilesystemImageStorage() + @pytest.fixture def sample_image(): # Create a sample image img = data.astronaut() - fname = 'test_img.png' + fname = "test_img.png" imsave(fname, img) return fname - + + def test_load_image(fis, sample_image): - img_test = fis.load_image('.', sample_image) + img_test = fis.load_image(".", sample_image) assert img_test.all() == data.astronaut().all() os.remove(sample_image) + def test_move_image(fis, sample_image): - temp_dir = 'temp' + temp_dir = "temp" os.mkdir(temp_dir) - fis.move_image('.', temp_dir, sample_image) - assert os.path.exists(os.path.join(temp_dir, 'test_img.png')) - os.remove(os.path.join(temp_dir, 'test_img.png')) + fis.move_image(".", temp_dir, sample_image) + assert os.path.exists(os.path.join(temp_dir, "test_img.png")) + os.remove(os.path.join(temp_dir, "test_img.png")) os.rmdir(temp_dir) + def test_save_image(fis): img = data.astronaut() - fname = 'output.png' - fis.save_image('.', fname, img) + fname = "output.png" + fis.save_image(".", fname, img) assert os.path.exists(fname) os.remove(fname) + def test_delete_image(fis, sample_image): - temp_dir = 'temp' + temp_dir = "temp" os.mkdir(temp_dir) - fis.move_image('.', temp_dir, sample_image) - fis.delete_image(temp_dir, 'test_img.png') - assert not os.path.exists(os.path.join(temp_dir, 'test_img.png')) + fis.move_image(".", temp_dir, sample_image) + fis.delete_image(temp_dir, "test_img.png") + assert not os.path.exists(os.path.join(temp_dir, "test_img.png")) os.rmdir(temp_dir) diff --git a/src/client/test/test_main_window.py b/src/client/test/test_main_window.py index d5fae533..788dea3c 100644 --- a/src/client/test/test_main_window.py +++ b/src/client/test/test_main_window.py @@ -1,7 +1,8 @@ import os import pytest import sys -sys.path.append('../') + +sys.path.append("../") from skimage import data from skimage.io import imsave @@ -16,11 +17,13 @@ from dcp_client.utils.sync_src_dst import DataRSync from dcp_client.utils import settings + @pytest.fixture() def setup_global_variable(): settings.accepted_types = (".jpg", ".jpeg", ".png", ".tiff", ".tif") yield settings.accepted_types + @pytest.fixture def app(qtbot, setup_global_variable): @@ -30,112 +33,121 @@ def app(qtbot, setup_global_variable): img2 = data.coffee() img3 = data.cat() - if not os.path.exists('train_data_path'): - os.mkdir('train_data_path') - imsave('train_data_path/astronaut.png', img1) - - if not os.path.exists('in_prog'): - os.mkdir('in_prog') - imsave('in_prog/coffee.png', img2) - - if not os.path.exists('eval_data_path'): - os.mkdir('eval_data_path') - imsave('eval_data_path/cat.png', img3) - - rsyncer = DataRSync(user_name="local", host_name="local", server_repo_path='.') - application = Application(BentomlModel(), - rsyncer, - FilesystemImageStorage(), - "0.0.0.0", - 7010, - 'eval_data_path', - 'train_data_path', - 'in_prog') + if not os.path.exists("train_data_path"): + os.mkdir("train_data_path") + imsave("train_data_path/astronaut.png", img1) + + if not os.path.exists("in_prog"): + os.mkdir("in_prog") + imsave("in_prog/coffee.png", img2) + + if not os.path.exists("eval_data_path"): + os.mkdir("eval_data_path") + imsave("eval_data_path/cat.png", img3) + + rsyncer = DataRSync(user_name="local", host_name="local", server_repo_path=".") + application = Application( + BentomlModel(), + rsyncer, + FilesystemImageStorage(), + "0.0.0.0", + 7010, + "eval_data_path", + "train_data_path", + "in_prog", + ) # Create an instance of MainWindow widget = MainWindow(application) qtbot.addWidget(widget) yield widget widget.close() - + + def test_main_window_setup(qtbot, app, setup_global_variable): settings.accepted_types = setup_global_variable assert app.title == "Data Overview" + def test_item_train_selected(qtbot, app, setup_global_variable): settings.accepted_types = setup_global_variable # Select the first item in the tree view - #index = app.list_view_train.model().index(0, 0) + # index = app.list_view_train.model().index(0, 0) index = app.list_view_train.indexAt(app.list_view_train.viewport().rect().topLeft()) pos = app.list_view_train.visualRect(index).center() # Simulate file click - QTest.mouseClick(app.list_view_train.viewport(), - Qt.LeftButton, - pos=pos) + QTest.mouseClick(app.list_view_train.viewport(), Qt.LeftButton, pos=pos) app.on_item_train_selected(index) # Assert that the selected item matches the expected item assert app.list_view_train.selectionModel().currentIndex() == index - assert app.app.cur_selected_img=='astronaut.png' - assert app.app.cur_selected_path==app.app.train_data_path + assert app.app.cur_selected_img == "astronaut.png" + assert app.app.cur_selected_path == app.app.train_data_path + def test_item_inprog_selected(qtbot, app, setup_global_variable): settings.accepted_types = setup_global_variable # Select the first item in the tree view - index = app.list_view_inprogr.indexAt(app.list_view_inprogr.viewport().rect().topLeft()) + index = app.list_view_inprogr.indexAt( + app.list_view_inprogr.viewport().rect().topLeft() + ) pos = app.list_view_inprogr.visualRect(index).center() # Simulate file click - QTest.mouseClick(app.list_view_inprogr.viewport(), - Qt.LeftButton, - pos=pos) + QTest.mouseClick(app.list_view_inprogr.viewport(), Qt.LeftButton, pos=pos) app.on_item_inprogr_selected(index) # Assert that the selected item matches the expected item assert app.list_view_inprogr.selectionModel().currentIndex() == index assert app.app.cur_selected_img == "coffee.png" assert app.app.cur_selected_path == app.app.inprogr_data_path + def test_item_eval_selected(qtbot, app, setup_global_variable): settings.accepted_types = setup_global_variable # Select the first item in the tree view index = app.list_view_eval.indexAt(app.list_view_eval.viewport().rect().topLeft()) pos = app.list_view_eval.visualRect(index).center() # Simulate file click - QTest.mouseClick(app.list_view_eval.viewport(), - Qt.LeftButton, - pos=pos) + QTest.mouseClick(app.list_view_eval.viewport(), Qt.LeftButton, pos=pos) app.on_item_eval_selected(index) # Assert that the selected item matches the expected item assert app.list_view_eval.selectionModel().currentIndex() == index - assert app.app.cur_selected_img=='cat.png' - assert app.app.cur_selected_path==app.app.eval_data_path + assert app.app.cur_selected_img == "cat.png" + assert app.app.cur_selected_path == app.app.eval_data_path + def test_train_button_click(qtbot, app): # Click the "Train Model" button app.sim = True QTest.mouseClick(app.train_button, Qt.LeftButton) # Wait until the worker thread is done - while app.worker_thread.isRunning(): QTest.qSleep(1000) + while app.worker_thread.isRunning(): + QTest.qSleep(1000) # The train functionality of the thread is tested with app tests + def test_inference_button_click(qtbot, app): # Click the "Generate Labels" button app.sim = True QTest.mouseClick(app.inference_button, Qt.LeftButton) # Wait until the worker thread is done - while app.worker_thread.isRunning(): QTest.qSleep(1000) - #QTest.qWaitForWindowActive(app, timeout=5000) + while app.worker_thread.isRunning(): + QTest.qSleep(1000) + # QTest.qWaitForWindowActive(app, timeout=5000) # The inference functionality of the thread is tested with app tests + def test_on_finished(qtbot, app): # Assert that the on_finished function re-enabled the buttons and set the worker thread to None assert app.train_button.isEnabled() assert app.inference_button.isEnabled() assert app.worker_thread is None + def test_launch_napari_button_click_without_selection(qtbot, app): # Try clicking the view button without having selected an image app.sim = True qtbot.mouseClick(app.launch_nap_button, Qt.LeftButton) - assert not hasattr(app, 'nap_win') + assert not hasattr(app, "nap_win") + def test_launch_napari_button_click(qtbot, app): settings.accepted_types = setup_global_variable @@ -143,28 +155,28 @@ def test_launch_napari_button_click(qtbot, app): index = app.list_view_eval.indexAt(app.list_view_eval.viewport().rect().topLeft()) pos = app.list_view_eval.visualRect(index).center() # Simulate file click - QTest.mouseClick(app.list_view_eval.viewport(), - Qt.LeftButton, - pos=pos) + QTest.mouseClick(app.list_view_eval.viewport(), Qt.LeftButton, pos=pos) app.on_item_eval_selected(index) # Now click the view button qtbot.mouseClick(app.launch_nap_button, Qt.LeftButton) # Assert that the napari window has launched - assert hasattr(app, 'nap_win') + assert hasattr(app, "nap_win") assert app.nap_win.isVisible() -@pytest.fixture(scope='session', autouse=True) + +@pytest.fixture(scope="session", autouse=True) def cleanup_files(request): # This code runs after all tests from all files have completed yield # Clean up - paths_to_clean = ['train_data_path', 'in_prog', 'eval_data_path'] + paths_to_clean = ["train_data_path", "in_prog", "eval_data_path"] for path in paths_to_clean: try: for fname in os.listdir(path): os.remove(os.path.join(path, fname)) os.rmdir(path) - except FileNotFoundError: pass + except FileNotFoundError: + pass except Exception as e: # Handle other exceptions - print(f"An error occurred while cleaning up {path}: {e}") \ No newline at end of file + print(f"An error occurred while cleaning up {path}: {e}") diff --git a/src/client/test/test_mywidget.py b/src/client/test/test_mywidget.py index e75172c1..7e10f53f 100644 --- a/src/client/test/test_mywidget.py +++ b/src/client/test/test_mywidget.py @@ -1,35 +1,47 @@ import pytest import sys -sys.path.append('../') + +sys.path.append("../") from PyQt5.QtWidgets import QMessageBox from dcp_client.gui._my_widget import MyWidget + @pytest.fixture def app(qtbot): - #q_app = QApplication([]) + # q_app = QApplication([]) widget = MyWidget() qtbot.addWidget(widget) yield widget widget.close() + def test_create_warning_box_ok(qtbot, app): result = None app.sim = True + def execute_warning_box(): nonlocal result box = QMessageBox() result = app.create_warning_box("Test Message", custom_dialog=box) - qtbot.waitUntil(execute_warning_box, timeout=5000) - assert result is True + + qtbot.waitUntil(execute_warning_box, timeout=5000) + assert result is True + def test_create_warning_box_cancel(qtbot, app): result = None app.sim = True + def execute_warning_box(): nonlocal result box = QMessageBox() - result = app.create_warning_box("Test Message", add_cancel_btn=True, custom_dialog=box) - qtbot.waitUntil(execute_warning_box, timeout=5000) # Add a timeout for the function to execute - assert result is False + result = app.create_warning_box( + "Test Message", add_cancel_btn=True, custom_dialog=box + ) + + qtbot.waitUntil( + execute_warning_box, timeout=5000 + ) # Add a timeout for the function to execute + assert result is False diff --git a/src/client/test/test_napari_window.py b/src/client/test/test_napari_window.py index 06978ebf..8c31ebcf 100644 --- a/src/client/test/test_napari_window.py +++ b/src/client/test/test_napari_window.py @@ -21,11 +21,12 @@ # yield napari_app # napari_app.close() + @pytest.fixture def napari_window(qtbot): - #img1 = data.astronaut() - #img2 = data.coffee() + # img1 = data.astronaut() + # img2 = data.coffee() img = data.cat() img_mask = np.zeros((2, img.shape[0], img.shape[1]), dtype=np.uint8) img_mask[0, 50:50, 50:50] = 1 @@ -34,61 +35,63 @@ def napari_window(qtbot): img_mask[1, 100:200, 100:200] = 1 img_mask[0, 200:300, 200:300] = 3 img_mask[1, 200:300, 200:300] = 2 - #img3_mask = img2_mask.copy() + # img3_mask = img2_mask.copy() + + if not os.path.exists("train_data_path"): + os.mkdir("train_data_path") - if not os.path.exists('train_data_path'): - os.mkdir('train_data_path') + if not os.path.exists("in_prog"): + os.mkdir("in_prog") - if not os.path.exists('in_prog'): - os.mkdir('in_prog') + if not os.path.exists("eval_data_path"): + os.mkdir("eval_data_path") + imsave("eval_data_path/cat.png", img) - if not os.path.exists('eval_data_path'): - os.mkdir('eval_data_path') - imsave('eval_data_path/cat.png', img) - - imsave('eval_data_path/cat_seg.tiff', img_mask) + imsave("eval_data_path/cat_seg.tiff", img_mask) - rsyncer = DataRSync(user_name="local", host_name="local", server_repo_path='.') + rsyncer = DataRSync(user_name="local", host_name="local", server_repo_path=".") application = Application( - BentomlModel(), - rsyncer, - FilesystemImageStorage(), - "0.0.0.0", + BentomlModel(), + rsyncer, + FilesystemImageStorage(), + "0.0.0.0", 7010, - os.path.join(os.getcwd(), 'eval_data_path'), - os.path.join(os.getcwd(), 'train_data_path'), - os.path.join(os.getcwd(), 'in_prog') + os.path.join(os.getcwd(), "eval_data_path"), + os.path.join(os.getcwd(), "train_data_path"), + os.path.join(os.getcwd(), "in_prog"), ) - application.cur_selected_img = 'cat.png' + application.cur_selected_img = "cat.png" application.cur_selected_path = application.eval_data_path widget = NapariWindow(application) - qtbot.addWidget(widget) - yield widget + qtbot.addWidget(widget) + yield widget widget.close() + def test_napari_window_initialization(napari_window): assert napari_window.viewer is not None assert napari_window.qctrl is not None assert napari_window.mask_choice_dropdown is not None + def test_on_add_to_curated_button_clicked(napari_window, monkeypatch): # Mock the create_warning_box method def mock_create_warning_box(message_text, message_title): - return None - monkeypatch.setattr(napari_window, 'create_warning_box', mock_create_warning_box) + return None + + monkeypatch.setattr(napari_window, "create_warning_box", mock_create_warning_box) - napari_window.app.cur_selected_img = 'cat.png' + napari_window.app.cur_selected_img = "cat.png" napari_window.app.cur_selected_path = napari_window.app.eval_data_path - napari_window.viewer.layers.selection.active.name = 'cat_seg' + napari_window.viewer.layers.selection.active.name = "cat_seg" # Simulate the button click napari_window.on_add_to_curated_button_clicked() - assert not os.path.exists('eval_data_path/cat.tiff') - assert not os.path.exists('eval_data_path/cat_seg.tiff') - assert os.path.exists('train_data_path/cat.png') - assert os.path.exists('train_data_path/cat_seg.tiff') - + assert not os.path.exists("eval_data_path/cat.tiff") + assert not os.path.exists("eval_data_path/cat_seg.tiff") + assert os.path.exists("train_data_path/cat.png") + assert os.path.exists("train_data_path/cat_seg.tiff") diff --git a/src/client/test/test_sync_src_dst.py b/src/client/test/test_sync_src_dst.py index ca652644..15ed79d3 100644 --- a/src/client/test/test_sync_src_dst.py +++ b/src/client/test/test_sync_src_dst.py @@ -1,26 +1,25 @@ import pytest -from dcp_client.utils.sync_src_dst import DataRSync +from dcp_client.utils.sync_src_dst import DataRSync @pytest.fixture def rsyncer(): - syncer = DataRSync(user_name="local", - host_name="local", - server_repo_path='.') + syncer = DataRSync(user_name="local", host_name="local", server_repo_path=".") return syncer + def test_init(rsyncer): - assert rsyncer.user_name=="local" - assert rsyncer.host_name=="local" - assert rsyncer.server_repo_path=='.' + assert rsyncer.user_name == "local" + assert rsyncer.host_name == "local" + assert rsyncer.server_repo_path == "." + def test_first_sync_e(rsyncer): msg, _ = rsyncer.first_sync("eval_data_path") - assert msg=="Error" + assert msg == "Error" + def test_sync(rsyncer): msg, _ = rsyncer.sync("server", "client", "eval_data_path") - assert msg=="Error" - - + assert msg == "Error" diff --git a/src/client/test/test_utils.py b/src/client/test/test_utils.py index d09c8df6..88d2ce5b 100644 --- a/src/client/test/test_utils.py +++ b/src/client/test/test_utils.py @@ -3,35 +3,38 @@ sys.path.append("../") from dcp_client.utils import utils + def test_get_relative_path(): - filepath = '/here/we/are/testing/something.txt' - assert utils.get_relative_path(filepath)== 'something.txt' + filepath = "/here/we/are/testing/something.txt" + assert utils.get_relative_path(filepath) == "something.txt" + def test_get_path_stem(): - filepath = '/here/we/are/testing/something.txt' - assert utils.get_path_stem(filepath)== 'something' + filepath = "/here/we/are/testing/something.txt" + assert utils.get_path_stem(filepath) == "something" + def test_get_path_name(): - filepath = '/here/we/are/testing/something.txt' - assert utils.get_path_name(filepath)== 'something.txt' + filepath = "/here/we/are/testing/something.txt" + assert utils.get_path_name(filepath) == "something.txt" + def test_get_path_parent(): - if sys.platform == 'win32' or sys.platform == 'cygwin': - filepath = '\\here\\we\\are\\testing\\something.txt' - assert utils.get_path_parent(filepath)== '\\here\\we\\are\\testing' + if sys.platform == "win32" or sys.platform == "cygwin": + filepath = "\\here\\we\\are\\testing\\something.txt" + assert utils.get_path_parent(filepath) == "\\here\\we\\are\\testing" else: - filepath = '/here/we/are/testing/something.txt' - assert utils.get_path_parent(filepath)== '/here/we/are/testing' + filepath = "/here/we/are/testing/something.txt" + assert utils.get_path_parent(filepath) == "/here/we/are/testing" + def test_join_path(): - if sys.platform == 'win32' or sys.platform == 'cygwin': - filepath = '\\here\\we\\are\\testing\\something.txt' - path1 = '\\here\\we\\are\\testing' - path2 = 'something.txt' + if sys.platform == "win32" or sys.platform == "cygwin": + filepath = "\\here\\we\\are\\testing\\something.txt" + path1 = "\\here\\we\\are\\testing" + path2 = "something.txt" else: - filepath = '/here/we/are/testing/something.txt' - path1 = '/here/we/are/testing' - path2 = 'something.txt' + filepath = "/here/we/are/testing/something.txt" + path1 = "/here/we/are/testing" + path2 = "something.txt" assert utils.join_path(path1, path2) == filepath - - diff --git a/src/client/test/test_welcome_window.py b/src/client/test/test_welcome_window.py index 4b15803d..9fdaa49c 100644 --- a/src/client/test/test_welcome_window.py +++ b/src/client/test/test_welcome_window.py @@ -1,6 +1,7 @@ import pytest import sys -sys.path.append('../') + +sys.path.append("../") from PyQt5.QtCore import Qt from PyQt5.QtWidgets import QMessageBox @@ -12,39 +13,48 @@ from dcp_client.utils.sync_src_dst import DataRSync from dcp_client.utils import settings + @pytest.fixture def setup_global_variable(): settings.accepted_types = (".jpg", ".jpeg", ".png", ".tiff", ".tif") yield settings.accepted_types + @pytest.fixture def app(qtbot): - rsyncer = DataRSync(user_name="local", host_name="local", server_repo_path='.') - application = Application(BentomlModel(), rsyncer, FilesystemImageStorage(), "0.0.0.0", 7010) + rsyncer = DataRSync(user_name="local", host_name="local", server_repo_path=".") + application = Application( + BentomlModel(), rsyncer, FilesystemImageStorage(), "0.0.0.0", 7010 + ) # Create an instance of WelcomeWindow # q_app = QApplication([]) widget = WelcomeWindow(application) qtbot.addWidget(widget) - yield widget + yield widget widget.close() + @pytest.fixture def app_remote(qtbot): - rsyncer = DataRSync(user_name="remote", host_name="remote", server_repo_path='.') - application = Application(BentomlModel(), rsyncer, FilesystemImageStorage(), "0.0.0.0", 7010) + rsyncer = DataRSync(user_name="remote", host_name="remote", server_repo_path=".") + application = Application( + BentomlModel(), rsyncer, FilesystemImageStorage(), "0.0.0.0", 7010 + ) # Create an instance of WelcomeWindow # q_app = QApplication([]) widget = WelcomeWindow(application) qtbot.addWidget(widget) - yield widget + yield widget widget.close() + def test_welcome_window_initialization(app): assert app.title == "Select Dataset" assert app.val_textbox.text() == "" assert app.inprogr_textbox.text() == "" assert app.train_textbox.text() == "" + def test_warning_for_same_paths(qtbot, app, monkeypatch): app.app.eval_data_path = "/same/path" app.app.train_data_path = "/same/path" @@ -54,32 +64,43 @@ def test_warning_for_same_paths(qtbot, app, monkeypatch): def custom_exec(self): return QMessageBox.Ok - monkeypatch.setattr(QMessageBox, 'exec', custom_exec) - qtbot.mouseClick(app.start_button, Qt.LeftButton) + monkeypatch.setattr(QMessageBox, "exec", custom_exec) + qtbot.mouseClick(app.start_button, Qt.LeftButton) assert app.create_warning_box assert app.message_text == "All directory names must be distinct." + def test_on_text_changed(qtbot, app): app.app.train_data_path = "/initial/train/path" app.app.eval_data_path = "/initial/eval/path" app.app.inprogr_data_path = "/initial/inprogress/path" - app.on_text_changed(field_obj=app.train_textbox, field_name="train", text="/new/train/path") + app.on_text_changed( + field_obj=app.train_textbox, field_name="train", text="/new/train/path" + ) assert app.app.train_data_path == "/new/train/path" - app.on_text_changed(field_obj=app.val_textbox, field_name="eval", text="/new/eval/path") + app.on_text_changed( + field_obj=app.val_textbox, field_name="eval", text="/new/eval/path" + ) assert app.app.eval_data_path == "/new/eval/path" - app.on_text_changed(field_obj=app.inprogr_textbox, field_name="inprogress", text="/new/inprogress/path") + app.on_text_changed( + field_obj=app.inprogr_textbox, + field_name="inprogress", + text="/new/inprogress/path", + ) assert app.app.inprogr_data_path == "/new/inprogress/path" + def test_start_main_not_selected(qtbot, app): app.app.train_data_path = None app.app.eval_data_path = None app.sim = True qtbot.mouseClick(app.start_button, Qt.LeftButton) - assert not hasattr(app, 'mw') + assert not hasattr(app, "mw") + def test_start_main(qtbot, app, setup_global_variable): settings.accepted_types = setup_global_variable @@ -93,11 +114,12 @@ def test_start_main(qtbot, app, setup_global_variable): # Simulate clicking the start button qtbot.mouseClick(app.start_button, Qt.LeftButton) # Check if the main window is created - #assert qtbot.waitUntil(lambda: hasattr(app, 'mw'), timeout=1000) - assert hasattr(app, 'mw') + # assert qtbot.waitUntil(lambda: hasattr(app, 'mw'), timeout=1000) + assert hasattr(app, "mw") # Check if the WelcomeWindow is hidden assert app.isHidden() + def test_start_upload_and_main(qtbot, app_remote, setup_global_variable, monkeypatch): settings.accepted_types = setup_global_variable app_remote.app.eval_data_path = "/path/to/eval" @@ -107,15 +129,15 @@ def test_start_upload_and_main(qtbot, app_remote, setup_global_variable, monkeyp def custom_exec(self): return QMessageBox.Ok - monkeypatch.setattr(QMessageBox, 'exec', custom_exec) - qtbot.mouseClick(app_remote.start_button, Qt.LeftButton) + monkeypatch.setattr(QMessageBox, "exec", custom_exec) + qtbot.mouseClick(app_remote.start_button, Qt.LeftButton) # should close because error on upload! - assert app_remote.done_upload==False + assert app_remote.done_upload == False assert not app_remote.isVisible() - assert not hasattr(app_remote, 'mw') - + assert not hasattr(app_remote, "mw") + -'''' +"""' # TODO wait for github respose def test_browse_eval_clicked(qtbot, app, monkeypatch): # Mock the QFileDialog so that it immediately returns a directory @@ -162,4 +184,4 @@ def test_browse_inprogr_clicked(qtbot, app): # Check if the textbox is updated with the selected path assert app.inprogr_textbox.text() == app.app.inprogr_data_path -''' \ No newline at end of file +""" diff --git a/src/server/dcp_server/main.py b/src/server/dcp_server/main.py index 84d8b003..9c149b5b 100644 --- a/src/server/dcp_server/main.py +++ b/src/server/dcp_server/main.py @@ -4,7 +4,8 @@ from dcp_server.utils.helpers import read_config -def main() -> None: + +def main() -> None: """ Contains main functionality related to the server. """ @@ -16,21 +17,24 @@ def main() -> None: # else: # config_path = 'config.cfg' - local_path = path.join(__file__, '..') + local_path = path.join(__file__, "..") dir_name = path.dirname(path.abspath(sys.argv[0])) - service_config = read_config('service', config_path = path.join(dir_name, 'config.yaml')) - port = str(service_config['port']) + service_config = read_config( + "service", config_path=path.join(dir_name, "config.yaml") + ) + port = str(service_config["port"]) - subprocess.run([ - "bentoml", - "serve", - '--working-dir', - local_path, - "service:svc", - "--reload", - "--port="+port, - ]) - + subprocess.run( + [ + "bentoml", + "serve", + "--working-dir", + local_path, + "service:svc", + "--reload", + "--port=" + port, + ] + ) if __name__ == "__main__": diff --git a/src/server/dcp_server/models/__init__.py b/src/server/dcp_server/models/__init__.py index ba003253..eba3d089 100644 --- a/src/server/dcp_server/models/__init__.py +++ b/src/server/dcp_server/models/__init__.py @@ -5,7 +5,4 @@ from .multicellpose import MultiCellpose from .unet import UNet -__all__ = ['CustomCellpose', - 'Inst2MultiSeg', - 'MultiCellpose', - 'UNet'] +__all__ = ["CustomCellpose", "Inst2MultiSeg", "MultiCellpose", "UNet"] diff --git a/src/server/dcp_server/models/classifiers.py b/src/server/dcp_server/models/classifiers.py index d093ab14..43fed489 100644 --- a/src/server/dcp_server/models/classifiers.py +++ b/src/server/dcp_server/models/classifiers.py @@ -14,25 +14,24 @@ class PatchClassifier(nn.Module): - - """ Fully convolutional classifier for cell images. NOTE -> This model cannot be used as a standalone model in DCP - """ - - def __init__(self, - model_name: str, - model_config: dict, - data_config: dict, - train_config: dict, - eval_config: dict - ) -> None: - """ Initialize the fully convolutional classifier. + """Fully convolutional classifier for cell images. NOTE -> This model cannot be used as a standalone model in DCP""" + + def __init__( + self, + model_name: str, + model_config: dict, + data_config: dict, + train_config: dict, + eval_config: dict, + ) -> None: + """Initialize the fully convolutional classifier. :param model_name: Name of the model. :type model_name: str :param model_config: Model configuration. :type model_config: dict :param data_config: Data configuration. - :type data_config: dict + :type data_config: dict :param train_config: Training configuration. :type train_config: dict :param eval_config: Evaluation configuration. @@ -40,20 +39,16 @@ def __init__(self, """ super().__init__() - self.model_name = model_name self.model_config = model_config["classifier"] self.data_config = data_config self.train_config = train_config["classifier"] self.eval_config = eval_config["classifier"] - + self.build_model() - def train (self, - imgs: List[np.ndarray], - labels: List[np.ndarray] - ) -> None: - """ Trains the given model + def train(self, imgs: List[np.ndarray], labels: List[np.ndarray]) -> None: + """Trains the given model :param imgs: List of input images with shape (3, dx, dy). :type imgs: List[np.ndarray[np.uint8]] @@ -63,36 +58,33 @@ def train (self, # Convert input images and labels to tensors imgs = torch.stack([torch.from_numpy(img.astype(np.float32)) for img in imgs]) - imgs = torch.permute(imgs, (0, 3, 1, 2)) + imgs = torch.permute(imgs, (0, 3, 1, 2)) # Your classification label mask labels = torch.LongTensor([label for label in labels]) # Create a training dataset and dataloader train_dataloader = DataLoader( - TensorDataset(imgs, labels), - batch_size=self.train_config["batch_size"]) + TensorDataset(imgs, labels), batch_size=self.train_config["batch_size"] + ) loss_fn = nn.CrossEntropyLoss() - optimizer = Adam( - params=self.parameters(), - lr=self.train_config["lr"] - ) + optimizer = Adam(params=self.parameters(), lr=self.train_config["lr"]) # optimizer_class = self.train_config["optimizer"] - #eval(f'{optimizer_class}(params={self.parameters()}, lr={lr})') - + # eval(f'{optimizer_class}(params={self.parameters()}, lr={lr})') + # TODO check if we should replace self.parameters with super.parameters() for _ in tqdm( range(self.train_config["n_epochs"]), - desc="Running PatchClassifier training" - ): - + desc="Running PatchClassifier training", + ): + self.loss, self.metric = 0, 0 for data in train_dataloader: imgs, labels = data optimizer.zero_grad() preds = self.forward(imgs) - + l = loss_fn(preds, labels) l.backward() optimizer.step() @@ -100,13 +92,11 @@ def train (self, self.metric += self.metric_fn(preds, labels) - self.loss /= len(train_dataloader) + self.loss /= len(train_dataloader) self.metric /= len(train_dataloader) - - def eval(self, - img: np.ndarray - ) -> torch.Tensor: - """ Evaluates the model on the provided image and return the predicted label. + + def eval(self, img: np.ndarray) -> torch.Tensor: + """Evaluates the model on the provided image and return the predicted label. :param img: Input image for evaluation. :type img: np.ndarray[np.uint8] @@ -114,16 +104,19 @@ def eval(self, :rtype: torch.Tensor """ # convert to tensor - img = torch.permute(torch.tensor(img.astype(np.float32)), (2, 0, 1)).unsqueeze(0) + img = torch.permute(torch.tensor(img.astype(np.float32)), (2, 0, 1)).unsqueeze( + 0 + ) preds = self.forward(img) y_hat = torch.argmax(preds, 1) return y_hat def build_model(self) -> None: - """ Builds the PatchClassifer. - """ + """Builds the PatchClassifer.""" in_channels = self.model_config["in_channels"] - in_channels = in_channels + 1 if self.model_config["include_mask"] else in_channels + in_channels = ( + in_channels + 1 if self.model_config["include_mask"] else in_channels + ) self.layer1 = nn.Sequential( nn.Conv2d(in_channels, 16, 3, 2, 5), @@ -145,18 +138,15 @@ def build_model(self) -> None: nn.ReLU(), nn.Dropout2d(p=0.2), ) - self.final_conv = nn.Conv2d(128, - self.model_config["num_classes"], - 1) + self.final_conv = nn.Conv2d(128, self.model_config["num_classes"], 1) self.pooling = nn.AdaptiveMaxPool2d(1) - self.metric_fn = F1Score(num_classes=self.model_config["num_classes"], - task="multiclass") + self.metric_fn = F1Score( + num_classes=self.model_config["num_classes"], task="multiclass" + ) - def forward(self, - x: torch.Tensor - ) -> torch.Tensor: - """ Performs forward pass of the PatchClassifier. + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Performs forward pass of the PatchClassifier. :param x: Input tensor. :type x: torch.Tensor @@ -175,17 +165,17 @@ def forward(self, class FeatureClassifier: - """ This class implements a shallow model for cell classification using scikit-learn. - """ - - def __init__(self, - model_name: str, - model_config: dict, - data_config: dict, - train_config: dict, - eval_config: dict - ) -> None: - """ Constructs all the necessary attributes for the FeatureClassifier + """This class implements a shallow model for cell classification using scikit-learn.""" + + def __init__( + self, + model_name: str, + model_config: dict, + data_config: dict, + train_config: dict, + eval_config: dict, + ) -> None: + """Constructs all the necessary attributes for the FeatureClassifier :param model_config: Model configuration. :type model_config: dict @@ -198,39 +188,34 @@ def __init__(self, """ self.model_name = model_name - self.model_config = model_config["classifier"] # use for initialising model + self.model_config = model_config["classifier"] # use for initialising model # self.data_config = data_config # self.train_config = train_config # self.eval_config = eval_config - self.model = RandomForestClassifier(**self.model_config) # TODO chnage config so RandomForestClassifier accepts input params + self.model = RandomForestClassifier( + **self.model_config + ) # TODO chnage config so RandomForestClassifier accepts input params - - def train(self, - X_train: List[np.ndarray], - y_train: List[np.ndarray] - ) -> None: - """ Trains the model using the provided training data. + def train(self, X_train: List[np.ndarray], y_train: List[np.ndarray]) -> None: + """Trains the model using the provided training data. :param X_train: Features of the training data. :type X_train: numpy.ndarray :param y_train: Labels of the training data. :type y_train: numpy.ndarray """ - self.model.fit(X_train,y_train) + self.model.fit(X_train, y_train) y_hat = self.model.predict(X_train) y_hat_proba = self.model.predict_proba(X_train) # Binary Cross Entrop Loss self.loss = log_loss(y_train, y_hat_proba) - self.metric = f1_score(y_train, y_hat, average='micro') + self.metric = f1_score(y_train, y_hat, average="micro") - - def eval(self, - X_test: np.ndarray - ) -> np.ndarray: - """ Evaluates the model on the provided test data. + def eval(self, X_test: np.ndarray) -> np.ndarray: + """Evaluates the model on the provided test data. :param X_test: Features of the test data. :type X_test: numpy.ndarray @@ -238,11 +223,11 @@ def eval(self, :rtype: numpy.ndarray """ - X_test = X_test.reshape(1,-1) + X_test = X_test.reshape(1, -1) try: y_hat = self.model.predict(X_test) except NotFittedError as e: y_hat = np.zeros(X_test.shape[0]) - + return y_hat diff --git a/src/server/dcp_server/models/custom_cellpose.py b/src/server/dcp_server/models/custom_cellpose.py index f3a85d0b..b41d04bb 100644 --- a/src/server/dcp_server/models/custom_cellpose.py +++ b/src/server/dcp_server/models/custom_cellpose.py @@ -11,50 +11,52 @@ from .model import Model + class CustomCellpose(models.CellposeModel, Model): """Custom cellpose model inheriting the attributes and functions from the original CellposeModel and implementing additional attributes and methods needed for this project. - """ - def __init__(self, - model_name: str, - model_config: dict, - data_config: dict, - train_config: dict, - eval_config: dict, - ) -> None: - """Constructs all the necessary attributes for the CustomCellpose. + """ + + def __init__( + self, + model_name: str, + model_config: dict, + data_config: dict, + train_config: dict, + eval_config: dict, + ) -> None: + """Constructs all the necessary attributes for the CustomCellpose. The model inherits all attributes from the parent class, the init allows to pass any other argument that the parent class accepts. - Please, visit here https://cellpose.readthedocs.io/en/latest/api.html#id4 for more details on arguments accepted. + Please, visit here https://cellpose.readthedocs.io/en/latest/api.html#id4 for more details on arguments accepted. :param model_name: The name of the current model :type model_name: str :param model_config: dictionary passed from the config file with all the arguments for the __init__ function and model initialization :type model_config: dict :param data_config: dictionary passed from the config file with all the data configurations - :type data_config: dict + :type data_config: dict :param train_config: dictionary passed from the config file with all the arguments for training function :type train_config: dict :param eval_config: dictionary passed from the config file with all the arguments for eval function :type eval_config: dict """ - + # Initialize the cellpose model # super().__init__(**model_config["segmentor"]) - Model.__init__(self, model_name, model_config, data_config, train_config, eval_config) + Model.__init__( + self, model_name, model_config, data_config, train_config, eval_config + ) models.CellposeModel.__init__(self, **model_config["segmentor"]) self.model_config = model_config self.data_config = data_config self.train_config = train_config self.eval_config = eval_config self.model_name = model_name - self.mkldnn = False # otherwise we get error with saving model + self.mkldnn = False # otherwise we get error with saving model self.loss = 1e6 self.metric = 0 - def train(self, - imgs: List[np.ndarray], - masks: List[np.ndarray] - ) -> None: + def train(self, imgs: List[np.ndarray], masks: List[np.ndarray]) -> None: """Trains the given model Calls the original train function. @@ -62,50 +64,47 @@ def train(self, :type imgs: List[np.ndarray] :param masks: masks of the given images (training labels) :type masks: List[np.ndarray] - """ - if self.train_config["segmentor"]["n_epochs"]==0: return + """ + if self.train_config["segmentor"]["n_epochs"] == 0: + return super().train( - train_data=deepcopy(imgs), #Cellpose changes the images + train_data=deepcopy(imgs), # Cellpose changes the images train_labels=masks, **self.train_config["segmentor"] - ) + ) pred_masks, pred_flows, true_flows = self.compute_masks_flows(imgs, masks) # get loss, combination of mse for flows and bce for cell probability - self.loss = self.loss_fn(true_flows, pred_flows) + self.loss = self.loss_fn(true_flows, pred_flows) self.metric = np.mean(aggregated_jaccard_index(masks, pred_masks)) - def eval(self, - img: np.ndarray - ) -> np.ndarray: + def eval(self, img: np.ndarray) -> np.ndarray: """Evaluate the model - find mask of the given image - Calls the original eval function. + Calls the original eval function. :param img: image to evaluate on :type img: np.ndarray :return: mask of the image, list of 2D arrays, or single 3D array (if do_3D=True) labelled image. :rtype: np.ndarray - """ + """ # 0 to take only mask - inline with other models eval should always return the final mask - return super().eval(x=img, **self.eval_config["segmentor"])[ - 0 - ] + return super().eval(x=img, **self.eval_config["segmentor"])[0] - def eval_all_outputs(self, - img: np.ndarray - ) -> tuple: + def eval_all_outputs(self, img: np.ndarray) -> tuple: """Get all outputs of the model when running eval. :param img: Input image for segmentation. :type img: numpy.ndarray - :return: mask, flows, styles etc. Returns the same as cellpose.models.CellposeModel.eval - see Cellpose API Guide for more details. + :return: mask, flows, styles etc. Returns the same as cellpose.models.CellposeModel.eval - see Cellpose API Guide for more details. :rtype: tuple """ return super().eval(x=img, **self.eval_config["segmentor"]) - + # I introduced typing here as suggest by the docstring - def compute_masks_flows(self, imgs: List[np.ndarray], masks:List[np.ndarray]) -> tuple: - """ Computes instance, binary mask and flows in x and y - needed for loss and metric computations + def compute_masks_flows( + self, imgs: List[np.ndarray], masks: List[np.ndarray] + ) -> tuple: + """Computes instance, binary mask and flows in x and y - needed for loss and metric computations :param imgs: images to train on (training data) :type imgs: List[np.ndarray] @@ -114,12 +113,12 @@ def compute_masks_flows(self, imgs: List[np.ndarray], masks:List[np.ndarray]) -> :return: A tuple containing the following elements: - pred_masks List [np.ndarray]: A list of predicted instance masks - pred_flows (torch.Tensor): A tensor holding the stacked predicted cell probability map, horizontal and vertical flows for all images - - true_lbl (np.ndarray): A numpy array holding the stacked true binary mask, horizontal and vertical flows for all images + - true_lbl (np.ndarray): A numpy array holding the stacked true binary mask, horizontal and vertical flows for all images :rtype: tuple - """ + """ # compute for loss and metric - true_bin_masks = [mask>0 for mask in masks] # get binary masks - true_flows = labels_to_flows(masks) # get cellpose flows + true_bin_masks = [mask > 0 for mask in masks] # get binary masks + true_flows = labels_to_flows(masks) # get cellpose flows # get predicted flows and cell probability pred_masks = [] pred_flows = [] @@ -127,23 +126,25 @@ def compute_masks_flows(self, imgs: List[np.ndarray], masks:List[np.ndarray]) -> for idx, img in enumerate(imgs): mask, flows, _ = super().eval(x=img, **self.eval_config["segmentor"]) pred_masks.append(mask) - pred_flows.append(np.stack([flows[1][0], flows[1][1], flows[2]])) # stack cell probability map, horizontal and vertical flow - true_lbl.append(np.stack([true_bin_masks[idx], true_flows[idx][2], true_flows[idx][3]])) - + pred_flows.append( + np.stack([flows[1][0], flows[1][1], flows[2]]) + ) # stack cell probability map, horizontal and vertical flow + true_lbl.append( + np.stack([true_bin_masks[idx], true_flows[idx][2], true_flows[idx][3]]) + ) + true_lbl = np.stack(true_lbl) - pred_flows=np.stack(pred_flows) - pred_flows = torch.from_numpy(pred_flows).float().to('cpu') + pred_flows = np.stack(pred_flows) + pred_flows = torch.from_numpy(pred_flows).float().to("cpu") return pred_masks, pred_flows, true_lbl - def masks_to_outlines(self, - mask: np.ndarray - ) -> np.ndarray: - """ get outlines of masks as a 0-1 array + def masks_to_outlines(self, mask: np.ndarray) -> np.ndarray: + """get outlines of masks as a 0-1 array Calls the original cellpose.utils.masks_to_outlines function :param mask: int, 2D or 3D array, mask of an image :type mask: ndarray :return: outlines :rtype: ndarray - """ - return utils.masks_to_outlines(mask) # [True, False] outputs + """ + return utils.masks_to_outlines(mask) # [True, False] outputs diff --git a/src/server/dcp_server/models/inst_to_multi_seg.py b/src/server/dcp_server/models/inst_to_multi_seg.py index df52fd46..fdd9c07a 100644 --- a/src/server/dcp_server/models/inst_to_multi_seg.py +++ b/src/server/dcp_server/models/inst_to_multi_seg.py @@ -10,46 +10,46 @@ get_centered_patches, find_max_patch_size, create_patch_dataset, - create_dataset_for_rf + create_dataset_for_rf, ) # Dictionary mapping class names to their corresponding classes -segmentor_mapping = { - "Cellpose": CustomCellpose -} +segmentor_mapping = {"Cellpose": CustomCellpose} classifier_mapping = { "PatchClassifier": PatchClassifier, - "RandomForest": FeatureClassifier + "RandomForest": FeatureClassifier, } class Inst2MultiSeg(Model): - """ A two stage model for: 1. instance segmentation and 2. object wise classification - """ - - def __init__(self, - model_name:str, - model_config: dict, - data_config: dict, - train_config: dict, - eval_config:dict - ) -> None: - """ Constructs all the necessary attributes for the Inst2MultiSeg + """A two stage model for: 1. instance segmentation and 2. object wise classification""" + + def __init__( + self, + model_name: str, + model_config: dict, + data_config: dict, + train_config: dict, + eval_config: dict, + ) -> None: + """Constructs all the necessary attributes for the Inst2MultiSeg :param model_name: Name of the model. :type model_name: str :param model_config: Model configuration. :type model_config: dict :param data_config: Data configurations - :type data_config: dict + :type data_config: dict :param train_config: Training configuration. :type train_config: dict :param eval_config: Evaluation configuration. :type eval_config: dict - """ - #super().__init__() - Model.__init__(self, model_name, model_config, data_config, train_config, eval_config) + """ + # super().__init__() + Model.__init__( + self, model_name, model_config, data_config, train_config, eval_config + ) self.model_name = model_name self.model_config = model_config @@ -58,93 +58,100 @@ def __init__(self, self.eval_config = eval_config self.segmentor_class = self.model_config.get("segmentor_name", "Cellpose") - self.classifier_class = self.model_config.get("classifier_name", "PatchClassifier") + self.classifier_class = self.model_config.get( + "classifier_name", "PatchClassifier" + ) # Initialize the cellpose model and the classifier segmentor = segmentor_mapping.get(self.segmentor_class) self.segmentor = segmentor( - self.segmentor_class, self.model_config, self.data_config, self.train_config, self.eval_config - ) + self.segmentor_class, + self.model_config, + self.data_config, + self.train_config, + self.eval_config, + ) classifier = classifier_mapping.get(self.classifier_class) self.classifier = classifier( - self.classifier_class, self.model_config, self.data_config, self.train_config, self.eval_config - ) - - # make sure include mask is set to False if we are using the random forest model - if self.classifier_class=="RandomForest": - if "include_mask" not in self.model_config["classifier"].keys() or self.model_config["classifier"]["include_mask"] is True: - #print("Include mask=True was found, but for Random Forest, this parameter must be set to False. Doing this now.") + self.classifier_class, + self.model_config, + self.data_config, + self.train_config, + self.eval_config, + ) + + # make sure include mask is set to False if we are using the random forest model + if self.classifier_class == "RandomForest": + if ( + "include_mask" not in self.model_config["classifier"].keys() + or self.model_config["classifier"]["include_mask"] is True + ): + # print("Include mask=True was found, but for Random Forest, this parameter must be set to False. Doing this now.") self.model_config["classifier"]["include_mask"] = False - - - def train(self, - imgs: List[np.ndarray], - masks: List[np.ndarray] - ) -> None: - """ Trains the given model. First trains the segmentor and then the clasiffier. + + def train(self, imgs: List[np.ndarray], masks: List[np.ndarray]) -> None: + """Trains the given model. First trains the segmentor and then the clasiffier. :param imgs: images to train on (training data) :type imgs: List[np.ndarray] :param masks: masks of the given images (training labels) - :type masks: List[np.ndarray] of same shape as output of eval, i.e. one channel instances, + :type masks: List[np.ndarray] of same shape as output of eval, i.e. one channel instances, second channel classes, so [2, H, W] or [2, 3, H, W] for 3D - """ + """ # train cellpose masks_instances = [mask[0] for mask in masks] - #masks_instances = list(np.array(masks)[:,0,...]) #[mask.sum(-1) for mask in masks] if masks[0].ndim == 3 else masks + # masks_instances = list(np.array(masks)[:,0,...]) #[mask.sum(-1) for mask in masks] if masks[0].ndim == 3 else masks self.segmentor.train(imgs, masks_instances) masks_classes = [mask[1] for mask in masks] # create patch dataset to train classifier - #masks_classes = list( + # masks_classes = list( # masks[:,1,...] - #) #[((mask > 0) * np.arange(1, 4)).sum(-1) for mask in masks] + # ) #[((mask > 0) * np.arange(1, 4)).sum(-1) for mask in masks] x, patch_masks, labels = create_patch_dataset( imgs, masks_classes, masks_instances, - noise_intensity = self.data_config["noise_intensity"], - max_patch_size = self.data_config["patch_size"], - include_mask = self.model_config["classifier"]["include_mask"] + noise_intensity=self.data_config["noise_intensity"], + max_patch_size=self.data_config["patch_size"], + include_mask=self.model_config["classifier"]["include_mask"], ) # additionally extract features from the patches if you are in RF model - if self.classifier_class == "RandomForest": + if self.classifier_class == "RandomForest": x = create_dataset_for_rf(x, patch_masks) # train classifier self.classifier.train(x, labels) # and compute metric and loss self.metric = (self.segmentor.metric + self.classifier.metric) / 2 - self.loss = (self.segmentor.loss + self.classifier.loss)/2 + self.loss = (self.segmentor.loss + self.classifier.loss) / 2 - def eval(self, - img: np.ndarray - ) -> np.ndarray: - """ Evaluate the model on the provided image and return the final mask. + def eval(self, img: np.ndarray) -> np.ndarray: + """Evaluate the model on the provided image and return the final mask. :param img: Input image for evaluation. :type img: np.ndarray[np.uint8] :return: Final mask containing instance mask and class masks. :rtype: np.ndarray[np.uint16] - """ + """ # TBD we assume image is 2D [H, W] (see fsimage storage) - # The final mask which is returned should have + # The final mask which is returned should have # first channel the output of cellpose and the rest are the class channels with torch.no_grad(): # get instance mask from segmentor instance_mask = self.segmentor.eval(img) # find coordinates of detected objects class_mask = np.zeros(instance_mask.shape) - + max_patch_size = self.data_config["patch_size"] - if max_patch_size is None: + if max_patch_size is None: max_patch_size = find_max_patch_size(instance_mask) - + # get patches centered around detected objects x, patch_masks, instance_labels, _ = get_centered_patches( img, instance_mask, max_patch_size, noise_intensity=self.data_config["noise_intensity"], - include_mask=self.model_config["classifier"]["include_mask"] + include_mask=self.model_config["classifier"]["include_mask"], ) if self.classifier_class == "RandomForest": x = create_dataset_for_rf(x, patch_masks) @@ -152,15 +159,17 @@ def eval(self, for idx in range(len(x)): patch_class = self.classifier.eval(x[idx]) # Assign predicted class to corresponding location in final_mask - patch_class = patch_class.item() if isinstance(patch_class, torch.Tensor) else patch_class - class_mask[instance_mask==instance_labels[idx]] = ( - patch_class + 1 + patch_class = ( + patch_class.item() + if isinstance(patch_class, torch.Tensor) + else patch_class ) + class_mask[instance_mask == instance_labels[idx]] = patch_class + 1 # Apply mask to final_mask, retaining only regions where cellpose_mask is greater than 0 final_mask = np.stack( - (instance_mask, class_mask), axis=self.eval_config['mask_channel_axis'] - ).astype( - np.uint16 - ) # size 2xHxW - + (instance_mask, class_mask), axis=self.eval_config["mask_channel_axis"] + ).astype( + np.uint16 + ) # size 2xHxW + return final_mask diff --git a/src/server/dcp_server/models/model.py b/src/server/dcp_server/models/model.py index 4809bdf5..3cda12c1 100644 --- a/src/server/dcp_server/models/model.py +++ b/src/server/dcp_server/models/model.py @@ -2,37 +2,34 @@ from typing import List import numpy as np + class Model(ABC): - def __init__(self, - model_name: str, - model_config: dict, - data_config: dict, - train_config: dict, - eval_config: dict, - ) -> None: - + def __init__( + self, + model_name: str, + model_config: dict, + data_config: dict, + train_config: dict, + eval_config: dict, + ) -> None: + self.model_name = model_name self.model_config = model_config self.data_config = data_config self.train_config = train_config self.eval_config = eval_config - + self.loss = 1e6 self.metric = 0 @abstractmethod - def train(self, - imgs: List[np.array], - masks: List[np.array] - ) -> None: + def train(self, imgs: List[np.array], masks: List[np.array]) -> None: pass - + @abstractmethod - def eval(self, - img: np.array - ) -> np.array: + def eval(self, img: np.array) -> np.array: pass - + ''' def update_configs(self, config: dict, @@ -50,8 +47,8 @@ def update_configs(self, ''' -#from segment_anything import SamPredictor, sam_model_registry -#from segment_anything.automatic_mask_generator import SamAutomaticMaskGenerator +# from segment_anything import SamPredictor, sam_model_registry +# from segment_anything.automatic_mask_generator import SamAutomaticMaskGenerator # class CustomSAMModel(): # # https://github.com/facebookresearch/segment-anything/blob/main/notebooks/automatic_mask_generator_example.ipynb # def __init__(self): diff --git a/src/server/dcp_server/models/multicellpose.py b/src/server/dcp_server/models/multicellpose.py index 9d1b110d..5ece6b97 100644 --- a/src/server/dcp_server/models/multicellpose.py +++ b/src/server/dcp_server/models/multicellpose.py @@ -5,21 +5,23 @@ from .model import Model from .custom_cellpose import CustomCellpose + class MultiCellpose(Model): - ''' + """ Multichannel image segmentation model. Run the separate CustomCellpose models for each channel return the mask corresponding to each object type. - ''' - - def __init__(self, - model_name: str, - model_config: dict, - data_config: dict, - train_config: dict, - eval_config: dict, - ) -> None: + """ + + def __init__( + self, + model_name: str, + model_config: dict, + data_config: dict, + train_config: dict, + eval_config: dict, + ) -> None: """Constructs all the necessary attributes for the MultiCellpose model. - + :param model_name: Name of the model. :type model_name: str :param model_config: Model configuration. @@ -29,8 +31,10 @@ def __init__(self, :param eval_config: Evaluation configuration. :type eval_config: dict """ - Model.__init__(self, model_name, model_config, data_config, train_config, eval_config) - + Model.__init__( + self, model_name, model_config, data_config, train_config, eval_config + ) + self.model_config = model_config self.data_config = data_config self.train_config = train_config @@ -41,17 +45,15 @@ def __init__(self, self.cellpose_models = [ CustomCellpose( "Cellpose", - self.model_config, + self.model_config, self.data_config, self.train_config, self.eval_config, - ) for _ in range(self.num_of_channels) - ] + ) + for _ in range(self.num_of_channels) + ] - def train(self, - imgs: List[np.ndarray], - masks: List[np.ndarray] - ) -> None: + def train(self, imgs: List[np.ndarray], masks: List[np.ndarray]) -> None: """ Train the model on the provided images and masks. @@ -60,27 +62,26 @@ def train(self, :param masks: Masks corresponding to the input images. :type masks: list[numpy.ndarray] """ - + for i in range(self.num_of_channels): - + masks_class = [] for mask in masks: - mask_class = mask[0].copy() # TODO - Do we need copy?? + mask_class = mask[0].copy() # TODO - Do we need copy?? # set all instances in the instance mask not corresponding to the class in question to zero - mask_class[0][ - mask_class[1]!=(i+1) - ] = 0 + mask_class[0][mask_class[1] != (i + 1)] = 0 masks_class.append(mask_class) self.cellpose_models[i].train(imgs, masks_class) - self.metric = np.mean([self.cellpose_models[i].metric for i in range(self.num_of_channels)]) - self.loss = np.mean([self.cellpose_models[i].loss for i in range(self.num_of_channels)]) - + self.metric = np.mean( + [self.cellpose_models[i].metric for i in range(self.num_of_channels)] + ) + self.loss = np.mean( + [self.cellpose_models[i].loss for i in range(self.num_of_channels)] + ) - def eval(self, - img: np.ndarray - ) -> np.ndarray: + def eval(self, img: np.ndarray) -> np.ndarray: """Evaluate the model on the provided image. The instance mask are computed as the union of the predicted model outputs, while the class of each object is assigned based on majority voting between the models. @@ -94,12 +95,12 @@ def eval(self, for i in range(self.num_of_channels): # get the instance mask and pixel-wise cell probability mask - instance_mask, probs, _ = self.cellpose_models[i].eval_all_outputs(img) + instance_mask, probs, _ = self.cellpose_models[i].eval_all_outputs(img) confidence_map = probs[2] # assign the appropriate class to all objects detected by this model class_mask = np.zeros_like(instance_mask) - class_mask[instance_mask>0]=(i + 1) - + class_mask[instance_mask > 0] = i + 1 + instance_masks.append(instance_mask) class_masks.append(class_mask) model_confidences.append(confidence_map) @@ -108,26 +109,25 @@ def eval(self, instance_masks, class_masks, model_confidences ) # set all connected components to the same label in the instance mask - instance_mask = label_mask(merged_mask_instances>0) + instance_mask = label_mask(merged_mask_instances > 0) # and set the class with the most pixels to that object - for inst_id in np.unique(instance_mask)[1:]: - where_inst_id = np.where(instance_mask==inst_id) + for inst_id in np.unique(instance_mask)[1:]: + where_inst_id = np.where(instance_mask == inst_id) vals, counts = np.unique(class_mask[where_inst_id], return_counts=True) class_mask[where_inst_id] = vals[np.argmax(counts)] # take the final mask by stancking instance and class mask final_mask = np.stack( - (instance_mask, class_mask), axis=self.eval_config['mask_channel_axis'] - ).astype( - np.uint16 - ) - + (instance_mask, class_mask), axis=self.eval_config["mask_channel_axis"] + ).astype(np.uint16) + return final_mask - - def merge_masks(self, - inst_masks: List[np.ndarray], - class_masks: List[np.ndarray], - probabilities: List[np.ndarray] - ) -> tuple: + + def merge_masks( + self, + inst_masks: List[np.ndarray], + class_masks: List[np.ndarray], + probabilities: List[np.ndarray], + ) -> tuple: """Merges the instance and class masks resulting from the different models using the pixel-wise cell probability. The output of the model with the maximum probability is selected for each pixel. @@ -146,17 +146,20 @@ def merge_masks(self, inst_masks = np.array(inst_masks) class_masks = np.array(class_masks) probabilities = np.array(probabilities) - + # Find the index of the mask with the maximum probability for each pixel max_prob_indices = np.argmax(probabilities, axis=0) - + # Use the index to select the corresponding mask for each pixel final_mask_inst = inst_masks[ - max_prob_indices, np.arange(inst_masks.shape[1])[:, None], np.arange(inst_masks.shape[2]) + max_prob_indices, + np.arange(inst_masks.shape[1])[:, None], + np.arange(inst_masks.shape[2]), ] final_mask_class = class_masks[ - max_prob_indices, np.arange(class_masks.shape[1])[:, None], np.arange(class_masks.shape[2]) + max_prob_indices, + np.arange(class_masks.shape[1])[:, None], + np.arange(class_masks.shape[2]), ] return final_mask_inst, final_mask_class - diff --git a/src/server/dcp_server/models/unet.py b/src/server/dcp_server/models/unet.py index 61d06a66..e4afb098 100644 --- a/src/server/dcp_server/models/unet.py +++ b/src/server/dcp_server/models/unet.py @@ -13,10 +13,9 @@ class UNet(nn.Module, Model): - """ Unet is a convolutional neural network architecture for semantic segmentation. - + :param in_channels: Number of input channels (default: 3). :type in_channels: int :param out_channels: Number of output channels (default: 4). @@ -24,17 +23,14 @@ class UNet(nn.Module, Model): :param features: List of feature channels for each encoder level (default: [64,128,256,512]). :type features: list """ - + class DoubleConv(nn.Module): """ DoubleConv module consists of two consecutive convolutional layers with batch normalization and ReLU activation functions. """ - def __init__(self, - in_channels: int, - out_channels: int - ) -> None: + def __init__(self, in_channels: int, out_channels: int) -> None: """ Initialize DoubleConv module. @@ -43,7 +39,7 @@ def __init__(self, :param out_channels: Number of output channels. :type out_channels: int """ - + super().__init__() self.conv = nn.Sequential( @@ -55,40 +51,40 @@ def __init__(self, nn.ReLU(), ) - def forward(self, - x: torch.Tensor - ) -> torch.Tensor: + def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward pass through the DoubleConv module. :param x: Input tensor. :type x: torch.Tensor """ return self.conv(x) - - def __init__(self, - model_name: str, - model_config: dict, - data_config: dict, - train_config: dict, - eval_config: dict, - ) -> None: - """ Constructs all the necessary attributes for the UNet model. + def __init__( + self, + model_name: str, + model_config: dict, + data_config: dict, + train_config: dict, + eval_config: dict, + ) -> None: + """Constructs all the necessary attributes for the UNet model. :param model_name: Name of the model. :type model_name: str :param model_config: Model configuration. :type model_config: dict :param data_config: Data configurations - :type data_config: dict + :type data_config: dict :param train_config: Training configuration. :type train_config: dict :param eval_config: Evaluation configuration. :type eval_config: dict """ - Model.__init__(self, model_name, model_config, data_config, train_config, eval_config) + Model.__init__( + self, model_name, model_config, data_config, train_config, eval_config + ) nn.Module.__init__(self) - #super().__init__() + # super().__init__() self.model_name = model_name self.model_config = model_config @@ -101,10 +97,7 @@ def __init__(self, self.build_model() - def train(self, - imgs: List[np.ndarray], - masks: List[np.ndarray] - ) -> None: + def train(self, imgs: List[np.ndarray], masks: List[np.ndarray]) -> None: """ Trains the UNet model using the provided images and masks. @@ -115,73 +108,71 @@ def train(self, """ imgs = convert_to_tensor(imgs, np.float32) - masks = convert_to_tensor([mask[1] for mask in masks], np.int16, unsqueeze=False) - + masks = convert_to_tensor( + [mask[1] for mask in masks], np.int16, unsqueeze=False + ) + # Create a training dataset and dataloader train_dataloader = DataLoader( TensorDataset(imgs, masks), - batch_size=self.train_config["classifier"]["batch_size"]) + batch_size=self.train_config["classifier"]["batch_size"], + ) loss_fn = nn.CrossEntropyLoss() optimizer = Adam( - params=self.parameters(), - lr=self.train_config["classifier"]["lr"] + params=self.parameters(), lr=self.train_config["classifier"]["lr"] ) for _ in tqdm( range(self.train_config["classifier"]["n_epochs"]), - desc="Running UNet training" + desc="Running UNet training", ): self.loss = 0 for imgs, masks in train_dataloader: - #forward path + # forward path preds = self.forward(imgs.float()) loss = loss_fn(preds, masks.long()) - #backward path + # backward path optimizer.zero_grad() loss.backward() optimizer.step() self.loss += loss.detach().mean().item() - self.loss /= len(train_dataloader) + self.loss /= len(train_dataloader) + + def eval(self, img: np.ndarray) -> np.ndarray: + """Evaluate the model on the provided image and return the predicted label. - def eval(self, - img: np.ndarray - ) -> np.ndarray: - """ Evaluate the model on the provided image and return the predicted label. - :param img: Input image for evaluation. :type img: np.ndarray[np.uint8] :return: predicted mask consists of instance and class masks :rtype: numpy.ndarray - """ + """ with torch.no_grad(): - #img = torch.from_numpy(img).float().unsqueeze(0) - #img = img.unsqueeze(1) if img.ndim == 3 else img + # img = torch.from_numpy(img).float().unsqueeze(0) + # img = img.unsqueeze(1) if img.ndim == 3 else img img = convert_to_tensor([img], np.float32) - + preds = self.forward(img) - class_mask = torch.argmax(preds, 1).numpy()[0] + class_mask = torch.argmax(preds, 1).numpy()[0] if self.eval_config["compute_instance"] is True: instance_mask = label((class_mask > 0).astype(int))[0] final_mask = np.stack( - [instance_mask, class_mask], - axis=self.eval_config['mask_channel_axis'] - ).astype( - np.uint16 - ) - else: final_mask = class_mask.astype(np.uint16) + [instance_mask, class_mask], + axis=self.eval_config["mask_channel_axis"], + ).astype(np.uint16) + else: + final_mask = class_mask.astype(np.uint16) return final_mask - + def build_model(self) -> None: - """ Builds the UNet. - """ + """Builds the UNet.""" in_channels = self.model_config["classifier"]["in_channels"] out_channels = self.model_config["classifier"]["num_classes"] + 1 features = self.model_config["classifier"]["features"] @@ -193,28 +184,20 @@ def build_model(self) -> None: # Encoder for feature in features: - self.encoder.append( - UNet.DoubleConv(in_channels, feature) - ) + self.encoder.append(UNet.DoubleConv(in_channels, feature)) in_channels = feature # Decoder for feature in features[::-1]: self.decoder.append( - nn.ConvTranspose2d( - feature*2, feature, kernel_size=2, stride=2 - ) - ) - self.decoder.append( - UNet.DoubleConv(feature*2, feature) + nn.ConvTranspose2d(feature * 2, feature, kernel_size=2, stride=2) ) + self.decoder.append(UNet.DoubleConv(feature * 2, feature)) - self.bottle_neck = UNet.DoubleConv(features[-1], features[-1]*2) + self.bottle_neck = UNet.DoubleConv(features[-1], features[-1] * 2) self.output_conv = nn.Conv2d(features[0], out_channels, kernel_size=1) - def forward(self, - x: torch.Tensor - ) -> torch.Tensor: + def forward(self, x: torch.Tensor) -> torch.Tensor: """ Forward pass of the UNet model. @@ -234,8 +217,8 @@ def forward(self, for i in np.arange(len(self.decoder), step=2): x = self.decoder[i](x) - skip_connection = skip_connections[i//2] + skip_connection = skip_connections[i // 2] concatenate_skip = torch.cat((skip_connection, x), dim=1) - x = self.decoder[i+1](concatenate_skip) + x = self.decoder[i + 1](concatenate_skip) return self.output_conv(x) diff --git a/src/server/dcp_server/segmentationclasses.py b/src/server/dcp_server/segmentationclasses.py index d9a5f4ee..78b1d326 100644 --- a/src/server/dcp_server/segmentationclasses.py +++ b/src/server/dcp_server/segmentationclasses.py @@ -5,66 +5,82 @@ from dcp_server import models as DCPModels # Import configuration -setup_config = helpers.read_config('setup', config_path = 'config.yaml') +setup_config = helpers.read_config("setup", config_path="config.yaml") -class GeneralSegmentation(): - """ Segmentation class. Defining the main functions needed for this project and served by service - segment image and train on images. - """ - def __init__(self, imagestorage: FilesystemImageStorage, runner, model: DCPModels) -> None: - """ Constructs all the necessary attributes for the GeneralSegmentation. + +class GeneralSegmentation: + """Segmentation class. Defining the main functions needed for this project and served by service - segment image and train on images.""" + + def __init__( + self, imagestorage: FilesystemImageStorage, runner, model: DCPModels + ) -> None: + """Constructs all the necessary attributes for the GeneralSegmentation. :param imagestorage: imagestorage system used (see fsimagestorage.py) :type imagestorage: FilesystemImageStorage class object :param runner: runner used in the service :type runner: CustomRunnable class object - :param model: model used for segmentation + :param model: model used for segmentation :type model: class object from the models.py - """ + """ self.imagestorage = imagestorage - self.runner = runner + self.runner = runner self.model = model self.no_files_msg = "No image-label pairs found in curated directory" - + async def segment_image(self, input_path: str, list_of_images: str) -> None: - """ Segments images from the given directory + """Segments images from the given directory :param input_path: directory where the images are saved and where segmentation results will be saved :type input_path: str :param list_of_images: list of image objects from the directory that are currently supported :type list_of_images: list - """ + """ for img_filepath in list_of_images: img = self.imagestorage.prepare_img_for_eval(img_filepath) # Add channel ax into the model's evaluation parameters dictionary - if self.imagestorage.model_used!="UNet": - self.model.eval_config['segmentor']['channel_axis'] = self.imagestorage.channel_ax + if self.imagestorage.model_used != "UNet": + self.model.eval_config["segmentor"][ + "channel_axis" + ] = self.imagestorage.channel_ax # Evaluate the model - mask = await self.runner.evaluate.async_run(img = img) + mask = await self.runner.evaluate.async_run(img=img) # And prepare the mask for saving - mask = self.imagestorage.prepare_mask_for_save(mask, self.model.eval_config['mask_channel_axis']) + mask = self.imagestorage.prepare_mask_for_save( + mask, self.model.eval_config["mask_channel_axis"] + ) # Save segmentation - seg_name = helpers.get_path_stem(img_filepath) + setup_config['seg_name_string'] + '.tiff' + seg_name = ( + helpers.get_path_stem(img_filepath) + + setup_config["seg_name_string"] + + ".tiff" + ) self.imagestorage.save_image(os.path.join(input_path, seg_name), mask) async def train(self, input_path: str) -> str: - """ Train model on images and masks in the given input directory. + """Train model on images and masks in the given input directory. Calls the runner's train function. :param input_path: directory where the images are saved :type input_path: str :return: runner's train function output - path of the saved model :rtype: str - """ + """ train_img_mask_pairs = self.imagestorage.get_image_seg_pairs(input_path) - if not train_img_mask_pairs: return self.no_files_msg - - imgs, masks = self.imagestorage.prepare_images_and_masks_for_training(train_img_mask_pairs) - model_save_path = await self.runner.train.async_run(imgs, masks) + if not train_img_mask_pairs: + return self.no_files_msg + + imgs, masks = self.imagestorage.prepare_images_and_masks_for_training( + train_img_mask_pairs + ) + model_save_path = await self.runner.train.async_run(imgs, masks) return model_save_path + + ''' class GFPProjectSegmentation(GeneralSegmentation): @@ -129,4 +145,4 @@ async def segment_image(self, input_path, list_of_images): # Save segmentation seg_name = helpers.get_path_stem(img_filepath) + setup_config['seg_name_string'] + '.tiff' self.imagestorage.save_image(os.path.join(input_path, seg_name), new_mask) -''' \ No newline at end of file +''' diff --git a/src/server/dcp_server/service.py b/src/server/dcp_server/service.py index cb560bed..d464545b 100644 --- a/src/server/dcp_server/service.py +++ b/src/server/dcp_server/service.py @@ -12,33 +12,46 @@ segmentation_module = __import__("segmentationclasses") # Import configuration -service_config = read_config('service', config_path = 'config.yaml') -model_config = read_config('model', config_path = 'config.yaml') -data_config = read_config('data', config_path = 'config.yaml') -train_config = read_config('train', config_path = 'config.yaml') -eval_config = read_config('eval', config_path = 'config.yaml') -setup_config = read_config('setup', config_path = 'config.yaml') +service_config = read_config("service", config_path="config.yaml") +model_config = read_config("model", config_path="config.yaml") +data_config = read_config("data", config_path="config.yaml") +train_config = read_config("train", config_path="config.yaml") +eval_config = read_config("eval", config_path="config.yaml") +setup_config = read_config("setup", config_path="config.yaml") # instantiate the model -model_class = getattr(models_module, setup_config['model_to_use']) -model = model_class(model_name=setup_config['model_to_use'], - model_config = model_config, - data_config = data_config, - train_config = train_config, - eval_config = eval_config) +model_class = getattr(models_module, setup_config["model_to_use"]) +model = model_class( + model_name=setup_config["model_to_use"], + model_config=model_config, + data_config=data_config, + train_config=train_config, + eval_config=eval_config, +) custom_model_runner = t.cast( - "CustomRunner", bentoml.Runner(CustomRunnable, name=service_config['runner_name'], - runnable_init_params={"model": model, "save_model_path": service_config['bento_model_path']}) + "CustomRunner", + bentoml.Runner( + CustomRunnable, + name=service_config["runner_name"], + runnable_init_params={ + "model": model, + "save_model_path": service_config["bento_model_path"], + }, + ), ) # instantiate the segmentation type -segm_class = getattr(segmentation_module, setup_config['segmentation']) -fsimagestorage = FilesystemImageStorage(data_config, setup_config['model_to_use']) -segmentation = segm_class(imagestorage=fsimagestorage, - runner = custom_model_runner, - model = model) +segm_class = getattr(segmentation_module, setup_config["segmentation"]) +fsimagestorage = FilesystemImageStorage(data_config, setup_config["model_to_use"]) +segmentation = segm_class( + imagestorage=fsimagestorage, runner=custom_model_runner, model=model +) # Call the service -service = CustomBentoService(runner=segmentation.runner, segmentation=segmentation, service_name=service_config['service_name']) -svc = service.start_service() \ No newline at end of file +service = CustomBentoService( + runner=segmentation.runner, + segmentation=segmentation, + service_name=service_config["service_name"], +) +svc = service.start_service() diff --git a/src/server/dcp_server/serviceclasses.py b/src/server/dcp_server/serviceclasses.py index 3ba77ba8..bb1b8e30 100644 --- a/src/server/dcp_server/serviceclasses.py +++ b/src/server/dcp_server/serviceclasses.py @@ -7,12 +7,14 @@ from dcp_server import models as DCPModels import dcp_server.segmentationclasses as DCPSegClasses + class CustomRunnable(bentoml.Runnable): - ''' + """ BentoML, Runner represents a unit of computation that can be executed on a remote Python worker and scales independently. CustomRunnable is a custom runner defined to meet all the requirements needed for this project. - ''' - SUPPORTED_RESOURCES = ("cpu",) #TODO add here? + """ + + SUPPORTED_RESOURCES = ("cpu",) # TODO add here? SUPPORTS_CPU_MULTI_THREADING = False def __init__(self, model: DCPModels, save_model_path: str) -> None: @@ -21,8 +23,8 @@ def __init__(self, model: DCPModels, save_model_path: str) -> None: :param model: model to be trained or evaluated - will be one of classes in models.py :param save_model_path: full path of the model object that it will be saved into :type save_model_path: str - """ - + """ + self.model = model self.save_model_path = save_model_path # update with the latest model if it already exists to continue training from there? @@ -44,16 +46,20 @@ def evaluate(self, img: np.ndarray) -> np.ndarray: mask = self.model.eval(img=img) return mask - + def check_and_load_model(self) -> None: """Checks if the specified model exists in BentoML's model repository. - If the model exists, it loads the latest version of the model into - memory. + If the model exists, it loads the latest version of the model into + memory. """ bento_model_list = [model.tag.name for model in bentoml.models.list()] if self.save_model_path in bento_model_list: - loaded_model = bentoml.picklable_model.load_model(self.save_model_path+":latest") - assert loaded_model.__class__.__name__ == self.model.__class__.__name__, 'Check your config, loaded model and model to use not the same!' + loaded_model = bentoml.picklable_model.load_model( + self.save_model_path + ":latest" + ) + assert ( + loaded_model.__class__.__name__ == self.model.__class__.__name__ + ), "Check your config, loaded model and model to use not the same!" self.model = loaded_model @bentoml.Runnable.method(batchable=False) @@ -66,14 +72,14 @@ def train(self, imgs: List[np.ndarray], masks: List[np.ndarray]) -> str: :type masks: List[np.ndarray] :return: path of the saved model :rtype: str - """ + """ self.model.train(imgs, masks) # Save the bentoml model bentoml.picklable_model.save_model( - self.save_model_path, + self.save_model_path, self.model, external_modules=[DCPModels], - ) + ) # bentoml.pytorch.save_model(self.save_model_path, # Model name in the local Model Store # self.model, # Model instance being saved # external_modules=[DCPModels] @@ -81,19 +87,22 @@ def train(self, imgs: List[np.ndarray], masks: List[np.ndarray]) -> str: return self.save_model_path -class CustomBentoService(): - """BentoML Service class. Contains all the functions necessary to serve the service with BentoML - """ - def __init__(self, runner: CustomRunnable, segmentation: DCPSegClasses, service_name: str) -> None: + +class CustomBentoService: + """BentoML Service class. Contains all the functions necessary to serve the service with BentoML""" + + def __init__( + self, runner: CustomRunnable, segmentation: DCPSegClasses, service_name: str + ) -> None: """Constructs all the necessary attributes for the class CustomBentoService(): :param runner: runner used in the service :type runner: CustomRunnable class object :param segmentation: segmentation type used in the service :type segmentation: segmentation class object from the segmentationclasses.py - :param service_name: name of the service + :param service_name: name of the service :type service_name: str - """ + """ self.runner = runner self.segmentation = segmentation self.service_name = service_name @@ -102,10 +111,12 @@ def start_service(self) -> None: """Starts the service :return: service object needed in service.py and for the bentoml serve call. - """ + """ svc = bentoml.Service(self.service_name, runners=[self.runner]) - @svc.api(input=Text(), output=NumpyNdarray()) #input path to the image output message with success and the save path + @svc.api( + input=Text(), output=NumpyNdarray() + ) # input path to the image output message with success and the save path async def segment_image(input_path: str) -> np.ndarray: """function served within the service, used to segment images @@ -113,10 +124,12 @@ async def segment_image(input_path: str) -> np.ndarray: :type input_path: str :return: list of files not supported :rtype: ndarray - """ + """ list_of_images = self.segmentation.imagestorage.search_images(input_path) - list_of_files_not_suported = self.segmentation.imagestorage.get_unsupported_files(input_path) - + list_of_files_not_suported = ( + self.segmentation.imagestorage.get_unsupported_files(input_path) + ) + if not list_of_images: return np.array(list_of_images) else: @@ -132,13 +145,12 @@ async def train(input_path: str) -> str: :type input_path: str :return: message of success if training went well :rtype: str - """ + """ print("Calling retrain from server.") # Train the model msg = await self.segmentation.train(input_path) - if msg!=self.segmentation.no_files_msg: + if msg != self.segmentation.no_files_msg: msg = "Success! Trained model saved in: " + msg return msg - + return svc - diff --git a/src/server/dcp_server/utils/fsimagestorage.py b/src/server/dcp_server/utils/fsimagestorage.py index 555cbce0..529b14f0 100644 --- a/src/server/dcp_server/utils/fsimagestorage.py +++ b/src/server/dcp_server/utils/fsimagestorage.py @@ -8,22 +8,26 @@ from dcp_server.utils.processing import pad_image, normalise # Import configuration -setup_config = helpers.read_config("setup", config_path = "config.yaml") +setup_config = helpers.read_config("setup", config_path="config.yaml") -class FilesystemImageStorage(): - """ + +class FilesystemImageStorage: + """ Class used to deal with everything related to image storing and processing - loading, saving, transforming. - """ + """ + def __init__(self, data_config: dict, model_used: str) -> None: self.root_dir = data_config["data_root"] - self.gray = bool(data_config["gray"]) + self.gray = bool(data_config["gray"]) self.rescale = bool(data_config["rescale"]) self.model_used = model_used self.channel_ax = None self.img_height = None self.img_width = None - - def load_image(self, cur_selected_img: str, gray: Optional[bool]=None) -> Optional[np.ndarray]: + + def load_image( + self, cur_selected_img: str, gray: Optional[bool] = None + ) -> Optional[np.ndarray]: """Load the image (using skiimage) :param cur_selected_img: full path of the image that needs to be loaded @@ -32,24 +36,26 @@ def load_image(self, cur_selected_img: str, gray: Optional[bool]=None) -> Option :type gray: bool or None, default=Nonee :return: loaded image :rtype: ndarray - """ - if gray is None: gray = self.gray + """ + if gray is None: + gray = self.gray try: - return imread(os.path.join(self.root_dir , cur_selected_img), as_gray=gray) - except ValueError: return None - + return imread(os.path.join(self.root_dir, cur_selected_img), as_gray=gray) + except ValueError: + return None + def save_image(self, to_save_path: str, img: np.ndarray) -> None: - """ Save given image using skimage. + """Save given image using skimage. :param to_save_path: full path to the directory that the image needs to be save into (use also image name in the path, eg. '/users/new_image.png') :type to_save_path: str :param img: image you wish to save :type img: ndarray - """ + """ imsave(os.path.join(self.root_dir, to_save_path), img) - + def search_images(self, directory: str) -> List[str]: - """ Get a list of full paths of the images in the directory. + """Get a list of full paths of the images in the directory. :param directory: Path to the directory to search for images. :type directory: str @@ -58,33 +64,55 @@ def search_images(self, directory: str) -> List[str]: """ # Take all segmentations of the image from the current directory: directory = os.path.join(self.root_dir, directory) - seg_files = [file_name for file_name in os.listdir(directory) if setup_config['seg_name_string'] in file_name] + seg_files = [ + file_name + for file_name in os.listdir(directory) + if setup_config["seg_name_string"] in file_name + ] # Take the image files - difference between the list of all the files in the directory and the list of seg files and only file extensions currently accepted - image_files = [os.path.join(directory, file_name) for file_name in os.listdir(directory) if (file_name not in seg_files) and (helpers.get_file_extension(file_name) in setup_config['accepted_types'])] + image_files = [ + os.path.join(directory, file_name) + for file_name in os.listdir(directory) + if (file_name not in seg_files) + and ( + helpers.get_file_extension(file_name) in setup_config["accepted_types"] + ) + ] return image_files - + def search_segs(self, cur_selected_img: str) -> List[str]: - """ Returns a list of full paths of segmentations for an image. + """Returns a list of full paths of segmentations for an image. :param cur_selected_img: Full path of the image for which segmentations are needed. :type cur_selected_img: str :return: List of segmentation paths for the given image. :rtype: list """ - + # Check the directory the image was selected from: - img_directory = helpers.get_path_parent(os.path.join(self.root_dir, cur_selected_img)) + img_directory = helpers.get_path_parent( + os.path.join(self.root_dir, cur_selected_img) + ) # Take all segmentations of the image from the current directory: - search_string = helpers.get_path_stem(cur_selected_img) + setup_config['seg_name_string'] - #seg_files = [os.path.join(img_directory, file_name) for file_name in os.listdir(img_directory) if search_string in file_name] + search_string = ( + helpers.get_path_stem(cur_selected_img) + setup_config["seg_name_string"] + ) + # seg_files = [os.path.join(img_directory, file_name) for file_name in os.listdir(img_directory) if search_string in file_name] # TODO: check where this is used - copied the command from app's search_segs function (to fix the 1_seg and 11_seg bug) - seg_files = [os.path.join(img_directory, file_name) for file_name in os.listdir(img_directory) if (search_string == helpers.get_path_stem(file_name) or str(file_name).startswith(search_string))] + seg_files = [ + os.path.join(img_directory, file_name) + for file_name in os.listdir(img_directory) + if ( + search_string == helpers.get_path_stem(file_name) + or str(file_name).startswith(search_string) + ) + ] return seg_files - - def get_image_seg_pairs(self, directory:str) -> List[tuple]: - """ Get pairs of (image, image_seg). + + def get_image_seg_pairs(self, directory: str) -> List[tuple]: + """Get pairs of (image, image_seg). Used, e.g., in training to create training data-training labels pairs. @@ -93,60 +121,66 @@ def get_image_seg_pairs(self, directory:str) -> List[tuple]: :return: List of tuple pairs (image, image_seg). :rtype: list """ - + image_files = self.search_images(os.path.join(self.root_dir, directory)) seg_files = [] for image in image_files: seg = self.search_segs(image) - #TODO - the search seg returns all the segs, but here we need only one, hence the seg[0]. Check if it is from training path? + # TODO - the search seg returns all the segs, but here we need only one, hence the seg[0]. Check if it is from training path? seg_files.append(seg[0]) return list(zip(image_files, seg_files)) - - def get_unsupported_files(self, directory:str) -> List[str]: - """ Get unsupported files found in the given directory. + + def get_unsupported_files(self, directory: str) -> List[str]: + """Get unsupported files found in the given directory. :param directory: Directory path to search for files in. :type directory: str :return: List of unsupported files. - :rtype: list + :rtype: list """ - return [file_name for file_name in os.listdir(os.path.join(self.root_dir, directory)) - if not file_name.startswith('.') and helpers.get_file_extension(file_name) not in setup_config['accepted_types']] - - def get_image_size_properties(self, img:np.ndarray, file_extension:str) -> None: + return [ + file_name + for file_name in os.listdir(os.path.join(self.root_dir, directory)) + if not file_name.startswith(".") + and helpers.get_file_extension(file_name) + not in setup_config["accepted_types"] + ] + + def get_image_size_properties(self, img: np.ndarray, file_extension: str) -> None: """Set properties of the image size :param img: Image (numpy array). :type img: ndarray :param file_extension: File extension of the image as saved in the directory. :type file_extension: str - """ + """ # TODO simplify! - + orig_size = img.shape - # png and jpeg will be RGB by default and 2D + # png and jpeg will be RGB by default and 2D # tif can be grayscale 2D or 3D [Z, H, W] # image channels have already been removed in imread if self.gray=True # skimage.imread reads RGB or RGBA images in always with channel axis in dim=2 - if file_extension in (".jpg", ".jpeg", ".png") and self.gray==False: + if file_extension in (".jpg", ".jpeg", ".png") and self.gray == False: self.img_height, self.img_width = orig_size[0], orig_size[1] self.channel_ax = 2 - elif file_extension in (".jpg", ".jpeg", ".png") and self.gray==True: + elif file_extension in (".jpg", ".jpeg", ".png") and self.gray == True: self.img_height, self.img_width = orig_size[0], orig_size[1] self.channel_ax = None - elif file_extension in (".tiff", ".tif") and len(orig_size)==2: + elif file_extension in (".tiff", ".tif") and len(orig_size) == 2: self.img_height, self.img_width = orig_size[0], orig_size[1] self.channel_ax = None # if we have 3 dimensions the [Z, H, W] - elif file_extension in (".tiff", ".tif") and len(orig_size)==3: - print('Warning: 3D image stack found. We are assuming your last dimension is your channel dimension. Please cross check this.') - self.img_height, self.img_width = orig_size[0], orig_size[1] + elif file_extension in (".tiff", ".tif") and len(orig_size) == 3: + print( + "Warning: 3D image stack found. We are assuming your last dimension is your channel dimension. Please cross check this." + ) + self.img_height, self.img_width = orig_size[0], orig_size[1] self.channel_ax = 2 else: - print('File not currently supported. See documentation for accepted types') + print("File not currently supported. See documentation for accepted types") - - def rescale_image(self, img: np.ndarray, order: int=2) -> np.ndarray: + def rescale_image(self, img: np.ndarray, order: int = 2) -> np.ndarray: """rescale image :param img: Image. @@ -156,16 +190,22 @@ def rescale_image(self, img: np.ndarray, order: int=2) -> np.ndarray: :return: Rescaled image. :rtype: ndarray """ - + if self.model_used == "UNet": - return pad_image(img, self.img_height, self.img_width, self.channel_ax, dividable= 16) + return pad_image( + img, self.img_height, self.img_width, self.channel_ax, dividable=16 + ) else: # Cellpose segmentation runs best with 512 size? TODO: check - max_dim = max(self.img_height, self.img_width) - rescale_factor = max_dim/512 - return rescale(img, 1/rescale_factor, order=order, channel_axis=self.channel_ax) - - def resize_mask(self, mask: np.ndarray, channel_ax: Optional[int]=None, order: int=0) -> np.ndarray: + max_dim = max(self.img_height, self.img_width) + rescale_factor = max_dim / 512 + return rescale( + img, 1 / rescale_factor, order=order, channel_axis=self.channel_ax + ) + + def resize_mask( + self, mask: np.ndarray, channel_ax: Optional[int] = None, order: int = 0 + ) -> np.ndarray: """resize the mask so it matches the original image size :param mask: Image. @@ -179,45 +219,48 @@ def resize_mask(self, mask: np.ndarray, channel_ax: Optional[int]=None, order: i :return: Resized image. :rtype: ndarray """ - + if self.model_used == "UNet": # we assume an order C, H, W - if channel_ax is not None and channel_ax==0: + if channel_ax is not None and channel_ax == 0: height_pad = mask.shape[1] - self.img_height - width_pad = mask.shape[2]- self.img_width + width_pad = mask.shape[2] - self.img_width return mask[:, :-height_pad, :-width_pad] - elif channel_ax is not None and channel_ax==2: + elif channel_ax is not None and channel_ax == 2: height_pad = mask.shape[0] - self.img_height - width_pad = mask.shape[1]-self.img_width + width_pad = mask.shape[1] - self.img_width return mask[:-height_pad, :-width_pad, :] - elif channel_ax is not None and channel_ax==1: + elif channel_ax is not None and channel_ax == 1: height_pad = mask.shape[2] - self.img_height - width_pad = mask.shape[0]- self.img_width + width_pad = mask.shape[0] - self.img_width return mask[:-width_pad, :, :-height_pad] - else: + else: height_pad = mask.shape[0] - self.img_height - width_pad = mask.shape[1]-self.img_width - return mask[:-height_pad,:-width_pad] + width_pad = mask.shape[1] - self.img_width + return mask[:-height_pad, :-width_pad] - else: + else: if channel_ax is not None: n_channel_dim = mask.shape[channel_ax] output_size = [self.img_height, self.img_width] output_size.insert(channel_ax, n_channel_dim) - else: output_size = [self.img_height, self.img_width] + else: + output_size = [self.img_height, self.img_width] return resize(mask, output_size, order=order) - - def prepare_images_and_masks_for_training(self, train_img_mask_pairs: List[tuple]) -> tuple: - """ Image and mask processing for training. + + def prepare_images_and_masks_for_training( + self, train_img_mask_pairs: List[tuple] + ) -> tuple: + """Image and mask processing for training. :param train_img_mask_pairs: List pairs of (image, image_seg) (as returned by get_image_seg_pairs() function). :type train_img_mask_pairs: list :return: Lists of processed images and masks. :rtype: tuple """ - - imgs=[] - masks=[] + + imgs = [] + masks = [] for img_file, mask_file in train_img_mask_pairs: img = self.load_image(img_file) img = normalise(img) @@ -225,23 +268,31 @@ def prepare_images_and_masks_for_training(self, train_img_mask_pairs: List[tuple self.get_image_size_properties(img, helpers.get_file_extension(img_file)) # Unet only accepts image sizes divisable by 16 if self.model_used == "UNet": - img = pad_image(img, self.img_height, self.img_width, channel_ax=self.channel_ax, dividable= 16) - mask = pad_image(mask, self.img_height, self.img_width, channel_ax=0, dividable= 16) - if self.model_used == "CustomCellpose" and len(mask.shape)==3: + img = pad_image( + img, + self.img_height, + self.img_width, + channel_ax=self.channel_ax, + dividable=16, + ) + mask = pad_image( + mask, self.img_height, self.img_width, channel_ax=0, dividable=16 + ) + if self.model_used == "CustomCellpose" and len(mask.shape) == 3: # if we also have class mask drop it - mask = masks[0] #assuming mask_channel_axis=0 + mask = masks[0] # assuming mask_channel_axis=0 imgs.append(img) masks.append(mask) return imgs, masks - - def prepare_img_for_eval(self, img_file:str) -> np.ndarray: + + def prepare_img_for_eval(self, img_file: str) -> np.ndarray: """Image processing for model inference. - :param img_file: the path to the image + :param img_file: the path to the image :type img_file: str :return: the loaded and processed image :rtype: np.ndarray - """ + """ # Load and normalise the image img = self.load_image(img_file) img = normalise(img) @@ -250,7 +301,7 @@ def prepare_img_for_eval(self, img_file:str) -> np.ndarray: if self.rescale: img = self.rescale_image(img) return img - + def prepare_mask_for_save(self, mask: np.ndarray, channel_ax: int) -> np.ndarray: """Prepares the mask output of the model to be saved. @@ -260,9 +311,11 @@ def prepare_mask_for_save(self, mask: np.ndarray, channel_ax: int) -> np.ndarray :rype channel_ax: int :return: the ready to save mask :rtype: np.ndarray - """ + """ # Resize the mask if rescaling took place before - if self.rescale is True: - if len(mask.shape)<3: channel_ax=None + if self.rescale is True: + if len(mask.shape) < 3: + channel_ax = None return self.resize_mask(mask, channel_ax) - else: return mask \ No newline at end of file + else: + return mask diff --git a/src/server/dcp_server/utils/helpers.py b/src/server/dcp_server/utils/helpers.py index 6d3eb61b..4e700587 100644 --- a/src/server/dcp_server/utils/helpers.py +++ b/src/server/dcp_server/utils/helpers.py @@ -1,7 +1,8 @@ from pathlib import Path import yaml -def read_config(name:str, config_path:str = 'config.yaml') -> dict: + +def read_config(name: str, config_path: str = "config.yaml") -> dict: """Reads the configuration file :param name: name of the section you want to read (e.g. 'setup','train') @@ -10,24 +11,36 @@ def read_config(name:str, config_path:str = 'config.yaml') -> dict: :type config_path: str, optional :return: dictionary from the config section given by name :rtype: dict - """ + """ with open(config_path) as config_file: - config_dict = yaml.safe_load(config_file) # json.load(config_file) for .cfg file + config_dict = yaml.safe_load( + config_file + ) # json.load(config_file) for .cfg file # Check if config file has main mandatory keys - assert all([i in config_dict.keys() for i in ['setup', 'service', 'model', 'train', 'eval']]) + assert all( + [ + i in config_dict.keys() + for i in ["setup", "service", "model", "train", "eval"] + ] + ) return config_dict[name] -def get_path_stem(filepath: str) -> str: return str(Path(filepath).stem) +def get_path_stem(filepath: str) -> str: + return str(Path(filepath).stem) -def get_path_name(filepath: str) -> str: return str(Path(filepath).name) +def get_path_name(filepath: str) -> str: + return str(Path(filepath).name) -def get_path_parent(filepath: str) -> str: return str(Path(filepath).parent) +def get_path_parent(filepath: str) -> str: + return str(Path(filepath).parent) -def join_path(root_dir:str, filepath: str) -> str: return str(Path(root_dir, filepath)) +def join_path(root_dir: str, filepath: str) -> str: + return str(Path(root_dir, filepath)) -def get_file_extension(file: str) -> str: return str(Path(file).suffix) +def get_file_extension(file: str) -> str: + return str(Path(file).suffix) diff --git a/src/server/dcp_server/utils/processing.py b/src/server/dcp_server/utils/processing.py index ba44dfc8..bdcbca6c 100644 --- a/src/server/dcp_server/utils/processing.py +++ b/src/server/dcp_server/utils/processing.py @@ -5,13 +5,14 @@ from scipy.ndimage import find_objects from skimage import measure import SimpleITK as sitk -from radiomics import shape2D +from radiomics import shape2D import torch -def normalise(img: np.ndarray, norm: str='min-max') -> np.ndarray: - """ Normalises the image based on the chosen method. Currently available methods are: + +def normalise(img: np.ndarray, norm: str = "min-max") -> np.ndarray: + """Normalises the image based on the chosen method. Currently available methods are: - min max normalisation. - + :param img: image to be normalised :type img: np.ndarray :param norm: the normalisation method to apply @@ -19,42 +20,51 @@ def normalise(img: np.ndarray, norm: str='min-max') -> np.ndarray: :return: the normalised image :rtype: np.ndarray """ - if norm=='min-max': - return (img - np.min(img)) / (np.max(img) - np.min(img)) - + if norm == "min-max": + return (img - np.min(img)) / (np.max(img) - np.min(img)) + + +def pad_image( + img: np.ndarray, + height: int, + width: int, + channel_ax: Optional[int] = None, + dividable: int = 16, +) -> np.ndarray: + """Pads the image such that it is dividable by a given number. -def pad_image(img: np.ndarray, height: int, width: int, channel_ax: Optional[int]=None, dividable:int = 16) -> np.ndarray: - """ Pads the image such that it is dividable by a given number. - :param img: image to be padded :type img: np.ndarray : param height: image height : type height: int : param width: image width - : type width: int - :param channel_ax: + : type width: int + :param channel_ax: :type channel_ax: int or None :param dividable: the number with which the new image size should be perfectly dividable by :type dividable: int :return: the padded image :rtype: np.ndarray """ - height_pad = (height//dividable + 1)*dividable - height - width_pad = (width//dividable + 1)*dividable - width - if channel_ax==0: + height_pad = (height // dividable + 1) * dividable - height + width_pad = (width // dividable + 1) * dividable - width + if channel_ax == 0: img = np.pad(img, ((0, 0), (0, height_pad), (0, width_pad))) - elif channel_ax==2: + elif channel_ax == 2: img = np.pad(img, ((0, height_pad), (0, width_pad), (0, 0))) else: - img = np.pad(img, ((0, height_pad), (0, width_pad))) + img = np.pad(img, ((0, height_pad), (0, width_pad))) return img -def convert_to_tensor(imgs: List[np.ndarray], dtype: type, unsqueeze: bool=True) -> torch.Tensor: - """ Convert the imgs to tensors of type dtype and add extra dimension if input bool is true. + +def convert_to_tensor( + imgs: List[np.ndarray], dtype: type, unsqueeze: bool = True +) -> torch.Tensor: + """Convert the imgs to tensors of type dtype and add extra dimension if input bool is true. :param imgs: the list of images to convert :type img: List[np.ndarray] - :param dtype: the data type to convert the image tensor + :param dtype: the data type to convert the image tensor :type dtype: type :param unsqueeze: If True an extra dim will be added at location zero :type unsqueeze: bool @@ -62,19 +72,20 @@ def convert_to_tensor(imgs: List[np.ndarray], dtype: type, unsqueeze: bool=True) :rtype: torch.Tensor """ # Convert images tensors - imgs = torch.stack([ - torch.from_numpy(img.astype(dtype)) for img in imgs - ]) + imgs = torch.stack([torch.from_numpy(img.astype(dtype)) for img in imgs]) imgs = imgs.unsqueeze(1) if imgs.ndim == 3 and unsqueeze is True else imgs return imgs -def crop_centered_padded_patch(img: np.ndarray, - patch_center_xy: tuple, - patch_size: tuple, - obj_label: int, - mask: np.ndarray=None, - noise_intensity: int=None) -> np.ndarray: - """ Crop a patch from an array centered at coordinates patch_center_xy with size patch_size, + +def crop_centered_padded_patch( + img: np.ndarray, + patch_center_xy: tuple, + patch_size: tuple, + obj_label: int, + mask: np.ndarray = None, + noise_intensity: int = None, +) -> np.ndarray: + """Crop a patch from an array centered at coordinates patch_center_xy with size patch_size, and apply padding if necessary. :param img: the input array from which the patch will be cropped @@ -85,19 +96,19 @@ def crop_centered_padded_patch(img: np.ndarray, :type patch_size: tuple :param obj_label: the instance label of the mask at the patch :type obj_label: int - :param mask: The mask array associated with the array x. - Mask is used during training to mask out non-central elements. + :param mask: The mask array associated with the array x. + Mask is used during training to mask out non-central elements. For RandomForest, it is used to calculate pyradiomics features. :type mask: np.ndarray, optional :param noise_intensity: intensity of noise to be added to the background :type noise_intensity: float, optional :return: the cropped patch with applied padding :rtype: np.ndarray - """ + """ height, width = patch_size # Size of the patch - img_height, img_width = img.shape[0], img.shape[1] # Size of the input image - + img_height, img_width = img.shape[0], img.shape[1] # Size of the input image + # Calculate the boundaries of the patch top = patch_center_xy[0] - height // 2 bottom = top + height @@ -111,70 +122,121 @@ def crop_centered_padded_patch(img: np.ndarray, mask_other_objs = (mask_ != obj_label) & (mask_ > 0) img[mask_other_objs] = 0 # Add random noise at locations where other objects are present if noise_intensity is given - if noise_intensity is not None: img[mask_other_objs] = np.random.normal(scale=noise_intensity, size=img[mask_other_objs].shape) + if noise_intensity is not None: + img[mask_other_objs] = np.random.normal( + scale=noise_intensity, size=img[mask_other_objs].shape + ) mask[mask_other_objs] = 0 # crop the mask - mask = mask[max(top, 0):min(bottom, img_height), max(left, 0):min(right, img_width), :] + mask = mask[ + max(top, 0) : min(bottom, img_height), + max(left, 0) : min(right, img_width), + :, + ] - patch = img[max(top, 0):min(bottom, img_height), max(left, 0):min(right, img_width), :] + patch = img[ + max(top, 0) : min(bottom, img_height), max(left, 0) : min(right, img_width), : + ] # Calculate the required padding amounts and apply padding if necessary - if left < 0: - patch = np.hstack(( - np.random.normal(scale=noise_intensity, size=(patch.shape[0], abs(left), patch.shape[2])).astype(np.uint8), - patch)) - if mask is not None: - mask = np.hstack(( - np.zeros((mask.shape[0], abs(left), mask.shape[2])).astype(np.uint8), - mask)) + if left < 0: + patch = np.hstack( + ( + np.random.normal( + scale=noise_intensity, + size=(patch.shape[0], abs(left), patch.shape[2]), + ).astype(np.uint8), + patch, + ) + ) + if mask is not None: + mask = np.hstack( + ( + np.zeros((mask.shape[0], abs(left), mask.shape[2])).astype( + np.uint8 + ), + mask, + ) + ) # Apply padding on the right side if necessary - if right > img_width: - patch = np.hstack(( - patch, - np.random.normal(scale=noise_intensity, size=(patch.shape[0], (right - img_width), patch.shape[2])).astype(np.uint8))) - if mask is not None: - mask = np.hstack(( - mask, - np.zeros((mask.shape[0], (right - img_width), mask.shape[2])).astype(np.uint8))) + if right > img_width: + patch = np.hstack( + ( + patch, + np.random.normal( + scale=noise_intensity, + size=(patch.shape[0], (right - img_width), patch.shape[2]), + ).astype(np.uint8), + ) + ) + if mask is not None: + mask = np.hstack( + ( + mask, + np.zeros( + (mask.shape[0], (right - img_width), mask.shape[2]) + ).astype(np.uint8), + ) + ) # Apply padding on the top side if necessary - if top < 0: - patch = np.vstack(( - np.random.normal(scale=noise_intensity, size=(abs(top), patch.shape[1], patch.shape[2])).astype(np.uint8), - patch)) - if mask is not None: - mask = np.vstack(( - np.zeros((abs(top), mask.shape[1], mask.shape[2])).astype(np.uint8), - mask)) + if top < 0: + patch = np.vstack( + ( + np.random.normal( + scale=noise_intensity, + size=(abs(top), patch.shape[1], patch.shape[2]), + ).astype(np.uint8), + patch, + ) + ) + if mask is not None: + mask = np.vstack( + ( + np.zeros((abs(top), mask.shape[1], mask.shape[2])).astype(np.uint8), + mask, + ) + ) # Apply padding on the bottom side if necessary - if bottom > img_height: - patch = np.vstack(( - patch, - np.random.normal(scale=noise_intensity, size=(bottom - img_height, patch.shape[1], patch.shape[2])).astype(np.uint8))) - if mask is not None: - mask = np.vstack(( - mask, - np.zeros((bottom - img_height, mask.shape[1], mask.shape[2])).astype(np.uint8))) - return patch, mask + if bottom > img_height: + patch = np.vstack( + ( + patch, + np.random.normal( + scale=noise_intensity, + size=(bottom - img_height, patch.shape[1], patch.shape[2]), + ).astype(np.uint8), + ) + ) + if mask is not None: + mask = np.vstack( + ( + mask, + np.zeros( + (bottom - img_height, mask.shape[1], mask.shape[2]) + ).astype(np.uint8), + ) + ) + return patch, mask def get_center_of_mass_and_label(mask: np.ndarray) -> tuple: - """ Computes the centers of mass for each object in a mask. + """Computes the centers of mass for each object in a mask. :param mask: the input mask containing labeled objects :type mask: np.ndarray - :return: + :return: - A list of tuples representing the coordinates (row, column) of the centers of mass for each object. - - A list of ints representing the labels for each object in the mask. - :rtype: + - A list of ints representing the labels for each object in the mask. + :rtype: - List [tuple] - List [int] """ # Compute the centers of mass for each labeled object in the mask - - #return [(int(x[0]), int(x[1])) - # for x in center_of_mass(mask, mask, np.arange(1, mask.max() + 1))] - + + # return [(int(x[0]), int(x[1])) + # for x in center_of_mass(mask, mask, np.arange(1, mask.max() + 1))] + centers = [] labels = [] for region in measure.regionprops(mask): @@ -182,17 +244,17 @@ def get_center_of_mass_and_label(mask: np.ndarray) -> tuple: centers.append((int(center[0]), int(center[1]))) labels.append(region.label) return centers, labels - - -def get_centered_patches(img: np.ndarray, - mask: np.ndarray, - p_size: int, - noise_intensity: int=5, - mask_class: Optional[int]=None, - include_mask: bool=False) -> tuple: - """ Extracts centered patches from the input image based on the centers of objects identified in the mask. +def get_centered_patches( + img: np.ndarray, + mask: np.ndarray, + p_size: int, + noise_intensity: int = 5, + mask_class: Optional[int] = None, + include_mask: bool = False, +) -> tuple: + """Extracts centered patches from the input image based on the centers of objects identified in the mask. :param img: The input image. :type img: numpy.ndarray @@ -205,49 +267,54 @@ def get_centered_patches(img: np.ndarray, :param mask_class: The class represented in the patch. :type mask_class: int :param include_mask: Whether or not to include the mask as an input argument to the model. - :type include_mask: bool + :type include_mask: bool :return: A tuple containing the following elements: - patches (numpy.ndarray): Extracted patches. - patch_masks (numpy.ndarray): Masks corresponding to the extracted patches. - instance_labels (list): Labels identifying each object instance in the extracted patches. - class_labels (list): Labels identifying the class of each object instance in the extracted patches. - :rtype: tuple - """ + :rtype: tuple + """ - patches, patch_masks, instance_labels, class_labels = [], [], [], [] + patches, patch_masks, instance_labels, class_labels = [], [], [], [] # if image is 2D add an additional dim for channels - if img.ndim<3: img = img[:, :, np.newaxis] - if mask.ndim<3: mask = mask[:, :, np.newaxis] + if img.ndim < 3: + img = img[:, :, np.newaxis] + if mask.ndim < 3: + mask = mask[:, :, np.newaxis] # compute center of mass of objects centers_of_mass, instance_labels = get_center_of_mass_and_label(mask) # Crop patches around each center of mass for c, obj_label in zip(centers_of_mass, instance_labels): c_x, c_y = c - patch, patch_mask = crop_centered_padded_patch(img.copy(), - (c_x, c_y), - (p_size, p_size), - obj_label, - mask=deepcopy(mask), - noise_intensity=noise_intensity) + patch, patch_mask = crop_centered_padded_patch( + img.copy(), + (c_x, c_y), + (p_size, p_size), + obj_label, + mask=deepcopy(mask), + noise_intensity=noise_intensity, + ) if include_mask is True: patch_mask = 255 * (patch_mask > 0).astype(np.uint8) patch = np.concatenate((patch, patch_mask), axis=-1) - + patches.append(patch) patch_masks.append(patch_mask) if mask_class is not None: # get the class instance for the specific object instance_labels.append(obj_label) - class_l = np.unique(mask_class[mask[:,:,0]==obj_label]) - assert class_l.shape[0] == 1, "ERROR"+str(class_l) + class_l = np.unique(mask_class[mask[:, :, 0] == obj_label]) + assert class_l.shape[0] == 1, "ERROR" + str(class_l) class_l = int(class_l[0]) - #-1 because labels from mask start from 1, we want classes to start from 0 - class_labels.append(class_l-1) - + # -1 because labels from mask start from 1, we want classes to start from 0 + class_labels.append(class_l - 1) + return patches, patch_masks, instance_labels, class_labels + def get_objects(mask: np.ndarray) -> List: - """ Finds labeled connected components in a binary mask. + """Finds labeled connected components in a binary mask. :param mask: The binary mask representing objects. :type mask: numpy.ndarray @@ -256,8 +323,9 @@ def get_objects(mask: np.ndarray) -> List: """ return find_objects(mask) + def find_max_patch_size(mask: np.ndarray) -> float: - """ Finds the maximum patch size in a mask. + """Finds the maximum patch size in a mask. :param mask: The binary mask representing objects. :type mask: numpy.ndarray @@ -289,18 +357,21 @@ def find_max_patch_size(mask: np.ndarray) -> float: # Check if the current patch size is larger than the maximum if total_size > max_patch_size: max_patch_size = total_size - + max_patch_size_edge = np.ceil(np.sqrt(max_patch_size)) return max_patch_size_edge - -def create_patch_dataset(imgs: List[np.ndarray], - masks_classes: Optional[Union[List[np.ndarray], torch.Tensor]], - masks_instances: Optional[Union[List[np.ndarray], torch.Tensor]], - noise_intensity: int, - max_patch_size: int, - include_mask: bool) -> tuple: - """ Splits images and masks into patches of equal size centered around the cells. + + +def create_patch_dataset( + imgs: List[np.ndarray], + masks_classes: Optional[Union[List[np.ndarray], torch.Tensor]], + masks_instances: Optional[Union[List[np.ndarray], torch.Tensor]], + noise_intensity: int, + max_patch_size: int, + include_mask: bool, +) -> tuple: + """Splits images and masks into patches of equal size centered around the cells. :param imgs: A list of input images. :type imgs: list of numpy.ndarray or torch.Tensor @@ -320,30 +391,32 @@ def create_patch_dataset(imgs: List[np.ndarray], .. note:: If patch_size is not given, the algorithm should first run through all images to find the max cell size, and use the max cell size to define the patch size. All patches and masks should then be returned - in the same format as imgs and masks (same type, i.e. check if tensor or np.array and same + in the same format as imgs and masks (same type, i.e. check if tensor or np.array and same convention of dims, e.g. CxHxW) """ if max_patch_size is None: max_patch_size = np.max([find_max_patch_size(mask) for mask in masks_instances]) - + patches, patch_masks, labels = [], [], [] - for img, mask_class, mask_instance in zip(imgs, masks_classes, masks_instances): + for img, mask_class, mask_instance in zip(imgs, masks_classes, masks_instances): # mask_instance has dimension WxH # mask_class has dimension WxH - patch, patch_mask, _, label = get_centered_patches(img=img, - mask=mask_instance, - p_size=max_patch_size, - noise_intensity=noise_intensity, - mask_class=mask_class, - include_mask = include_mask) + patch, patch_mask, _, label = get_centered_patches( + img=img, + mask=mask_instance, + p_size=max_patch_size, + noise_intensity=noise_intensity, + mask_class=mask_class, + include_mask=include_mask, + ) patches.extend(patch) patch_masks.extend(patch_mask) - labels.extend(label) + labels.extend(label) return patches, patch_masks, labels def get_shape_features(img: np.ndarray, mask: np.ndarray) -> np.ndarray: - """ Calculate shape-based radiomic features from an image within the region defined by the mask. + """Calculate shape-based radiomic features from an image within the region defined by the mask. :param img: The input image. :type img: numpy.ndarray @@ -357,14 +430,17 @@ def get_shape_features(img: np.ndarray, mask: np.ndarray) -> np.ndarray: image = sitk.GetImageFromArray(img.squeeze()) roi_mask = sitk.GetImageFromArray(mask.squeeze()) - shape_calculator = shape2D.RadiomicsShape2D(inputImage=image, inputMask=roi_mask, label=255) + shape_calculator = shape2D.RadiomicsShape2D( + inputImage=image, inputMask=roi_mask, label=255 + ) # Calculate the shape-based radiomic features shape_features = shape_calculator.execute() return np.array(list(shape_features.values())) + def extract_intensity_features(image: np.ndarray, mask: np.ndarray) -> np.ndarray: - """ Extracts intensity-based features from an image within the region defined by the mask. + """Extracts intensity-based features from an image within the region defined by the mask. :param image: The input image. :type image: numpy.ndarray @@ -373,33 +449,36 @@ def extract_intensity_features(image: np.ndarray, mask: np.ndarray) -> np.ndarra :return: An array containing the extracted intensity-based features, including median intensity, mean intensity, and 25th/75th percentile intensity within the masked region. :rtype: numpy.ndarray """ - + features = {} - + # Ensure the image and mask have the same dimensions if image.shape != mask.shape: raise ValueError("Image and mask must have the same dimensions") - masked_image = image[(mask>0)] + masked_image = image[(mask > 0)] # features["min_intensity"] = np.min(masked_image) # features["max_intensity"] = np.max(masked_image) features["median_intensity"] = np.median(masked_image) features["mean_intensity"] = np.mean(masked_image) features["25th_percentile_intensity"] = np.percentile(masked_image, 25) features["75th_percentile_intensity"] = np.percentile(masked_image, 75) - + return np.array(list(features.values())) -def create_dataset_for_rf(imgs: List[np.ndarray], masks: List[np.ndarray]) -> List[np.ndarray]: - """ Extracts shape and intensity-based features from images within regions defined by masks. + +def create_dataset_for_rf( + imgs: List[np.ndarray], masks: List[np.ndarray] +) -> List[np.ndarray]: + """Extracts shape and intensity-based features from images within regions defined by masks. :param imgs: A list of input images. :type imgs: list :param masks: A list of corresponding masks defining regions of interest. :type masks: list :return: A list of arrays containing shape and intensity-based features. - :rtype: list + :rtype: list """ X = [] for img, mask in zip(imgs, masks): @@ -407,5 +486,5 @@ def create_dataset_for_rf(imgs: List[np.ndarray], masks: List[np.ndarray]) -> Li intensity_features = extract_intensity_features(img, mask) features_list = np.concatenate((shape_features, intensity_features), axis=0) X.append(features_list) - - return X \ No newline at end of file + + return X diff --git a/src/server/test/synthetic_dataset.py b/src/server/test/synthetic_dataset.py index 58c87ba6..5c9e0fcb 100644 --- a/src/server/test/synthetic_dataset.py +++ b/src/server/test/synthetic_dataset.py @@ -41,7 +41,8 @@ def assign_unique_colors(labels, colors): return label_colors -def custom_label2rgb(labels, colors=['red', 'green', 'blue'], bg_label=0, alpha=0.5): + +def custom_label2rgb(labels, colors=["red", "green", "blue"], bg_label=0, alpha=0.5): """ Converts a label array to an RGB image using assigned colors for each label. @@ -64,14 +65,17 @@ def custom_label2rgb(labels, colors=['red', 'green', 'blue'], bg_label=0, alpha= for label in np.unique(labels): mask = labels == label if label in label_colors: - rgb = color.label2rgb(mask, colors=[label_colors[label]], bg_label=bg_label, alpha=alpha) + rgb = color.label2rgb( + mask, colors=[label_colors[label]], bg_label=bg_label, alpha=alpha + ) rgb_image += rgb return rgb_image + def add_padding_for_rotation(image, angle): """ - Apply padding and rotation to an image. + Apply padding and rotation to an image. The purpose of this function is to ensure that the rotated image fits within its original dimensions by adding padding, preventing any parts of the image from being cropped. @@ -97,20 +101,25 @@ def add_padding_for_rotation(image, angle): pad_h = (new_h - h) // 2 # Add padding to the image - padded_image = cv2.copyMakeBorder(image, pad_h, pad_h, pad_w, pad_w, cv2.BORDER_CONSTANT) + padded_image = cv2.copyMakeBorder( + image, pad_h, pad_h, pad_w, pad_w, cv2.BORDER_CONSTANT + ) # Rotate the padded image center = (padded_image.shape[1] // 2, padded_image.shape[0] // 2) rotation_matrix = cv2.getRotationMatrix2D(center, angle, 1.0) - rotated_image = cv2.warpAffine(padded_image, rotation_matrix, (padded_image.shape[1], padded_image.shape[0])) + rotated_image = cv2.warpAffine( + padded_image, rotation_matrix, (padded_image.shape[1], padded_image.shape[0]) + ) return rotated_image + def get_object_images(objects): """ Load object images from file paths. - :param objects: A list of dictionaries containing information about the objects such as name, path, intensity + :param objects: A list of dictionaries containing information about the objects such as name, path, intensity :type objects: list[dict] :return: A list of object images loaded from the specified file paths. :rtype: list[numpy.ndarray] @@ -119,14 +128,22 @@ def get_object_images(objects): object_images = [] for obj in objects: - img = cv2.imread(obj['path']) + img = cv2.imread(obj["path"]) # img = cv2.resize(img, obj['size']) img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) object_images.append(img) return object_images -def generate_dataset(num_samples, objects, canvas_size, max_object_counts=None, noise_intensity=None, max_rotation_angle=None): + +def generate_dataset( + num_samples, + objects, + canvas_size, + max_object_counts=None, + noise_intensity=None, + max_rotation_angle=None, +): """ Generate a synthetic dataset with images and masks. @@ -150,7 +167,7 @@ def generate_dataset(num_samples, objects, canvas_size, max_object_counts=None, dataset_masks = [] object_images = get_object_images(objects) - class_intensities = [ (obj['intensity'][0], obj['intensity'][1]) for obj in objects] + class_intensities = [(obj["intensity"][0], obj["intensity"][1]) for obj in objects] if len(object_images[0].shape) == 3: num_of_img_channels = object_images[0].shape[-1] @@ -161,8 +178,12 @@ def generate_dataset(num_samples, objects, canvas_size, max_object_counts=None, max_object_counts = [10] * len(object_images) for _ in range(num_samples): - canvas = np.zeros((canvas_size[0], canvas_size[1], num_of_img_channels), dtype=np.uint8) - mask = np.zeros((canvas_size[0], canvas_size[1], len(object_images)), dtype=np.uint8) + canvas = np.zeros( + (canvas_size[0], canvas_size[1], num_of_img_channels), dtype=np.uint8 + ) + mask = np.zeros( + (canvas_size[0], canvas_size[1], len(object_images)), dtype=np.uint8 + ) for object_index, object_img in enumerate(object_images): @@ -170,70 +191,104 @@ def generate_dataset(num_samples, objects, canvas_size, max_object_counts=None, object_count = random.randint(1, max_count) for _ in range(object_count): - + canvas_range = max(canvas_size) - object_size = random.randint(canvas_range//20, canvas_range//5) + object_size = random.randint(canvas_range // 20, canvas_range // 5) object_img_resized = cv2.resize(object_img, (object_size, object_size)) # object_img_resized = (object_img_resized>0).astype(np.uint8)*(255 - object_size) - intensity_mean = (class_intensities[object_index][1] - class_intensities[object_index][0])/2 - intensity_scale = (class_intensities[object_index][1] - intensity_mean)/3 - class_intensity = np.random.normal(loc=intensity_mean, scale=intensity_scale) - class_intensity = np.clip(class_intensity, class_intensities[object_index][0], class_intensities[object_index][1]) + intensity_mean = ( + class_intensities[object_index][1] + - class_intensities[object_index][0] + ) / 2 + intensity_scale = ( + class_intensities[object_index][1] - intensity_mean + ) / 3 + class_intensity = np.random.normal( + loc=intensity_mean, scale=intensity_scale + ) + class_intensity = np.clip( + class_intensity, + class_intensities[object_index][0], + class_intensities[object_index][1], + ) # class_intensity = random.randint(int(class_intensities[object_index][0]), int(class_intensities[object_index][1])) - object_img_resized = (object_img_resized>0).astype(np.uint8)*(class_intensity)*255 + object_img_resized = ( + (object_img_resized > 0).astype(np.uint8) * (class_intensity) * 255 + ) if num_of_img_channels == 1: - + if max_rotation_angle is not None: # Randomly rotate the object image - rotation_angle = random.uniform(-max_rotation_angle, max_rotation_angle) - object_img_transformed = add_padding_for_rotation(object_img_resized, rotation_angle) + rotation_angle = random.uniform( + -max_rotation_angle, max_rotation_angle + ) + object_img_transformed = add_padding_for_rotation( + object_img_resized, rotation_angle + ) else: object_img_transformed = object_img_resized - - object_size_x, object_size_y = object_img_transformed.shape - + object_size_x, object_size_y = object_img_transformed.shape object_mask = np.zeros((object_size_x, object_size_y), dtype=np.uint8) if num_of_img_channels == 1: # Grayscale image object_mask[object_img_transformed > 0] = object_index + 1 # object_img_resized = np.expand_dims(object_img_resized, axis=-1) - object_img_transformed = np.expand_dims(object_img_transformed, axis=-1) + object_img_transformed = np.expand_dims( + object_img_transformed, axis=-1 + ) else: # Color image with alpha channel object_mask[object_img_resized[:, :, -1] > 0] = object_index + 1 - x = random.randint(0, canvas_size[1] - object_size_x) y = random.randint(0, canvas_size[0] - object_size_y) - intersecting_mask = mask[y:y + object_size_y, x:x + object_size_x].max(axis=-1) + intersecting_mask = mask[ + y : y + object_size_y, x : x + object_size_x + ].max(axis=-1) if (intersecting_mask > 0).any(): continue # Skip if there is an intersection with objects from other classes - - assert mask[y:y + object_size_y, x:x + object_size_x, object_index].shape == object_mask.shape - canvas[y:y + object_size_y, x:x + object_size_x] = object_img_transformed - mask[y:y + object_size_y, x:x + object_size_x, object_index] = np.maximum( - mask[y:y + object_size_y, x:x + object_size_x, object_index], object_mask + assert ( + mask[ + y : y + object_size_y, x : x + object_size_x, object_index + ].shape + == object_mask.shape + ) + + canvas[y : y + object_size_y, x : x + object_size_x] = ( + object_img_transformed + ) + mask[y : y + object_size_y, x : x + object_size_x, object_index] = ( + np.maximum( + mask[ + y : y + object_size_y, x : x + object_size_x, object_index + ], + object_mask, + ) ) - # Add noise to the canvas if noise_intensity is not None: if num_of_img_channels == 1: - noise = np.random.normal(scale=noise_intensity, size=(canvas_size[0], canvas_size[1], 1)) + noise = np.random.normal( + scale=noise_intensity, size=(canvas_size[0], canvas_size[1], 1) + ) # noise = random_noise(canvas, mode='speckle', mean=noise_intensity) - + else: - noise = np.random.normal(scale=noise_intensity, size=(canvas_size[0], canvas_size[1], num_of_img_channels)) + noise = np.random.normal( + scale=noise_intensity, + size=(canvas_size[0], canvas_size[1], num_of_img_channels), + ) noisy_canvas = canvas + noise.astype(np.uint8) - dataset_images.append(noisy_canvas.squeeze(2)) - + dataset_images.append(noisy_canvas.squeeze(2)) + else: dataset_images.append(canvas.squeeze(2)) @@ -251,7 +306,10 @@ def generate_dataset(num_samples, objects, canvas_size, max_object_counts=None, return dataset_images, dataset_masks -def get_synthetic_dataset(num_samples, canvas_size=(512,512), max_object_counts=[15, 15, 15]): + +def get_synthetic_dataset( + num_samples, canvas_size=(512, 512), max_object_counts=[15, 15, 15] +): """Generates a synthetic dataset with images and masks. :param num_samples: The number of samples to generate. @@ -264,23 +322,21 @@ def get_synthetic_dataset(num_samples, canvas_size=(512,512), max_object_counts= :rtype: tuple """ objects = [ - { - - 'name': 'triangle', - 'path': 'test/shapes/triangle.png', - 'intensity' : [0, 0.33] - }, - { - 'name': 'circle', - 'path': 'test/shapes/circle.png', - 'intensity' : [0.34, 0.66] - }, - { - 'name': 'square', - 'path': 'test/shapes/square.png', - 'intensity' : [0.67, 1.0] - }, + { + "name": "triangle", + "path": "test/shapes/triangle.png", + "intensity": [0, 0.33], + }, + {"name": "circle", "path": "test/shapes/circle.png", "intensity": [0.34, 0.66]}, + {"name": "square", "path": "test/shapes/square.png", "intensity": [0.67, 1.0]}, ] - - images, masks = generate_dataset(num_samples, objects, canvas_size=canvas_size, max_object_counts=max_object_counts, noise_intensity=5, max_rotation_angle=30) + + images, masks = generate_dataset( + num_samples, + objects, + canvas_size=canvas_size, + max_object_counts=max_object_counts, + noise_intensity=5, + max_rotation_angle=30, + ) return images, masks diff --git a/src/server/test/test_integration.py b/src/server/test/test_integration.py index 3ef619fc..8637377e 100644 --- a/src/server/test/test_integration.py +++ b/src/server/test/test_integration.py @@ -1,13 +1,15 @@ import sys + sys.path.append(".") from glob import glob import pytest -#import inspect + +# import inspect import random import numpy as np -import torch +import torch from torchmetrics import JaccardIndex from dcp_server.models import * @@ -23,107 +25,123 @@ "CustomCellpose": CustomCellpose, "Inst2MultiSeg": Inst2MultiSeg, "MultiCellpose": MultiCellpose, - "UNet": UNet + "UNet": UNet, } config_paths = glob("test/configs/*.yaml") + @pytest.fixture(params=config_paths) def config_path(request): return request.param + @pytest.fixture() -#def model(model_class, config_path): +# def model(model_class, config_path): def model(config_path): - setup_config = read_config('setup', config_path=config_path) - model_config = read_config('model', config_path=config_path) - data_config = read_config('data', config_path=config_path) - train_config = read_config('train', config_path=config_path) - eval_config = read_config('eval', config_path=config_path) - + setup_config = read_config("setup", config_path=config_path) + model_config = read_config("model", config_path=config_path) + data_config = read_config("data", config_path=config_path) + train_config = read_config("train", config_path=config_path) + eval_config = read_config("eval", config_path=config_path) + model_name = setup_config["model_to_use"] model_class = model_mapping.get(model_name) - model = model_class(model_name, model_config, data_config, train_config, eval_config) + model = model_class( + model_name, model_config, data_config, train_config, eval_config + ) # str(model_class) return model + @pytest.fixture def data_train(): - images, masks = get_synthetic_dataset(num_samples=4, canvas_size=(512,768)) + images, masks = get_synthetic_dataset(num_samples=4, canvas_size=(512, 768)) masks = [np.array(mask) for mask in masks] masks_instances = [mask.sum(-1) for mask in masks] masks_classes = [((mask > 0) * np.arange(1, 4)).sum(-1) for mask in masks] - masks_ = [np.stack((instances, classes)) for instances, classes in zip(masks_instances, masks_classes)] + masks_ = [ + np.stack((instances, classes)) + for instances, classes in zip(masks_instances, masks_classes) + ] return images, masks_ + @pytest.fixture -def data_eval(): +def data_eval(): img, msk = get_synthetic_dataset(num_samples=1) msk = np.array(msk) - msk_ = np.stack((msk.sum(-1), ((msk > 0) * np.arange(1, 4)).sum(-1)), axis=0).transpose(1,0,2,3) + msk_ = np.stack( + (msk.sum(-1), ((msk > 0) * np.arange(1, 4)).sum(-1)), axis=0 + ).transpose(1, 0, 2, 3) return img, msk_ + def test_train_eval_run(data_train, data_eval, model): """ Performs testing, training, and evaluation with the provided data and model. """ images, masks = data_train - if model.model_name == "CustomCellpose": masks = [mask[0] for mask in masks] + if model.model_name == "CustomCellpose": + masks = [mask[0] for mask in masks] model.train(images, masks) - + imgs_test, masks_test = data_eval - if model.model_name == "CustomCellpose": masks = [mask[0] for mask in masks_test] + if model.model_name == "CustomCellpose": + masks = [mask[0] for mask in masks_test] jaccard_index_instances = 0 jaccard_index_classes = 0 - jaccard_metric_binary = JaccardIndex(task="multiclass", num_classes=2, average="macro", ignore_index=0) - jaccard_metric_multi = JaccardIndex(task="multiclass", num_classes=4, average="macro", ignore_index=0) + jaccard_metric_binary = JaccardIndex( + task="multiclass", num_classes=2, average="macro", ignore_index=0 + ) + jaccard_metric_multi = JaccardIndex( + task="multiclass", num_classes=4, average="macro", ignore_index=0 + ) for img, mask in zip(imgs_test, masks_test): - #mask - instance segmentation mask + classes (2, 512, 512) - #pred_mask - tuple of cellpose (512, 512), patch net multiclass segmentation (512, 512, 2) + # mask - instance segmentation mask + classes (2, 512, 512) + # pred_mask - tuple of cellpose (512, 512), patch net multiclass segmentation (512, 512, 2) + + pred_mask = model.eval(img) - pred_mask = model.eval(img) - if pred_mask.ndim > 2: - pred_mask_bin = torch.tensor((pred_mask[0]>0).astype(bool).astype(int)) + pred_mask_bin = torch.tensor((pred_mask[0] > 0).astype(bool).astype(int)) else: pred_mask_bin = torch.tensor((pred_mask > 0).astype(bool).astype(int)) - bin_mask = torch.tensor((mask[0]>0).astype(bool).astype(int)) + bin_mask = torch.tensor((mask[0] > 0).astype(bool).astype(int)) - jaccard_index_instances += jaccard_metric_binary( - pred_mask_bin, - bin_mask - ) + jaccard_index_instances += jaccard_metric_binary(pred_mask_bin, bin_mask) if pred_mask.ndim > 2: jaccard_index_classes += jaccard_metric_multi( - torch.tensor(pred_mask[1].astype(int)), - torch.tensor(mask[1].astype(int)) + torch.tensor(pred_mask[1].astype(int)), + torch.tensor(mask[1].astype(int)), ) - + jaccard_index_instances /= len(imgs_test) - assert(jaccard_index_instances>0.2) + assert jaccard_index_instances > 0.2 # retrieve the attribute names of the class of the current model attrs = model.__dict__.keys() if "metric" in attrs: - assert(model.metric>0.1) + assert model.metric > 0.1 if "loss" in attrs: - assert(model.loss<0.83) + assert model.loss < 0.83 - # for PatchCNN model + # for PatchCNN model if pred_mask.ndim > 2: jaccard_index_classes /= len(imgs_test) - assert(jaccard_index_classes>0.1) + assert jaccard_index_classes > 0.1 + # def test_train_run(data_train, model): @@ -140,12 +158,12 @@ def test_train_eval_run(data_train, data_eval, model): # assert(model.metric>0.1) # if "loss" in attrs: # assert(model.loss<0.3) - + # def test_eval_run(data_train, data_eval, model): # images, masks = data_train # model.train(images, masks) - + # imgs_test, masks_test = data_eval # jaccard_index_instances = 0 @@ -160,7 +178,7 @@ def test_train_eval_run(data_train, data_eval, model): # #pred_mask - tuple of cellpose (512, 512), patch net multiclass segmentation (512, 512, 2) # pred_mask = model.eval(img) #, channels=[0,0]) - + # if pred_mask.ndim > 2: # pred_mask_bin = torch.tensor((pred_mask[0]>0).astype(bool).astype(int)) # else: @@ -169,21 +187,21 @@ def test_train_eval_run(data_train, data_eval, model): # bin_mask = torch.tensor((mask[0]>0).astype(bool).astype(int)) # jaccard_index_instances += jaccard_metric_binary( -# pred_mask_bin, +# pred_mask_bin, # bin_mask # ) # if pred_mask.ndim > 2: # jaccard_index_classes += jaccard_metric_multi( -# torch.tensor(pred_mask[1].astype(int)), +# torch.tensor(pred_mask[1].astype(int)), # torch.tensor(mask[1].astype(int)) # ) - + # jaccard_index_instances /= len(imgs_test) # assert(jaccard_index_instances>0.2) -# # for PatchCNN model +# # for PatchCNN model # if pred_mask.ndim > 2: # jaccard_index_classes /= len(imgs_test) diff --git a/src/server/test/test_models.py b/src/server/test/test_models.py index 7816f594..86529e1c 100644 --- a/src/server/test/test_models.py +++ b/src/server/test/test_models.py @@ -5,33 +5,55 @@ from dcp_server.models.classifiers import FeatureClassifier from dcp_server.utils.helpers import read_config + def test_eval_rf_not_fitted(): """ Tests the evaluation of a random forest model that has not been fitted. """ - model_config = read_config('model', config_path='test/configs/test_config_Inst2MultiSeg_RF.yaml') - data_config = read_config('data', config_path='test/configs/test_config_Inst2MultiSeg_RF.yaml') - train_config = read_config('train', config_path='test/configs/test_config_Inst2MultiSeg_RF.yaml') - eval_config = read_config('eval', config_path='test/configs/test_config_Inst2MultiSeg_RF.yaml') - - model_rf = FeatureClassifier("Random Forest", model_config, data_config, train_config, eval_config) + model_config = read_config( + "model", config_path="test/configs/test_config_Inst2MultiSeg_RF.yaml" + ) + data_config = read_config( + "data", config_path="test/configs/test_config_Inst2MultiSeg_RF.yaml" + ) + train_config = read_config( + "train", config_path="test/configs/test_config_Inst2MultiSeg_RF.yaml" + ) + eval_config = read_config( + "eval", config_path="test/configs/test_config_Inst2MultiSeg_RF.yaml" + ) + + model_rf = FeatureClassifier( + "Random Forest", model_config, data_config, train_config, eval_config + ) - X_test = np.array([[1, 2, 3]]) + X_test = np.array([[1, 2, 3]]) # if we don't fit the model then the model returns zeros - assert np.all(model_rf.eval(X_test)== np.zeros(X_test.shape)) + assert np.all(model_rf.eval(X_test) == np.zeros(X_test.shape)) + def test_update_configs(): """ Tests the update of model training and evaluation configurations. """ - model_config = read_config('model', config_path='test/configs/test_config_Inst2MultiSeg_RF.yaml') - data_config = read_config('data', config_path='test/configs/test_config_Inst2MultiSeg_RF.yaml') - train_config = read_config('train', config_path='test/configs/test_config_Inst2MultiSeg_RF.yaml') - eval_config = read_config('eval', config_path='test/configs/test_config_Inst2MultiSeg_RF.yaml') - - model = models.CustomCellpose("Cellpose", model_config, data_config, train_config, eval_config) + model_config = read_config( + "model", config_path="test/configs/test_config_Inst2MultiSeg_RF.yaml" + ) + data_config = read_config( + "data", config_path="test/configs/test_config_Inst2MultiSeg_RF.yaml" + ) + train_config = read_config( + "train", config_path="test/configs/test_config_Inst2MultiSeg_RF.yaml" + ) + eval_config = read_config( + "eval", config_path="test/configs/test_config_Inst2MultiSeg_RF.yaml" + ) + + model = models.CustomCellpose( + "Cellpose", model_config, data_config, train_config, eval_config + ) new_train_config = {"param1": "value1"} new_eval_config = {"param2": "value2"} @@ -40,6 +62,3 @@ def test_update_configs(): assert model.train_config == new_train_config assert model.eval_config == new_eval_config - - - diff --git a/src/server/test/test_utils.py b/src/server/test/test_utils.py index fd02c044..b0c4f71f 100644 --- a/src/server/test/test_utils.py +++ b/src/server/test/test_utils.py @@ -2,6 +2,7 @@ import pytest from dcp_server.utils.processing import find_max_patch_size + @pytest.fixture def sample_mask(): mask = np.zeros((10, 10), dtype=np.uint8) @@ -9,12 +10,9 @@ def sample_mask(): mask[7:9, 2:5] = 1 return mask + def test_find_max_patch_size(sample_mask): # Test when the function is called with a sample mask result = find_max_patch_size(sample_mask) assert isinstance(result, float) assert result > 0 - - - - From 79c3fd5388591daffab8a4415e9e9de8b13d41e6 Mon Sep 17 00:00:00 2001 From: Christina Bukas Date: Wed, 13 Mar 2024 14:39:37 +0100 Subject: [PATCH 18/26] ingore docs build --- .gitignore | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 24bcbb6c..0f64af41 100644 --- a/.gitignore +++ b/.gitignore @@ -21,7 +21,7 @@ __pycache__/ # Distribution / packaging .Python -# build/ +build/ develop-eggs/ dist/ downloads/ From 2000359ed169e6c8da1008e23130e843f2a3fe45 Mon Sep 17 00:00:00 2001 From: Christina Bukas Date: Wed, 13 Mar 2024 14:40:41 +0100 Subject: [PATCH 19/26] added client typing --- src/client/dcp_client/gui/main_window.py | 22 +- src/client/dcp_client/gui/napari_window.py | 32 ++- src/client/dcp_client/gui/welcome_window.py | 14 +- src/client/dcp_client/utils/bentoml_model.py | 16 +- src/client/dcp_client/utils/compute4mask.py | 210 ++++++++++++++++ src/client/dcp_client/utils/fsimagestorage.py | 11 +- src/client/dcp_client/utils/settings.py | 3 +- src/client/dcp_client/utils/sync_src_dst.py | 10 +- src/client/dcp_client/utils/utils.py | 225 +----------------- 9 files changed, 283 insertions(+), 260 deletions(-) create mode 100644 src/client/dcp_client/utils/compute4mask.py diff --git a/src/client/dcp_client/gui/main_window.py b/src/client/dcp_client/gui/main_window.py index ae4917ed..c1eec891 100644 --- a/src/client/dcp_client/gui/main_window.py +++ b/src/client/dcp_client/gui/main_window.py @@ -11,7 +11,7 @@ QProgressBar, QShortcut, ) -from PyQt5.QtCore import Qt, QThread, pyqtSignal +from PyQt5.QtCore import Qt, QThread, QModelIndex, pyqtSignal from PyQt5.QtGui import QKeySequence from dcp_client.utils import settings @@ -51,7 +51,7 @@ def __init__( self.app = app self.task = task - def run(self): + def run(self) -> None: """ Once run_inference or run_train is executed, the tuple of (message_text, message_title) will be returned to on_finished. @@ -83,7 +83,7 @@ class MainWindow(MyWidget): :type train_data_path: string """ - def __init__(self, app: Application): + def __init__(self, app: Application) -> None: """ Initializes the MainWindow. @@ -100,7 +100,7 @@ def __init__(self, app: Application): self.worker_thread = None self.main_window() - def main_window(self): + def main_window(self) -> None: """Sets up the GUI""" self.setWindowTitle(self.title) # self.resize(1000, 1500) @@ -218,7 +218,7 @@ def main_window(self): self.setLayout(main_layout) self.show() - def on_item_train_selected(self, item): + def on_item_train_selected(self, item: QModelIndex) -> None: """ Is called once an image is selected in the 'curated dataset' folder. @@ -228,7 +228,7 @@ def on_item_train_selected(self, item): self.app.cur_selected_img = item.data() self.app.cur_selected_path = self.app.train_data_path - def on_item_eval_selected(self, item): + def on_item_eval_selected(self, item: QModelIndex) -> None: """ Is called once an image is selected in the 'uncurated dataset' folder. @@ -238,7 +238,7 @@ def on_item_eval_selected(self, item): self.app.cur_selected_img = item.data() self.app.cur_selected_path = self.app.eval_data_path - def on_item_inprogr_selected(self, item): + def on_item_inprogr_selected(self, item: QModelIndex) -> None: """ Is called once an image is selected in the 'in progress' folder. @@ -248,7 +248,7 @@ def on_item_inprogr_selected(self, item): self.app.cur_selected_img = item.data() self.app.cur_selected_path = self.app.inprogr_data_path - def on_train_button_clicked(self): + def on_train_button_clicked(self) -> None: """ Is called once user clicks the "Train Model" button. """ @@ -260,7 +260,7 @@ def on_train_button_clicked(self): # start the worker thread to train self.worker_thread.start() - def on_run_inference_button_clicked(self): + def on_run_inference_button_clicked(self) -> None: """ Is called once user clicks the "Generate Labels" button. """ @@ -272,7 +272,7 @@ def on_run_inference_button_clicked(self): # start the worker thread to run inference self.worker_thread.start() - def on_launch_napari_button_clicked(self): + def on_launch_napari_button_clicked(self) -> None: """ Launches the napari window after the image is selected. """ @@ -283,7 +283,7 @@ def on_launch_napari_button_clicked(self): self.nap_win = NapariWindow(self.app) self.nap_win.show() - def on_finished(self, result): + def on_finished(self, result: tuple) -> None: """ Is called once the worker thread emits the on finished signal. diff --git a/src/client/dcp_client/gui/napari_window.py b/src/client/dcp_client/gui/napari_window.py index 2ca2a18f..001720cf 100644 --- a/src/client/dcp_client/gui/napari_window.py +++ b/src/client/dcp_client/gui/napari_window.py @@ -1,14 +1,16 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional from copy import deepcopy from qtpy.QtWidgets import QPushButton, QComboBox, QLabel, QGridLayout from qtpy.QtCore import Qt import napari +import numpy as np if TYPE_CHECKING: from dcp_client.app import Application -from dcp_client.utils.utils import get_path_stem, check_equal_arrays, Compute4Mask +from dcp_client.utils.utils import get_path_stem, check_equal_arrays +from dcp_client.utils.compute4mask import Compute4Mask from dcp_client.gui._my_widget import MyWidget @@ -19,7 +21,7 @@ class NapariWindow(MyWidget): :type app: Application """ - def __init__(self, app: Application): + def __init__(self, app: Application) -> None: """Initializes the NapariWindow. :param app: The Application instance. @@ -125,7 +127,7 @@ def __init__(self, app: Application): self.setLayout(layout) - def set_editable_mask(self): + def set_editable_mask(self) -> None: """ This function is not implemented. In theory the use can choose between which mask to edit. Currently painting and erasing is only possible on instance mask and in the class mask only @@ -133,7 +135,7 @@ def set_editable_mask(self): """ pass - def on_seg_channel_changed(self, event): + def on_seg_channel_changed(self, event) -> None: """ Is triggered each time the user selects a different layer in the viewer. """ @@ -148,7 +150,7 @@ def on_seg_channel_changed(self, event): else: pass - def axis_changed(self, event): + def axis_changed(self, event) -> None: """ Is triggered each time the user switches the viewer between the mask channels. At this point the class mask needs to be updated according to the changes made tot the instance segmentation mask. @@ -172,7 +174,7 @@ def axis_changed(self, event): self.update_labels_mask(masks[0]) self.switch_to_labels_mask() - def switch_to_instance_mask(self): + def switch_to_instance_mask(self) -> None: """ Switch the application to the active mask mode by enabling 'paint_button', 'erase_button' and 'fill_button'. @@ -181,7 +183,7 @@ def switch_to_instance_mask(self): self.switch_controls("erase_button", True) self.switch_controls("fill_button", True) - def switch_to_labels_mask(self): + def switch_to_labels_mask(self) -> None: """ Switch the application to non-active mask mode by enabling 'fill_button' and disabling 'paint_button' and 'erase_button'. """ @@ -197,7 +199,7 @@ def switch_to_labels_mask(self): self.switch_controls("erase_button", False, info_message_erase) self.switch_controls("fill_button", True) - def update_labels_mask(self, instance_mask): + def update_labels_mask(self, instance_mask: np.ndarray) -> None: """Updates the class mask based on changes in the instance mask. If the instance mask has changed since the last switch between channels, the class mask needs to be updated accordingly. @@ -229,7 +231,9 @@ def update_labels_mask(self, instance_mask): self.layer.data[1] = vis_labels_mask self.layer.refresh() - def update_instance_mask(self, instance_mask, labels_mask): + def update_instance_mask( + self, instance_mask: np.ndarray, labels_mask: np.ndarray + ) -> None: """Updates the instance mask based on changes in the labels mask. If the labels mask has changed, but only if an object has been removed, the instance mask is updated accordingly. @@ -253,7 +257,9 @@ def update_instance_mask(self, instance_mask, labels_mask): self.layer.data[0] = self.original_instance_mask[self.cur_selected_seg] self.layer.refresh() - def switch_controls(self, target_widget, status: bool, info_message=None): + def switch_controls( + self, target_widget: str, status: bool, info_message: Optional[str] = None + ) -> None: """Enables or disables a specific widget. :param target_widget: The name of the widget to be controlled within the QCtrl object. @@ -270,7 +276,7 @@ def switch_controls(self, target_widget, status: bool, info_message=None): except: pass - def on_add_to_curated_button_clicked(self): + def on_add_to_curated_button_clicked(self) -> None: """Defines what happens when the "Move to curated dataset folder" button is clicked.""" if self.app.cur_selected_path == str(self.app.train_data_path): message_text = "Image is already in the 'Curated data' folder and should not be changed again" @@ -325,7 +331,7 @@ def on_add_to_curated_button_clicked(self): self.viewer.close() self.close() - def on_add_to_inprogress_button_clicked(self): + def on_add_to_inprogress_button_clicked(self) -> None: """Defines what happens when the "Move to curation in progress folder" button is clicked.""" # TODO: Do we allow this? What if they moved it by mistake? User can always manually move from their folders?) if self.app.cur_selected_path == str(self.app.train_data_path): diff --git a/src/client/dcp_client/gui/welcome_window.py b/src/client/dcp_client/gui/welcome_window.py index 0856403a..f4bd73da 100644 --- a/src/client/dcp_client/gui/welcome_window.py +++ b/src/client/dcp_client/gui/welcome_window.py @@ -25,7 +25,7 @@ class WelcomeWindow(MyWidget): By clicking 'start' the MainWindow is called. """ - def __init__(self, app: Application): + def __init__(self, app: Application) -> None: """Initializes the WelcomeWindow. :param app: The Application instance. @@ -108,7 +108,7 @@ def __init__(self, app: Application): self.show() - def browse_eval_clicked(self): + def browse_eval_clicked(self) -> None: """Activates when the user clicks the button to choose the evaluation directory (QFileDialog) and displays the name of the evaluation directory chosen in the validation textbox line (QLineEdit). """ @@ -121,7 +121,7 @@ def browse_eval_clicked(self): finally: self.fd = None - def browse_train_clicked(self): + def browse_train_clicked(self) -> None: """Activates when the user clicks the button to choose the train directory (QFileDialog) and displays the name of the train directory chosen in the train textbox line (QLineEdit). """ @@ -132,7 +132,7 @@ def browse_train_clicked(self): self.app.train_data_path = fd.selectedFiles()[0] self.train_textbox.setText(self.app.train_data_path) - def on_text_changed(self, field_obj, field_name, text): + def on_text_changed(self, field_obj: QLineEdit, field_name: str, text: str) -> None: """ Update data paths based on text changes in input fields. Used for copying paths in the welcome window. @@ -153,7 +153,7 @@ def on_text_changed(self, field_obj, field_name, text): self.app.inprogr_data_path = text field_obj.setText(text) - def browse_inprogr_clicked(self): + def browse_inprogr_clicked(self) -> None: """ Activates when the user clicks the button to choose the curation in progress directory (QFileDialog) and displays the name of the evaluation directory chosen in the validation textbox line (QLineEdit). @@ -167,7 +167,7 @@ def browse_inprogr_clicked(self): ] # TODO: case when browse is clicked but nothing is specified - currently it is filled with os.getcwd() self.inprogr_textbox.setText(self.app.inprogr_data_path) - def start_main(self): + def start_main(self) -> None: """Starts the main window after the user clicks 'Start' and only if both evaluation and train directories are chosen and all unique.""" if ( @@ -190,7 +190,7 @@ def start_main(self): self.message_text = "You need to specify a folder both for your uncurated and curated dataset (even if the curated folder is currently empty). Please go back and select folders for both." _ = self.create_warning_box(self.message_text, message_title="Warning") - def start_upload_and_main(self): + def start_upload_and_main(self) -> None: """ If the configs are set to use remote not local server then the user is asked to confirm the upload of their data to the server and the upload starts before launching the main window. diff --git a/src/client/dcp_client/utils/bentoml_model.py b/src/client/dcp_client/utils/bentoml_model.py index 5c2e58fa..c0493dbf 100644 --- a/src/client/dcp_client/utils/bentoml_model.py +++ b/src/client/dcp_client/utils/bentoml_model.py @@ -1,5 +1,5 @@ import asyncio -from typing import Optional +from typing import Optional, List from bentoml.client import Client as BentoClient from bentoml.exceptions import BentoMLException @@ -17,7 +17,7 @@ def __init__(self, client: Optional[BentoClient] = None): """ self.client = client - def connect(self, ip: str = "0.0.0.0", port: int = 7010): + def connect(self, ip: str = "0.0.0.0", port: int = 7010) -> bool: """Connects to the BentoML server. :param ip: IP address of the BentoML server. Default is '0.0.0.0'. @@ -35,7 +35,7 @@ def connect(self, ip: str = "0.0.0.0", port: int = 7010): return False # except ConnectionRefusedError @property - def is_connected(self): + def is_connected(self) -> bool: """Checks if the BentomlModel is connected to the BentoML server. :return: True if connected, False otherwise. @@ -43,7 +43,7 @@ def is_connected(self): """ return bool(self.client) - async def _run_train(self, data_path): + async def _run_train(self, data_path: str): """Runs the training task asynchronously. :param data_path: Path to the training data. @@ -52,11 +52,12 @@ async def _run_train(self, data_path): """ try: response = await self.client.async_train(data_path) + print('kkkkkkkkkkkkkkkkkkkkk', type(response)) return response except BentoMLException: return None - def run_train(self, data_path): + def run_train(self, data_path: str): """Runs the training. :param data_path: Path to the training data. @@ -65,7 +66,7 @@ def run_train(self, data_path): """ return asyncio.run(self._run_train(data_path)) - async def _run_inference(self, data_path): + async def _run_inference(self, data_path: str): """Runs the inference task asynchronously. :param data_path: Path to the data for inference. @@ -74,11 +75,12 @@ async def _run_inference(self, data_path): """ try: response = await self.client.async_segment_image(data_path) + print('jjjjjjjjjj', type(response)) return response except BentoMLException: return None - def run_inference(self, data_path): + def run_inference(self, data_path: str) -> List: """Runs the inference. :param data_path: Path to the data for inference. diff --git a/src/client/dcp_client/utils/compute4mask.py b/src/client/dcp_client/utils/compute4mask.py new file mode 100644 index 00000000..f14bff5d --- /dev/null +++ b/src/client/dcp_client/utils/compute4mask.py @@ -0,0 +1,210 @@ +from typing import List +import numpy as np +from skimage.measure import find_contours, label +from skimage.draw import polygon_perimeter + + +class Compute4Mask: + """ + Compute4Mask provides methods for manipulating masks to make visualisation in the viewer easier. + """ + + @staticmethod + def get_contours( + instance_mask: np.ndarray, contours_level: float = None + ) -> np.ndarray: + """Find contours of objects in the instance mask. This function is used to identify the contours of the objects to prevent the problem of the merged + objects in napari window (mask). + + :param instance_mask: The instance mask array. + :type instance_mask: numpy.ndarray + :param contours_level: Value along which to find contours in the array. See skimage.measure.find_contours for more. + :type: None or float + :return: A binary mask where the contours of all objects in the instance segmentation mask are one and the rest is background. + :rtype: numpy.ndarray + + """ + instance_ids = Compute4Mask.get_unique_objects( + instance_mask + ) # get object instance labels ignoring background + contour_mask = np.zeros_like(instance_mask) + for instance_id in instance_ids: + # get a binary mask only of object + single_obj_mask = np.zeros_like(instance_mask) + single_obj_mask[instance_mask == instance_id] = 1 + try: + # compute contours for mask + contours = find_contours(single_obj_mask, contours_level) + # sometimes little dots appeas as additional contours so remove these + if len(contours) > 1: + contour_sizes = [contour.shape[0] for contour in contours] + contour = contours[contour_sizes.index(max(contour_sizes))].astype( + int + ) + else: + contour = contours[0] + # and draw onto contours mask + rr, cc = polygon_perimeter( + contour[:, 0], contour[:, 1], contour_mask.shape + ) + contour_mask[rr, cc] = instance_id + except: + print("Could not create contour for instance id", instance_id) + return contour_mask + + @staticmethod + def add_contour(labels_mask: np.ndarray, instance_mask: np.ndarray) -> np.ndarray: + """Add contours of objects to the labels mask. + + :param labels_mask: The class mask array without the contour pixels annotated. + :type labels_mask: numpy.ndarray + :param instance_mask: The instance mask array. + :type instance_mask: numpy.ndarray + :return: The updated class mask including contours. + :rtype: numpy.ndarray + """ + instance_ids = Compute4Mask.get_unique_objects(instance_mask) + for instance_id in instance_ids: + where_instances = np.where(instance_mask == instance_id) + # get unique class ids where the object is present + class_vals, counts = np.unique( + labels_mask[where_instances], return_counts=True + ) + # and take the class id which is most heavily represented + class_id = class_vals[np.argmax(counts)] + # make sure instance mask and class mask match + labels_mask[np.where(instance_mask == instance_id)] = class_id + return labels_mask + + @staticmethod + def compute_new_instance_mask( + labels_mask: np.ndarray, instance_mask: np.ndarray + ) -> np.ndarray: + """Given an updated labels mask, update also the instance mask accordingly. + So far the user can only remove an entire object in the labels mask view by + setting the color of the object to the background. + Therefore the instance mask can only change by entirely removing an object. + + :param labels_mask: The labels mask array, with changes made by the user. + :type labels_mask: numpy.ndarray + :param instance_mask: The existing instance mask, which needs to be updated. + :type instance_mask: numpy.ndarray + :return: The updated instance mask. + :rtype: numpy.ndarray + """ + instance_ids = Compute4Mask.get_unique_objects(instance_mask) + for instance_id in instance_ids: + unique_items_in_class_mask = list( + np.unique(labels_mask[instance_mask == instance_id]) + ) + if ( + len(unique_items_in_class_mask) == 1 + and unique_items_in_class_mask[0] == 0 + ): + instance_mask[instance_mask == instance_id] = 0 + return instance_mask + + @staticmethod + def compute_new_labels_mask( + labels_mask: np.ndarray, + instance_mask: np.ndarray, + original_instance_mask: np.ndarray, + old_instances: np.ndarray, + ) -> np.ndarray: + """Given the existing labels mask, the updated instance mask is used to update the labels mask. + + :param labels_mask: The existing labels mask, which needs to be updated. + :type labels_mask: numpy.ndarray + :param instance_mask: The instance mask array, with changes made by the user. + :type instance_mask: numpy.ndarray + :param original_instance_mask: The instance mask array, before the changes made by the user. + :type original_instance_mask: numpy.ndarray + :param old_instances: A list of the instance label ids in original_instance_mask. + :type old_instances: list + :return: The new labels mask, with updated changes according to those the user has made in the instance mask. + :rtype: numpy.ndarray + """ + new_labels_mask = np.zeros_like(labels_mask) + for instance_id in np.unique(instance_mask): + where_instance = np.where(instance_mask == instance_id) + # if the label is background skip + if instance_id == 0: + continue + # if the label is a newly added object, add with the same id to the labels mask + # this is an indication to the user that this object needs to be assigned a class + elif instance_id not in old_instances: + new_labels_mask[where_instance] = instance_id + else: + where_instance_orig = np.where(original_instance_mask == instance_id) + # if the locations of the instance haven't changed, means object wasn't changed, do nothing + num_classes = np.unique(labels_mask[where_instance]) + # if area was erased and object retains same class + if len(num_classes) == 1: + new_labels_mask[where_instance] = num_classes[0] + # area was added where there is background or other class + else: + old_class_id, counts = np.unique( + labels_mask[where_instance_orig], return_counts=True + ) + # assert len(old_class_id)==1 + # old_class_id = old_class_id[0] + # and take the class id which is most heavily represented + old_class_id = old_class_id[np.argmax(counts)] + new_labels_mask[where_instance] = old_class_id + + return new_labels_mask + + @staticmethod + def get_unique_objects(active_mask: np.ndarray) -> List: + """Gets unique objects from the active mask. + + :param active_mask: The mask array. + :type active_mask: numpy.ndarray + :return: A list of unique object labels. + :rtype: list + """ + return list(np.unique(active_mask)[1:]) + + @staticmethod + def assert_consistent_labels(mask: np.ndarray) -> tuple: + """Before saving the final mask make sure the user has not mistakenly made an error during annotation, + such that one instance id does not correspond to exactly one class id. Also checks whether for one instance id + multiple classes exist. + :param mask: The mask which we want to test. + :type mask: numpy.ndarray + :return: + - A boolean which is True if there is more than one connected components corresponding to an instance id and Fale otherwise. + - A boolean which is True if there is a missmatch between the instance mask and class masks (not 1-1 correspondance) and Flase otherwise. + - A list with all the instance ids for which more than one connected component was found. + - A list with all the instance ids for which a missmatch between class and instance masks was found. + :rtype : + - bool + - bool + - list[int] + - list[int] + """ + user_annot_error = False + mask_mismatch_error = False + faulty_ids_annot = [] + faulty_ids_missmatch = [] + instance_mask, class_mask = mask[0], mask[1] + instance_ids = Compute4Mask.get_unique_objects(instance_mask) + for instance_id in instance_ids: + # check if there are more than one objects (connected components) with same instance_id + if np.unique(label(instance_mask == instance_id)).shape[0] > 2: + user_annot_error = True + faulty_ids_annot.append(instance_id) + # and check if there is a mismatch between class mask and instance mask - should never happen! + if ( + np.unique(class_mask[np.where(instance_mask == instance_id)]).shape[0] + > 1 + ): + mask_mismatch_error = True + faulty_ids_missmatch.append(instance_id) + + return ( + user_annot_error, + mask_mismatch_error, + faulty_ids_annot, + faulty_ids_missmatch, + ) diff --git a/src/client/dcp_client/utils/fsimagestorage.py b/src/client/dcp_client/utils/fsimagestorage.py index 52d6c006..3e8a5e3c 100644 --- a/src/client/dcp_client/utils/fsimagestorage.py +++ b/src/client/dcp_client/utils/fsimagestorage.py @@ -1,5 +1,6 @@ -from skimage.io import imread, imsave import os +import numpy as np +from skimage.io import imread, imsave from dcp_client.app import ImageStorage @@ -7,7 +8,7 @@ class FilesystemImageStorage(ImageStorage): """FilesystemImageStorage class for handling image storage operations on the local filesystem.""" - def load_image(self, from_directory, cur_selected_img): + def load_image(self, from_directory: str, cur_selected_img: str) -> np.ndarray: """Loads an image from the specified directory. :param from_directory: Path to the directory containing the image. @@ -19,7 +20,7 @@ def load_image(self, from_directory, cur_selected_img): # Read the selected image and read the segmentation if any: return imread(os.path.join(from_directory, cur_selected_img)) - def move_image(self, from_directory, to_directory, cur_selected_img): + def move_image(self, from_directory: str, to_directory: str, cur_selected_img: str) -> None: """Moves an image from one directory to another. :param from_directory: Path to the source directory. @@ -37,7 +38,7 @@ def move_image(self, from_directory, to_directory, cur_selected_img): os.path.join(to_directory, cur_selected_img), ) - def save_image(self, to_directory, cur_selected_img, img): + def save_image(self, to_directory: str, cur_selected_img: str, img: np.ndarray) -> None: """Saves an image to the specified directory. :param to_directory: Path to the directory where the image will be saved. @@ -49,7 +50,7 @@ def save_image(self, to_directory, cur_selected_img, img): imsave(os.path.join(to_directory, cur_selected_img), img) - def delete_image(self, from_directory, cur_selected_img): + def delete_image(self, from_directory: str, cur_selected_img: str) -> None: """Deletes an image from the specified directory. :param from_directory: Path to the directory containing the image. diff --git a/src/client/dcp_client/utils/settings.py b/src/client/dcp_client/utils/settings.py index 3decd50f..5107fb82 100644 --- a/src/client/dcp_client/utils/settings.py +++ b/src/client/dcp_client/utils/settings.py @@ -1,4 +1,5 @@ -def init(): +def init() -> None: + """ Initialise global variables.""" global accepted_types accepted_types = (".jpg", ".jpeg", ".png", ".tiff", ".tif") global seg_name_string diff --git a/src/client/dcp_client/utils/sync_src_dst.py b/src/client/dcp_client/utils/sync_src_dst.py index 951f8cb3..0698901d 100644 --- a/src/client/dcp_client/utils/sync_src_dst.py +++ b/src/client/dcp_client/utils/sync_src_dst.py @@ -15,7 +15,7 @@ def __init__( user_name: str, host_name: str, server_repo_path: str, - ): + ) -> None: """Constructs all the necessary attributes for the CustomRunnable. :param user_name: the user name of the server - if "local", then it is assumed that local machine is used for the server @@ -29,12 +29,14 @@ def __init__( self.host_name = host_name self.server_repo_path = server_repo_path - def first_sync(self, path): + def first_sync(self, path: str) -> tuple: """ During the first sync the folder structure should be created on the server :param path: Path to the local directory to synchronize. :type path: str + :return: result message of subprocess + :rtype: tuple """ server = self.user_name + "@" + self.host_name + ":" + self.server_repo_path try: @@ -44,7 +46,7 @@ def first_sync(self, path): except subprocess.CalledProcessError as e: return ("Error", e) - def sync(self, src, dst, path): + def sync(self, src: str, dst: str, path: str) -> tuple: """Syncs the data between the src and the dst. Both src and dst can be one of either 'client' or 'server', whereas path is the local path we wish to sync @@ -54,6 +56,8 @@ def sync(self, src, dst, path): :type dst: str :param path: Path to the directory we want to synchronize. :type path: str + :return: result message of subprocess + :rtype: tuple """ path += "/" # otherwise it doesn't go in the directory diff --git a/src/client/dcp_client/utils/utils.py b/src/client/dcp_client/utils/utils.py index 5b2ef133..eb08f881 100644 --- a/src/client/dcp_client/utils/utils.py +++ b/src/client/dcp_client/utils/utils.py @@ -1,12 +1,10 @@ -from PyQt5.QtWidgets import QFileIconProvider -from PyQt5.QtCore import QSize -from PyQt5.QtGui import QPixmap, QIcon -import numpy as np -from skimage.measure import find_contours, label -from skimage.draw import polygon_perimeter +from qtpy.QtWidgets import QFileIconProvider +from qtpy.QtCore import QSize +from qtpy.QtGui import QPixmap, QIcon from pathlib import Path, PurePath import yaml +import numpy as np from dcp_client.utils import settings @@ -17,7 +15,7 @@ def __init__(self) -> None: super().__init__() self.ICON_SIZE = QSize(512, 512) - def icon(self, type: "QFileIconProvider.IconType"): + def icon(self, type: QFileIconProvider.IconType) -> QIcon: """Returns the icon for the specified file type. :param type: The type of the file for which the icon is requested. @@ -38,7 +36,7 @@ def icon(self, type: "QFileIconProvider.IconType"): return super().icon(type) -def read_config(name, config_path="config.yaml") -> dict: +def read_config(name: str, config_path: str = "config.yaml") -> dict: """Reads the configuration file :param name: name of the section you want to read (e.g. 'setup','train') @@ -57,7 +55,7 @@ def read_config(name, config_path="config.yaml") -> dict: return config_dict[name] -def get_relative_path(filepath): +def get_relative_path(filepath: str) -> str: """Returns the name of the file from the given filepath. :param filepath: The path of the file. @@ -68,7 +66,7 @@ def get_relative_path(filepath): return PurePath(filepath).name -def get_path_stem(filepath): +def get_path_stem(filepath: str) -> str: """Returns the stem (filename without its extension) from the given filepath. :param filepath: The path of the file. @@ -79,7 +77,7 @@ def get_path_stem(filepath): return str(Path(filepath).stem) -def get_path_name(filepath): +def get_path_name(filepath: str) -> str: """Returns the name of the file from the given filepath. :param filepath: The path of the file. @@ -90,7 +88,7 @@ def get_path_name(filepath): return str(Path(filepath).name) -def get_path_parent(filepath): +def get_path_parent(filepath: str) -> str: """Returns the parent directory of the given filepath. :param filepath: The path of the file. @@ -101,7 +99,7 @@ def get_path_parent(filepath): return str(Path(filepath).parent) -def join_path(root_dir, filepath): +def join_path(root_dir: str, filepath: str) -> str: """Joins the root directory path with the given filepath. :param root_dir: The root directory. @@ -114,7 +112,7 @@ def join_path(root_dir, filepath): return str(Path(root_dir, filepath)) -def check_equal_arrays(array1, array2): +def check_equal_arrays(array1: np.ndarray, array2: np.ndarray) -> bool: """Checks if two arrays are equal. :param array1: The first array. @@ -125,202 +123,3 @@ def check_equal_arrays(array1, array2): :rtype: bool """ return np.array_equal(array1, array2) - - -class Compute4Mask: - """ - Compute4Mask provides methods for manipulating masks. - """ - - @staticmethod - def get_contours(instance_mask, contours_level=None): - """Find contours of objects in the instance mask. This function is used to identify the contours of the objects to prevent the problem of the merged - objects in napari window (mask). - - :param instance_mask: The instance mask array. - :type instance_mask: numpy.ndarray - :param contours_level: Value along which to find contours in the array. See skimage.measure.find_contours for more. - :type: None or float - :return: A binary mask where the contours of all objects in the instance segmentation mask are one and the rest is background. - :rtype: numpy.ndarray - - """ - instance_ids = Compute4Mask.get_unique_objects( - instance_mask - ) # get object instance labels ignoring background - contour_mask = np.zeros_like(instance_mask) - for instance_id in instance_ids: - # get a binary mask only of object - single_obj_mask = np.zeros_like(instance_mask) - single_obj_mask[instance_mask == instance_id] = 1 - try: - # compute contours for mask - contours = find_contours(single_obj_mask, contours_level) - # sometimes little dots appeas as additional contours so remove these - if len(contours) > 1: - contour_sizes = [contour.shape[0] for contour in contours] - contour = contours[contour_sizes.index(max(contour_sizes))].astype( - int - ) - else: - contour = contours[0] - # and draw onto contours mask - rr, cc = polygon_perimeter( - contour[:, 0], contour[:, 1], contour_mask.shape - ) - contour_mask[rr, cc] = instance_id - except: - print("Could not create contour for instance id", instance_id) - return contour_mask - - @staticmethod - def add_contour(labels_mask, instance_mask): - """Add contours of objects to the labels mask. - - :param labels_mask: The class mask array without the contour pixels annotated. - :type labels_mask: numpy.ndarray - :param instance_mask: The instance mask array. - :type instance_mask: numpy.ndarray - :return: The updated class mask including contours. - :rtype: numpy.ndarray - """ - instance_ids = Compute4Mask.get_unique_objects(instance_mask) - for instance_id in instance_ids: - where_instances = np.where(instance_mask == instance_id) - # get unique class ids where the object is present - class_vals, counts = np.unique( - labels_mask[where_instances], return_counts=True - ) - # and take the class id which is most heavily represented - class_id = class_vals[np.argmax(counts)] - # make sure instance mask and class mask match - labels_mask[np.where(instance_mask == instance_id)] = class_id - return labels_mask - - @staticmethod - def compute_new_instance_mask(labels_mask, instance_mask): - """Given an updated labels mask, update also the instance mask accordingly. - So far the user can only remove an entire object in the labels mask view by - setting the color of the object to the background. - Therefore the instance mask can only change by entirely removing an object. - - :param labels_mask: The labels mask array, with changes made by the user. - :type labels_mask: numpy.ndarray - :param instance_mask: The existing instance mask, which needs to be updated. - :type instance_mask: numpy.ndarray - :return: The updated instance mask. - :rtype: numpy.ndarray - """ - instance_ids = Compute4Mask.get_unique_objects(instance_mask) - for instance_id in instance_ids: - unique_items_in_class_mask = list( - np.unique(labels_mask[instance_mask == instance_id]) - ) - if ( - len(unique_items_in_class_mask) == 1 - and unique_items_in_class_mask[0] == 0 - ): - instance_mask[instance_mask == instance_id] = 0 - return instance_mask - - @staticmethod - def compute_new_labels_mask( - labels_mask, instance_mask, original_instance_mask, old_instances - ): - """Given the existing labels mask, the updated instance mask is used to update the labels mask. - - :param labels_mask: The existing labels mask, which needs to be updated. - :type labels_mask: numpy.ndarray - :param instance_mask: The instance mask array, with changes made by the user. - :type instance_mask: numpy.ndarray - :param original_instance_mask: The instance mask array, before the changes made by the user. - :type original_instance_mask: numpy.ndarray - :param old_instances: A list of the instance label ids in original_instance_mask. - :type old_instances: list - :return: The new labels mask, with updated changes according to those the user has made in the instance mask. - :rtype: numpy.ndarray - """ - new_labels_mask = np.zeros_like(labels_mask) - for instance_id in np.unique(instance_mask): - where_instance = np.where(instance_mask == instance_id) - # if the label is background skip - if instance_id == 0: - continue - # if the label is a newly added object, add with the same id to the labels mask - # this is an indication to the user that this object needs to be assigned a class - elif instance_id not in old_instances: - new_labels_mask[where_instance] = instance_id - else: - where_instance_orig = np.where(original_instance_mask == instance_id) - # if the locations of the instance haven't changed, means object wasn't changed, do nothing - num_classes = np.unique(labels_mask[where_instance]) - # if area was erased and object retains same class - if len(num_classes) == 1: - new_labels_mask[where_instance] = num_classes[0] - # area was added where there is background or other class - else: - old_class_id, counts = np.unique( - labels_mask[where_instance_orig], return_counts=True - ) - # assert len(old_class_id)==1 - # old_class_id = old_class_id[0] - # and take the class id which is most heavily represented - old_class_id = old_class_id[np.argmax(counts)] - new_labels_mask[where_instance] = old_class_id - - return new_labels_mask - - @staticmethod - def get_unique_objects(active_mask): - """Gets unique objects from the active mask. - - :param active_mask: The mask array. - :type active_mask: numpy.ndarray - :return: A list of unique object labels. - :rtype: list - """ - return list(np.unique(active_mask)[1:]) - - @staticmethod - def assert_consistent_labels(mask): - """Before saving the final mask make sure the user has not mistakenly made an error during annotation, - such that one instance id does not correspond to exactly one class id. Also checks whether for one instance id - multiple classes exist. - :param mask: The mask which we want to test. - :type mask: numpy.ndarray - :return: - - A boolean which is True if there is more than one connected components corresponding to an instance id and Fale otherwise. - - A boolean which is True if there is a missmatch between the instance mask and class masks (not 1-1 correspondance) and Flase otherwise. - - A list with all the instance ids for which more than one connected component was found. - - A list with all the instance ids for which a missmatch between class and instance masks was found. - :rtype : - - bool - - bool - - list[int] - - list[int] - """ - user_annot_error = False - mask_mismatch_error = False - faulty_ids_annot = [] - faulty_ids_missmatch = [] - instance_mask, class_mask = mask[0], mask[1] - instance_ids = Compute4Mask.get_unique_objects(instance_mask) - for instance_id in instance_ids: - # check if there are more than one objects (connected components) with same instance_id - if np.unique(label(instance_mask == instance_id)).shape[0] > 2: - user_annot_error = True - faulty_ids_annot.append(instance_id) - # and check if there is a mismatch between class mask and instance mask - should never happen! - if ( - np.unique(class_mask[np.where(instance_mask == instance_id)]).shape[0] - > 1 - ): - mask_mismatch_error = True - faulty_ids_missmatch.append(instance_id) - - return ( - user_annot_error, - mask_mismatch_error, - faulty_ids_annot, - faulty_ids_missmatch, - ) From 221d289f53cfdab1dc7aac776dea412f65853cc8 Mon Sep 17 00:00:00 2001 From: Christina Bukas Date: Wed, 13 Mar 2024 18:14:52 +0100 Subject: [PATCH 20/26] update documentation --- docs/source/conf.py | 4 +- docs/source/dcp_client.rst | 41 +++++++++++++++++-- docs/source/dcp_server.rst | 38 +++++++++-------- docs/source/dcp_server.utils.rst | 37 +++++++++++++++++ src/server/MANIFEST.in | 1 + src/server/dcp_server/__init__.py | 8 +--- src/server/dcp_server/config.yaml | 6 +-- .../dcp_server/models/inst_to_multi_seg.py | 2 +- src/server/dcp_server/segmentationclasses.py | 5 +-- src/server/dcp_server/utils/__init__.py | 0 src/server/dcp_server/utils/fsimagestorage.py | 18 +++----- src/server/dcp_server/utils/helpers.py | 6 +-- src/server/dcp_server/utils/processing.py | 10 ++--- 13 files changed, 118 insertions(+), 58 deletions(-) create mode 100644 docs/source/dcp_server.utils.rst create mode 100644 src/server/MANIFEST.in create mode 100644 src/server/dcp_server/utils/__init__.py diff --git a/docs/source/conf.py b/docs/source/conf.py index 1145052d..c0df37ab 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -25,7 +25,7 @@ # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output html_theme = 'alabaster' -html_static_path = ['_static'] +#html_static_path = ['_static'] import os import sys @@ -35,7 +35,7 @@ # Add parent dir to known paths p = Path(__file__).parents[2] sys.path.insert(0, os.path.abspath(p)) - +sys.path.insert(0, os.path.join(p, 'src/server/dcp_server')) # Add the following extensions extensions = [ 'sphinx.ext.autodoc', diff --git a/docs/source/dcp_client.rst b/docs/source/dcp_client.rst index ab9b9ecd..f921a211 100644 --- a/docs/source/dcp_client.rst +++ b/docs/source/dcp_client.rst @@ -1,13 +1,32 @@ dcp\_client package =================== +The dcp_client package contains modules and subpackages for interacting with a server for model inference and training. It provides functionalities for managing GUI windows, handling image storage, and connecting to the server for model operations. -.. toctree:: - :maxdepth: 4 +dcp_client.app + Defines the core application class and related functionalities. + - ``dcp_client.app.Application``: Represents the main application and provides methods for image management, model interaction, and server connectivity. + - ``dcp_client.app.DataSync``: Abstract base class for data synchronization operations. + - ``dcp_client.app.ImageStorage``: Abstract base class for image storage operations. + - ``dcp_client.app.Model``: Abstract base class for model operations. - dcp_client.gui - dcp_client.utils +dcp_client.gui + Contains modules for GUI components. + - ``dcp_client.gui.main_window``: Defines the main application window and associated event functions. + - ``dcp_client.gui.napari_window``: Manages the Napari window and its functionalities. + - ``dcp_client.gui.welcome_window``: Implements the welcome window and its interactions. +dcp_client.utils + Contains utility modules for various tasks. + - ``dcp_client.utils.bentoml_model``: Handles interactions with BentoML for model inference and training. + - ``dcp_client.utils.fsimagestorage``: Provides functions for managing images stored in the filesystem. + - ``dcp_client.utils.settings``: Defines initialization functions and settings. + - ``dcp_client.utils.sync_src_dst``: Implements data synchronization between source and destination. + - ``dcp_client.utils.utils``: Offers various utility functions for common tasks. + + +Submodules +---------- dcp\_client.app module ---------------------- @@ -17,3 +36,17 @@ dcp\_client.app module :undoc-members: :show-inheritance: +dcp\_client.gui module +---------------------- +.. toctree:: + :maxdepth: 4 + + dcp_client.gui + +dcp\_client.utils module +------------------------ +.. toctree:: + :maxdepth: 4 + + dcp_client.utils + \ No newline at end of file diff --git a/docs/source/dcp_server.rst b/docs/source/dcp_server.rst index dfb482d9..78a1ef23 100644 --- a/docs/source/dcp_server.rst +++ b/docs/source/dcp_server.rst @@ -1,23 +1,25 @@ dcp\_server package =================== -.. automodule:: dcp_server - :members: - :undoc-members: - :show-inheritance: - :exclude-members: dcp\_server.main module +The dcp_server package is structured to handle various server-side functionalities related model serving for segmentation and training. -Submodules ----------- +dcp_server.models + Defines various models for cell classification and segmentation, including CellClassifierFCNN, CellClassifierShallowModel, CellposePatchCNN, CustomCellposeModel, and UNet. + These models handle tasks such as evaluation, forward pass, training, and updating configurations. -dcp\_server.fsimagestorage module ---------------------------------- +dcp_server.segmentationclasses + Defines segmentation classes for specific projects, such as GFPProjectSegmentation, GeneralSegmentation, and MitoProjectSegmentation. + These classes contain methods for segmenting images and training models on images and masks. -.. automodule:: dcp_server.fsimagestorage - :members: - :undoc-members: - :show-inheritance: +dcp_server.serviceclasses + Defines service classes, such as CustomBentoService and CustomRunnable, for serving the models with BentoML and handling computation on remote Python workers. +dcp_server.utils + Provides various utility functions for dealing with image storage, image processing, feature extraction, file handling, configuration reading, and path manipulation. + + +Submodules +---------- dcp\_server.models module ------------------------- @@ -44,9 +46,9 @@ dcp\_server.serviceclasses module :show-inheritance: dcp\_server.utils module ------------------------- +--------------------------------- -.. automodule:: dcp_server.utils - :members: - :undoc-members: - :show-inheritance: +.. toctree:: + :maxdepth: 4 + + dcp_server.utils \ No newline at end of file diff --git a/docs/source/dcp_server.utils.rst b/docs/source/dcp_server.utils.rst new file mode 100644 index 00000000..a6334330 --- /dev/null +++ b/docs/source/dcp_server.utils.rst @@ -0,0 +1,37 @@ +dcp\_server.utils package +========================= + +.. automodule:: dcp_server.utils + :members: + :undoc-members: + :show-inheritance: + +Submodules +---------- + +dcp\_server.utils.fsimagestorage module +--------------------------------------- + +.. automodule:: dcp_server.utils.fsimagestorage + :members: + :undoc-members: + :show-inheritance: + +dcp\_server.utils.helpers module +--------------------------------------- + +.. automodule:: dcp_server.utils.helpers + :members: + :undoc-members: + :show-inheritance: + +dcp\_server.utils.processing module +----------------------------------- + +.. automodule:: dcp_server.utils.processing + :members: + :undoc-members: + :show-inheritance: + + + diff --git a/src/server/MANIFEST.in b/src/server/MANIFEST.in new file mode 100644 index 00000000..ffd67494 --- /dev/null +++ b/src/server/MANIFEST.in @@ -0,0 +1 @@ +include dcp_server/*.yaml \ No newline at end of file diff --git a/src/server/dcp_server/__init__.py b/src/server/dcp_server/__init__.py index ffbb8826..a125355e 100644 --- a/src/server/dcp_server/__init__.py +++ b/src/server/dcp_server/__init__.py @@ -2,15 +2,11 @@ Overview of dcp_server Package ============================== -The dcp_server package is structured to handle various server-side functionalities related to image processing, segmentation, and model serving. +The dcp_server package is structured to handle various server-side functionalities related model serving for segmentation and training. Submodules: ------------ -dcp_server.fsimagestorage - Provides a class FilesystemImageStorage for dealing with image storage, loading, saving, and processing. - Contains methods for retrieving image-segmentation pairs, getting image size properties, loading images, preparing images and masks for training, rescaling images, resizing masks, saving images, and searching for images and segmentations in directories. - dcp_server.models Defines various models for cell classification and segmentation, including CellClassifierFCNN, CellClassifierShallowModel, CellposePatchCNN, CustomCellposeModel, and UNet. These models handle tasks such as evaluation, forward pass, training, and updating configurations. @@ -23,6 +19,6 @@ Defines service classes, such as CustomBentoService and CustomRunnable, for serving the models with BentoML and handling computation on remote Python workers. dcp_server.utils - Provides various utility functions for image processing, feature extraction, file handling, configuration reading, and path manipulation. + Provides various utility functions for dealing with image storage, image processing, feature extraction, file handling, configuration reading, and path manipulation. """ diff --git a/src/server/dcp_server/config.yaml b/src/server/dcp_server/config.yaml index 70b0011f..224009f2 100644 --- a/src/server/dcp_server/config.yaml +++ b/src/server/dcp_server/config.yaml @@ -1,9 +1,7 @@ { "setup": { "segmentation": "GeneralSegmentation", - "model_to_use": "Inst2MultiSeg", - "accepted_types": [".jpg", ".jpeg", ".png", ".tiff", ".tif"], - "seg_name_string": "_seg" + "model_to_use": "Inst2MultiSeg" }, "service": { @@ -31,6 +29,8 @@ "data": { "data_root": "data", + "accepted_types": [".jpg", ".jpeg", ".png", ".tiff", ".tif"], + "seg_name_string": "_seg", "patch_size": 64, "noise_intensity": 5, "gray": True, diff --git a/src/server/dcp_server/models/inst_to_multi_seg.py b/src/server/dcp_server/models/inst_to_multi_seg.py index fdd9c07a..43c3db01 100644 --- a/src/server/dcp_server/models/inst_to_multi_seg.py +++ b/src/server/dcp_server/models/inst_to_multi_seg.py @@ -96,7 +96,7 @@ def train(self, imgs: List[np.ndarray], masks: List[np.ndarray]) -> None: :type imgs: List[np.ndarray] :param masks: masks of the given images (training labels) :type masks: List[np.ndarray] of same shape as output of eval, i.e. one channel instances, - second channel classes, so [2, H, W] or [2, 3, H, W] for 3D + second channel classes, so [2, H, W] or [2, 3, H, W] for 3D. """ # train cellpose masks_instances = [mask[0] for mask in masks] diff --git a/src/server/dcp_server/segmentationclasses.py b/src/server/dcp_server/segmentationclasses.py index 78b1d326..b3897ff7 100644 --- a/src/server/dcp_server/segmentationclasses.py +++ b/src/server/dcp_server/segmentationclasses.py @@ -4,9 +4,6 @@ from dcp_server.utils.fsimagestorage import FilesystemImageStorage from dcp_server import models as DCPModels -# Import configuration -setup_config = helpers.read_config("setup", config_path="config.yaml") - class GeneralSegmentation: """Segmentation class. Defining the main functions needed for this project and served by service - segment image and train on images.""" @@ -53,7 +50,7 @@ async def segment_image(self, input_path: str, list_of_images: str) -> None: # Save segmentation seg_name = ( helpers.get_path_stem(img_filepath) - + setup_config["seg_name_string"] + + self.imagestorage.seg_name_string + ".tiff" ) self.imagestorage.save_image(os.path.join(input_path, seg_name), mask) diff --git a/src/server/dcp_server/utils/__init__.py b/src/server/dcp_server/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/server/dcp_server/utils/fsimagestorage.py b/src/server/dcp_server/utils/fsimagestorage.py index 529b14f0..d89025b3 100644 --- a/src/server/dcp_server/utils/fsimagestorage.py +++ b/src/server/dcp_server/utils/fsimagestorage.py @@ -7,9 +7,6 @@ from dcp_server.utils import helpers from dcp_server.utils.processing import pad_image, normalise -# Import configuration -setup_config = helpers.read_config("setup", config_path="config.yaml") - class FilesystemImageStorage: """ @@ -18,6 +15,8 @@ class FilesystemImageStorage: def __init__(self, data_config: dict, model_used: str) -> None: self.root_dir = data_config["data_root"] + self.seg_name_string = data_config["seg_name_string"] + self.accepted_types = data_config["accepted_types"] self.gray = bool(data_config["gray"]) self.rescale = bool(data_config["rescale"]) self.model_used = model_used @@ -67,16 +66,14 @@ def search_images(self, directory: str) -> List[str]: seg_files = [ file_name for file_name in os.listdir(directory) - if setup_config["seg_name_string"] in file_name + if self.seg_name_string in file_name ] # Take the image files - difference between the list of all the files in the directory and the list of seg files and only file extensions currently accepted image_files = [ os.path.join(directory, file_name) for file_name in os.listdir(directory) if (file_name not in seg_files) - and ( - helpers.get_file_extension(file_name) in setup_config["accepted_types"] - ) + and (helpers.get_file_extension(file_name) in self.accepted_types) ] return image_files @@ -94,9 +91,7 @@ def search_segs(self, cur_selected_img: str) -> List[str]: os.path.join(self.root_dir, cur_selected_img) ) # Take all segmentations of the image from the current directory: - search_string = ( - helpers.get_path_stem(cur_selected_img) + setup_config["seg_name_string"] - ) + search_string = helpers.get_path_stem(cur_selected_img) + self.seg_name_string # seg_files = [os.path.join(img_directory, file_name) for file_name in os.listdir(img_directory) if search_string in file_name] # TODO: check where this is used - copied the command from app's search_segs function (to fix the 1_seg and 11_seg bug) @@ -142,8 +137,7 @@ def get_unsupported_files(self, directory: str) -> List[str]: file_name for file_name in os.listdir(os.path.join(self.root_dir, directory)) if not file_name.startswith(".") - and helpers.get_file_extension(file_name) - not in setup_config["accepted_types"] + and helpers.get_file_extension(file_name) not in self.accepted_types ] def get_image_size_properties(self, img: np.ndarray, file_extension: str) -> None: diff --git a/src/server/dcp_server/utils/helpers.py b/src/server/dcp_server/utils/helpers.py index 4e700587..b4cb15c6 100644 --- a/src/server/dcp_server/utils/helpers.py +++ b/src/server/dcp_server/utils/helpers.py @@ -2,13 +2,13 @@ import yaml -def read_config(name: str, config_path: str = "config.yaml") -> dict: +def read_config(name: str, config_path: str) -> dict: """Reads the configuration file :param name: name of the section you want to read (e.g. 'setup','train') :type name: string - :param config_path: path to the configuration file, defaults to 'config.yaml' - :type config_path: str, optional + :param config_path: path to the configuration file + :type config_path: str :return: dictionary from the config section given by name :rtype: dict """ diff --git a/src/server/dcp_server/utils/processing.py b/src/server/dcp_server/utils/processing.py index bdcbca6c..9c7f4b03 100644 --- a/src/server/dcp_server/utils/processing.py +++ b/src/server/dcp_server/utils/processing.py @@ -35,10 +35,10 @@ def pad_image( :param img: image to be padded :type img: np.ndarray - : param height: image height - : type height: int - : param width: image width - : type width: int + :param height: image height + :type height: int + :param width: image width + :type width: int :param channel_ax: :type channel_ax: int or None :param dividable: the number with which the new image size should be perfectly dividable by @@ -97,7 +97,7 @@ def crop_centered_padded_patch( :param obj_label: the instance label of the mask at the patch :type obj_label: int :param mask: The mask array associated with the array x. - Mask is used during training to mask out non-central elements. + Mask is used during training to mask out non-central elements. For RandomForest, it is used to calculate pyradiomics features. :type mask: np.ndarray, optional :param noise_intensity: intensity of noise to be added to the background From 52eda8c97c62d02f684f6746430631e6f9933c3a Mon Sep 17 00:00:00 2001 From: Christina Bukas Date: Wed, 13 Mar 2024 19:04:59 +0100 Subject: [PATCH 21/26] update test --- src/client/test/test_compute4mask.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/client/test/test_compute4mask.py b/src/client/test/test_compute4mask.py index b4cc7435..5304e2ee 100644 --- a/src/client/test/test_compute4mask.py +++ b/src/client/test/test_compute4mask.py @@ -1,6 +1,6 @@ import numpy as np import pytest -from dcp_client.utils.utils import Compute4Mask +from dcp_client.utils.compute4mask import Compute4Mask @pytest.fixture From 091d6bbfdb35de6d566d78ffd719c67321c178ec Mon Sep 17 00:00:00 2001 From: Christina Bukas Date: Wed, 13 Mar 2024 19:29:34 +0100 Subject: [PATCH 22/26] update tests --- src/server/dcp_server/models/unet.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/src/server/dcp_server/models/unet.py b/src/server/dcp_server/models/unet.py index e4afb098..c5c7a34f 100644 --- a/src/server/dcp_server/models/unet.py +++ b/src/server/dcp_server/models/unet.py @@ -7,6 +7,7 @@ from torch import nn from torch.optim import Adam from torch.utils.data import TensorDataset, DataLoader +from torchmetrics import JaccardIndex from .model import Model from dcp_server.utils.processing import convert_to_tensor @@ -94,6 +95,10 @@ def __init__( self.loss = 1e6 self.metric = 0 + self.num_classes = self.model_config["classifier"]["num_classes"] + 1 + self.metric_f = JaccardIndex( + task="multiclass", num_classes=self.num_classes, average="macro", ignore_index=0 + ) self.build_model() @@ -144,6 +149,12 @@ def train(self, imgs: List[np.ndarray], masks: List[np.ndarray]) -> None: self.loss /= len(train_dataloader) + # compute metric on test set after train is complete + for imgs, masks in train_dataloader: + pred_masks = self.forward(imgs.float()) + self.metric += self.metric_f(masks, pred_masks) + self.metric /= len(train_dataloader) + def eval(self, img: np.ndarray) -> np.ndarray: """Evaluate the model on the provided image and return the predicted label. @@ -174,7 +185,7 @@ def eval(self, img: np.ndarray) -> np.ndarray: def build_model(self) -> None: """Builds the UNet.""" in_channels = self.model_config["classifier"]["in_channels"] - out_channels = self.model_config["classifier"]["num_classes"] + 1 + out_channels = self.num_classes features = self.model_config["classifier"]["features"] self.encoder = nn.ModuleList() From a34e3343fa8e03768ced9ccd54ab457fa2c8bc6d Mon Sep 17 00:00:00 2001 From: Christina Bukas Date: Wed, 13 Mar 2024 19:29:46 +0100 Subject: [PATCH 23/26] update tests --- .../configs/test_config_MultiCellpose.yaml | 2 +- src/server/test/configs/test_config_UNet.yaml | 2 +- src/server/test/test_integration.py | 20 ++++++------- src/server/test/test_models.py | 30 ------------------- 4 files changed, 12 insertions(+), 42 deletions(-) diff --git a/src/server/test/configs/test_config_MultiCellpose.yaml b/src/server/test/configs/test_config_MultiCellpose.yaml index b74476fe..46b913d7 100644 --- a/src/server/test/configs/test_config_MultiCellpose.yaml +++ b/src/server/test/configs/test_config_MultiCellpose.yaml @@ -31,7 +31,7 @@ "train":{ "segmentor":{ - "n_epochs": 20, + "n_epochs": 30, "channels": [0,0], "min_train_masks": 1, "learning_rate":0.01 diff --git a/src/server/test/configs/test_config_UNet.yaml b/src/server/test/configs/test_config_UNet.yaml index f6ee29bc..f4eba079 100644 --- a/src/server/test/configs/test_config_UNet.yaml +++ b/src/server/test/configs/test_config_UNet.yaml @@ -29,7 +29,7 @@ "train":{ "classifier":{ - "n_epochs": 20, + "n_epochs": 30, "lr": 0.005, "batch_size": 5, "optimizer": "Adam" diff --git a/src/server/test/test_integration.py b/src/server/test/test_integration.py index 8637377e..6e37ea22 100644 --- a/src/server/test/test_integration.py +++ b/src/server/test/test_integration.py @@ -82,12 +82,21 @@ def test_train_eval_run(data_train, data_eval, model): """ Performs testing, training, and evaluation with the provided data and model. """ - + # train images, masks = data_train if model.model_name == "CustomCellpose": masks = [mask[0] for mask in masks] model.train(images, masks) + # retrieve the attribute names of the class of the current model + attrs = model.__dict__.keys() + + if "metric" in attrs: + assert model.metric > 0.1 + if "loss" in attrs: + assert model.loss < 0.83 + + # validate imgs_test, masks_test = data_eval if model.model_name == "CustomCellpose": masks = [mask[0] for mask in masks_test] @@ -128,15 +137,6 @@ def test_train_eval_run(data_train, data_eval, model): jaccard_index_instances /= len(imgs_test) assert jaccard_index_instances > 0.2 - # retrieve the attribute names of the class of the current model - attrs = model.__dict__.keys() - - if "metric" in attrs: - assert model.metric > 0.1 - if "loss" in attrs: - assert model.loss < 0.83 - - # for PatchCNN model if pred_mask.ndim > 2: jaccard_index_classes /= len(imgs_test) diff --git a/src/server/test/test_models.py b/src/server/test/test_models.py index 86529e1c..eddf8f94 100644 --- a/src/server/test/test_models.py +++ b/src/server/test/test_models.py @@ -32,33 +32,3 @@ def test_eval_rf_not_fitted(): # if we don't fit the model then the model returns zeros assert np.all(model_rf.eval(X_test) == np.zeros(X_test.shape)) - -def test_update_configs(): - """ - Tests the update of model training and evaluation configurations. - """ - - model_config = read_config( - "model", config_path="test/configs/test_config_Inst2MultiSeg_RF.yaml" - ) - data_config = read_config( - "data", config_path="test/configs/test_config_Inst2MultiSeg_RF.yaml" - ) - train_config = read_config( - "train", config_path="test/configs/test_config_Inst2MultiSeg_RF.yaml" - ) - eval_config = read_config( - "eval", config_path="test/configs/test_config_Inst2MultiSeg_RF.yaml" - ) - - model = models.CustomCellpose( - "Cellpose", model_config, data_config, train_config, eval_config - ) - - new_train_config = {"param1": "value1"} - new_eval_config = {"param2": "value2"} - - model.update_configs(new_train_config, new_eval_config) - - assert model.train_config == new_train_config - assert model.eval_config == new_eval_config From 219931ed6ad781380e78014592605f71d6db2957 Mon Sep 17 00:00:00 2001 From: Christina Bukas Date: Wed, 13 Mar 2024 19:40:34 +0100 Subject: [PATCH 24/26] update img file path in docs --- docs/source/dcp_client_installation.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/dcp_client_installation.rst b/docs/source/dcp_client_installation.rst index 2b010e42..b4a883b2 100644 --- a/docs/source/dcp_client_installation.rst +++ b/docs/source/dcp_client_installation.rst @@ -114,7 +114,7 @@ DCP Shortcuts - In the Data Overview window, clicking on an image and the hitting the **Enter** key, is equivalent to clicking the 'View Image and Fix Label' button - The viewer accepts all Napari Shortcuts. The current list of the shortcuts for macOS can be see below: -.. image:: https://raw.githubusercontent.com/HelmholtzAI-Consultants-Munich/data-centric-platform/add-documentation/src/client/readme_figs/napari_shortcuts.png +.. image:: https://raw.githubusercontent.com/HelmholtzAI-Consultants-Munich/data-centric-platform/main/src/client/readme_figs/napari_shortcuts.png :width: 600 :height: 500 :align: center From 4cbad7048c1a50ace67ee4b03f17c95c4510aac4 Mon Sep 17 00:00:00 2001 From: Christina Bukas Date: Wed, 13 Mar 2024 20:11:57 +0100 Subject: [PATCH 25/26] last typing changes and merge update --- src/client/MANIFEST.in | 2 +- src/client/dcp_client/main.py | 4 ++-- src/client/dcp_client/utils/bentoml_model.py | 8 ++++---- src/server/dcp_server/config.yaml | 2 +- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/client/MANIFEST.in b/src/client/MANIFEST.in index 8809a659..c6c02f12 100644 --- a/src/client/MANIFEST.in +++ b/src/client/MANIFEST.in @@ -1 +1 @@ -include dcp_client/*.cfg \ No newline at end of file +include dcp_client/*.yaml \ No newline at end of file diff --git a/src/client/dcp_client/main.py b/src/client/dcp_client/main.py index b87a34be..ef16a971 100644 --- a/src/client/dcp_client/main.py +++ b/src/client/dcp_client/main.py @@ -33,11 +33,11 @@ def main(): if args.mode == "local": server_config = read_config( - "server", config_path=path.join(dir_name, "config.cfg") + "server", config_path=path.join(dir_name, "config.yaml") ) elif args.mode == "remote": server_config = read_config( - "server", config_path=path.join(dir_name, "config_remote.cfg") + "server", config_path=path.join(dir_name, "config_remote.yaml") ) image_storage = FilesystemImageStorage() diff --git a/src/client/dcp_client/utils/bentoml_model.py b/src/client/dcp_client/utils/bentoml_model.py index c0493dbf..189c1294 100644 --- a/src/client/dcp_client/utils/bentoml_model.py +++ b/src/client/dcp_client/utils/bentoml_model.py @@ -43,16 +43,16 @@ def is_connected(self) -> bool: """ return bool(self.client) - async def _run_train(self, data_path: str): + async def _run_train(self, data_path: str) -> Optional[str]: """Runs the training task asynchronously. :param data_path: Path to the training data. :type data_path: str :return: Response from the server if successful, None otherwise. + :rtype: str, or None """ try: response = await self.client.async_train(data_path) - print('kkkkkkkkkkkkkkkkkkkkk', type(response)) return response except BentoMLException: return None @@ -66,16 +66,16 @@ def run_train(self, data_path: str): """ return asyncio.run(self._run_train(data_path)) - async def _run_inference(self, data_path: str): + async def _run_inference(self, data_path: str) -> Optional[np.ndarray]: """Runs the inference task asynchronously. :param data_path: Path to the data for inference. :type data_path: str :return: List of files not supported by the server if unsuccessful, otherwise returns None. + :rtype: np.ndarray, or None """ try: response = await self.client.async_segment_image(data_path) - print('jjjjjjjjjj', type(response)) return response except BentoMLException: return None diff --git a/src/server/dcp_server/config.yaml b/src/server/dcp_server/config.yaml index 224009f2..5652469a 100644 --- a/src/server/dcp_server/config.yaml +++ b/src/server/dcp_server/config.yaml @@ -44,7 +44,7 @@ "min_train_masks": 1 }, "classifier":{ - "n_epochs": 100, + "n_epochs": 20, "lr": 0.001, "batch_size": 1, "optimizer": "Adam" From 7c7500b5d6abb19a09fa649396568654508b679d Mon Sep 17 00:00:00 2001 From: Christina Bukas Date: Thu, 14 Mar 2024 00:43:42 +0100 Subject: [PATCH 26/26] update conifgs --- src/client/dcp_client/utils/bentoml_model.py | 1 + src/server/dcp_server/config_instance.yaml | 6 +++--- src/server/dcp_server/config_semantic.yaml | 6 +++--- src/server/dcp_server/models/unet.py | 2 +- 4 files changed, 8 insertions(+), 7 deletions(-) diff --git a/src/client/dcp_client/utils/bentoml_model.py b/src/client/dcp_client/utils/bentoml_model.py index 189c1294..5f57b421 100644 --- a/src/client/dcp_client/utils/bentoml_model.py +++ b/src/client/dcp_client/utils/bentoml_model.py @@ -2,6 +2,7 @@ from typing import Optional, List from bentoml.client import Client as BentoClient from bentoml.exceptions import BentoMLException +import numpy as np from dcp_client.app import Model diff --git a/src/server/dcp_server/config_instance.yaml b/src/server/dcp_server/config_instance.yaml index 1af6b1eb..db266da0 100644 --- a/src/server/dcp_server/config_instance.yaml +++ b/src/server/dcp_server/config_instance.yaml @@ -1,9 +1,7 @@ { "setup": { "segmentation": "GeneralSegmentation", - "model_to_use": "CustomCellpose", - "accepted_types": [".jpg", ".jpeg", ".png", ".tiff", ".tif"], - "seg_name_string": "_seg" + "model_to_use": "CustomCellpose" }, "service": { @@ -21,6 +19,8 @@ "data": { "data_root": "data", + "accepted_types": [".jpg", ".jpeg", ".png", ".tiff", ".tif"], + "seg_name_string": "_seg", "gray": True, "rescale": True }, diff --git a/src/server/dcp_server/config_semantic.yaml b/src/server/dcp_server/config_semantic.yaml index 928eb931..e72459ac 100644 --- a/src/server/dcp_server/config_semantic.yaml +++ b/src/server/dcp_server/config_semantic.yaml @@ -1,9 +1,7 @@ { "setup": { "segmentation": "GeneralSegmentation", - "model_to_use": "UNet", - "accepted_types": [".jpg", ".jpeg", ".png", ".tiff", ".tif"], - "seg_name_string": "_seg" + "model_to_use": "UNet" }, "service": { @@ -23,6 +21,8 @@ "data": { "data_root": "data", + "accepted_types": [".jpg", ".jpeg", ".png", ".tiff", ".tif"], + "seg_name_string": "_seg", "gray": True, "rescale": True }, diff --git a/src/server/dcp_server/models/unet.py b/src/server/dcp_server/models/unet.py index c5c7a34f..9d85a5f7 100644 --- a/src/server/dcp_server/models/unet.py +++ b/src/server/dcp_server/models/unet.py @@ -152,7 +152,7 @@ def train(self, imgs: List[np.ndarray], masks: List[np.ndarray]) -> None: # compute metric on test set after train is complete for imgs, masks in train_dataloader: pred_masks = self.forward(imgs.float()) - self.metric += self.metric_f(masks, pred_masks) + self.metric += self.metric_f(pred_masks, masks) self.metric /= len(train_dataloader) def eval(self, img: np.ndarray) -> np.ndarray: