From 22f07c4103c97ba3f0cf76e231209b9384a44440 Mon Sep 17 00:00:00 2001 From: Mariia Date: Wed, 31 Jan 2024 23:13:07 +0100 Subject: [PATCH 1/8] integrate multicellpose --- src/server/dcp_server/config.cfg | 5 ++- src/server/dcp_server/models.py | 72 ++++++++++++++++++++++++++++++++ 2 files changed, 75 insertions(+), 2 deletions(-) diff --git a/src/server/dcp_server/config.cfg b/src/server/dcp_server/config.cfg index 74ddc832..5c83933b 100644 --- a/src/server/dcp_server/config.cfg +++ b/src/server/dcp_server/config.cfg @@ -1,7 +1,7 @@ { "setup": { "segmentation": "GeneralSegmentation", - "model_to_use": "UNet", + "model_to_use": "CellposeMultichannel", "accepted_types": [".jpg", ".jpeg", ".png", ".tiff", ".tif"], "seg_name_string": "_seg" }, @@ -16,6 +16,7 @@ "model": { "segmentor": { "model_type": "cyto" + }, "classifier":{ "model_class": "RandomForest", @@ -33,7 +34,7 @@ "train":{ "segmentor":{ - "n_epochs": 10, + "n_epochs": 5, "channels": [0,0], "min_train_masks": 1 }, diff --git a/src/server/dcp_server/models.py b/src/server/dcp_server/models.py index b7e41e4f..fa05f801 100644 --- a/src/server/dcp_server/models.py +++ b/src/server/dcp_server/models.py @@ -539,7 +539,79 @@ def eval(self, img): return final_mask +class CellposeMultichannel(): + ''' + Multichannel image segmentation model. + Run the separate cellpose model for each channel return the mask corresponding to each object type. + + Args: + num_of_channels (int): Number of channels in the input image. + device: The device on which the models should run (e.g., 'cuda' for GPU or 'cpu' for CPU). + + ''' + + def __init__(self, model_config, train_config, eval_config, model_name="Cellpose"): + + self.model_config = model_config + self.train_config = train_config + self.eval_config = eval_config + self.model_name = model_name + self.num_of_channels = self.model_config["classifier"]["num_classes"] + + self.cellpose_models = [ + CustomCellposeModel(self.model_config, + self.train_config, + self.eval_config, + self.model_name + ) for _ in range(self.num_of_channels) + ] + + def train(self, imgs, masks): + + for i in range(self.num_of_channels): + + masks_class = [] + + for mask in masks: + mask_class = mask.copy() + mask_class[0][mask_class[1]!=(i+1)] = 0 + masks_class.append(mask_class) + + self.cellpose_models[i].train(imgs, masks_class) + + self.metric = np.mean([self.cellpose_models[i].metric for i in range(self.num_of_channels)]) + self.loss = np.mean([self.cellpose_models[i].loss for i in range(self.num_of_channels)]) + + + def eval(self, img): + + instance_masks, class_masks = [], [] + + instance_offset = 0 + + for i in range(self.num_of_channels): + + res = self.cellpose_models[i].eval(img) + res[res>0] += instance_offset + instance_masks.append(res) + + instance_offset = np.max(res) + + label_mask = res.copy() + label_mask[res>0]=(i + 1) + class_masks.append(label_mask) + + instance_mask, class_mask = sum(instance_masks), sum(class_masks) + final_mask = np.stack((instance_mask, class_mask), axis=self.eval_config['mask_channel_axis']).astype(np.uint16) + return final_mask + + + + # return np.sum([self.cellpose_models[i].eval(img) for i in range(self.num_of_channels)]) + + + # class CustomSAMModel(): # # https://github.com/facebookresearch/segment-anything/blob/main/notebooks/automatic_mask_generator_example.ipynb # def __init__(self): From 381c536be6c57f7e099b3cf0ad194d21f7054f64 Mon Sep 17 00:00:00 2001 From: Mariia Date: Thu, 1 Feb 2024 01:05:22 +0100 Subject: [PATCH 2/8] merge the outputs of the multicellpose models --- src/server/dcp_server/models.py | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/src/server/dcp_server/models.py b/src/server/dcp_server/models.py index fa05f801..09fe0427 100644 --- a/src/server/dcp_server/models.py +++ b/src/server/dcp_server/models.py @@ -50,7 +50,11 @@ def __init__(self, model_config, train_config, eval_config, model_name): def update_configs(self, train_config, eval_config): self.train_config = train_config self.eval_config = eval_config - + + def eval_probas(self, img): + + return super().eval(x=img, **self.eval_config["segmentor"])[1][2] + def eval(self, img): """Evaluate the model - find mask of the given image Calls the original eval function. @@ -585,7 +589,7 @@ def train(self, imgs, masks): def eval(self, img): - instance_masks, class_masks = [], [] + instance_masks, class_masks, class_probas = [], [], [] instance_offset = 0 @@ -601,7 +605,18 @@ def eval(self, img): label_mask[res>0]=(i + 1) class_masks.append(label_mask) - instance_mask, class_mask = sum(instance_masks), sum(class_masks) + class_proba = self.cellpose_models[i].eval_probas(img) + class_probas.append(class_proba) + + instance_mask = sum(instance_masks) + class_probas = np.argmax(np.stack(class_probas), axis=0) + + # merge the outputs of the n models + class_masks = np.stack(class_masks) + indexes = class_probas*class_probas.size + np.arange(class_probas.size).reshape(class_probas.shape) + indexes = np.unravel_index(indexes, class_masks.shape) + class_mask = class_masks[indexes] + final_mask = np.stack((instance_mask, class_mask), axis=self.eval_config['mask_channel_axis']).astype(np.uint16) return final_mask From e1300b29da76a196c6ede793cccd7226c5ea8e25 Mon Sep 17 00:00:00 2001 From: Mariia Date: Wed, 7 Feb 2024 13:53:09 +0100 Subject: [PATCH 3/8] Rewrite the docstrings to reStructuredType style and add documentation for all functions in models.py --- src/server/dcp_server/models.py | 225 +++++++++++++++++++++++++++----- 1 file changed, 190 insertions(+), 35 deletions(-) diff --git a/src/server/dcp_server/models.py b/src/server/dcp_server/models.py index 09fe0427..62dc5d2d 100644 --- a/src/server/dcp_server/models.py +++ b/src/server/dcp_server/models.py @@ -48,10 +48,24 @@ def __init__(self, model_config, train_config, eval_config, model_name): self.model_name = model_name def update_configs(self, train_config, eval_config): + """Update the training and evaluation configurations. + + :param train_config: Dictionary containing the training configuration. + :type train_config: dict + :param eval_config: Dictionary containing the evaluation configuration. + :type eval_config: dict + """ self.train_config = train_config self.eval_config = eval_config def eval_probas(self, img): + """Get the probability mask for the given input image. + + :param img: Input image for segmentation. + :type img: numpy.ndarray + :return: Probability mask for the input image. + :rtype: numpy.ndarray + """ return super().eval(x=img, **self.eval_config["segmentor"])[1][2] @@ -117,17 +131,27 @@ def masks_to_outlines(self, mask): class CellClassifierFCNN(nn.Module): - ''' - Fully convolutional classifier for cell images. NOTE -> This model cannot be used as a standalone model in DCP - - Args: - model_config (dict): Model configuration. - train_config (dict): Training configuration. - eval_config (dict): Evaluation configuration. + """Fully convolutional classifier for cell images. NOTE -> This model cannot be used as a standalone model in DCP + + :param model_config: Model configuration. + :type model_config: dict + :param train_config: Training configuration. + :type train_config: dict + :param eval_config: Evaluation configuration. + :type eval_config: dict - ''' + """ def __init__(self, model_config, train_config, eval_config): + """Initialize the fully convolutional classifier. + + :param model_config: Model configuration. + :type model_config: dict + :param train_config: Training configuration. + :type train_config: dict + :param eval_config: Evaluation configuration. + :type eval_config: dict + """ super().__init__() self.in_channels = model_config["classifier"].get("in_channels",1) @@ -165,10 +189,25 @@ def __init__(self, model_config, train_config, eval_config): self.metric_fn = F1Score(num_classes=self.num_classes, task="multiclass") def update_configs(self, train_config, eval_config): + """ + Update the training and evaluation configurations. + + :param train_config: Dictionary containing the training configuration. + :type train_config: dict + :param eval_config: Dictionary containing the evaluation configuration. + :type eval_config: dict + """ self.train_config = train_config self.eval_config = eval_config def forward(self, x): + """ Performs forward pass of the CellClassifierFCNN. + + :param x: Input tensor. + :type x: torch.Tensor + :return: Output tensor after passing through the network. + :rtype: torch.Tensor + """ x = self.layer1(x) x = self.layer2(x) @@ -180,10 +219,12 @@ def forward(self, x): return x def train (self, imgs, labels): - """ - input: - 1) imgs - List[np.ndarray[np.uint8]] with shape (3, dx, dy) - 2) labels - List[int] + """Trains the given model + + :param imgs: List of input images with shape (3, dx, dy). + :type imgs: List[np.ndarray[np.uint8]] + :param labels: List of classification labels. + :type labels: List[int] """ lr = self.train_config['lr'] @@ -228,12 +269,13 @@ def train (self, imgs, labels): self.metric /= len(train_dataloader) def eval(self, img): + """Evaluates the model on the provided image and return the predicted label. + + :param img: Input image for evaluation. + :type img: np.ndarray[np.uint8] + :return: y_hat - predicted label. + :rtype: torch.Tensor """ - Evaluate the model on the provided image and return the predicted label. - Input: - img: np.ndarray[np.uint8] - Output: y_hat - The predicted label - """ # normalise img = (img-np.min(img))/(np.max(img)-np.min(img)) # convert to tensor @@ -244,12 +286,22 @@ def eval(self, img): class CellposePatchCNN(nn.Module): - """ Cellpose & patches of cells and then cnn to classify each patch """ def __init__(self, model_config, train_config, eval_config, model_name): + """Constructs all the necessary attributes for the CellposePatchCNN + + :param model_config: Model configuration. + :type model_config: dict + :param train_config: Training configuration. + :type train_config: dict + :param eval_config: Evaluation configuration. + :type eval_config: dict + :param model_name: Name of the model. + :type model_name: str + """ super().__init__() self.model_config = model_config @@ -278,6 +330,13 @@ def __init__(self, model_config, train_config, eval_config, model_name): self.include_mask = False def update_configs(self, train_config, eval_config): + """Update the training and evaluation configurations. + + :param train_config: Dictionary containing the training configuration. + :type train_config: dict + :param eval_config: Dictionary containing the evaluation configuration. + :type eval_config: dict + """ self.train_config = train_config self.eval_config = eval_config @@ -312,6 +371,13 @@ def train(self, imgs, masks): self.loss = (self.segmentor.loss + self.classifier.loss)/2 def eval(self, img): + """Evaluate the model on the provided image and return the final mask. + + :param img: Input image for evaluation. + :type img: np.ndarray[np.uint8] + :return: Final mask containing instance mask and class masks. + :rtype: np.ndarray[np.uint16] + """ # TBD we assume image is 2D [H, W] (see fsimage storage) # The final mask which is returned should have # first channel the output of cellpose and the rest are the class channels @@ -346,8 +412,20 @@ def eval(self, img): return final_mask class CellClassifierShallowModel: + """ + This class implements a shallow model for cell classification using scikit-learn. + """ def __init__(self, model_config, train_config, eval_config): + """Constructs all the necessary attributes for the CellClassifierShallowModel + + :param model_config: Model configuration. + :type model_config: dict + :param train_config: Training configuration. + :type train_config: dict + :param eval_config: Evaluation configuration. + :type eval_config: dict + """ self.model_config = model_config self.train_config = train_config @@ -357,6 +435,13 @@ def __init__(self, model_config, train_config, eval_config): def train(self, X_train, y_train): + """Trains the model using the provided training data. + + :param X_train: Features of the training data. + :type X_train: numpy.ndarray + :param y_train: Labels of the training data. + :type y_train: numpy.ndarray + """ self.model.fit(X_train,y_train) @@ -369,6 +454,13 @@ def train(self, X_train, y_train): def eval(self, X_test): + """Evaluates the model on the provided test data. + + :param X_test: Features of the test data. + :type X_test: numpy.ndarray + :return: y_hat - predicted labels. + :rtype: numpy.ndarray + """ X_test = X_test.reshape(1,-1) @@ -384,22 +476,30 @@ class UNet(nn.Module): """ Unet is a convolutional neural network architecture for semantic segmentation. - Args: - in_channels (int): Number of input channels (default: 3). - out_channels (int): Number of output channels (default: 4). - features (list): List of feature channels for each encoder level (default: [64,128,256,512]). + :param in_channels: Number of input channels (default: 3). + :type in_channels: int + :param out_channels: Number of output channels (default: 4). + :type out_channels: int + :param features: List of feature channels for each encoder level (default: [64,128,256,512]). + :type features: list """ class DoubleConv(nn.Module): """ DoubleConv module consists of two consecutive convolutional layers with batch normalization and ReLU activation functions. - Args: - in_channels (int): Number of input channels. - out_channels (int): Number of output channels. """ def __init__(self, in_channels, out_channels): + """ + Initialize DoubleConv module. + + :param in_channels: Number of input channels. + :type in_channels: int + :param out_channels: Number of output channels. + :type out_channels: int + """ + super().__init__() self.conv = nn.Sequential( @@ -412,10 +512,27 @@ def __init__(self, in_channels, out_channels): ) def forward(self, x): + """Forward pass through the DoubleConv module. + + :param x: Input tensor. + :type x: torch.Tensor + """ return self.conv(x) def __init__(self, model_config, train_config, eval_config, model_name): + """Constructs all the necessary attributes for the UNet model. + + + :param model_config: Model configuration. + :type model_config: dict + :param train_config: Training configuration. + :type train_config: dict + :param eval_config: Evaluation configuration. + :type eval_config: dict + :param model_name: Name of the model. + :type model_name: str + """ super().__init__() self.model_config = model_config @@ -458,6 +575,14 @@ def __init__(self, model_config, train_config, eval_config, model_name): self.output_conv = nn.Conv2d(self.features[0], self.out_channels, kernel_size=1) def forward(self, x): + """ + Forward pass of the UNet model. + + :param x: Input tensor. + :type x: torch.Tensor + :return: Output tensor. + :rtype: torch.Tensor + """ skip_connections = [] for encoder in self.encoder: x = encoder(x) @@ -476,6 +601,14 @@ def forward(self, x): return self.output_conv(x) def train(self, imgs, masks): + """ + Trains the UNet model using the provided images and masks. + + :param imgs: Input images for training. + :type imgs: list[numpy.ndarray] + :param masks: Masks corresponding to the input images. + :type masks: list[numpy.ndarray] + """ lr = self.train_config["classifier"]['lr'] epochs = self.train_config["classifier"]['n_epochs'] @@ -523,9 +656,11 @@ def train(self, imgs, masks): def eval(self, img): """ Evaluate the model on the provided image and return the predicted label. - Input: - img: np.ndarray[np.uint8] - Output: y_hat - The predicted label + + :param img: Input image for evaluation. + :type img: np.ndarray[np.uint8] + :return: predicted mask consists of instance and class masks + :rtype: numpy.ndarray """ with torch.no_grad(): # normalise @@ -546,15 +681,21 @@ def eval(self, img): class CellposeMultichannel(): ''' Multichannel image segmentation model. - Run the separate cellpose model for each channel return the mask corresponding to each object type. - - Args: - num_of_channels (int): Number of channels in the input image. - device: The device on which the models should run (e.g., 'cuda' for GPU or 'cpu' for CPU). - + Run the separate CustomCellposeModel models for each channel return the mask corresponding to each object type. ''' def __init__(self, model_config, train_config, eval_config, model_name="Cellpose"): + """Constructs all the necessary attributes for the CellposeMultichannel model. + + :param model_config: Model configuration. + :type model_config: dict + :param train_config: Training configuration. + :type train_config: dict + :param eval_config: Evaluation configuration. + :type eval_config: dict + :param model_name: Name of the model. + :type model_name: str + """ self.model_config = model_config self.train_config = train_config @@ -571,6 +712,14 @@ def __init__(self, model_config, train_config, eval_config, model_name="Cellpose ] def train(self, imgs, masks): + """ + Train the model on the provided images and masks. + + :param imgs: Input images for training. + :type imgs: list[numpy.ndarray] + :param masks: Masks corresponding to the input images. + :type masks: list[numpy.ndarray] + """ for i in range(self.num_of_channels): @@ -588,6 +737,13 @@ def train(self, imgs, masks): def eval(self, img): + """Evaluate the model on the provided image. + + :param img: Input image for evaluation. + :type img: np.ndarray[np.uint8] + :return: predicted mask consists of instance and class masks + :rtype: numpy.ndarray + """ instance_masks, class_masks, class_probas = [], [], [] @@ -623,7 +779,6 @@ def eval(self, img): - # return np.sum([self.cellpose_models[i].eval(img) for i in range(self.num_of_channels)]) From 18ad640089f610f2cfebe99fc82cac7b12fc3183 Mon Sep 17 00:00:00 2001 From: Mariia Date: Thu, 15 Feb 2024 18:07:21 +0100 Subject: [PATCH 4/8] subprocess.Popen for windows platform: try changing stdin --- src/client/test/test_app.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/client/test/test_app.py b/src/client/test/test_app.py index 9a45d116..858f2f2d 100644 --- a/src/client/test/test_app.py +++ b/src/client/test/test_app.py @@ -68,13 +68,14 @@ def test_run_inference_run(app): "--reload", "--port=7010", ] - process = subprocess.Popen(command) + process = subprocess.Popen(command, stdin=subprocess.PIPE, shell=False) # and wait until it is setup if sys.platform == 'win32' or sys.platform == 'cygwin': time.sleep(120) else: time.sleep(60) # then do model serving message_text, message_title = app.run_inference() # and assert returning message + print(f"HERE: {message_text, message_title}") assert message_text== "Success! Masks generated for all images" assert message_title=="Information" # finally clean up process From ef49ce4e67f7e73614cbfdcb7155783ef770a368 Mon Sep 17 00:00:00 2001 From: Mariia Date: Sun, 18 Feb 2024 13:13:03 +0100 Subject: [PATCH 5/8] Change the cellpose version in requirements. --- src/server/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/server/requirements.txt b/src/server/requirements.txt index e3a9efbb..d84b7f43 100644 --- a/src/server/requirements.txt +++ b/src/server/requirements.txt @@ -1,5 +1,5 @@ wheel==0.42.0 -cellpose>=2.2 +cellpose==2.2.3 bentoml==1.0.16 scikit-image>=0.19.3 torchmetrics>=0.11.4 From 675825f346884528e995a9fb86f644b057fcf432 Mon Sep 17 00:00:00 2001 From: Mariia Date: Sun, 18 Feb 2024 13:34:32 +0100 Subject: [PATCH 6/8] Add pip install flags. --- .github/workflows/test.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 255b0ac9..1a800706 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -48,7 +48,7 @@ jobs: pip install pytest-qt pip install pytest-xvfb pip install coverage - pip install -e ".[testing]" + pip install -e --no-cache-dir ".[testing]" working-directory: src/client - name: Install server dependencies (for communication tests) @@ -96,7 +96,7 @@ jobs: pip install pytest pip install wheel pip install coverage - pip install -e ".[testing]" + pip install -e --no-cache-dir ".[testing]" working-directory: src/server - name: Test with pytest From 3e1c85dfaf5dedadcc065df91dd33819ceedeb3e Mon Sep 17 00:00:00 2001 From: Christina Bukas Date: Fri, 23 Feb 2024 16:14:17 +0100 Subject: [PATCH 7/8] adapted eval mode to only have one instance id for each connected component --- src/server/dcp_server/config.cfg | 4 +- src/server/dcp_server/models.py | 63 ++++++++++++++++++-------------- 2 files changed, 38 insertions(+), 29 deletions(-) diff --git a/src/server/dcp_server/config.cfg b/src/server/dcp_server/config.cfg index 5c83933b..24f44ea0 100644 --- a/src/server/dcp_server/config.cfg +++ b/src/server/dcp_server/config.cfg @@ -8,7 +8,7 @@ "service": { "runner_name": "bento_runner", - "bento_model_path": "unetN", + "bento_model_path": "cp-multi", "service_name": "data-centric-platform", "port": 7010 }, @@ -21,7 +21,7 @@ "classifier":{ "model_class": "RandomForest", "in_channels": 1, - "num_classes": 3, + "num_classes": 2, "features":[64,128,256,512], "black_bg": "False", "include_mask": "False" diff --git a/src/server/dcp_server/models.py b/src/server/dcp_server/models.py index 62dc5d2d..a43a01d1 100644 --- a/src/server/dcp_server/models.py +++ b/src/server/dcp_server/models.py @@ -8,6 +8,8 @@ from tqdm import tqdm import numpy as np from scipy.ndimage import label +from skimage.measure import label as label_mask + from sklearn.ensemble import RandomForestClassifier from sklearn.metrics import f1_score, log_loss @@ -58,8 +60,8 @@ def update_configs(self, train_config, eval_config): self.train_config = train_config self.eval_config = eval_config - def eval_probas(self, img): - """Get the probability mask for the given input image. + def eval_all_outputs(self, img): + """Get all outputs of the model when running eval. :param img: Input image for segmentation. :type img: numpy.ndarray @@ -67,7 +69,7 @@ def eval_probas(self, img): :rtype: numpy.ndarray """ - return super().eval(x=img, **self.eval_config["segmentor"])[1][2] + return super().eval(x=img, **self.eval_config["segmentor"]) def eval(self, img): """Evaluate the model - find mask of the given image @@ -745,37 +747,44 @@ def eval(self, img): :rtype: numpy.ndarray """ - instance_masks, class_masks, class_probas = [], [], [] - - instance_offset = 0 + instance_masks, class_masks, model_confidences = [], [], [] for i in range(self.num_of_channels): - res = self.cellpose_models[i].eval(img) - res[res>0] += instance_offset - instance_masks.append(res) - - instance_offset = np.max(res) - - label_mask = res.copy() - label_mask[res>0]=(i + 1) - class_masks.append(label_mask) - - class_proba = self.cellpose_models[i].eval_probas(img) - class_probas.append(class_proba) - - instance_mask = sum(instance_masks) - class_probas = np.argmax(np.stack(class_probas), axis=0) - - # merge the outputs of the n models - class_masks = np.stack(class_masks) - indexes = class_probas*class_probas.size + np.arange(class_probas.size).reshape(class_probas.shape) - indexes = np.unravel_index(indexes, class_masks.shape) - class_mask = class_masks[indexes] + instance_mask, probs, _ = self.cellpose_models[i].eval_all_outputs(img) + confidence = probs[2] + class_mask = np.zeros_like(instance_mask) + class_mask[instance_mask>0]=(i + 1) + + instance_masks.append(instance_mask) + class_masks.append(class_mask) + model_confidences.append(confidence) + + merged_mask_instances, class_mask = self.merge_masks(instance_masks, class_masks, model_confidences) + instance_mask = label_mask(merged_mask_instances>0) + for inst_id in np.unique(instance_mask)[1:]: + where_inst_id = np.where(instance_mask==inst_id) + vals, counts = np.unique(class_mask[where_inst_id], return_counts=True) + class_mask[where_inst_id] = vals[np.argmax(counts)] final_mask = np.stack((instance_mask, class_mask), axis=self.eval_config['mask_channel_axis']).astype(np.uint16) return final_mask + + def merge_masks(self, inst_masks, class_masks, probabilities): + # Convert lists to numpy arrays + inst_masks = np.array(inst_masks) + class_masks = np.array(class_masks) + probabilities = np.array(probabilities) + + # Find the index of the mask with the maximum probability for each pixel + max_prob_indices = np.argmax(probabilities, axis=0) + + # Use the index to select the corresponding mask for each pixel + final_mask_inst = inst_masks[max_prob_indices, np.arange(inst_masks.shape[1])[:, None], np.arange(inst_masks.shape[2])] + final_mask_class = class_masks[max_prob_indices, np.arange(class_masks.shape[1])[:, None], np.arange(class_masks.shape[2])] + + return final_mask_inst, final_mask_class From 88ed260e31d35432e2082db0a2cd2148f75a8e27 Mon Sep 17 00:00:00 2001 From: Christina Bukas Date: Fri, 23 Feb 2024 16:27:37 +0100 Subject: [PATCH 8/8] added documentation --- src/server/dcp_server/models.py | 29 ++++++++++++++++++++++++----- 1 file changed, 24 insertions(+), 5 deletions(-) diff --git a/src/server/dcp_server/models.py b/src/server/dcp_server/models.py index a43a01d1..1bbaacaf 100644 --- a/src/server/dcp_server/models.py +++ b/src/server/dcp_server/models.py @@ -729,9 +729,10 @@ def train(self, imgs, masks): for mask in masks: mask_class = mask.copy() + # set all instances in the instance mask not corresponding to the class in question to zero mask_class[0][mask_class[1]!=(i+1)] = 0 masks_class.append(mask_class) - + self.cellpose_models[i].train(imgs, masks_class) self.metric = np.mean([self.cellpose_models[i].metric for i in range(self.num_of_channels)]) @@ -739,7 +740,8 @@ def train(self, imgs, masks): def eval(self, img): - """Evaluate the model on the provided image. + """Evaluate the model on the provided image. The instance mask are computed as the union of the predicted model outputs, while the class of + each object is assigned based on majority voting between the models. :param img: Input image for evaluation. :type img: np.ndarray[np.uint8] @@ -750,28 +752,45 @@ def eval(self, img): instance_masks, class_masks, model_confidences = [], [], [] for i in range(self.num_of_channels): - + # get the instance mask and pixel-wise cell probability mask instance_mask, probs, _ = self.cellpose_models[i].eval_all_outputs(img) confidence = probs[2] + # assign the appropriate class to all objects detected by this model class_mask = np.zeros_like(instance_mask) class_mask[instance_mask>0]=(i + 1) instance_masks.append(instance_mask) class_masks.append(class_mask) model_confidences.append(confidence) - + # merge the outputs of the different models using the pixel-wise cell probability mask merged_mask_instances, class_mask = self.merge_masks(instance_masks, class_masks, model_confidences) + # set all connected components to the same label in the instance mask instance_mask = label_mask(merged_mask_instances>0) + # and set the class with the most pixels to that object for inst_id in np.unique(instance_mask)[1:]: where_inst_id = np.where(instance_mask==inst_id) vals, counts = np.unique(class_mask[where_inst_id], return_counts=True) class_mask[where_inst_id] = vals[np.argmax(counts)] - + # take the final mask by stancking instance and class mask final_mask = np.stack((instance_mask, class_mask), axis=self.eval_config['mask_channel_axis']).astype(np.uint16) return final_mask def merge_masks(self, inst_masks, class_masks, probabilities): + """Merges the instance and class masks resulting from the different models using the pixel-wise cell probability. The output of the model + with the maximum probability is selected for each pixel. + + :param inst_masks: List of predicted instance masks from each model. + :type inst_masks: List[np.array] + :param class_masks: List of corresponding class masks from each model. + :type class_masks: List[np.array] + :param probabilities: List of corresponding pixel-wise cell probability masks + :type probabilities: List[np.array] + :return: A tuple containing the following elements: + - final_mask_inst (numpy.ndarray): A single instance mask where for each pixel the output of the model with the highest probability is selected + - final_mask_class (numpy.ndarray): A single class mask where for each pixel the output of the model with the highest probability is selected + :rtype: tuple + """ # Convert lists to numpy arrays inst_masks = np.array(inst_masks) class_masks = np.array(class_masks)