From e918071e3593dfd272f20a8d58906c8399b25180 Mon Sep 17 00:00:00 2001 From: Mariia Koren Date: Wed, 12 Jul 2023 15:36:02 +0200 Subject: [PATCH 01/47] integration of patch model --- src/client/dcp_client/config.cfg | 4 +- src/server/dcp_server/models.py | 94 ++++++++++++++++++++++++++++++++ 2 files changed, 96 insertions(+), 2 deletions(-) diff --git a/src/client/dcp_client/config.cfg b/src/client/dcp_client/config.cfg index d7eb494..53b5db3 100644 --- a/src/client/dcp_client/config.cfg +++ b/src/client/dcp_client/config.cfg @@ -1,9 +1,9 @@ { "server":{ "user": "ubuntu", - "host": "jusuf-vm2", + "host": "local", "data-path": "/home/ubuntu/dcp-data", - "ip": "134.94.88.74", + "ip": "0.0.0.0", "port": 7010 } } \ No newline at end of file diff --git a/src/server/dcp_server/models.py b/src/server/dcp_server/models.py index 8125df2..5d4d143 100644 --- a/src/server/dcp_server/models.py +++ b/src/server/dcp_server/models.py @@ -60,6 +60,58 @@ def masks_to_outlines(self, mask): """ return utils.masks_to_outlines(mask) #[True, False] outputs +class CellFullyConvClassifier(nn.Module): + + ''' + Fully convolutional classifier for cell images. + + Args: + in_channels (int): Number of input channels. + num_classes (int): Number of output classes. + ''' + + def __init__(self, in_channels, num_classes): + super().__init__() + + self.in_channels = in_channels + self.num_classes = num_classes + + self.layer1 = nn.Sequential( + nn.Conv2d(in_channels, 16, 3, 2, 5), + nn.BatchNorm2d(16), + nn.ReLU(), + nn.Dropout2d(p=0.2), + ) + + self.layer2 = nn.Sequential( + nn.Conv2d(16, 64, 3, 1, 1), + nn.BatchNorm2d(64), + nn.ReLU(), + nn.Dropout2d(p=0.2), + ) + + self.layer3 = nn.Sequential( + nn.Conv2d(64, 128, 3, 2, 4), + nn.BatchNorm2d(128), + nn.ReLU(), + nn.Dropout2d(p=0.2), + ) + + self.final_conv = nn.Conv2d(128, num_classes, 1) + + self.pooling = nn.AdaptiveMaxPool2d(1) + + def forward(self, x): + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + + x = self.final_conv(x) + x = self.pooling(x) + x = x.view(x.size(0), -1) + + return x # class CustomSAMModel(): @@ -67,5 +119,47 @@ def masks_to_outlines(self, mask): # def __init__(self): # pass +class CellposePatchCNN(): + + """Cellpose & patches of cells and then cnn to classify each patch + """ + + def __init__(self, model_config, train_config, eval_config ): + + # Initialize the cellpose model + self.train_config = train_config + self.eval_config = eval_config + + self.classifier = CellFullyConvClassifier() + self.segmentation = CustomCellposeModel(model_config, train_config, eval_config) + + def train(self, imgs, masks): + + # masks should have first channel as a cellpose mask and + # all other layers corresponds to the classes + ## TODO: take care of the images and masks preparation + self.segmentation.train(imgs, masks) + + def create_patch_dataset(imgs, masks, save_imgs_path, black_bg:bool, include_mask:bool): + ''' + Create a folder with n subfolders corresponding to each class, each subfolder contains the patches of the corresponding celltype. + + Args: + dataset (list): A list of tuples containing image and mask pairs. + save_imgs_path (str): The path to save the generated patch dataset. + black_bg (bool): Flag indicating whether to use a black background for patches. + include_mask (bool): Flag indicating whether to include the mask along with patches. + ''' + + for img, msk in zip(imgs, masks): + + for channel in range(num_of_channels): + + loc = find_objects(msk[channel]) + msk_patches = get_patches(loc, img, msk[channel], black_bg=black_bg, include_mask=include_mask) + save_patches(msk_patches,channel, save_imgs_path) + + + From effbe6aae030956003c53e053c18585f583bb246 Mon Sep 17 00:00:00 2001 From: Christina Bukas Date: Thu, 13 Jul 2023 09:11:13 +0200 Subject: [PATCH 02/47] add comments with to dos --- src/server/dcp_server/models.py | 61 +++++++++++++++++++++------------ 1 file changed, 40 insertions(+), 21 deletions(-) diff --git a/src/server/dcp_server/models.py b/src/server/dcp_server/models.py index 5d4d143..f78168d 100644 --- a/src/server/dcp_server/models.py +++ b/src/server/dcp_server/models.py @@ -101,6 +101,12 @@ def __init__(self, in_channels, num_classes): self.pooling = nn.AdaptiveMaxPool2d(1) + #def train (self, x, y): + ## TODO should call forward repeatedly and perform the entire train loop + + #def eval(self, x): + ## TODO should call forward once, model is in eval mode, and return predicted masks + def forward(self, x): x = self.layer1(x) @@ -119,6 +125,7 @@ def forward(self, x): # def __init__(self): # pass + class CellposePatchCNN(): """Cellpose & patches of cells and then cnn to classify each patch @@ -131,33 +138,45 @@ def __init__(self, model_config, train_config, eval_config ): self.eval_config = eval_config self.classifier = CellFullyConvClassifier() - self.segmentation = CustomCellposeModel(model_config, train_config, eval_config) + self.segmentor = CustomCellposeModel(model_config, train_config, eval_config) def train(self, imgs, masks): # masks should have first channel as a cellpose mask and # all other layers corresponds to the classes - ## TODO: take care of the images and masks preparation - self.segmentation.train(imgs, masks) + ## TODO: take care of the images and masks preparation -> this step isn't for now @Mariia + self.segmentor.train(imgs, masks) + ## TODO call create_patches (adjust create_patch_dataset function) + ## to prepare imgs and masks for training CNN + ## TODO call self.classifier.train(imgs, masks) - def create_patch_dataset(imgs, masks, save_imgs_path, black_bg:bool, include_mask:bool): - ''' - Create a folder with n subfolders corresponding to each class, each subfolder contains the patches of the corresponding celltype. - - Args: - dataset (list): A list of tuples containing image and mask pairs. - save_imgs_path (str): The path to save the generated patch dataset. - black_bg (bool): Flag indicating whether to use a black background for patches. - include_mask (bool): Flag indicating whether to include the mask along with patches. - ''' - - for img, msk in zip(imgs, masks): - - for channel in range(num_of_channels): - - loc = find_objects(msk[channel]) - msk_patches = get_patches(loc, img, msk[channel], black_bg=black_bg, include_mask=include_mask) - save_patches(msk_patches,channel, save_imgs_path) + def eval(self, img, **eval_config): + pass + ## TODO implement the eval pipeline, i.e. first call self.segmentor.eval, then split again into patches + ## using resulting seg and then call self.classifier.eval. The final mask which is returned should have + ## first channel the output of cellpose and the rest are the class channels + + def create_patch_dataset(self, imgs, masks, black_bg:bool, include_mask:bool): + ''' + TODO: Split img and masks into patches of equal size which are centered around the cells. + The algorithm should first run through all images to find the max cell size, and use + the max cell size to define the patch size. All patches and masks should then be returned + in the same foramt as imgs and masks (same type, i.e. check if tensor or np.array and same + convention of dims, e.g. CxHxW) + Args: + imgs (): + masks (): + black_bg (bool): Flag indicating whether to use a black background for patches. + include_mask (bool): Flag indicating whether to include the mask along with patches. + ''' + + for img, msk in zip(imgs, masks): + + for channel in range(num_of_channels): + + loc = find_objects(msk[channel]) + msk_patches = get_patches(loc, img, msk[channel], black_bg=black_bg, include_mask=include_mask) + save_patches(msk_patches,channel, save_imgs_path) From 325b849bbe9e4d3e91ea144da3b1c3ac9a9642dd Mon Sep 17 00:00:00 2001 From: Mariia Koren Date: Sat, 15 Jul 2023 23:36:47 +0200 Subject: [PATCH 03/47] added train and eval parts in models --- src/server/dcp_server/models.py | 255 ++++++++++++++++++++++++++++++-- 1 file changed, 242 insertions(+), 13 deletions(-) diff --git a/src/server/dcp_server/models.py b/src/server/dcp_server/models.py index f78168d..edb0bec 100644 --- a/src/server/dcp_server/models.py +++ b/src/server/dcp_server/models.py @@ -1,4 +1,14 @@ +#models + + from cellpose import models, utils + +import torch +from torch import nn + +from torch.optim import Adam +from torch.utils.data import TensorDataset, DataLoader + #from segment_anything import SamPredictor, sam_model_registry #from segment_anything.automatic_mask_generator import SamAutomaticMaskGenerator @@ -70,12 +80,14 @@ class CellFullyConvClassifier(nn.Module): num_classes (int): Number of output classes. ''' - def __init__(self, in_channels, num_classes): + def __init__(self, in_channels, num_classes, **kwargs): super().__init__() self.in_channels = in_channels self.num_classes = num_classes + self.hparams = kwargs + self.layer1 = nn.Sequential( nn.Conv2d(in_channels, 16, 3, 2, 5), nn.BatchNorm2d(16), @@ -101,11 +113,57 @@ def __init__(self, in_channels, num_classes): self.pooling = nn.AdaptiveMaxPool2d(1) - #def train (self, x, y): + def train (self, imgs, labels): ## TODO should call forward repeatedly and perform the entire train loop + + """ + input: + 1) imgs - List[np.ndarray[np.uint8]] with shape (3, dx, dy) + 2) y - List[int] + """ + + lr = self.hparams.get('lr', 0.001) + epochs = self.hparams.get('epochs', 1) + batch_size = self.hparams.get('batch_size', 1) + optimizer_class = self.hparams.get('optimizer', 'Adam') + + imgs = [ torch.from_numpy(img) for img in imgs] + labels = torch.tensor(labels) + + train_dataset = TensorDataset(imgs, labels) + train_dataloader = DataLoader(train_dataset, batch_size=batch_size) + + # eval method evaluates a python string and returns an object, e.g. eval('print(1)') = 1 + # or eval('[1, 2, 3]') = [1, 2, 3] + loss_fn = nn.CrossEntropyLoss() + optimizer = eval(f'{optimizer_class}(lr={lr})') + + for _ in epochs: + for i, data in enumerate(train_dataloader): + + imgs, labels = data + optimizer.zero_grad() + preds = self.forward(imgs) + + y_hats = torch.argmax(preds, 1) + loss = loss_fn(y_hats, labels) + loss.backward() + + optimizer.step() - #def eval(self, x): + def eval(self, imgs): ## TODO should call forward once, model is in eval mode, and return predicted masks + """ + input: + 1) imgs - List[np.ndarray[np.uint8]] with shape (3, dx, dy) + """ + labels = [] + for img in imgs: + + img = torch.from_numpy(img).unsqueeze(0) + labels.append(self.forward(img)) + + return labels def forward(self, x): @@ -140,23 +198,187 @@ def __init__(self, model_config, train_config, eval_config ): self.classifier = CellFullyConvClassifier() self.segmentor = CustomCellposeModel(model_config, train_config, eval_config) - def train(self, imgs, masks): + def train(self, imgs, masks, **kwargs): # masks should have first channel as a cellpose mask and # all other layers corresponds to the classes ## TODO: take care of the images and masks preparation -> this step isn't for now @Mariia + + black_bg = kwargs.get("black_bg", False) + include_mask = kwargs.get("include_mask", False) + self.segmentor.train(imgs, masks) + patches, labels = self.create_patch_dataset(self, imgs, masks, black_bg:bool, include_mask:bool, **kwargs) + self.classifier.train(patches, labels) + ## TODO call create_patches (adjust create_patch_dataset function) ## to prepare imgs and masks for training CNN ## TODO call self.classifier.train(imgs, masks) def eval(self, img, **eval_config): - pass + ## TODO implement the eval pipeline, i.e. first call self.segmentor.eval, then split again into patches ## using resulting seg and then call self.classifier.eval. The final mask which is returned should have ## first channel the output of cellpose and the rest are the class channels + mask = self.segmentor.eval(img) + patches, labels = self.create_patch_dataset(self, [img], [mask], black_bg:bool, include_mask:bool, **kwargs) + results = self.classifier.eval(patches, labels) + + return results + + def find_max_patch_size(mask): + + # Find objects in the binary image + objects = ndi.find_objects(mask) + + # Initialize variables to store the maximum patch size + max_patch_size = 0 + max_patch_indices = None + + # Iterate over the found objects + for obj in objects: + # Extract start and stop values from the slice object + slices = [s for s in obj] + start = [s.start for s in slices] + stop = [s.stop for s in slices] + + # Calculate the size of the patch along each axis + patch_size = tuple(stop[i] - start[i] for i in range(len(start))) + + # Calculate the total size (area) of the patch + total_size = 1 + for size in patch_size: + total_size *= size + + # Check if the current patch size is larger than the maximum + if total_size > max_patch_size: + max_patch_size = total_size + max_patch_indices = obj + + max_patch_size_edge = np.ceil(np.sqrt(max_patch_size)) + + return max_patch_size_edge + + def pad_centered_padded_patch(x: np.ndarray, c, p, mask: np.ndarray=None, noise_intensity=None) -> np.ndarray: + """ + Crop a patch from an array `x` centered at coordinates `c` with size `p`, and apply padding if necessary. + + Args: + x (np.ndarray): The input array from which the patch will be cropped. + c (tuple): The coordinates (row, column, channel) at the center of the patch. + p (tuple): The size of the patch to be cropped (height, width). + remove_other_instances (bool): Flag indicating whether to remove other instances in the patch. + + Returns: + np.ndarray: The cropped patch with applied padding. + """ + + height, width = p # Size of the patch + + # Calculate the boundaries of the patch + top = c[0] - height // 2 + bottom = top + height + + left = c[1] - width // 2 + right = left + width + + # Crop the patch from the input array + + if mask is not None: + + mask_ = mask.max(-1) if len(mask.shape) == 3 else mask + central_label = mask_[c[0], c[1]] + m = (mask_ != central_label) & (mask_ > 0) + x[m] = 0 + + if noise_intensity is not None: + x[m] = np.random.normal(scale=noise_intensity, size=x[m].shape) + + patch = x[max(top, 0):min(bottom, x.shape[0]), max(left, 0):min(right, x.shape[1])] + + if len(c) == 3: + patch = patch[...,c[2]] + + # Calculate the required padding amounts + + size_x, size_y = x.shape[1], x.shape[0] + + # Apply padding if necessary + if left < 0: + patch = np.hstack(( + np.random.normal(scale=noise_intensity, size=(patch.shape[0], abs(left))).astype(np.uint8), patch + )) + + # Apply padding on the right side if necessary + if right > size_x: + patch = np.hstack(( + patch, np.random.normal(scale=noise_intensity, size=(patch.shape[0], right - size_x)).astype(np.uint8) + )) + + # Apply padding on the top side if necessary + if top < 0: + patch = np.vstack(( + np.random.normal(scale=noise_intensity, size=(abs(top), patch.shape[1])).astype(np.uint8), patch + )) + + # Apply padding on the bottom side if necessary + if bottom > size_y: + patch = np.vstack(( + patch, np.random.normal(scale=noise_intensity, size=(bottom - size_y, patch.shape[1])).astype(np.uint8) + )) + + return patch + + + def get_center_of_mass(mask: np.ndarray) -> np.ndarray: + """ + Compute the centers of mass for each object in a mask. + + Args: + mask (np.ndarray): The input mask containing labeled objects. + + Returns: + np.ndarray: An array of coordinates (row, column, channel) representing the centers of mass for each object. + """ + # Compute the centers of mass for each labeled object in the mask + centers_of_mass = np.array( + list(map( + lambda x: (int(x[0]), int(x[1]), int(x[2])) if len(mask.shape) == 3 else (int(x[0]), int(x[1]), -1), + ndi.center_of_mass(mask, mask, np.arange(1, mask.max() + 1)) + )) + ) + + return centers_of_mass + + + def get_centered_patches(img, mask, p_size: int, noise_intensity=5): + + ''' + Extracts centered patches from the input image based on the centers of objects identified in the mask. + + Args: + img: The input image. + mask: The mask representing the objects in the image. + p_size (int): The size of the patches to extract. + noise_intensity: The intensity of noise to add to the patches. - def create_patch_dataset(self, imgs, masks, black_bg:bool, include_mask:bool): + ''' + + patches, labels = [], [] + + centers_of_mass = get_center_of_mass(mask) + # Crop patches around each center of mass and save them + for i, c in enumerate(centers_of_mass): + c_x, c_y, label = c + + patch = pad_centered_padded_patch(img.copy(), (c_x, c_y), (p_size, p_size), mask=mask, noise_intensity=noise_intensity) + + patches.append(patch) + labels.append(label) + + return patches, labels + + def create_patch_dataset(self, imgs, masks, black_bg:bool, include_mask:bool, **kwargs): ''' TODO: Split img and masks into patches of equal size which are centered around the cells. The algorithm should first run through all images to find the max cell size, and use @@ -170,15 +392,22 @@ def create_patch_dataset(self, imgs, masks, black_bg:bool, include_mask:bool): include_mask (bool): Flag indicating whether to include the mask along with patches. ''' + noise_intensity = kwargs.get("noise_intensity", 5) + + max_patch_size = np.max([self.find_max_patch_size(mask) in masks]) + + patches, labels = [], [] + for img, msk in zip(imgs, masks): for channel in range(num_of_channels): loc = find_objects(msk[channel]) - msk_patches = get_patches(loc, img, msk[channel], black_bg=black_bg, include_mask=include_mask) - save_patches(msk_patches,channel, save_imgs_path) - - - - - + patch, label = get_centered_patches(img, msk, int(1.5 * img.shape[0] // 5), noise_intensity=noise_intensity) + + patches.append(patch) + labels.append(label) + # msk_patches = get_patches(loc, img, msk[channel], black_bg=black_bg, include_mask=include_mask) + # save_patches(msk_patches,channel, save_imgs_path) + + return patches, labels \ No newline at end of file From 5ca7b15b645ea32d8a073816da0401c602fc0d33 Mon Sep 17 00:00:00 2001 From: Mariia Koren Date: Mon, 17 Jul 2023 11:07:43 +0200 Subject: [PATCH 04/47] add comments and small changes --- src/server/dcp_server/models.py | 66 +++++++++++++++++++-------------- 1 file changed, 38 insertions(+), 28 deletions(-) diff --git a/src/server/dcp_server/models.py b/src/server/dcp_server/models.py index edb0bec..1be8829 100644 --- a/src/server/dcp_server/models.py +++ b/src/server/dcp_server/models.py @@ -86,7 +86,10 @@ def __init__(self, in_channels, num_classes, **kwargs): self.in_channels = in_channels self.num_classes = num_classes - self.hparams = kwargs + self.train_config = train_config + self.eval_config = eval_config + + # self.hparams = kwargs self.layer1 = nn.Sequential( nn.Conv2d(in_channels, 16, 3, 2, 5), @@ -113,6 +116,18 @@ def __init__(self, in_channels, num_classes, **kwargs): self.pooling = nn.AdaptiveMaxPool2d(1) + def forward(self, x): + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + + x = self.final_conv(x) + x = self.pooling(x) + x = x.view(x.size(0), -1) + + return x + def train (self, imgs, labels): ## TODO should call forward repeatedly and perform the entire train loop @@ -122,14 +137,16 @@ def train (self, imgs, labels): 2) y - List[int] """ - lr = self.hparams.get('lr', 0.001) - epochs = self.hparams.get('epochs', 1) - batch_size = self.hparams.get('batch_size', 1) - optimizer_class = self.hparams.get('optimizer', 'Adam') + lr = self.train_config.get('lr', 0.001) + epochs = self.train_config.get('epochs', 1) + batch_size = self.train_config.get('batch_size', 1) + optimizer_class = self.train_config.get('optimizer', 'Adam') + # Convert input images and labels to tensors imgs = [ torch.from_numpy(img) for img in imgs] labels = torch.tensor(labels) + # Create a training dataset and dataloader train_dataset = TensorDataset(imgs, labels) train_dataloader = DataLoader(train_dataset, batch_size=batch_size) @@ -154,6 +171,7 @@ def train (self, imgs, labels): def eval(self, imgs): ## TODO should call forward once, model is in eval mode, and return predicted masks """ + Evaluate the model on the provided images and return predicted labels. input: 1) imgs - List[np.ndarray[np.uint8]] with shape (3, dx, dy) """ @@ -164,20 +182,6 @@ def eval(self, imgs): labels.append(self.forward(img)) return labels - - def forward(self, x): - - x = self.layer1(x) - x = self.layer2(x) - x = self.layer3(x) - - x = self.final_conv(x) - x = self.pooling(x) - x = x.view(x.size(0), -1) - - return x - - # class CustomSAMModel(): # # https://github.com/facebookresearch/segment-anything/blob/main/notebooks/automatic_mask_generator_example.ipynb # def __init__(self): @@ -186,15 +190,17 @@ def forward(self, x): class CellposePatchCNN(): - """Cellpose & patches of cells and then cnn to classify each patch + """ + Cellpose & patches of cells and then cnn to classify each patch """ def __init__(self, model_config, train_config, eval_config ): - # Initialize the cellpose model + self.train_config = train_config self.eval_config = eval_config + # Initialize the classifier and the cellpose model self.classifier = CellFullyConvClassifier() self.segmentor = CustomCellposeModel(model_config, train_config, eval_config) @@ -204,16 +210,18 @@ def train(self, imgs, masks, **kwargs): # all other layers corresponds to the classes ## TODO: take care of the images and masks preparation -> this step isn't for now @Mariia + ## TODO call create_patches (adjust create_patch_dataset function) + ## to prepare imgs and masks for training CNN + ## TODO call self.classifier.train(imgs, masks) + black_bg = kwargs.get("black_bg", False) include_mask = kwargs.get("include_mask", False) + include_cellpose_mask = kwargs.get("include_cellpose_mask", True) self.segmentor.train(imgs, masks) patches, labels = self.create_patch_dataset(self, imgs, masks, black_bg:bool, include_mask:bool, **kwargs) self.classifier.train(patches, labels) - ## TODO call create_patches (adjust create_patch_dataset function) - ## to prepare imgs and masks for training CNN - ## TODO call self.classifier.train(imgs, masks) def eval(self, img, **eval_config): @@ -222,13 +230,15 @@ def eval(self, img, **eval_config): ## first channel the output of cellpose and the rest are the class channels mask = self.segmentor.eval(img) patches, labels = self.create_patch_dataset(self, [img], [mask], black_bg:bool, include_mask:bool, **kwargs) - results = self.classifier.eval(patches, labels) + # result is a one channel image that contains labels corresponding to the class labels + result = self.classifier.eval(patches, labels) + result_with_cellpose_mask = torch.stack((mask, result), 0) - return results + return result_with_cellpose_mask def find_max_patch_size(mask): - # Find objects in the binary image + # Find objects in the mask objects = ndi.find_objects(mask) # Initialize variables to store the maximum patch size @@ -267,7 +277,6 @@ def pad_centered_padded_patch(x: np.ndarray, c, p, mask: np.ndarray=None, noise_ x (np.ndarray): The input array from which the patch will be cropped. c (tuple): The coordinates (row, column, channel) at the center of the patch. p (tuple): The size of the patch to be cropped (height, width). - remove_other_instances (bool): Flag indicating whether to remove other instances in the patch. Returns: np.ndarray: The cropped patch with applied padding. @@ -288,6 +297,7 @@ def pad_centered_padded_patch(x: np.ndarray, c, p, mask: np.ndarray=None, noise_ mask_ = mask.max(-1) if len(mask.shape) == 3 else mask central_label = mask_[c[0], c[1]] + # Zero out values in the patch where the mask is not equal to the central label m = (mask_ != central_label) & (mask_ > 0) x[m] = 0 From 7b4d847cfdb3b6f77594500315075c8efa6fd88e Mon Sep 17 00:00:00 2001 From: Mariia Koren Date: Sun, 23 Jul 2023 22:20:47 +0200 Subject: [PATCH 05/47] finished create_patch_datasetand fixed errors --- src/server/dcp_server/config.cfg | 6 +- src/server/dcp_server/models.py | 142 +++++++++++++++++++++++-------- 2 files changed, 110 insertions(+), 38 deletions(-) diff --git a/src/server/dcp_server/config.cfg b/src/server/dcp_server/config.cfg index 01efeb1..42d33b6 100644 --- a/src/server/dcp_server/config.cfg +++ b/src/server/dcp_server/config.cfg @@ -15,7 +15,11 @@ "model_type":"cyto" }, "data": { - "data_root": "/home/ubuntu/dcp-data" + "data_root": "/home/ubuntu/dcp-data", + "patch_size":64, + "noise_intensity":5, + "num_classes":3, + }, "train":{ "n_epochs": 2, diff --git a/src/server/dcp_server/models.py b/src/server/dcp_server/models.py index 1be8829..6ef5fb6 100644 --- a/src/server/dcp_server/models.py +++ b/src/server/dcp_server/models.py @@ -89,8 +89,6 @@ def __init__(self, in_channels, num_classes, **kwargs): self.train_config = train_config self.eval_config = eval_config - # self.hparams = kwargs - self.layer1 = nn.Sequential( nn.Conv2d(in_channels, 16, 3, 2, 5), nn.BatchNorm2d(16), @@ -172,8 +170,10 @@ def eval(self, imgs): ## TODO should call forward once, model is in eval mode, and return predicted masks """ Evaluate the model on the provided images and return predicted labels. - input: - 1) imgs - List[np.ndarray[np.uint8]] with shape (3, dx, dy) + Input: + imgs: List[np.ndarray[np.uint8]] with shape (3, dx, dy) + Output: + labels: List of predicted labels. """ labels = [] for img in imgs: @@ -204,19 +204,29 @@ def __init__(self, model_config, train_config, eval_config ): self.classifier = CellFullyConvClassifier() self.segmentor = CustomCellposeModel(model_config, train_config, eval_config) - def train(self, imgs, masks, **kwargs): + def init_from_checkpoints(self, chpt_classifier=None, chpt_segmentor=None): + """ + Initialize the model from pre-trained checkpoints. + """ + + self.segmentor = CustomCellposeModel( + model_config={"gpu":torch.cuda.is_available(), "pretrained_model":chpt_segmentor} + ) + self.classifier.load_state_dict(torch.load(chpt_classifier)["model"]) + + + def train(self, imgs, masks, **train_config): # masks should have first channel as a cellpose mask and # all other layers corresponds to the classes ## TODO: take care of the images and masks preparation -> this step isn't for now @Mariia - ## TODO call create_patches (adjust create_patch_dataset function) ## to prepare imgs and masks for training CNN ## TODO call self.classifier.train(imgs, masks) - black_bg = kwargs.get("black_bg", False) - include_mask = kwargs.get("include_mask", False) - include_cellpose_mask = kwargs.get("include_cellpose_mask", True) + black_bg = train_config.get("black_bg", False) + include_mask = train_config.get("include_mask", False) + include_cellpose_mask = train_config.get("include_cellpose_mask", True) self.segmentor.train(imgs, masks) patches, labels = self.create_patch_dataset(self, imgs, masks, black_bg:bool, include_mask:bool, **kwargs) @@ -229,14 +239,59 @@ def eval(self, img, **eval_config): ## using resulting seg and then call self.classifier.eval. The final mask which is returned should have ## first channel the output of cellpose and the rest are the class channels mask = self.segmentor.eval(img) - patches, labels = self.create_patch_dataset(self, [img], [mask], black_bg:bool, include_mask:bool, **kwargs) - # result is a one channel image that contains labels corresponding to the class labels - result = self.classifier.eval(patches, labels) - result_with_cellpose_mask = torch.stack((mask, result), 0) + result = self.get_prediction(img, mask) + + return result + + def get_prediction(self, input_image, cellpose_mask, model_classifier): + """ + Performs object segmentation and classification on an input image using the Cellpose model and a classifier model. + + Args: + image_path (str): The file path of the input image. + model (CellposeModel): The Cellpose model used for object segmentation. + model_classifier: The classifier model used for object classification. + + Returns: + tuple: A tuple containing the cellpose_mask and final_mask, representing the segmentation masks obtained from + the Cellpose model and the combined segmentation and classification mask, respectively. + """ - return result_with_cellpose_mask + # Obtain segmentation mask using Cellpose model - def find_max_patch_size(mask): + # Find objects in the cellpose_mask + locs = find_objects(cellpose_mask) + + # Get patches and labels based on object centroids + patches, labels = get_centered_patches(input_image, cellpose_mask, int(1.5 * input_image.shape[0] // 5), noise_intensity=5) + + labels = torch.tensor(labels) + labels_fit = [] + + final_mask = torch.zeros(cellpose_mask.shape) + + with torch.no_grad(): + for i, patch in enumerate(patches): + loc = locs[i] + + # Prepare image patch for classification + img = torch.tensor(patch.astype(np.float32)).unsqueeze(0).repeat(3, 1, 1).unsqueeze(0) / 255 + + # Perform inference using model_classifier + logits = self.classifier(img) + + _, predicted = torch.max(logits, 1) + labels_fit.append(predicted) + + # Assign predicted class to corresponding location in final_mask + final_mask[loc] = predicted + 1 + + # Apply mask to final_mask, retaining only regions where cellpose_mask is greater than 0 + final_mask = final_mask * ((cellpose_mask > 0).astype(np.uint8)) + + return final_mask + + def find_max_patch_size(self, mask): # Find objects in the mask objects = ndi.find_objects(mask) @@ -269,7 +324,7 @@ def find_max_patch_size(mask): return max_patch_size_edge - def pad_centered_padded_patch(x: np.ndarray, c, p, mask: np.ndarray=None, noise_intensity=None) -> np.ndarray: + def pad_centered_padded_patch(self, x: np.ndarray, c, p, mask: np.ndarray=None, noise_intensity=None) -> np.ndarray: """ Crop a patch from an array `x` centered at coordinates `c` with size `p`, and apply padding if necessary. @@ -295,7 +350,7 @@ def pad_centered_padded_patch(x: np.ndarray, c, p, mask: np.ndarray=None, noise_ if mask is not None: - mask_ = mask.max(-1) if len(mask.shape) == 3 else mask + mask_ = mask.max(-1) if len(mask.shape) >= 3 else mask central_label = mask_[c[0], c[1]] # Zero out values in the patch where the mask is not equal to the central label m = (mask_ != central_label) & (mask_ > 0) @@ -340,7 +395,7 @@ def pad_centered_padded_patch(x: np.ndarray, c, p, mask: np.ndarray=None, noise_ return patch - def get_center_of_mass(mask: np.ndarray) -> np.ndarray: + def get_center_of_mass(self, mask: np.ndarray) -> np.ndarray: """ Compute the centers of mass for each object in a mask. @@ -351,17 +406,15 @@ def get_center_of_mass(mask: np.ndarray) -> np.ndarray: np.ndarray: An array of coordinates (row, column, channel) representing the centers of mass for each object. """ # Compute the centers of mass for each labeled object in the mask - centers_of_mass = np.array( - list(map( - lambda x: (int(x[0]), int(x[1]), int(x[2])) if len(mask.shape) == 3 else (int(x[0]), int(x[1]), -1), - ndi.center_of_mass(mask, mask, np.arange(1, mask.max() + 1)) - )) - ) + centers_of_mass = [ + (int(x[0]), int(x[1]), int(x[2])) if mask.ndim >= 3 else (int(x[0]), int(x[1]), -1) + for x in ndi.center_of_mass(mask, mask, np.arange(1, mask.max() + 1)) + ] return centers_of_mass - def get_centered_patches(img, mask, p_size: int, noise_intensity=5): + def get_centered_patches(self, img, mask, p_size: int, noise_intensity=5): ''' Extracts centered patches from the input image based on the centers of objects identified in the mask. @@ -393,7 +446,7 @@ def create_patch_dataset(self, imgs, masks, black_bg:bool, include_mask:bool, ** TODO: Split img and masks into patches of equal size which are centered around the cells. The algorithm should first run through all images to find the max cell size, and use the max cell size to define the patch size. All patches and masks should then be returned - in the same foramt as imgs and masks (same type, i.e. check if tensor or np.array and same + in the same format as imgs and masks (same type, i.e. check if tensor or np.array and same convention of dims, e.g. CxHxW) Args: imgs (): @@ -402,22 +455,37 @@ def create_patch_dataset(self, imgs, masks, black_bg:bool, include_mask:bool, ** include_mask (bool): Flag indicating whether to include the mask along with patches. ''' - noise_intensity = kwargs.get("noise_intensity", 5) + noise_intensity = kwargs.get("data", {}).get("noise_intensity", 5) - max_patch_size = np.max([self.find_max_patch_size(mask) in masks]) + #max_patch_size = np.max([self.find_max_patch_size(mask) in masks]) + max_patch_size = kwargs.get("data", {}).get("patch_size", 64) + num_of_channels = kwargs.get("data", {}).get("num_classes", 3) patches, labels = [], [] + is_train = masks[0].ndim == 3 for img, msk in zip(imgs, masks): - for channel in range(num_of_channels): - - loc = find_objects(msk[channel]) - patch, label = get_centered_patches(img, msk, int(1.5 * img.shape[0] // 5), noise_intensity=noise_intensity) - + # training dataset + if is_train: + # mask has dimension WxHxNum_of_channels + for channel in range(num_of_channels): + + patch, label = get_centered_patches(img, msk, int(1.5 * img.shape[0] // 5), + noise_intensity=noise_intensity) + + patches.append(patch) + labels.append(label) + + # test dataset + # mask has dimention WxH + else: + patch, _ = get_centered_patches(img, msk, int(1.5 * img.shape[0] // 5), + noise_intensity=noise_intensity) patches.append(patch) - labels.append(label) - # msk_patches = get_patches(loc, img, msk[channel], black_bg=black_bg, include_mask=include_mask) - # save_patches(msk_patches,channel, save_imgs_path) - return patches, labels \ No newline at end of file + if is_train: + return patches, labels + + else: + return patches \ No newline at end of file From 38a8140ae3d55b6df880af0f2fed7a29faf30e64 Mon Sep 17 00:00:00 2001 From: KorenMary Date: Thu, 27 Jul 2023 11:21:21 +0000 Subject: [PATCH 06/47] rewrite the tests from unittest to pytest --- src/client/test/test_app.py | 109 +++++++++++++------------ src/client/test/test_fsimagestorage.py | 85 +++++++++---------- 2 files changed, 98 insertions(+), 96 deletions(-) diff --git a/src/client/test/test_app.py b/src/client/test/test_app.py index b9164c1..f2a29ca 100644 --- a/src/client/test/test_app.py +++ b/src/client/test/test_app.py @@ -1,62 +1,69 @@ import os +import sys from skimage import data from skimage.io import imsave -import unittest +import pytest + +sys.path.append("../") from dcp_client.app import Application from dcp_client.utils.bentoml_model import BentomlModel from dcp_client.utils.fsimagestorage import FilesystemImageStorage from dcp_client.utils.sync_src_dst import DataRSync -class TestApplication(unittest.TestCase): - - def test_run_train(self): - pass - - def test_run_inference(self): - pass - - def test_load_image(self): - - img = data.astronaut() - img2 = data.cat() - os.mkdir('in_prog') - imsave('in_prog/test_img.png', img) - imsave('in_prog/test_img2.png', img2) - rsyncer = DataRSync(user_name="local", - host_name="local", - server_repo_path='.') - self.app = Application(BentomlModel(), - rsyncer, - FilesystemImageStorage(), - "0.0.0.0", - 7010) - - self.app.cur_selected_img = 'test_img.png' - self.app.cur_selected_path = 'in_prog' - - img_test = self.app.load_image() # if image_name is None - self.assertEqual(img.all(), img_test.all()) - img_test2 = self.app.load_image('test_img2.png') # if a filename is given - self.assertEqual(img2.all(), img_test2.all()) - - # delete everyting we created - os.remove('in_prog/test_img.png') - os.remove('in_prog/test_img2.png') - os.rmdir('in_prog') - - def test_save_image(self): - pass - - def test_move_images(self): - pass - - def test_delete_images(self): - pass - - def test_search_segs(self): - pass + +@pytest.fixture +def app(): + img = data.astronaut() + img2 = data.cat() + os.mkdir('in_prog') + + imsave('in_prog/test_img.png', img) + imsave('in_prog/test_img2.png', img2) + + rsyncer = DataRSync(user_name="local", host_name="local", server_repo_path='.') + app = Application(BentomlModel(), rsyncer, FilesystemImageStorage(), "0.0.0.0", 7010) + + app.cur_selected_img = 'test_img.png' + app.cur_selected_path = 'in_prog' + + return app, img, img2 + +def test_load_image(app): + app, img, img2 = app # Unpack the app, img, and img2 from the fixture + + img_test = app.load_image() # if image_name is None + assert img.all() == img_test.all() + + img_test2 = app.load_image('test_img2.png') # if a filename is given + assert img2.all() == img_test2.all() + + # delete everything we created + os.remove('in_prog/test_img.png') + os.remove('in_prog/test_img2.png') + os.rmdir('in_prog') + +def test_run_train(): + pass + +def test_run_inference(): + pass + +def test_save_image(): + pass + +def test_move_images(): + pass + +def test_delete_images(): + pass + +def test_search_segs(): + pass + + + + + -if __name__=='__main__': - unittest.main() \ No newline at end of file diff --git a/src/client/test/test_fsimagestorage.py b/src/client/test/test_fsimagestorage.py index 8a6fb9f..275e5f0 100644 --- a/src/client/test/test_fsimagestorage.py +++ b/src/client/test/test_fsimagestorage.py @@ -1,51 +1,46 @@ import os +import pytest from skimage.io import imsave from skimage import data -import unittest from dcp_client.utils.fsimagestorage import FilesystemImageStorage - -class TestFilesystemImageStorage(unittest.TestCase): - - def test_load_image(self): - fis = FilesystemImageStorage() - img = data.astronaut() - fname = 'test_img.png' - imsave(fname, img) - img_test = fis.load_image('.', fname) - self.assertEqual(img.all(), img_test.all()) - os.remove(fname) - - def test_move_image(self): - fis = FilesystemImageStorage() - img = data.astronaut() - fname = 'test_img.png' - os.mkdir('temp') - imsave(fname, img) - fis.move_image('.', 'temp', fname) - self.assertTrue(os.path.exists('temp/test_img.png')) - os.remove('temp/test_img.png') - os.rmdir('temp') - - def test_save_image(self): - fis = FilesystemImageStorage() - img = data.astronaut() - fname = 'test_img.png' - fis.save_image('.', fname, img) - self.assertTrue(os.path.exists(fname)) - os.remove(fname) - - def test_delete_image(self): - fis = FilesystemImageStorage() - img = data.astronaut() - fname = 'test_img.png' - os.mkdir('temp') - imsave('temp/test_img.png', img) - fis.delete_image('temp', fname) - self.assertFalse(os.path.exists('temp/test_img.png')) - os.rmdir('temp') - - -if __name__=='__main__': - unittest.main() \ No newline at end of file +@pytest.fixture +def fis(): + return FilesystemImageStorage() + +@pytest.fixture +def sample_image(): + # Create a sample image + img = data.astronaut() + fname = 'test_img.png' + imsave(fname, img) + return fname + +def test_load_image(fis, sample_image): + img_test = fis.load_image('.', sample_image) + assert img_test.all() == data.astronaut().all() + os.remove(sample_image) + +def test_move_image(fis, sample_image): + temp_dir = 'temp' + os.mkdir(temp_dir) + fis.move_image('.', temp_dir, sample_image) + assert os.path.exists(os.path.join(temp_dir, 'test_img.png')) + os.remove(os.path.join(temp_dir, 'test_img.png')) + os.rmdir(temp_dir) + +def test_save_image(fis): + img = data.astronaut() + fname = 'output.png' + fis.save_image('.', fname, img) + assert os.path.exists(fname) + os.remove(fname) + +def test_delete_image(fis, sample_image): + temp_dir = 'temp' + os.mkdir(temp_dir) + fis.move_image('.', temp_dir, sample_image) + fis.delete_image(temp_dir, 'test_img.png') + assert not os.path.exists(os.path.join(temp_dir, 'test_img.png')) + os.rmdir(temp_dir) From c18a796483cee746b2fcd99c898b60e2971976a6 Mon Sep 17 00:00:00 2001 From: KorenMary Date: Tue, 26 Sep 2023 17:29:16 +0000 Subject: [PATCH 07/47] add integration tests --- src/server/dcp_server/test.py | 80 +++++++++++++++++ src/server/dcp_server/test_integration.py | 100 ++++++++++++++++++++++ 2 files changed, 180 insertions(+) create mode 100644 src/server/dcp_server/test.py create mode 100644 src/server/dcp_server/test_integration.py diff --git a/src/server/dcp_server/test.py b/src/server/dcp_server/test.py new file mode 100644 index 0000000..d77a7d6 --- /dev/null +++ b/src/server/dcp_server/test.py @@ -0,0 +1,80 @@ +import numpy as np +import torch + +from tqdm import tqdm + +import sys +import cv2 +import os + +from copy import deepcopy + +import sys +sys.path.append("../") + +from models import CellposePatchCNN +from dcp_server.utils import read_config +from skimage.color import label2rgb + +def get_dataset(dataset_path): + + images_path = os.path.join(dataset_path, "images") + masks_path = os.path.join(dataset_path, "masks") + + + images_files = [img for img in os.listdir(images_path)] + masks_files = [mask for mask in os.listdir(masks_path)] + + images, masks = [], [] + for img_file, mask_file in zip(images_files, masks_files): + + img_path = os.path.join(images_path, img_file) + mask_path = os.path.join(masks_path, mask_file) + + img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE) + msk = np.load(mask_path) + + images.append(img) + masks.append(msk) + + return images, masks + + + +if __name__=='__main__': + + img = cv2.imread("/home/ubuntu/data-centric-platform/src/server/dcp_server/data/img.jpg", cv2.IMREAD_GRAYSCALE) + msk = np.load("/home/ubuntu/data-centric-platform/src/server/dcp_server/data/mask.npy") + classifier_model_config, classifier_train_config, classifier_eval_config = {}, {}, {} + + segmentor_model_config = read_config('model', config_path = 'config.cfg') + segmentor_train_config = read_config('train', config_path = 'config.cfg') + segmentor_eval_config = read_config('eval', config_path = 'config.cfg') + + patch_model = CellposePatchCNN( + segmentor_model_config, segmentor_train_config, segmentor_eval_config, + classifier_model_config, classifier_train_config, classifier_eval_config) + + images, masks = get_dataset("/home/ubuntu/data-centric-platform/src/server/dcp_server/data") + + for i in tqdm(range(1)): + loss_train = patch_model.train(deepcopy(images), deepcopy(masks)) + assert(loss_train>1e-2) + + # instance segmentation mask (C, W, H) --> semantic multiclass segmentation mask (W, H) + for i in range(msk.shape[0]): + msk[i, ...][msk[i, ...] > 0] = i + 1 + + msk = msk.sum(0) + + final_mask, jaccard_index = patch_model.eval(img, mask_test=torch.tensor(msk)) + final_mask = final_mask.numpy() + + cv2.imwrite("/home/ubuntu/data-centric-platform/src/server/dcp_server/data/final_mask.jpg", 255*label2rgb(final_mask)) + print(jaccard_index) + + + + + + \ No newline at end of file diff --git a/src/server/dcp_server/test_integration.py b/src/server/dcp_server/test_integration.py new file mode 100644 index 0000000..0ec002d --- /dev/null +++ b/src/server/dcp_server/test_integration.py @@ -0,0 +1,100 @@ +import os +import cv2 +import sys +import torch + +import numpy as np +from tqdm import tqdm +from copy import deepcopy +from skimage.color import label2rgb + +sys.path.append("../") +from dcp_server.models import CellposePatchCNN +from dcp_server.utils import read_config + +import pytest + + +def get_dataset(dataset_path): + + images_path = os.path.join(dataset_path, "images") + masks_path = os.path.join(dataset_path, "masks") + + + images_files = [img for img in os.listdir(images_path)] + masks_files = [mask for mask in os.listdir(masks_path)] + + images, masks = [], [] + for img_file, mask_file in zip(images_files, masks_files): + + img_path = os.path.join(images_path, img_file) + mask_path = os.path.join(masks_path, mask_file) + + img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE) + msk = np.load(mask_path) + + images.append(img) + masks.append(msk) + + return images, masks + + +@pytest.fixture +def patch_model(): + classifier_model_config = { + "chpt_path": + "/home/ubuntu/data-centric-platform/src/server/dcp_server/data/classifier_checkpoint.pth" + } + classifier_train_config, classifier_eval_config = {}, {} + + segmentor_model_config = read_config('model', config_path='config.cfg') + segmentor_train_config = read_config('train', config_path='config.cfg') + segmentor_eval_config = read_config('eval', config_path='config.cfg') + + patch_model = CellposePatchCNN( + segmentor_model_config, segmentor_train_config, segmentor_eval_config, + classifier_model_config, classifier_train_config, classifier_eval_config) + return patch_model + +@pytest.fixture +def data_train(): + images, masks = get_dataset("/home/ubuntu/data-centric-platform/src/server/dcp_server/data") + return images, masks + +@pytest.fixture +def data_eval(): + img = cv2.imread("/home/ubuntu/data-centric-platform/src/server/dcp_server/data/img.jpg", cv2.IMREAD_GRAYSCALE) + msk = np.load("/home/ubuntu/data-centric-platform/src/server/dcp_server/data/mask.npy") + return img, msk + +def test_train_run(data_train, patch_model): + + images, masks = data_train + for _ in tqdm(range(1)): + loss_train = patch_model.train(deepcopy(images), deepcopy(masks)) + assert(loss_train>1e-2) + +def test_eval_run(data_eval, patch_model): + + img, msk = data_eval + # instance segmentation mask (C, W, H) --> semantic multiclass segmentation mask (W, H) + for i in range(msk.shape[0]): + msk[i, ...][msk[i, ...] > 0] = i + 1 + + msk = msk.sum(0) + + final_mask, jaccard_index = patch_model.eval(img, mask_test=torch.tensor(msk)) + final_mask = final_mask.numpy() + + cv2.imwrite("/home/ubuntu/data-centric-platform/src/server/dcp_server/data/final_mask.jpg", 255*label2rgb(final_mask)) + assert(jaccard_index<0.6) + + + + + + + + + + \ No newline at end of file From 8a03102162791472e6ee29ecc62e5b21a2f6d8f9 Mon Sep 17 00:00:00 2001 From: KorenMary Date: Wed, 27 Sep 2023 15:36:57 +0000 Subject: [PATCH 08/47] add integration tests and adjusted models.py --- src/server/dcp_server/config.cfg | 9 +- src/server/dcp_server/models.py | 166 +++++++++++++++------- src/server/dcp_server/test.py | 22 +-- src/server/dcp_server/test_integration.py | 8 +- 4 files changed, 132 insertions(+), 73 deletions(-) diff --git a/src/server/dcp_server/config.cfg b/src/server/dcp_server/config.cfg index 42d33b6..586e3ed 100644 --- a/src/server/dcp_server/config.cfg +++ b/src/server/dcp_server/config.cfg @@ -1,10 +1,10 @@ { - "setup":{ + "setup": { "segmentation": "GeneralSegmentation", "accepted_types": [".jpg", ".jpeg", ".png", ".tiff", ".tif"], "seg_name_string": "_seg" }, - "service":{ + "service": { "model_to_use": "CustomCellposeModel", "save_model_path": "mytrainedmodel", "runner_name": "cellpose_runner", @@ -18,11 +18,10 @@ "data_root": "/home/ubuntu/dcp-data", "patch_size":64, "noise_intensity":5, - "num_classes":3, - + "num_classes":3 }, "train":{ - "n_epochs": 2, + "n_epochs": 1, "channels":[0] }, "eval":{ diff --git a/src/server/dcp_server/models.py b/src/server/dcp_server/models.py index 6ef5fb6..c10921b 100644 --- a/src/server/dcp_server/models.py +++ b/src/server/dcp_server/models.py @@ -1,5 +1,6 @@ -#models - +import numpy as np +from scipy.ndimage import find_objects, center_of_mass +from torchmetrics import JaccardIndex from cellpose import models, utils @@ -9,6 +10,10 @@ from torch.optim import Adam from torch.utils.data import TensorDataset, DataLoader +from copy import deepcopy + +from tqdm import tqdm + #from segment_anything import SamPredictor, sam_model_registry #from segment_anything.automatic_mask_generator import SamAutomaticMaskGenerator @@ -57,7 +62,7 @@ def train(self, imgs, masks): :param masks: masks of the given images (training labels) :type masks: List[np.ndarray] """ - super().train(train_data=imgs, train_labels=masks, **self.train_config) + super().train(train_data=deepcopy(imgs), train_labels=masks, **self.train_config) def masks_to_outlines(self, mask): """ get outlines of masks as a 0-1 array @@ -80,17 +85,17 @@ class CellFullyConvClassifier(nn.Module): num_classes (int): Number of output classes. ''' - def __init__(self, in_channels, num_classes, **kwargs): + def __init__(self, model_config, train_config, eval_config): super().__init__() - self.in_channels = in_channels - self.num_classes = num_classes + self.in_channels = model_config.get("in_channels", 1) + self.num_classes = model_config.get("num_classes", 3) self.train_config = train_config self.eval_config = eval_config self.layer1 = nn.Sequential( - nn.Conv2d(in_channels, 16, 3, 2, 5), + nn.Conv2d(self.in_channels, 16, 3, 2, 5), nn.BatchNorm2d(16), nn.ReLU(), nn.Dropout2d(p=0.2), @@ -110,11 +115,11 @@ def __init__(self, in_channels, num_classes, **kwargs): nn.Dropout2d(p=0.2), ) - self.final_conv = nn.Conv2d(128, num_classes, 1) + self.final_conv = nn.Conv2d(128, self.num_classes, 1) self.pooling = nn.AdaptiveMaxPool2d(1) - def forward(self, x): + def forward(self, x): x = self.layer1(x) x = self.layer2(x) @@ -141,7 +146,7 @@ def train (self, imgs, labels): optimizer_class = self.train_config.get('optimizer', 'Adam') # Convert input images and labels to tensors - imgs = [ torch.from_numpy(img) for img in imgs] + imgs = torch.stack([ torch.from_numpy(img) for img in imgs]) labels = torch.tensor(labels) # Create a training dataset and dataloader @@ -151,20 +156,25 @@ def train (self, imgs, labels): # eval method evaluates a python string and returns an object, e.g. eval('print(1)') = 1 # or eval('[1, 2, 3]') = [1, 2, 3] loss_fn = nn.CrossEntropyLoss() - optimizer = eval(f'{optimizer_class}(lr={lr})') + optimizer = Adam(params=self.parameters(), lr=lr) #eval(f'{optimizer_class}(params={self.parameters()}, lr={lr})') - for _ in epochs: + for _ in tqdm(range(epochs), desc="Running training"): for i, data in enumerate(train_dataloader): imgs, labels = data + imgs, labels = imgs.float() / 255, labels.long() + imgs = imgs.unsqueeze(1) + optimizer.zero_grad() preds = self.forward(imgs) y_hats = torch.argmax(preds, 1) - loss = loss_fn(y_hats, labels) + loss = loss_fn(preds, labels) loss.backward() optimizer.step() + + return loss def eval(self, imgs): ## TODO should call forward once, model is in eval mode, and return predicted masks @@ -194,20 +204,29 @@ class CellposePatchCNN(): Cellpose & patches of cells and then cnn to classify each patch """ - def __init__(self, model_config, train_config, eval_config ): + def __init__(self, + segmentor_model_config, segmentor_train_config, segmentor_eval_config, + classifier_model_config, classifier_train_config, classifier_eval_config ): - self.train_config = train_config - self.eval_config = eval_config + self.segmentor_model_config = segmentor_model_config + self.segmentor_train_config = segmentor_train_config + self.segmentor_eval_config = segmentor_eval_config + + self.classifier_model_config = classifier_model_config + self.classifier_train_config = classifier_train_config + self.classifier_eval_config = classifier_eval_config # Initialize the classifier and the cellpose model - self.classifier = CellFullyConvClassifier() - self.segmentor = CustomCellposeModel(model_config, train_config, eval_config) + self.classifier = CellFullyConvClassifier( + classifier_model_config, classifier_train_config, classifier_eval_config) + self.segmentor = CustomCellposeModel( + segmentor_model_config, segmentor_train_config, segmentor_eval_config) def init_from_checkpoints(self, chpt_classifier=None, chpt_segmentor=None): - """ - Initialize the model from pre-trained checkpoints. - """ + """ + Initialize the model from pre-trained checkpoints. + """ self.segmentor = CustomCellposeModel( model_config={"gpu":torch.cuda.is_available(), "pretrained_model":chpt_segmentor} @@ -215,7 +234,7 @@ def init_from_checkpoints(self, chpt_classifier=None, chpt_segmentor=None): self.classifier.load_state_dict(torch.load(chpt_classifier)["model"]) - def train(self, imgs, masks, **train_config): + def train(self, imgs, masks): # masks should have first channel as a cellpose mask and # all other layers corresponds to the classes @@ -224,38 +243,63 @@ def train(self, imgs, masks, **train_config): ## to prepare imgs and masks for training CNN ## TODO call self.classifier.train(imgs, masks) - black_bg = train_config.get("black_bg", False) - include_mask = train_config.get("include_mask", False) - include_cellpose_mask = train_config.get("include_cellpose_mask", True) + num_classes = self.classifier_model_config.get("num_classes", 3) - self.segmentor.train(imgs, masks) - patches, labels = self.create_patch_dataset(self, imgs, masks, black_bg:bool, include_mask:bool, **kwargs) - self.classifier.train(patches, labels) + black_bg = self.classifier_model_config.get("black_bg", False) + include_mask = self.classifier_model_config.get("include_mask", False) + include_cellpose_mask = self.classifier_model_config.get("include_cellpose_mask", True) + masks_1channel = [mask.sum(0) for mask in masks] + masks_classifier = [mask if mask.shape[-1] == num_classes else + mask.transpose(1, 2, 0) for mask in masks] - def eval(self, img, **eval_config): + self.segmentor.train(imgs, masks_1channel) + patches, labels = self.create_patch_dataset( + imgs, masks_classifier, black_bg, include_mask, **self.classifier_train_config + ) + + train_loss = self.classifier.train(patches, labels) + return train_loss + + + def eval(self, img, instance_mask, **eval_config): ## TODO implement the eval pipeline, i.e. first call self.segmentor.eval, then split again into patches ## using resulting seg and then call self.classifier.eval. The final mask which is returned should have ## first channel the output of cellpose and the rest are the class channels - mask = self.segmentor.eval(img) - result = self.get_prediction(img, mask) + - return result + final_mask = self.get_prediction(img, instance_mask) + # instance segmentation mask (C, W, H) --> semantic multiclass segmentation mask (W, H) + for i in range(instance_mask.shape[0]): + instance_mask[i, ...][instance_mask[i, ...] > 0] = i + 1 - def get_prediction(self, input_image, cellpose_mask, model_classifier): - """ + seg_mask = instance_mask.sum(0) + + final_mask = torch.tensor(final_mask).to(int) + seg_mask = torch.tensor(seg_mask) + + + jaccard_score = JaccardIndex(task="multiclass", num_classes=4, ignore_index=0) + + jaccard_index = jaccard_score(final_mask.sum(0), seg_mask) + + + return final_mask, jaccard_index + + def get_prediction(self, input_image, cellpose_mask): + """ Performs object segmentation and classification on an input image using the Cellpose model and a classifier model. Args: image_path (str): The file path of the input image. - model (CellposeModel): The Cellpose model used for object segmentation. - model_classifier: The classifier model used for object classification. + model (CellposeModel): The Cellpose model used for object segmentation, instance segmenation mask. + Returns: tuple: A tuple containing the cellpose_mask and final_mask, representing the segmentation masks obtained from the Cellpose model and the combined segmentation and classification mask, respectively. - """ + """ # Obtain segmentation mask using Cellpose model @@ -263,7 +307,7 @@ def get_prediction(self, input_image, cellpose_mask, model_classifier): locs = find_objects(cellpose_mask) # Get patches and labels based on object centroids - patches, labels = get_centered_patches(input_image, cellpose_mask, int(1.5 * input_image.shape[0] // 5), noise_intensity=5) + patches, labels = self.get_centered_patches(input_image, cellpose_mask, int(1.5 * input_image.shape[0] // 5), noise_intensity=5) labels = torch.tensor(labels) labels_fit = [] @@ -275,8 +319,9 @@ def get_prediction(self, input_image, cellpose_mask, model_classifier): loc = locs[i] # Prepare image patch for classification - img = torch.tensor(patch.astype(np.float32)).unsqueeze(0).repeat(3, 1, 1).unsqueeze(0) / 255 - + img = torch.tensor(patch.astype(np.float32)).unsqueeze(0).unsqueeze(0) / 255 + # img = img.mean(dim=1, keepdim=True) + # Perform inference using model_classifier logits = self.classifier(img) @@ -287,14 +332,14 @@ def get_prediction(self, input_image, cellpose_mask, model_classifier): final_mask[loc] = predicted + 1 # Apply mask to final_mask, retaining only regions where cellpose_mask is greater than 0 - final_mask = final_mask * ((cellpose_mask > 0).astype(np.uint8)) + final_mask = final_mask * ((cellpose_mask > 0).long()) return final_mask def find_max_patch_size(self, mask): # Find objects in the mask - objects = ndi.find_objects(mask) + objects = find_objects(mask) # Initialize variables to store the maximum patch size max_patch_size = 0 @@ -336,6 +381,11 @@ def pad_centered_padded_patch(self, x: np.ndarray, c, p, mask: np.ndarray=None, Returns: np.ndarray: The cropped patch with applied padding. """ + if mask.shape[0] < mask.shape[-1]: + if isinstance(mask,torch.Tensor): + mask = mask.permute(1,2,0).numpy() + else: + mask = mask.transpose(1, 2, 0) height, width = p # Size of the patch @@ -351,8 +401,11 @@ def pad_centered_padded_patch(self, x: np.ndarray, c, p, mask: np.ndarray=None, if mask is not None: mask_ = mask.max(-1) if len(mask.shape) >= 3 else mask - central_label = mask_[c[0], c[1]] + # central_label = mask_[c[0], c[1]] + central_label = mask_[c[0]][c[1]] # Zero out values in the patch where the mask is not equal to the central label + + # m = (mask_ != central_label) & (mask_ > 0) m = (mask_ != central_label) & (mask_ > 0) x[m] = 0 @@ -405,10 +458,11 @@ def get_center_of_mass(self, mask: np.ndarray) -> np.ndarray: Returns: np.ndarray: An array of coordinates (row, column, channel) representing the centers of mass for each object. """ + # Compute the centers of mass for each labeled object in the mask centers_of_mass = [ (int(x[0]), int(x[1]), int(x[2])) if mask.ndim >= 3 else (int(x[0]), int(x[1]), -1) - for x in ndi.center_of_mass(mask, mask, np.arange(1, mask.max() + 1)) + for x in center_of_mass(mask, mask, np.arange(1, mask.max() + 1)) ] return centers_of_mass @@ -428,14 +482,19 @@ def get_centered_patches(self, img, mask, p_size: int, noise_intensity=5): ''' patches, labels = [], [] + # is_train = mask.ndim == 3 - centers_of_mass = get_center_of_mass(mask) + centers_of_mass = self.get_center_of_mass(mask) # Crop patches around each center of mass and save them for i, c in enumerate(centers_of_mass): c_x, c_y, label = c - - patch = pad_centered_padded_patch(img.copy(), (c_x, c_y), (p_size, p_size), mask=mask, noise_intensity=noise_intensity) + + # if is_train: + patch = self.pad_centered_padded_patch(img.copy(), (c_x, c_y), (p_size, p_size), mask=mask, noise_intensity=noise_intensity) + # else: + # patch = self.pad_centered_padded_patch(img.copy(), (c_x, c_y), (p_size, p_size), mask=mask, noise_intensity=noise_intensity) + patches.append(patch) labels.append(label) @@ -459,7 +518,7 @@ def create_patch_dataset(self, imgs, masks, black_bg:bool, include_mask:bool, ** #max_patch_size = np.max([self.find_max_patch_size(mask) in masks]) max_patch_size = kwargs.get("data", {}).get("patch_size", 64) - num_of_channels = kwargs.get("data", {}).get("num_classes", 3) + num_of_channels = kwargs.get("num_classes", 3) patches, labels = [], [] is_train = masks[0].ndim == 3 @@ -470,19 +529,18 @@ def create_patch_dataset(self, imgs, masks, black_bg:bool, include_mask:bool, ** if is_train: # mask has dimension WxHxNum_of_channels for channel in range(num_of_channels): - - patch, label = get_centered_patches(img, msk, int(1.5 * img.shape[0] // 5), + patches, labels = self.get_centered_patches(img, msk, int(1.5 * img.shape[0] // 5), noise_intensity=noise_intensity) - patches.append(patch) - labels.append(label) + patches.extend(patches) + labels.extend(labels) # test dataset # mask has dimention WxH else: - patch, _ = get_centered_patches(img, msk, int(1.5 * img.shape[0] // 5), + patches, _ = self.get_centered_patches(img, msk, int(1.5 * img.shape[0] // 5), noise_intensity=noise_intensity) - patches.append(patch) + patches.extend(patches) if is_train: return patches, labels diff --git a/src/server/dcp_server/test.py b/src/server/dcp_server/test.py index d77a7d6..f675149 100644 --- a/src/server/dcp_server/test.py +++ b/src/server/dcp_server/test.py @@ -47,9 +47,9 @@ def get_dataset(dataset_path): msk = np.load("/home/ubuntu/data-centric-platform/src/server/dcp_server/data/mask.npy") classifier_model_config, classifier_train_config, classifier_eval_config = {}, {}, {} - segmentor_model_config = read_config('model', config_path = 'config.cfg') - segmentor_train_config = read_config('train', config_path = 'config.cfg') - segmentor_eval_config = read_config('eval', config_path = 'config.cfg') + segmentor_model_config = read_config("model", config_path = "config.cfg") + segmentor_train_config = read_config("train", config_path = "config.cfg") + segmentor_eval_config = read_config("eval", config_path = "config.cfg") patch_model = CellposePatchCNN( segmentor_model_config, segmentor_train_config, segmentor_eval_config, @@ -57,17 +57,19 @@ def get_dataset(dataset_path): images, masks = get_dataset("/home/ubuntu/data-centric-platform/src/server/dcp_server/data") - for i in tqdm(range(1)): - loss_train = patch_model.train(deepcopy(images), deepcopy(masks)) - assert(loss_train>1e-2) + # for i in tqdm(range(1)): + # loss_train = patch_model.train(deepcopy(images), deepcopy(masks)) + # assert(loss_train>1e-2) # instance segmentation mask (C, W, H) --> semantic multiclass segmentation mask (W, H) - for i in range(msk.shape[0]): - msk[i, ...][msk[i, ...] > 0] = i + 1 + # for i in range(msk.shape[0]): + # msk[i, ...][msk[i, ...] > 0] = i + 1 - msk = msk.sum(0) + # msk = msk.sum(0) - final_mask, jaccard_index = patch_model.eval(img, mask_test=torch.tensor(msk)) + # img = img.mean(axis=1, keepdims=True) + + final_mask, jaccard_index = patch_model.eval(img, instance_mask=msk) final_mask = final_mask.numpy() cv2.imwrite("/home/ubuntu/data-centric-platform/src/server/dcp_server/data/final_mask.jpg", 255*label2rgb(final_mask)) diff --git a/src/server/dcp_server/test_integration.py b/src/server/dcp_server/test_integration.py index 0ec002d..2b632c4 100644 --- a/src/server/dcp_server/test_integration.py +++ b/src/server/dcp_server/test_integration.py @@ -78,12 +78,12 @@ def test_eval_run(data_eval, patch_model): img, msk = data_eval # instance segmentation mask (C, W, H) --> semantic multiclass segmentation mask (W, H) - for i in range(msk.shape[0]): - msk[i, ...][msk[i, ...] > 0] = i + 1 + # for i in range(msk.shape[0]): + # msk[i, ...][msk[i, ...] > 0] = i + 1 - msk = msk.sum(0) + # msk = msk.sum(0) - final_mask, jaccard_index = patch_model.eval(img, mask_test=torch.tensor(msk)) + final_mask, jaccard_index = patch_model.eval(img, instance_mask=torch.tensor(msk)) final_mask = final_mask.numpy() cv2.imwrite("/home/ubuntu/data-centric-platform/src/server/dcp_server/data/final_mask.jpg", 255*label2rgb(final_mask)) From c0443b656f88e03d604d2880153681157a795870 Mon Sep 17 00:00:00 2001 From: Christina Bukas Date: Wed, 18 Oct 2023 18:51:18 +0200 Subject: [PATCH 09/47] added channel axis to rescale so mask retains channel dim --- src/server/dcp_server/fsimagestorage.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/server/dcp_server/fsimagestorage.py b/src/server/dcp_server/fsimagestorage.py index 222b964..98ad112 100644 --- a/src/server/dcp_server/fsimagestorage.py +++ b/src/server/dcp_server/fsimagestorage.py @@ -109,7 +109,7 @@ def get_image_size_properties(self, img, file_extension): # tif can be grayscale 2D or 2D RGB and RGBA if file_extension in (".jpg", ".jpeg", ".png") or (file_extension in (".tiff", ".tif") and len(orig_size)==2 or (len(orig_size)==3 and (orig_size[-1]==3 or orig_size[-1]==4))): height, width = orig_size[0], orig_size[1] - channel_ax = None + channel_ax = 2 # or 3D tiff grayscale elif file_extension in (".tiff", ".tif") and len(orig_size)==3: print('Warning: 3D image stack found. We are assuming your first dimension is your stack dimension. Please cross check this.') @@ -121,7 +121,7 @@ def get_image_size_properties(self, img, file_extension): return height, width, channel_ax - def rescale_image(self, img, height, width, channel_ax): + def rescale_image(self, img, height, width, channel_ax, order): """rescale image :param img: image @@ -137,7 +137,7 @@ def rescale_image(self, img, height, width, channel_ax): """ max_dim = max(height, width) rescale_factor = max_dim/512 - return rescale(img, 1/rescale_factor, channel_axis=channel_ax) + return rescale(img, 1/rescale_factor, order=order, channel_axis=channel_ax) def resize_image(self, img, height, width, order): """resize image @@ -166,6 +166,6 @@ def prepare_images_and_masks_for_training(self, train_img_mask_pairs): imgs=[] masks=[] for img_file, mask_file in train_img_mask_pairs: - imgs.append(rgb2gray(imread(img_file))) + imgs.append(imread(img_file)) masks.append(imread(mask_file)) return imgs, masks \ No newline at end of file From 76fde8beb018fdab01bd4637aac065d449824e91 Mon Sep 17 00:00:00 2001 From: Christina Bukas Date: Wed, 18 Oct 2023 18:51:35 +0200 Subject: [PATCH 10/47] adapted config for CellposePatchCNN model --- src/server/dcp_server/config.cfg | 53 ++++++++++++++++++++++++-------- 1 file changed, 41 insertions(+), 12 deletions(-) diff --git a/src/server/dcp_server/config.cfg b/src/server/dcp_server/config.cfg index 586e3ed..8af1781 100644 --- a/src/server/dcp_server/config.cfg +++ b/src/server/dcp_server/config.cfg @@ -1,29 +1,58 @@ { "setup": { - "segmentation": "GeneralSegmentation", - "accepted_types": [".jpg", ".jpeg", ".png", ".tiff", ".tif"], - "seg_name_string": "_seg" + "segmentation": "GeneralSegmentation", + "accepted_types": [".jpg", ".jpeg", ".png", ".tiff", ".tif"], + "seg_name_string": "_seg" }, + "service": { - "model_to_use": "CustomCellposeModel", + "model_to_use": "CellposePatchCNN", "save_model_path": "mytrainedmodel", "runner_name": "cellpose_runner", "service_name": "data-centric-platform", "port": 7010 }, - "model": { - "model_type":"cyto" + + "model": { + "segmentor": { + "model_type": "cyto" + }, + "classifier":{ + "in_channels": 3, + "num_classes": 3, + "black_bg": "False", + "include_mask": "False" + } }, + "data": { - "data_root": "/home/ubuntu/dcp-data", - "patch_size":64, - "noise_intensity":5, - "num_classes":3 + "data_root": "/home/ubuntu/dcp-data" }, + "train":{ - "n_epochs": 1, - "channels":[0] + "segmentor":{ + "n_epochs": 1, + "channels":[0] + }, + "classifier":{ + "train_data":{ + "patch_size": 64, + "noise_intensity": 5, + "num_classes": 3 + }, + "n_epochs": 1, + "lr": 0.001, + "batch_size": 1, + "optimizer": "Adam" + } }, + "eval":{ + "segmentor": "None", + "classifier": { + "data":{ + "patch_size": 64 + } + } } } \ No newline at end of file From 2755ad5e0bd779da5d1eab03d840a3d86f4cc130 Mon Sep 17 00:00:00 2001 From: Christina Bukas Date: Wed, 18 Oct 2023 18:52:18 +0200 Subject: [PATCH 11/47] adapted to allow for 2 channel mask (class+intance) to be resized correctly --- src/server/dcp_server/segmentationclasses.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/src/server/dcp_server/segmentationclasses.py b/src/server/dcp_server/segmentationclasses.py index e5aeb28..ce5a713 100644 --- a/src/server/dcp_server/segmentationclasses.py +++ b/src/server/dcp_server/segmentationclasses.py @@ -35,17 +35,14 @@ async def segment_image(self, input_path, list_of_images): img = self.imagestorage.load_image(img_filepath) # Get size properties height, width, channel_ax = self.imagestorage.get_image_size_properties(img, utils.get_file_extension(img_filepath)) - img = self.imagestorage.rescale_image(img, height, width, channel_ax) - + img = self.imagestorage.rescale_image(img, height, width, channel_ax, order=None) # Add channel ax into the model's evaluation parameters dictionary self.model.eval_config['z_axis'] = channel_ax - # Evaluate the model mask = await self.runner.evaluate.async_run(img = img, **self.model.eval_config) - # Resize the mask - mask = self.imagestorage.resize_image(mask, height, width, order=0) - + channel_ax = self.model.eval_config['z_axis'] + mask = self.imagestorage.rescale_image(mask, height, width, channel_ax, order=0) # Save segmentation seg_name = utils.get_path_stem(img_filepath) + setup_config['seg_name_string'] + '.tiff' self.imagestorage.save_image(os.path.join(input_path, seg_name), mask) @@ -65,7 +62,6 @@ async def train(self, input_path): return "No images and segs found" imgs, masks = self.imagestorage.prepare_images_and_masks_for_training(train_img_mask_pairs) - model_save_path = await self.runner.train.async_run(imgs, masks) return model_save_path From 1325572e18d5b8ff8f601adc98b9573ac2b7f41e Mon Sep 17 00:00:00 2001 From: Christina Bukas Date: Wed, 18 Oct 2023 18:53:23 +0200 Subject: [PATCH 12/47] major changes in CellClassifierFCNN and CellposePatchCNN to make compatible with DCP --- src/server/dcp_server/models.py | 371 ++++++++++++++------------------ 1 file changed, 167 insertions(+), 204 deletions(-) diff --git a/src/server/dcp_server/models.py b/src/server/dcp_server/models.py index c10921b..f1c1400 100644 --- a/src/server/dcp_server/models.py +++ b/src/server/dcp_server/models.py @@ -1,18 +1,12 @@ -import numpy as np -from scipy.ndimage import find_objects, center_of_mass -from torchmetrics import JaccardIndex - from cellpose import models, utils - import torch from torch import nn - from torch.optim import Adam from torch.utils.data import TensorDataset, DataLoader - from copy import deepcopy - from tqdm import tqdm +import numpy as np +from scipy.ndimage import find_objects, center_of_mass #from segment_anything import SamPredictor, sam_model_registry #from segment_anything.automatic_mask_generator import SamAutomaticMaskGenerator @@ -75,7 +69,7 @@ def masks_to_outlines(self, mask): """ return utils.masks_to_outlines(mask) #[True, False] outputs -class CellFullyConvClassifier(nn.Module): +class CellClassifierFCNN(nn.Module): ''' Fully convolutional classifier for cell images. @@ -88,12 +82,12 @@ class CellFullyConvClassifier(nn.Module): def __init__(self, model_config, train_config, eval_config): super().__init__() - self.in_channels = model_config.get("in_channels", 1) - self.num_classes = model_config.get("num_classes", 3) + self.in_channels = model_config["in_channels"] + self.num_classes = model_config["num_classes"] + 1 self.train_config = train_config self.eval_config = eval_config - + self.layer1 = nn.Sequential( nn.Conv2d(self.in_channels, 16, 3, 2, 5), nn.BatchNorm2d(16), @@ -114,9 +108,7 @@ def __init__(self, model_config, train_config, eval_config): nn.ReLU(), nn.Dropout2d(p=0.2), ) - self.final_conv = nn.Conv2d(128, self.num_classes, 1) - self.pooling = nn.AdaptiveMaxPool2d(1) def forward(self, x): @@ -128,75 +120,65 @@ def forward(self, x): x = self.final_conv(x) x = self.pooling(x) x = x.view(x.size(0), -1) - return x - def train (self, imgs, labels): - ## TODO should call forward repeatedly and perform the entire train loop - + def train (self, imgs, labels): """ input: 1) imgs - List[np.ndarray[np.uint8]] with shape (3, dx, dy) - 2) y - List[int] + 2) labels - List[int] """ - lr = self.train_config.get('lr', 0.001) - epochs = self.train_config.get('epochs', 1) - batch_size = self.train_config.get('batch_size', 1) - optimizer_class = self.train_config.get('optimizer', 'Adam') + lr = self.train_config['lr'] + epochs = self.train_config['n_epochs'] + batch_size = self.train_config['batch_size'] + # optimizer_class = self.train_config['optimizer'] # Convert input images and labels to tensors - imgs = torch.stack([ torch.from_numpy(img) for img in imgs]) - labels = torch.tensor(labels) + imgs = [(img-np.min(img))/(np.max(img)-np.min(img)) for img in imgs] + # convert to tensor + imgs = torch.stack([torch.from_numpy(img.astype(np.float32)) for img in imgs]) + imgs = torch.permute(imgs, (0, 3, 1, 2)) + # Your classification label mask + labels = torch.LongTensor([label for label in labels]) + # Convert to one-hot encoding + #labels = torch.nn.functional.one_hot(labels, self.num_classes) # Create a training dataset and dataloader train_dataset = TensorDataset(imgs, labels) train_dataloader = DataLoader(train_dataset, batch_size=batch_size) - # eval method evaluates a python string and returns an object, e.g. eval('print(1)') = 1 - # or eval('[1, 2, 3]') = [1, 2, 3] loss_fn = nn.CrossEntropyLoss() optimizer = Adam(params=self.parameters(), lr=lr) #eval(f'{optimizer_class}(params={self.parameters()}, lr={lr})') - - for _ in tqdm(range(epochs), desc="Running training"): - for i, data in enumerate(train_dataloader): - + # TODO check if we should replace self.parameters with super.parameters() + + for _ in tqdm(range(epochs), desc="Running CellClassifierFCNN training"): + self.epoch_loss = 0 + for data in train_dataloader: imgs, labels = data - imgs, labels = imgs.float() / 255, labels.long() - imgs = imgs.unsqueeze(1) - optimizer.zero_grad() preds = self.forward(imgs) - - y_hats = torch.argmax(preds, 1) loss = loss_fn(preds, labels) loss.backward() - optimizer.step() + self.epoch_loss += loss.item() - return loss + self.epoch_loss /= len(train_dataloader) - def eval(self, imgs): - ## TODO should call forward once, model is in eval mode, and return predicted masks + def eval(self, img): """ - Evaluate the model on the provided images and return predicted labels. + Evaluate the model on the provided image and return the predicted label. Input: - imgs: List[np.ndarray[np.uint8]] with shape (3, dx, dy) - Output: - labels: List of predicted labels. - """ - labels = [] - for img in imgs: - - img = torch.from_numpy(img).unsqueeze(0) - labels.append(self.forward(img)) - - return labels -# class CustomSAMModel(): -# # https://github.com/facebookresearch/segment-anything/blob/main/notebooks/automatic_mask_generator_example.ipynb -# def __init__(self): -# pass - + img: np.ndarray[np.uint8] + Output: y_hat - The predicted label + """ + # normalise + img = (img-np.min(img))/(np.max(img)-np.min(img)) + # convert to tensor + img = torch.permute(torch.tensor(img.astype(np.float32)), (2, 0, 1)).unsqueeze(0) + preds = self.forward(img) + y_hat = torch.argmax(preds, 1) + return y_hat class CellposePatchCNN(): @@ -204,24 +186,19 @@ class CellposePatchCNN(): Cellpose & patches of cells and then cnn to classify each patch """ - def __init__(self, - segmentor_model_config, segmentor_train_config, segmentor_eval_config, - classifier_model_config, classifier_train_config, classifier_eval_config ): - + def __init__(self, model_config, train_config, eval_config): - self.segmentor_model_config = segmentor_model_config - self.segmentor_train_config = segmentor_train_config - self.segmentor_eval_config = segmentor_eval_config - - self.classifier_model_config = classifier_model_config - self.classifier_train_config = classifier_train_config - self.classifier_eval_config = classifier_eval_config + self.model_config = model_config + self.train_config = train_config + self.eval_config = eval_config - # Initialize the classifier and the cellpose model - self.classifier = CellFullyConvClassifier( - classifier_model_config, classifier_train_config, classifier_eval_config) - self.segmentor = CustomCellposeModel( - segmentor_model_config, segmentor_train_config, segmentor_eval_config) + # Initialize the cellpose model and the classifier + self.segmentor = CustomCellposeModel(self.model_config["segmentor"], + self.train_config["segmentor"], + self.eval_config["segmentor"]) + self.classifier = CellClassifierFCNN(self.model_config["classifier"], + self.train_config["classifier"], + self.eval_config["classifier"]) def init_from_checkpoints(self, chpt_classifier=None, chpt_segmentor=None): """ @@ -235,58 +212,62 @@ def init_from_checkpoints(self, chpt_classifier=None, chpt_segmentor=None): def train(self, imgs, masks): - - # masks should have first channel as a cellpose mask and - # all other layers corresponds to the classes - ## TODO: take care of the images and masks preparation -> this step isn't for now @Mariia - ## TODO call create_patches (adjust create_patch_dataset function) - ## to prepare imgs and masks for training CNN - ## TODO call self.classifier.train(imgs, masks) - - num_classes = self.classifier_model_config.get("num_classes", 3) - - black_bg = self.classifier_model_config.get("black_bg", False) - include_mask = self.classifier_model_config.get("include_mask", False) - include_cellpose_mask = self.classifier_model_config.get("include_cellpose_mask", True) - - masks_1channel = [mask.sum(0) for mask in masks] - masks_classifier = [mask if mask.shape[-1] == num_classes else - mask.transpose(1, 2, 0) for mask in masks] - - self.segmentor.train(imgs, masks_1channel) - patches, labels = self.create_patch_dataset( - imgs, masks_classifier, black_bg, include_mask, **self.classifier_train_config - ) - - train_loss = self.classifier.train(patches, labels) - return train_loss - - - def eval(self, img, instance_mask, **eval_config): + # masks should have first channel as a cellpose mask and all other layers + # correspond to the classes, to prepare imgs and masks for training CNN + + # TODO I commented below lines. I think we should expect masks to have same size as output of eval, i.e. one channel instances, second channel classes + # in this case the shape of masks shall be [2, H, W] or [2, 3, H, W] for 3D + # In this case remove commented lines + # num_classes = self.model_config["classifier"]["num_classes"] + # masks_1channel = [mask.sum(0) for mask in masks] + # masks_classifier = [mask if mask.shape[-1] == num_classes else + # mask.transpose(1, 2, 0) for mask in masks] - ## TODO implement the eval pipeline, i.e. first call self.segmentor.eval, then split again into patches - ## using resulting seg and then call self.classifier.eval. The final mask which is returned should have - ## first channel the output of cellpose and the rest are the class channels - - - final_mask = self.get_prediction(img, instance_mask) - # instance segmentation mask (C, W, H) --> semantic multiclass segmentation mask (W, H) - for i in range(instance_mask.shape[0]): - instance_mask[i, ...][instance_mask[i, ...] > 0] = i + 1 - - seg_mask = instance_mask.sum(0) + # train cellpose + masks = np.array(masks) + masks_instances = list(masks[:,0, ...]) + self.segmentor.train(imgs, masks_instances) + # create patch dataset to train classifier + masks_classes = list(masks[:,1, ...]) + patches, labels = self.create_patch_dataset(imgs, masks_classes, masks_instances) + # train classifier + self.classifier.train(patches, labels) + #return # TODO - define if we need to return something - final_mask = torch.tensor(final_mask).to(int) - seg_mask = torch.tensor(seg_mask) - + def eval(self, img, **eval_config): - jaccard_score = JaccardIndex(task="multiclass", num_classes=4, ignore_index=0) + # TBD we assume image is either 2D [H, W] or 3D [H, W, C] + + # The final mask which is returned should have + # first channel the output of cellpose and the rest are the class channels + # TODO test case produces img with size HxW for eval and HxWx3 for train + with torch.no_grad(): + # get instance mask from segmentor + instance_mask = self.segmentor.eval(img) + # find coordinates of detected objects + locs = find_objects(instance_mask) + class_mask = np.zeros(instance_mask.shape) + # get patches centered around detected objects + patches, _ = self.get_centered_patches(img, + instance_mask, + self.eval_config["classifier"]["data"]["patch_size"], + noise_intensity=5) + # loop over patches and create classification mask + for idx, patch in enumerate(patches): + patch_class = self.classifier.eval(patch) + loc = locs[idx] + # Assign predicted class to corresponding location in final_mask + class_mask[loc] = patch_class.item() + 1 + # Apply mask to final_mask, retaining only regions where cellpose_mask is greater than 0 + class_mask = class_mask * (instance_mask > 0)#.long()) + final_mask = np.stack((instance_mask, class_mask)).astype(np.uint16) - jaccard_index = jaccard_score(final_mask.sum(0), seg_mask) - + self.eval_config['z_axis'] = 0 - return final_mask, jaccard_index + return final_mask + # REMOVE? replaced by code in eval + ''' def get_prediction(self, input_image, cellpose_mask): """ Performs object segmentation and classification on an input image using the Cellpose model and a classifier model. @@ -295,7 +276,6 @@ def get_prediction(self, input_image, cellpose_mask): image_path (str): The file path of the input image. model (CellposeModel): The Cellpose model used for object segmentation, instance segmenation mask. - Returns: tuple: A tuple containing the cellpose_mask and final_mask, representing the segmentation masks obtained from the Cellpose model and the combined segmentation and classification mask, respectively. @@ -335,7 +315,7 @@ def get_prediction(self, input_image, cellpose_mask): final_mask = final_mask * ((cellpose_mask > 0).long()) return final_mask - + ''' def find_max_patch_size(self, mask): # Find objects in the mask @@ -369,7 +349,12 @@ def find_max_patch_size(self, mask): return max_patch_size_edge - def pad_centered_padded_patch(self, x: np.ndarray, c, p, mask: np.ndarray=None, noise_intensity=None) -> np.ndarray: + def crop_centered_padded_patch(self, + x: np.ndarray, + c, + p, + mask: np.ndarray=None, + noise_intensity=None) -> np.ndarray: """ Crop a patch from an array `x` centered at coordinates `c` with size `p`, and apply padding if necessary. @@ -380,11 +365,9 @@ def pad_centered_padded_patch(self, x: np.ndarray, c, p, mask: np.ndarray=None, Returns: np.ndarray: The cropped patch with applied padding. - """ + """ + if mask.shape[0] < mask.shape[-1]: - if isinstance(mask,torch.Tensor): - mask = mask.permute(1,2,0).numpy() - else: mask = mask.transpose(1, 2, 0) height, width = p # Size of the patch @@ -397,53 +380,45 @@ def pad_centered_padded_patch(self, x: np.ndarray, c, p, mask: np.ndarray=None, right = left + width # Crop the patch from the input array - if mask is not None: - mask_ = mask.max(-1) if len(mask.shape) >= 3 else mask # central_label = mask_[c[0], c[1]] central_label = mask_[c[0]][c[1]] # Zero out values in the patch where the mask is not equal to the central label - # m = (mask_ != central_label) & (mask_ > 0) m = (mask_ != central_label) & (mask_ > 0) x[m] = 0 - if noise_intensity is not None: x[m] = np.random.normal(scale=noise_intensity, size=x[m].shape) - patch = x[max(top, 0):min(bottom, x.shape[0]), max(left, 0):min(right, x.shape[1])] - + patch = x[max(top, 0):min(bottom, x.shape[0]), max(left, 0):min(right, x.shape[1]), :] + if len(c) == 3: patch = patch[...,c[2]] - - # Calculate the required padding amounts + # Calculate the required padding amounts size_x, size_y = x.shape[1], x.shape[0] # Apply padding if necessary - if left < 0: + if left < 0: patch = np.hstack(( - np.random.normal(scale=noise_intensity, size=(patch.shape[0], abs(left))).astype(np.uint8), patch - )) - + np.random.normal(scale=noise_intensity, size=(patch.shape[0], abs(left), patch.shape[2])).astype(np.uint8), + patch)) # Apply padding on the right side if necessary - if right > size_x: + if right > size_x: patch = np.hstack(( - patch, np.random.normal(scale=noise_intensity, size=(patch.shape[0], right - size_x)).astype(np.uint8) - )) - + patch, + np.random.normal(scale=noise_intensity, size=(patch.shape[0], (right - size_x), patch.shape[2])).astype(np.uint8))) # Apply padding on the top side if necessary - if top < 0: + if top < 0: patch = np.vstack(( - np.random.normal(scale=noise_intensity, size=(abs(top), patch.shape[1])).astype(np.uint8), patch - )) - + np.random.normal(scale=noise_intensity, size=(abs(top), patch.shape[1], patch.shape[2])).astype(np.uint8), + patch)) # Apply padding on the bottom side if necessary - if bottom > size_y: + if bottom > size_y: patch = np.vstack(( - patch, np.random.normal(scale=noise_intensity, size=(bottom - size_y, patch.shape[1])).astype(np.uint8) - )) + patch, + np.random.normal(scale=noise_intensity, size=(bottom - size_y, patch.shape[1], patch.shape[2])).astype(np.uint8))) return patch @@ -460,15 +435,15 @@ def get_center_of_mass(self, mask: np.ndarray) -> np.ndarray: """ # Compute the centers of mass for each labeled object in the mask - centers_of_mass = [ - (int(x[0]), int(x[1]), int(x[2])) if mask.ndim >= 3 else (int(x[0]), int(x[1]), -1) - for x in center_of_mass(mask, mask, np.arange(1, mask.max() + 1)) - ] - - return centers_of_mass - - - def get_centered_patches(self, img, mask, p_size: int, noise_intensity=5): + return [(int(x[0]), int(x[1])) + for x in center_of_mass(mask, mask, np.arange(1, mask.max() + 1))] + + def get_centered_patches(self, + img, + mask, + p_size: int, + noise_intensity=5, + mask_class=None): ''' Extracts centered patches from the input image based on the centers of objects identified in the mask. @@ -482,27 +457,27 @@ def get_centered_patches(self, img, mask, p_size: int, noise_intensity=5): ''' patches, labels = [], [] - # is_train = mask.ndim == 3 - + # if image is 2D add an additional dim for channels + if img.ndim<3: img = img[:, :, np.newaxis] + if mask.ndim<3: mask = mask[:, :, np.newaxis] + # compute center of mass of objects centers_of_mass = self.get_center_of_mass(mask) - # Crop patches around each center of mass and save them - for i, c in enumerate(centers_of_mass): - c_x, c_y, label = c - - - # if is_train: - patch = self.pad_centered_padded_patch(img.copy(), (c_x, c_y), (p_size, p_size), mask=mask, noise_intensity=noise_intensity) - # else: - # patch = self.pad_centered_padded_patch(img.copy(), (c_x, c_y), (p_size, p_size), mask=mask, noise_intensity=noise_intensity) - + # Crop patches around each center of mass + for c in centers_of_mass: + c_x, c_y = c + patch = self.crop_centered_padded_patch(img.copy(), + (c_x, c_y), + (p_size, p_size), + mask=mask, + noise_intensity=noise_intensity) patches.append(patch) - labels.append(label) + if mask_class is not None: labels.append(mask_class[c[0]][c[1]]) return patches, labels - def create_patch_dataset(self, imgs, masks, black_bg:bool, include_mask:bool, **kwargs): + def create_patch_dataset(self, imgs, masks_classes, masks_instances): ''' - TODO: Split img and masks into patches of equal size which are centered around the cells. + Splits img and masks into patches of equal size which are centered around the cells. The algorithm should first run through all images to find the max cell size, and use the max cell size to define the patch size. All patches and masks should then be returned in the same format as imgs and masks (same type, i.e. check if tensor or np.array and same @@ -514,36 +489,24 @@ def create_patch_dataset(self, imgs, masks, black_bg:bool, include_mask:bool, ** include_mask (bool): Flag indicating whether to include the mask along with patches. ''' - noise_intensity = kwargs.get("data", {}).get("noise_intensity", 5) - - #max_patch_size = np.max([self.find_max_patch_size(mask) in masks]) - max_patch_size = kwargs.get("data", {}).get("patch_size", 64) - num_of_channels = kwargs.get("num_classes", 3) + noise_intensity = self.train_config["classifier"]["train_data"]["noise_intensity"] + max_patch_size = self.train_config["classifier"]["train_data"]["patch_size"] + num_classes = self.train_config["classifier"]["train_data"]["num_classes"] patches, labels = [], [] - is_train = masks[0].ndim == 3 - - for img, msk in zip(imgs, masks): - - # training dataset - if is_train: - # mask has dimension WxHxNum_of_channels - for channel in range(num_of_channels): - patches, labels = self.get_centered_patches(img, msk, int(1.5 * img.shape[0] // 5), - noise_intensity=noise_intensity) - - patches.extend(patches) - labels.extend(labels) - - # test dataset - # mask has dimention WxH - else: - patches, _ = self.get_centered_patches(img, msk, int(1.5 * img.shape[0] // 5), - noise_intensity=noise_intensity) - patches.extend(patches) - - if is_train: - return patches, labels - - else: - return patches \ No newline at end of file + for img, mask_class, mask_instance in zip(imgs, masks_classes, masks_instances): + # Convert to one-hot encoding + # mask has dimension WxHxNum_of_channels + patch, label = self.get_centered_patches(img, + mask_instance, + self.train_config["classifier"]["train_data"]["patch_size"], + noise_intensity=noise_intensity, + mask_class=mask_class) + patches.extend(patch) + labels.extend(label) + return patches, labels + +# class CustomSAMModel(): +# # https://github.com/facebookresearch/segment-anything/blob/main/notebooks/automatic_mask_generator_example.ipynb +# def __init__(self): +# pass From 42c3e77f66cec9dc970b21b0fc3cb0f442d10e87 Mon Sep 17 00:00:00 2001 From: Christina Bukas Date: Thu, 19 Oct 2023 16:21:27 +0200 Subject: [PATCH 13/47] adapted integration tests for test and train of patch model --- src/server/dcp_server/models.py | 13 +++--- src/server/dcp_server/test_integration.py | 52 +++++++++++------------ 2 files changed, 31 insertions(+), 34 deletions(-) diff --git a/src/server/dcp_server/models.py b/src/server/dcp_server/models.py index f1c1400..a177af4 100644 --- a/src/server/dcp_server/models.py +++ b/src/server/dcp_server/models.py @@ -57,6 +57,7 @@ def train(self, imgs, masks): :type masks: List[np.ndarray] """ super().train(train_data=deepcopy(imgs), train_labels=masks, **self.train_config) + self.loss = self.loss_fn(masks, super().eval(imgs, self.eval_config)[0]) def masks_to_outlines(self, mask): """ get outlines of masks as a 0-1 array @@ -153,17 +154,17 @@ def train (self, imgs, labels): # TODO check if we should replace self.parameters with super.parameters() for _ in tqdm(range(epochs), desc="Running CellClassifierFCNN training"): - self.epoch_loss = 0 + self.loss = 0 for data in train_dataloader: imgs, labels = data optimizer.zero_grad() preds = self.forward(imgs) - loss = loss_fn(preds, labels) - loss.backward() + l = loss_fn(preds, labels) + l.backward() optimizer.step() - self.epoch_loss += loss.item() + self.loss += l.item() - self.epoch_loss /= len(train_dataloader) + self.loss /= len(train_dataloader) def eval(self, img): """ @@ -236,7 +237,7 @@ def train(self, imgs, masks): def eval(self, img, **eval_config): - # TBD we assume image is either 2D [H, W] or 3D [H, W, C] + # TBD we assume image is either 2D [H, W] or 3D [H, W, C] (see fsimage storage) # The final mask which is returned should have # first channel the output of cellpose and the rest are the class channels diff --git a/src/server/dcp_server/test_integration.py b/src/server/dcp_server/test_integration.py index 2b632c4..5bca348 100644 --- a/src/server/dcp_server/test_integration.py +++ b/src/server/dcp_server/test_integration.py @@ -2,6 +2,8 @@ import cv2 import sys import torch +from torchmetrics import JaccardIndex + import numpy as np from tqdm import tqdm @@ -41,19 +43,12 @@ def get_dataset(dataset_path): @pytest.fixture def patch_model(): - classifier_model_config = { - "chpt_path": - "/home/ubuntu/data-centric-platform/src/server/dcp_server/data/classifier_checkpoint.pth" - } - classifier_train_config, classifier_eval_config = {}, {} - - segmentor_model_config = read_config('model', config_path='config.cfg') - segmentor_train_config = read_config('train', config_path='config.cfg') - segmentor_eval_config = read_config('eval', config_path='config.cfg') - - patch_model = CellposePatchCNN( - segmentor_model_config, segmentor_train_config, segmentor_eval_config, - classifier_model_config, classifier_train_config, classifier_eval_config) + + model_config = read_config('model', config_path='config.cfg') + train_config = read_config('train', config_path='config.cfg') + eval_config = read_config('eval', config_path='config.cfg') + + patch_model = CellposePatchCNN(model_config, train_config, eval_config) return patch_model @pytest.fixture @@ -70,24 +65,25 @@ def data_eval(): def test_train_run(data_train, patch_model): images, masks = data_train - for _ in tqdm(range(1)): - loss_train = patch_model.train(deepcopy(images), deepcopy(masks)) - assert(loss_train>1e-2) + patch_model.train(images, masks) + assert(patch_model.segmentor.loss>1e-2) #TODO figure out appropriate value + assert(patch_model.classifier.loss>1e-2) def test_eval_run(data_eval, patch_model): - img, msk = data_eval - # instance segmentation mask (C, W, H) --> semantic multiclass segmentation mask (W, H) - # for i in range(msk.shape[0]): - # msk[i, ...][msk[i, ...] > 0] = i + 1 - - # msk = msk.sum(0) - - final_mask, jaccard_index = patch_model.eval(img, instance_mask=torch.tensor(msk)) - final_mask = final_mask.numpy() - - cv2.imwrite("/home/ubuntu/data-centric-platform/src/server/dcp_server/data/final_mask.jpg", 255*label2rgb(final_mask)) - assert(jaccard_index<0.6) + imgs, masks = data_eval + jaccard_index_instances = 0 + jaccard_index_classes = 0 + for img, mask in zip(imgs, masks): + pred_mask = patch_model.eval(img) + jaccard_index_instances += JaccardIndex(pred_mask[0], mask[0]) + jaccard_index_classes += JaccardIndex(pred_mask[1], mask[1]) + + jaccard_index_instances /= len(imgs) + assert(jaccard_index_instances<0.6) + jaccard_index_classes /= len(imgs) + assert(jaccard_index_instances<0.6) + From df065a417c677329e89bb3ef077b2fcc2cec477b Mon Sep 17 00:00:00 2001 From: Christina Bukas Date: Thu, 19 Oct 2023 16:34:32 +0200 Subject: [PATCH 14/47] added tests to seperate folder --- src/server/{dcp_server => test}/test.py | 0 src/server/{dcp_server => test}/test_integration.py | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename src/server/{dcp_server => test}/test.py (100%) rename src/server/{dcp_server => test}/test_integration.py (100%) diff --git a/src/server/dcp_server/test.py b/src/server/test/test.py similarity index 100% rename from src/server/dcp_server/test.py rename to src/server/test/test.py diff --git a/src/server/dcp_server/test_integration.py b/src/server/test/test_integration.py similarity index 100% rename from src/server/dcp_server/test_integration.py rename to src/server/test/test_integration.py From d972364d8d41993dd233c5f928437c25cf423fd5 Mon Sep 17 00:00:00 2001 From: Mariia Date: Fri, 27 Oct 2023 21:47:38 +0200 Subject: [PATCH 15/47] Added and tested synthetic data generation on-the-fly for integration tests. --- src/server/test/shapes/circle.png | Bin 0 -> 21504 bytes src/server/test/shapes/square.png | Bin 0 -> 2048 bytes src/server/test/shapes/triangle.png | Bin 0 -> 5120 bytes src/server/test/synthetic_dataset.py | 244 +++++++++++++++++++++++++++ src/server/test/test.py | 82 --------- src/server/test/test_integration.py | 71 ++++---- 6 files changed, 279 insertions(+), 118 deletions(-) create mode 100644 src/server/test/shapes/circle.png create mode 100644 src/server/test/shapes/square.png create mode 100644 src/server/test/shapes/triangle.png create mode 100644 src/server/test/synthetic_dataset.py delete mode 100644 src/server/test/test.py diff --git a/src/server/test/shapes/circle.png b/src/server/test/shapes/circle.png new file mode 100644 index 0000000000000000000000000000000000000000..3d2fd3ee3e770b6217429c289dce92add1b17774 GIT binary patch literal 21504 zcmZ5|2{@Ep^!Tg3QVCyQMJ2RQ$X1jk%v4{?PzaT*X&YO}*q1SVg|B2uma@&1y;XME z%6wA9*vsAsV>kA}(EnWd{h$B;_dGn$oO91P_ug~QcJKS%_rAXFY5om6HeeXWPdM|p z0fzlSM*r*90wyJAVmpTEw-Ek5X5`U2+~N7e+%=vxE+bR8p69QFieWs*L=Ns+_dLMi zWLWvJ8&B5-(=+&q7vZnR&Rwa0k6|iPK&s?|^EkOwX{iht>?DG2Q z&Le_vH!9I%>35hkCl6KixGgmsSpmN`_uYGY&y#u7%F1|)1K+F`!?A2^-c#n4%?Crf zxv`~`F3VrJ-D8Cr9M+RdTvE04k{Uy9sX$L=$$h555;P)thk$TkbE} z?%cPS^HwHLHn|=66ug%%&HTxnv?3e-DHOOs9~0SmWw39ye46lWAfWrbZ%KYg>XvWQ zv~G2V)SCsVvE$2wZgeVtl4axJ`dy`Hf1Hrk=02fcQC0Cwt83?^HOz{{_=3LeFV;;J ze%n@3K9ryS=nR@@68Cg`omCJ!z{!}SpP$A_)W!)2-n;NALiV+cMVGkn+u4)lLw?5N zdBt+VWb2&9Dccejbf?D}qE zwvVFOUvC6UWd*uUt*cDQJj0!96Q&&5_M!q();xS87DeGv!|K z199htrfwqpm77yw6nEn@OCf63;~+-I7Wd_O8)ThI2rlbnHYzH#k0}a>Mr#F?Siq1Cam&LSTZhzb5@OJ(wvSXmaJ>1esBBak5n3*RCwVL3j)db*yM#8 zy$1{W@5O^}yGJ2@W~io8**LR{zF8n1|0QXmUL|OuaijY!)VaxQt&-f%9ZS`Y9jj!d zI`T^N+bdEEUg|*8i-p?9P6ap&#q6KhMj5Yp&nU;(k&@tU24fkE|c#n_PV0 zO1oLAzVyC(Deu1HYSdJ)QM1E}#b$N~KA+F?!hRj8CC$<;BuRp#lLuJ~BPzWWSpC}}=ev!U!Z^Cs_Jib;M!{+MI8 zsXkCwl7^q|?A*L$Y25qgANo2kAB^3Z@>OifOGFbLc}KKYRs71Yr_RRM8hS!qf&<=H zUpbzGZqn87GIW_r%=5e=*bcV!c=5Z(j^`0R?i%A0laukCXDwxuv6&Cmez@%d^fNg8 zyT3AP^*FhG=S9MRnDz!S+tjyBD%68vJ4Om5Q9R8^qVmdKvRlQVQ}ouR{1Gexjr&Ii=G)R#C^^x_jQ)%E8 zEG>R&(JzSy;TRu@ZpMSpgfUyk%j0HOFc@;>6}Fm;>}(hxHuX3`Fysglc>zRT;MU&! z6(EQyK1nVYyfb6>0}NkVi}OlYV0|h3Pd;p8T?aFMz_ZK{f|fzSTIwi!Cgb~w1DN9R z0ne+YQ$9>-h$`umi3|tkX)yyH%03;2aG_@_^f>qxEQk{oXn@hm>nA@6B$OaUD{+dH z%|tJH|G5Thz3w{BeHWQ}5cJc$QOl4{ENc+c;-Lt5zV{G&Iyi6^1gqU;JaIbCDGoe` z(Io0gUvPor3C$oUK!uak)i@pE6dI_5(UPJ9+!wa8q0W4dxag1j7k5$%#8&Z8b}VtJ0|{#-2?lTO zGB{4}sXtd^jmBOA>U%YXJ=G!gss(E;1bS-@^I#f7mzbE|eVE$dT}Bl>Potd&dv(D} zcV0v<7?8Gi8UHz1j1QgFYy|wmw^AxKfvRy*8w_1gtAiPqQFdY{_8}~ZmFk|tlGcj^ zmWPRqBh0mew=it4^ED6uR2>*w)JLM3YL)oVD$IIN{Fi!_<1f%nyqNFII(Ki?)tL4D zskrHzdhx*BaTj8bw%|1kyER@u#qE_gg52Z9Cg6(Ze3|D3uK<3!uv_h`pdVx%x$VBW z{Ct`&JeZ4-mzZ~Qsv!iNS09N!oG0R)8!`PC;!9iwG6&+#j)DFxZ%%UH0gPgx&9Eu- zwgvk+A?njCNIL$OR|c>;47=zQmWtjx7{sGGmYmhoKq-u(eK{fFr$#l*_s&`>KYgoM z42H>!bk&pMQ%wosQ`|aYX-3)Izc9Q|J*i_U!C8`Uh;x?Ta>D5kEG&5{ZuY8PF}Mn8 zN==H}ym}LcZ?2#0JeEL~B)nk!m<(6^cH&PA|N5&A<4!_=B%#uK%Ja0xe|H5jyqsoC z0wSSxb#Q)IEXwl(5B9^eev&zw5C!5!e<;bDZS&In zj$s)>#pRRTs(K9YD7&eagRbe1 zG~+O=cGZ9y=|!poAzWQ)et44@8@~#}^Ch#YH8p&7meYcZs>-XGU4RvNC7*yh?Y$#f zB0VRO(aaPQ6vr?Zyq7>nt6_lz;q`W30e(L=%nQ%Ecu+IiOC1P3?=tQ=6$Ea?Fl`A< zsxCWQo>2e!6m>gOS5OGUBJ{mA+{64(r8mV3-6|S5yG3b=68;Uy76FPAH^5d+Z&1{ z2!7jr>K#Ykmz`LJVcQF2&Q4`3fH(GMoy(`+Zr&s56BjASml@4$}u12_pYaW5A zeguQ~wbb*>&CkBwh5gOpumo@1(}cV6YoTBX)oLgUF5&WEY0sv1YE-A=!y|_JC-bcQ z6?`zvcKG9z*|5tlLglR~HFb@jL1Gw|;4827N%~lL1pSASL$GBNSb^Q+gvS+XhtGt+ zP@g|5807N|R*_OaIpLDZOE73UMeQg%QSzAwyP`i(Fn3OOQ+U#ghguc)B1=H6G^>HR zIpfsz1cRr%)OXB7maQ13Xg}cT^$La(I5YClCSw{`6w2 zpl_5!&ebVTQ;&9GIk1m*YW-w^;RA{E%ZZF}t4IZ743i3;@*x*BJdm)FxT~j^|6jBW zm_l=D3d=23m{6~Ono3~CmYjiJr|L-y5~~kkqu@my+ zpwhhkpCw~_f5LiSv-A)e384MLI+T5WNkk6?TLxD?Vq4dC#8Q=Z-)^}-_Hn0o5_ zi#pPw4n8`+Q|RLS@zof%Hmrtk%-JhEX)su;BEZ7(I);7NRbRjqkADv3U?L;X*As-D zn6gThCdkbGXb_d3@Lz*r8x7Uw@;h!X24zlmO2?=7%k!IF!gks5a&4flZiyU%06{pg2BRR>KEod z%fB%UACye0xSpy=u@}e8>wDrI_o>#GgW( zt04Ul>4f^2>NMfSWWSP@PlYh3Ei8R(o1a9J>(utmCA~pF^SM2*VXi`2-yr@JHNh!f z0Z_DkC1R)4__bfq`?Dv@|F8nYd=Y$*6Vpl|So&M0u0 zf-zOcvs|R4E9zKVRKV&e#Q5$ADe2TYVy?(hFcH45ZYB{C^!~|NtH@Y&gh(RQ@74_t zf1xp7q^24QjkwY*sow(sf z>Ga2mjAzfjRl|DFB_307eJ(nX-{Yr2uaf#1WF2l1zqCcj^Yvg7qOU$3%Xu8Ba_pXQi4`PA@5 z(oI2}aKT4dxQH@I-SqGnwfV1Z*MLAOa&z28$-^Lc?<;4+WrT~accJ0Sq(=|6l$tp? zBQUWW(|><&8MsjIbc!k*=-m#%4AZWD>lv-wQ%5|@c_tzSolCaw+^1TtZF_xcK;Tlr zZRp9UGG{=8WpuzMgUK>dIScd9p8wbIb&_kamY=v+hb%~uk&b6|4!9h!xs}MMu}fIE z&vuA{{+fDKdsXS2Bd{I95iYvMUEy|XWRHpmF!o5tPkvIJ)()UD3U)(Z*9Rtb ziw%lTQ?GTbgXYYcP~B1k&C?K)v{jBl&#zu@aCF`fHF$X5UQEn+H6-CQf0w1$QDxbt zYg5~W$=#p867>%Rm3gFImHx$KZBWqwE|d3V>Bdk9seK9tqoOs3xT;4@?DSnaqA{`$ z(c2A}aolyk7+54SvNOu~frirUlYhSe$s|&@RO-$2c{*c%J2DgF zAU*ZNI!$7}H~lSCAaP;Md*X$OYMMGtJAHH4`#J3eE4U)M^rR+Kx6UBv6m`ATeGxDg zjQgRNtw)u7r1;EOEP%P?7H<6j`hn+=M>{j5Qht)hj zs%oy&yVJK>u+%9L6m{i~`T}R)RQq%vIwQwHABdDc=IFK<6r7^cJDvb&{d{j+ZnE=3 zt+s&u?Geieml~gB!@S0pI--y9Z~*8I&zCgW=aK4|PNFj;9Ux+2*|n0oT?WKc)Ya`2 z#B+CbTv_r=h*kqV(LoEfe;lhA1?|<`6B%buZiP0ntx!5%W!KC}@nSBWfFuo&v?-d6P*4TT)%p3ljC*^K1A)!<4@` zUh1jpuXUeXEBQr&E>L+WiBQ7L!^()EuI6b)V~ z(%HL>tn0cLQT?au>J=TJ0%Zm)FenRuKbbW2GQKjZ;C zy}ymb>=srBFj%WF&_V@j6HG=T{+he{i|A%eNUu=^R*2@MsVlOfT_*;Ljy4Tgx{8rYhED+k|EXkn>h1dwE#7>X8m7OR z*O4xd94bD{wr7JKWe{qta$YF*C_B7@F5sa^d`x#1 zcSFrm@qznIc_Sv-N)^1Kh@kMRI?e1fRrB~;K{X^Ws;1oAczr}0t{Pb6}vT{hFqJWFYmL zZx{^y3clj9pixsDagv=ccn~NyO2x;jxfr5IfDBE@z7Y^QZM^R+F3fo zQwzCHt+KGOh*4A>G4?Q-h89KHP@U*_%~gMU0CmWB7tDn=@Fq_^ts#-Ir1dXiIJ96Y z-k4rTG~+}H?gm8mS#|f&tmA2$>5Ks_3ZigMHQn7fT1O0G-$NS|EW7=&m5$xLRyFR| z*^l54l-#F|!TN=#sKq`%pqVx@lc}B7l*mYPN)?2<3(GcF4-L)KO^c*64r(F&)*h@d z-rxA5j_B<18SP{+^*(0Coyg;_K|8ZNNMgv*F70PGL$wCJ<>7kipbJbix`stAE{sv< zn*D^MFbqq+9h`H#dWWxtu$ONdz%=E0wZ-X+e8Dad{c%%wK&O;?0&Ch~$p6$K?#`aO z5Hzr4*Rbj1bWz_=0p7l7_OxtrWX;&cjr>H;+Xt% zwXLqz51gtG9>})~gY(}TlD;#Z_^!L&kky?s?&{w-;a#;C$f7UJI z5juK<^uIn+xK_p6E*k!Mbe4W0Hq4VxUYI3oXiQGGTv=buIrN39jI2oShuZ%ra6GV*mLU zAzI#^t5M_K?-dUWs+q()DyVE!{8!|aR1-R@3g!b)+4_>5b`E$@kqu~*veukOilk@4kR8@5g9G=n-!I_b)~{gFd|OBE(*oSCPp+i3~5Mn4>PhSMJr9WF+Jy zo6#BMPUPzN$HT;1(%;yV;+-f*;vaJZbCNX@86llv%ixH$t8@i>a*ZPj#k4OaS6m0U z{aFN$*4Ryeu+Gd@vA^m&Fw7>!sO$km?H;0urMk-e!}SM>)}vVp5en(s4xA+hP~E*( zqY=r?I4py0MRNdAgbJ%SG%VE-RUFas)6ThNead#{6x75*4FY?Wp;re-M0`^MFLb&9 zjE{Y#>wWJ+iYA@W(uuqX9~*xw(uVQY$?tP4aw6-j?9-g4Q8_WPz%)v-n0k-m&{?}c zYHZOLfG7e5U(_4k*AY)UA`hd{-LmVmZ*t~dsi1IMYejmcTrKXq&%dI$V3}HH;Xetx zbEH7(-r{!%5w2Qq($LOi9dU%MF&tGoWHl>0PI(5aYynXJ(A#&d$`AtLRA7e&AT_J? zq0iK1zhLn}WJwWvtp}R(Iv?6|JWEwz(~IdFoHlkmb%fjAa}N1xM5A1DEE8fmdo9wF zw!_f_HS7TDWRWL8l<%GWM21Vds+&8ge1t1Ia}&9?OrP${njeV_ z-A*Zh@F$CF-e-+6Syvs=Kxxml>^io4Ol=nto8GrfJ?L%~7fBqUGh`kALpjT%=7__$;gk0h+OuzQ04Aq1|y2=S95b%{&dE1M`YK#Hsdm8 zxHqO+4o!I(BKoIij#lWr)?)DEQoT+rBh;D(C)?SiKovBXu&0uzsBaFpdG1|qYkf0H zXH*oe2Y_-m-h7tHvT#Hi++}Dp+y7>$KPyND(S;S552|y0>VpIUxIS8@9-=kmZ|OY_ zV-r_PP*D@TFdb~M{8kb2U>41Ux49I9_mJsm|S2AD_lRC@rm#li@F zQ;4GRZERApFp2~}E`MVw&!63>=XoCWpAhL5e!pNNZo&H)0xhomG7F$q6(NVm6r3&# zK*{OxAqZ)Vy~;*m;LSWc>#;Ce>MO4;45yhaS%p0Sd@zZ>0Kml+pu?mWw)3?^Cy})p zzs!!cH)&$Bd@b+*Qe3?(F9cHOERYvdTw^TDp%(CH4??Dzp{sXq{DKI1W~K}cM)^AR z1wT%F_W3fS#t;5K@bxgpJ;S`~q8v&bJK8wnh3v#t1?KRg@K%U0_e1Q$W-Ekm2+4Ya-)2`<^hG z{1xM+xS7aI0bHT?9RRUY*JlA>d7x?-JJk~hKv0n?0`_L6!3lhTRgN&4(UrgJOERrm z0;r2cXzj3b?<@h(^E!`!m~O)xFr~@o2v9EnH9P5xeoYu9P$#oKF=o-Ut#cRA-nzdW zwOGH1NAO`&f=fZ=JsR=wT?vk>z71wic-Y zR9;?{@E{bZYcE>beq)nDSAf~< zYg13V(ODYi@?XS7LEJ(<@(@h)o!-kpYJ=Hz_)Q5>eceQ7me(!Fww&~7w8NNd3pGaKRwd)g2}orjFdo2Tq=N;OXg+(P(tQSUItLT&3*&# zQ@18Ek}_JCy{|v5GrK|AoAY#_40(1$*;hDoI9-qizp-Q|qpU!xnmmgCcIDFS`Ts5hE}zc? zQkBgA27qD!&w-}UxIZ`K>=h*o75*m4_K--_)AOiYc3!FE?6#?m?wpp^HONJza-#X& zIi4?vdv9GZV9@X`#$Qi}GOqc#1f0;#XfkmFIbTi(|(iH^dkyrLP*IJj`s z%F5B?(<1hYh>~mE-obB7rZ)~)UYg6oZ;EySVtKY4H}s`()-IgOeOJ-+j4CHioubyk zK~2Y){5C+~U6VA4Q&;GWDETc2l)R<(ojrMKIdag#J5|54NomU=MxPKUaD*Gzvu9aM zJ>f1tj_URZSTLD^1^`Ir67ZX%9di^|f>v1aHXh-=a{5hWQ~n-$w#r#)viY|Ca4kD4 zoewRJzD=xvGIw~m8_6?j)m_bxTX?@5lOl7+@N|Z^Ir21$wl@N7gx<%mYBAKhWi@%(Rv>U8sw`G?L{z=35dzW>Em zsy>O>KbK`M(HU;?$l&opxnnr)F!MPgx^5K;4aZg*1O{>3odYWkZg2=?B#grHF22+m zak9MAdBl;)z7HI|D;)A5>Z|+j2F21}?5qzf{qZjbQjeM;^;2Zq>-@muSHxZhu?JVg zhE9uc1Pz{S5@TD_03L|~&I-oQkKz%BYY!H9ZZ3Obt z?SqIRGgGgNO`@$ZtX_b`V}|NB%B!N6j-00>R}s(jZ?&eTIu`Es0g{Q;5YZvE+C?>5UIcb*Q^i;P4VOS z%zuwVb{O)C`|rUzd&BI$uP|JtqnZYv5VZ~My0D4tY)8=4@_>J1q)K}O(Ef{}p?Q#a6$L|q{lgWK6YA5_Z)9o%{vfO^N zF_0^x357-agKAf;2;&0{I{qH)<+)Q`uwh=&6c<=+SBNgnf}j46Iz3@o@2^PG0q)5a zo%h6O8i_|2;8ezxVelHCQ*E)zXe!IrIm)^ zR^WD8;ZAhqw6w2uF$D~nE93ZL?)%$3=K8(J<9;4rSDfM+I7RY`=R|?+kFL0s)9GJO z4g{@2>hKqt7sO4P|Mr}(;9h2<&(V5aXHJ0Mty$sL0+&*2S`nKLV)yhBmzfRNgBSbzMN8FM8_KX~K1_17I!v)_^V{f!JJ z@y$y3CGm=owMJIxh-*d2y3UW5$25@iF@SnvC77$bz|(+Wc~n*6B}a~YXJ;j^@GFSg zbGD0KEYL?xk>N9Ob>?@s{(9Cc50AzmHp`CLCUFGiUEZN)nO@#t0oqC^@zpU%V`A}E# z?5rz)GDpGZfAZ85GCI%_?@;D?b7IsnU%CXtk}pT9`AQ$qETu) zUj3#Y^sD}mFv>NwMsmmDxbD*Y59mQ6loD&`qU64O4@8-KgqV*U=68({uL?Wo*DFBr zEoiAS(*P=pfp-g(5((@G^mr1}*S5(U7}()3ZfP;R2bpq2efOJ3Tm0_Wt<*R3mH%(C zW0yo{?3UjHba-jeL-kCex9}zaXb~2N0kAh7K#xYT2<=w8Pj124oFLiT2hn7;ZxAji zJt(ccv{Hr5Y5(t%jAx}J+gbaH$?9DJie6o&LlsrFW(~-Y^3QIVayk5x6Vx&Me7Tph zYV&H@rGH3fM$k+#c+xP;WJxXW&uCjhY-iX?4JUsyn#XE#&u%D|M->VFbh7VFgG*UK zrTv;KRcl8>(}QjW?%|cXwWJw}Sn5jMs-FD6%3I=#AXNU;o;NHsxxg)te*Xa=sGPH3 zMwWM7c@u~?ex<@;Tu&${F! zazuPvR_b?!9H`$ZDth78G*+Kpk7%8s`67i-lfe@J+Vwj`RSIZT@H*|aO<%gj ze_XRb^%NzFW1JdDEiOjJhRMBi2Y_G|i5>-GwDkJUO!B4FAisZ2i?^ZnbFOxmwmFKpXofi;MbrXx4h^$gqs*xS;EhU#|PuxW%_=o(cw&th>vDztKwdn^DH; z0O7`%m}9w328KH2Vrq_0Ol)%oc?O*5rn++ReUjf!+6|=qxE6ZO&O-m5^1H^b-YFDS zo$q^3-cb!K#&P7<2s+7^R?SZMXo@j@&3ymRH9O&1>aL0tU#=FZUmd?(8|OXR6^J~z zcE50s>$uXF6@a(Tb-Y0+YJq}A>LodHq1AN4gv%Ud??&q|RcS8gDZfN9^gV;xiw_+h z3xWgu-$3T2PGQ_4{V#DeeSE0wIq%4*B4NMJBEH6mSkw8M&ZY;r{@Yr}3n*f0B~n4u zx)or1he1TQ;6j=pvR;}gxyc}@ThQm-@<}<(R9?Qdn}ru&-c#i5G|n#9GxccVT|R3# z_BL5lTc_>;ZmL>qIlA!2v>5z%alvHyNFiA5Ow6;Yf8L@^Yj0>`ZyG+YA#4*{YJ8{Hok? zbiFv*4C$ZO%Z`A-Ck&lHwXMuxQKkZUh7Ioy0cFnV7Gd0Lt2=^7XK7Ee1#NCgdyhvK zeGWvPZuGTW?t;|PbMLu~Riyqb@rS<3g#-kb(Sa3j6XwS zm!4iygC))sCrvPa!_j`DD8lKhMrRwn;{=T{(TfpqmdSV_Pp;4N*8GXajcLn_XquS5 zFyZzR^N|};!qxMVE6e>T=4MFmctUY%@?g2&JIEsH^CBHu0{0!nr)qVRm^$X{T!Mx?QV&-gdAa9tp8T|Wo!CT zPrgQKOUP63yScm(qWlwkqpD%L8_qh=VD+5>TP9qJIotn=%Etpt`A6G3cz3SGkmZtD z-*Q#3VbC=9sE7skq5E=Ou;%bbH{#DKYM{9kkmEmUIJ|5qE@uoZ=W*JV65C}k>SY(F z#(OzO>;EBYWojBfs^I6~Tn8-Jn%l9rbQwZ8anf&aT)HhTL5{rkUb*~EXx_yl3$JbK zoAQ|>l`L<|KNmQD@1AlPhpQ53#Geby?QeUPa-XtRA@w>BlG|cS8tA|F*l9&ktMsWfVnSU>yV>x?(cWO=qNi1 z8geBqgQ%f}Qp@06{w3k_p*#u_w9tC;GwLm!r6^qLdsRfCrQQ@JNe!Ytvr6Yf zVX{9t`_eYss_-tP00YDlU6W(}Wwx7WWITEtDb&mQWb1+NZ+iUZNRdG(Ui3eEX@B9t zJI-8HE%P5l@m_jllUwEm>6!^*N)W0!`%|9gN|&2(&pOqiDRg9q-qIe2F&8+YAp}3U z6=d+WoFO5M@@K{--qz<^+~+1Kd?*1LhtB2hFSp?C9NdFeyZNVn+vBSW&TDoCA?1t@>{om!)_hLV{D?PwJR$1`280N{ZfOg52HW;u zBX1AwX+_!M+pe1$l=L#Bh$Op3%6}!q3Zyk}F3I zM$7fQJ?i{X0*ib+!!we*mpq7kL5iwmfAYm-Az7&42O>Vib)Y_w^V4;$g8+rU;nR5} z$s~xn+v*t~+7D|xKBCT~L*%$FjCQjTo$0x|n#{Nx1xjzDc!sypr1trlh7HR5h1^DA z9d9FQw{U)>oIznf5dra-_WtUH)%G}9KY`_!7w`i+ZHJy^b8O(y=fpC7l3Hm~^9et* zrL!*R(FWe;Z1za)`zmT%2U4%TM_!wA+RJiA!m|c`v^M@nK1Q7w0(Wt_i>R%tGdpSn zgm6EP2a2Gj-LSuU@~6E@_PbujY`ZKHpxTwyr0*_-yH>VA1x2?lr(%9qe78+2tbz@& z4IpTWB{s3o5@gA_>vTF)0TJeHL+&a&azfL5$)~OwO|xKD(sbkW&2c4FZYw`>gM!Gu znUA^K21h1XPAaIJ?rGkCOQA8kT+i;w&?y)vrJXS8^>W02{yn$lfXn-5K(b4vKCUGC zwZCl+cTPnG8D`I?Q2PT$RppE)8STMn^yi}cCJ*LreK9@3TDAPdO}XRSEor0Ba>l=G zQx&8S<#RDrr@tGdSyWLC+phznRyp;%ZdyVRl?UD+L-wR>7E{67s+-nXMI|_(2w|O1 z$o1siHPjG!1#O z3Ef@=rHQH54x5tP8~PT}xTAqv9Z(g}|4vzN%)XFT1c@pCGBVzZo8M*9)`7V)rB%>s zyRs54+FP4%Z8pm8Ww<(+0r)*_=+-8q|H>H`SKhR0K2Tk4+Oj7t zt%~}y<3FItu-N}yEbVg;wOO$u7}<;CGkZLuP1#wL?V>A>4vT zD3y>W#Dwxup`S)m78J}A)o+Q$iQmowqtibo{%))Og(%#w!2s}E+VPTWA~YE}@)5qE z4h{$~!pHcjWbx}aK_`WCNCM^VEgqw$UWTA1G-Fcl#AB8&^p-P(n0}h;5aOBgn1j{) zrKV^cbqdu(w0i<7(>ji_ehw3O1pe)CXhg>mVtYWeCbBiDJIY-l$|E{<3Vs5!PJsA; z8xOB6g1DrhHqlulWtJvp3`;%!5(qOLxk%8SC`n()5wtItjw z%tqm&J*)wtzbpDlU-GB|xwR}2y$p|S!N*v(ot-pbMB(5W$#T`{EL;8h=eco%u_#=F z#@|5m`JGLQd-`k;b;;@tFM2GcrmRr?8c>c^@KXH&F=Q@1pI_Qo9L~|nxx5tj167dt zkT`3TXxlspHCY?iNWZ4vCKlg=F<{Bg9n}~B+aE?7w#?)eBlgXeF*QW-Uo$AR}9m^=_K@Zw1Y&iT@& zm8=O?G4rq>dgkJ_yZZ6iVRM~1i=~}oR)|Ppr{jyL2k*FZCA=utX|jDAvW+0)A1!;U zkq?Oe0A*YWU0H6c-CWi+PBAZ9dXcccPw{8#3=T%)0{wLbQ7O{=Q?Vh-XtvQV@`RzDDF7k){7B7S8bpcg-a~EDu^Oa=A2W6&1IQy6Rg!%{^&I^eDdp! zemAtEDhs%Cae>PS-x?R$Nbk8xQ9*SuA?^8)bYq!c-Km-yKkk6l>q}2IXjsccy?Dpn zE^0CH0~RRlp7O6-#2ryDa=0zHR}&a0Mk+^nR`kk-s22s?@u~z^*UFK}6wCBhi16Nj zf*vF4b5bn6bQATT4Bc4=LS-iFi-?1f5mC6k^i_hVpkISClO1Oo-X!J8`?>K|Y3u?S z8|1283uB!nU30iVTBy(&~B}rAjSkjSi|pLOiaI6|49hXDk(wMFL$G!Hw7g1n9Q$zd+Xf0aIY+GizJuI zD-FWr=hi9^d@YvFZg{Do4SYkbUVuUiCT3GE5OU0*&i)9G7{d6X zV+|ymyv7@F+*DqOe|W=1iE3hs5n!b$TecX&_F2~XF0A24xSeD_KCV=uhkr>fESDef6m;YH6jAi;4j6hyYLv2wy)l& zsViGrTCa*a!!{JO1{G-$NEZsF4GqMiVaE}uiE17%$!bH>C-RcLo^$Xr*N;08Q)~Cy zgtRsrNvE%?qMDbDY2*Xhd&z3O!YBJAx7y5Q5smwJpti|4X5-M>uthrkUr2XgWDr{M z(e*{jU#6teKUY!r(q#l)pwFBnwE-!^#H2&uS`+$USjJU^l#_%TCQE0huORT+3+eL) z+oN$T?+-Bdt?qZ=E%nQ=A!Ke4^nz(TFNy5i)9@_mkb>WQ;E)EJ2ezFBTGb~~juS3` zs@G*{z;Qk0!I?=0zJZn`Bj>l)-?SJOj1sOxVq7r_28?K2+r9n85PK+b+LM(>Qd9_oixWgU+#KkjXKKGMY;y<(UD4ns zi5AGbbd%<;C&5()+dM9zls(6?Pz2=T#WDrwsHAyWzxqqXjO7;N$G6g%3<{JE`nu%5K67*rGUb9trP4rT44a zn;YN?jhvqe{iMY`ATqfWu#_TBcrml8Kf9Kh47y6L+g@{M8I=OR=cJXfL7l`vftT9VWaVqTMID zm1)C6iFBxx&z0S>prOimpk1pSZt@NeOc3$;WkFjYN*##*#{0wqZjWEa;@`D?D_#xH z*^(ybgwGxbPco776A*ao^>G!3Z%8K1eN5RxFnApenSPIg3FOfj7cp^1^+ys-ka}D_ zKJY;w#{k27DH^8UiGw6{@JO%gilm!Ssk(4Rfs# zy9t$T6T~a|UxOfndCA46^2W$Z5IilgMRViPWQal?Q8kQ;n0*q2lb4syt_v#qz=K^$ ztS|8VNBVd;sw&f*-U>s=d38x$A@q>uneYeix!ZTXnK{mjQSu0^o{uSO2?pDR@P=jD z3UIt;Op;HysvRZ~d24QRDdRibw4-d$VKD`t2uKiam@hrt=>R8WjLz(p0JhBE2`R3@uUjX zX`S#FC0v5#969+eSg>SxP2P#P8A*Z{iO0prxGxMm=1%hD@by%00tzV)`9jsemGxnF zcC;qlB0Py-0X$M=hnrt`InHOBkx#vM`UpA192;l?oNeVzdP(H67D%R}+IA16e=$Zo1gYq%; zzz{bXU@fVl@>qEp_dL0HtPdD{sEB9P7tp@>4j28UwVM6QfXrhJ}pbxw=f_MRNM&0^;Ao^JXgTo-ka(N z>p3KZ>xHZH+$li}KS!=)S*M!84dCWp#!))IfHrvUjQD`Nx_Sp(%$t^@2y6%vje+bS z8Q&#vI|eQrLm^_n;|+hjzm0LhehA)oa~k zxWL@WN1JLI`=$mORQ_*zfZMwT#O^4@=EE$P;9j(f-a>CRkhN4%r>dyGz9{&u#nJ{8 z921z)Dmvjw8gdjj0r#U&5KIx~4+X!2YgiJ}q06-U1BG|Cf{RMTQ$uE|VkHR%n}qNu z=>AJ{Fo3+J%a_zu_(>1ohWSZTY-EcAiMATfrb8QW(eAbgX^DyMB>!~c#WFsT27G!V zq##OXa|vBbt`|M3tpu^OG#7!G(YE(+xBsayzRf>&Cl|6gfki#2dl$wl!$%`9&j;c$ zT8|i?xZC_=3Bm<4$of>cksSW|m|7I4-OkF`0ePv{;H&3IG zmblQ-o>ll>xtf|aw{!tZ=rzUS#VlO<*H9`w4<6+;=D-IWB7Jxo#cJ4_@tFQoe?70U zyrXM?^nx8v1bJ0|W5$PP+c0h7v6+O~%~~K~D`;ZFet)89I1i1TP~AC1w#kA|Bs7~~ z{2MFo?cW*0Lvu>_-1&2Lye*)!pP|_pd(`gUfjKJ_zZKLQbA%ha;S$p>>s_woS3LtX z(alTm=xZ2Ax)2}Dx9n9dlg5voYj{)~I(r+k7QsMM2!AxduFN^m5wm{M+p^%kKXns` zIDC~y`B0p2l={Qf_zv0X&b+6m&LkqZgb1Zr)VSE4U14z;C>$5MYYHHXU^UwsvYC89udKPMfs5!G$ z!Gh5N(fo>-3%#nv`&v<68VuAZZcXhEC+mD^S(Wb?TWFxzi8Pb2t^2h!(yI4*#{C2L2S&ogR2{Goq_FN?()?OguWQga?QC_E$ZYiGjYGdw-5 zSuK<5nt2PJLBIh8)gCEdO|BaA_G%%#yheg%wC^H>qwu5{`8%KdmFOqdDo_Hvq`a3x zCXWDlc;mm4wF1nlmb~(9m6+|=9d^@H(R4HV#!%x|$yzsNkCUB8N5UNW z)9-H#iA6)y{K$;S_`Zub<5*A@>xP$u!1V8L6D3LN+szyvZ=iE)Z9cD-d86blrprhO zExC^t$G}GDSlhkLW$=9poR0~<(R8kPkNX06+?Noz-m z*eIW$V#KP$Gl9c%;R#COIt)=gNTDda?Xk9=`E`f7X>?qVPO&h4q=~zEo$>ft!W%Sz z2*yc$)9<*rtqy{3x7*%vEsD|aSXs-j!Yfzm8}{5rEr6B z_gF&6a&za6+8HZn&P|UGW9$E^UnD2ESX#z(u-d!_QQrnR!p8=vfg_^2e7JRb+Y_EC zCRC2LHDoS)$F(7LiCZq)vpqRyoRR7j{5A0+wQO~bWXCKlU!vtdRHWKp?h-%+7KGJHA=AYyotIN zxY;FQ>GiqV;Z*n54*32Gh;DM1jvyOYpMA0|QA zL)Hc}kB_eIxUo#%>wZvK(O!Ip|8$@N5qHpql&>J0xl?A^fOvjJjqz1$_U-j=A4^(X zt=>WP1OxvPePVUl5v#mm_l=?=B|`H2DmQTa;Y*iW{?ATN)7<%?=v(EPlXmmHH=lof zwH@lU?@ ieT$2Mut6hmsDQ!K)z4*}Q$iE4su|7yqxpY?V6h2^zrHyjIa%WJ6#Tq`m5f`1>)$viP6ep;~!8KyEFvTCjQZm3c zX<5<9nrI9zd@Q;lyMUQ6#gD7lnK~160vjC;2{HpJkokd(uygOl$rjC${a%yi?KwT~ zInO!IbKa8m$q6%LK{5cCk*L=h0K94bE1ia(lBS&x@wmNy?Z#DjdU=6F0zN(3ve8`TyXt5w6(#_o6z149Uais1s)Id^uWD)Ffaf^L-6=9 zJbw-o6OcEnBoZ52z1NVO0$vgyncR;E3<^<2L@tV2vQ!-%6R%G)q@>a5TQahCnaufx z7N*E%Klu64s{5w#bo2Sk-~Vv)R!67jm%IIg!%xP#w%Kz5B&CVExQ%%~Kj@0f znH5JAwar#`>e7qGwptFemhIK|clayzCrKRx%Wu^^UFL6R+32L(<;Tuy5X);n&!!hQ z7WDRxAImINhZH=MjvjuRHF0&kzklLr*~d=v$R0=b4Rv*RZQA&oS9=B%y7#4ewx)Q# zRE!lpF6zlQ*%D%E^5z&n9e8Bvi^z$$ao_E*b>%H+&DPhO=m}+OuWFsaQ^ZzkoC=pb z(Xn~W26}AW@O@^qMsd8mHn+Tz%~7`=R+XiAt`cP0+To5(rc!nPndxqgMR8%@yH78C z6h;MDTa&q?PAb6oQ@20KY3X^Cb=3enw}Buz&LKgERBmO?Xsr@788#M8qXkWdy@)13 zqN${s!l`~Abg`KhcCl>uinr0I2&D+pOyM;xO?Su9QoC8>TwvGH-{xYVaIfths+bBj zW=Ah8R;Y3GllO&N7e^dnBf3)Fu_z&@!JAE7;qO>JR*V-OI-p%dlOsNUO0q6Mez*DgujQ z<63(CT&IRVJL|=c5pnkA-8u6DoKesD-mDU5?#ghkUnd0-lGD|j7bBjX7w?0JgBO~|-zs#*zB zQQYS#MEqu6eYkcNbY2Fp#K{Gc%FhKh^s;|kimae(XE_(}J15K@Rn_8!#vm^~emn9O z>hR|vfa^STiyi*I|9=8?#t0_> literal 0 HcmV?d00001 diff --git a/src/server/test/shapes/triangle.png b/src/server/test/shapes/triangle.png new file mode 100644 index 0000000000000000000000000000000000000000..8ed8ba908d575b3b21d34d1a6a73fe9b81d6fea7 GIT binary patch literal 5120 zcmeHL`9BnD`+jVZY7;bf zqM5SA$u=E{5H$uF)gWa^mP+6InbG(C1KwZXAACN{v)%V~UH5h0XS_(XL19xSLQP;58uig-25e|9WmpGLgs!qF`bb1(EBG#_QC8*N;A zEpGowM<<%x^4NnHJ+jqzY^wCg41B}X<*5hgsyy;7^kwBJ`4(@DaPO{=flEii&p-Z_ zg(j4waM@qgSr_c;cUgiD)f^e0svdzx2Qk^7`px2&IPzuVlnjs=h4VuF@SeG=9+x zp)pa^1HCIcrOdGCr3k(44EwUHVpS-C)6MTLmP1wTybs?6WYa0$k;!D7%Ra;jbFw zKCKNSt-+znEx4v(t%qQr5EkMX}9X>zS7&pP| zq`&EqL%kE#CxRq9Im-Oy)}N@$>*ts}T7wWXxu?xFc;r0uc7-1VT{hK6?gND2>fxOl zt0M?~c9EG<9aHs;h7=m-nOppS2!k=~#v06nI@>8s}9ja(Cs8 zyC4)0uk|H;+*}5YbWlk~)A%-wk``85ncKXCF+d~*Ttui1imPSS)!fwf? zPNL^=I~sXnGxOdf6_ouKm9yRXI(Z^!h(G)-4Iw#d8A%ey&v%Ef3@#hX`O@2Y?G3=a zXQDaF?Z*1zgCl1lO2y7Fvn17lpgCy41~HUF+U1^k&iu(CDYScYoyY)iwK>S6EsJWz z>FK--yC=C;@nzVTwj#Q4W~8@xE^0U-^m$cJ7cF^b?L1R}7{ACDN$>{&10{IOjbbY+im*fAtqN)8_&u)ZoeE56>f%0>X@6 z<7o#)IT{2D)rs(#LBz;d9sM}Wm^{(a7iJlPMANk^UA%q2~)YyNAm@QIK&? z#T7!Y`iadC?H=uo(yYW+rBgX2);^#J#({(nb&>!x)?9SZeDLVAdD#kt{s}0HDo}cr z&jx9>OQAg;Jbr*!4QiFc>^bkp&MUhWsiP`Oa(HCUU_e!6r=Rn~eF$}^x=D~~ZqAxc zuT5~{`#F<}v{6-UqF4*LZ2s$+CjEAZE|C|uA#N8Ct9EDjtg7o!wX-aW;5WC||5dMr z)XXDl6KX&s8_D5kRkbdUr5!0LT7jzW(Y2qqRArLf`t6KSRV%MZlJQYqdwqG=$hXng z6~;LLN@-tgy}>x6ed^$yB1p2vUqv?*=Usq)MoPT(*y#G4XySV%(s*5X_*sZTi{9*0!L>L|ir6=3lvSWzy)PcIi&Sy-Bs0 zsQ#|}CwwWiMG0gC?AU;eAS0!#$6E?$T}#obB?8i3WzCtN#QdlG!{fP-xGb$KYEWrW zIq?U%FcJL5b(TE|)1YayC>rE&Y*rZGz_EZ!cZtJ<7=8gPw~BhmL~J zJothF)Fnx`v2RI2!V<(U!evoMbiX!M#g)ysnIZ#}A(0p<*SXp6t0j0U0EG00gU`Cy z-7olavmNRJ^VSP5d6qXM?M6>{kggH_kFnC`WqFcgU^p zvL1Gp5Btz9pRaQYWU`F!hIB~p`S^n|vV_4lFclqnb1-oI0En;eNkVbNe8zYHpmr0` z$5o<+7eCD(tB!rjip=Dgf($ozAW|XjH4T zb)It{*`%VSAebZ5;5j1MbTa-!TfK>O&AGMv+Rv5gRc)Ow2E~?I2HiU@9dk$H4KJ?g z@+x;=nlVP81`~Oodlw)zQt0^V-(g)=1?04&TtOly71!Vu?S5H`XptSkY|3AJ=nWf8 z0`KU@*wce2L44^kV7}bo3jT^^;gf*sZXM zo=Ep(-UBR>IMBJx9k?8enU3UEnX1!JJ8H|_WcB@^0DlEBs+on>COAM9z z>-px-W&FVbvGphK|493|OKRkA3_IK_E|m2~+?Q77uIQ^pXud{4O+tgWcHxSAC5F3v z@$AjQ=&6zJYKN8O<5vdj4$Jh$iuUKK*auHVq-qqRMAg)s{ejcMHYs)`7*mKoD^_G^ z)l?PFna#QxY*!pu{p3?8d3_w~A>h_9#0Y@0ud;8JtlkIOQlYL4aim_maa%SzjzJ3b z@&S-Sr5zWS_K~jv1kIZp<>eP&D$LrcXml97_KnAqkejv zZZ{mVG;e2QCVdGu8>#a=Hn@F z`P{}z91KK!2+x1!dmLt--r+VKp4$jshN*Xed{<*ZaXa?KgdLIYn8m1HN4s#aZeZ?# z#7Q-Z$PnxdD215SN{7Q#)L4CXHZ`Et7u6d_9aXpO>|KXlgaTRi_5#m9xGWy>`fux7 z^t^y=1B#Xto7z*(t0=NU>M=%^V(uX1sO6w1g?#ndTry1DhX_)Fv=-JC#9$_7Sv% z)x{FT6k2=>dW=mCM^G|6=(!BEpQt5inV4!JIPJO3!9!L}*QE~OjzxqGsaPWv2m+7g zL-tF!?w08)dwhi3q8j0qDc?M2Z%2WT0U| zdCM7RU*B~`tRCrqb1m`Iu1+und|!`FIt?D^BwjiWkc2G9o^f&+umYY-jEU*D?CQDX zX6*CzbKbm5L>F)0$TQ-1M@Hv4Pp4CY@0;#BqcQstoi2R;_fl?}s=FvR4YyOuU8RK3 zwe0{UP&7FJET48{*&Y9xAB@}SLmSH78ko*VZzLPw3)g;)Wt59Uowk0j-i*V8hUoUyD&lANo`pfhf4}{TDJq>%iT(1u|-j6p+oHg`YkFaM2aQhkM+-q zRmz?MvfB+q?o-mffgq?ONO++pA@~mDcHv0LbpcC^$fhNfnZ>4~x%kN)HwKEZ(dd+-a)^1)Q%)ng8uTX#qvwXO?XRhsR^;_ZyjN)K; zEu43X_>dvL^Yeh%Hd!zfJ6|spuM2&|MZi6+1D4sQ>wW{iRLE_deQFVBoK{{x^@Z8| zP|bmrBg!rOJS!IbK4+xixW82sKuQtic21`JvYs;f0(5Te!y?KE>}n@w_Z||EePjGs z&?Gr+5iOHjZDTxGUSTQ&MiBh9h!$sP{4=+S+~i@$BR2}ZZltiL0zu*s9V;l>HLBe2 z>*t(bh?L-;W;!~)L+t@i*C5Hg1;+1#qTgdj117L;rgD9u^eFbU3jZGj)uwru%s3Ib zCoTKCVQ#WznGB_dT*c-{4Lq zcWX@&jX#U_*>kD`p35k-q13U$$jth{Mbk1VP(E0B2TFJ4;i~fLPul&P$PhmRKwQOO zSt$4XDX;?}v%Yc-bTQ2)25xKP7#OZUyPG88No#?r_M_E8`6ZOmBdYlA zWID{sjvzbk;yyd3$lw=s3Tup3G2ANFuDl^7cqXES0EGx}YH6BI&oHwH@kFki!0s2E zs>1X}2bR)VjdR88nas-+R{z}@Rj>3ZSGSCy9QY>^qS$S-E!yO9@_+UJ-_`&B0kYVa A9{>OV literal 0 HcmV?d00001 diff --git a/src/server/test/synthetic_dataset.py b/src/server/test/synthetic_dataset.py new file mode 100644 index 0000000..3712935 --- /dev/null +++ b/src/server/test/synthetic_dataset.py @@ -0,0 +1,244 @@ +import numpy as np +import cv2 +import matplotlib.pyplot as plt +import random +import os +import math + +from skimage.transform import rotate +import skimage.color as color +from skimage.util import random_noise +from skimage import io +import scipy.ndimage as ndi + +# set seed for reproducibility +seed_value = 2023 +random.seed(seed_value) +np.random.seed(seed_value) + + +def assign_unique_colors(labels, colors): + ''' + Assigns unique colors to each label in the given label array. + ''' + unique_labels = np.unique(labels) + # Create a dictionary to store the color assignment for each label + label_colors = {} + + # Iterate over the unique labels and assign colors + for label in unique_labels: + # Skip assigning colors if the label is 0 (background) + if label == 0: + continue + + # Check if the label is present in the labels + if label in labels: + # Assign the color to the label + color_index = label % len(colors) + label_colors[label] = colors[color_index] + + return label_colors + +def custom_label2rgb(labels, colors=['red', 'green', 'blue'], bg_label=0, alpha=0.5): + ''' + Converts a label array to an RGB image using assigned colors for each label. + ''' + + label_colors = assign_unique_colors(labels, colors) + + # Convert the labels to RGB using the assigned colors + rgb_image = np.zeros((labels.shape[0], labels.shape[1], 3), dtype=float) + for label in np.unique(labels): + mask = labels == label + if label in label_colors: + rgb = color.label2rgb(mask, colors=[label_colors[label]], bg_label=bg_label, alpha=alpha) + rgb_image += rgb + + return rgb_image + +def add_padding_for_rotation(image, angle): + ''' + Apply padding and rotation to an image. + The purpose of this function is to ensure that the rotated image fits within its original dimensions by adding padding, preventing any parts of the image from being cropped. + + Args: + image (numpy.ndarray): The input image. + angle (float): The rotation angle in degrees. + ''' + + # Calculate rotated bounding box + h, w = image.shape[:2] + center = (w // 2, h // 2) + rotation_matrix = cv2.getRotationMatrix2D(center, angle, 1.0) + cos_theta = abs(rotation_matrix[0, 0]) + sin_theta = abs(rotation_matrix[0, 1]) + new_w = int((h * sin_theta) + (w * cos_theta)) + new_h = int((h * cos_theta) + (w * sin_theta)) + + # Calculate padding amounts + pad_w = (new_w - w) // 2 + pad_h = (new_h - h) // 2 + + # Add padding to the image + padded_image = cv2.copyMakeBorder(image, pad_h, pad_h, pad_w, pad_w, cv2.BORDER_CONSTANT) + + # Rotate the padded image + center = (padded_image.shape[1] // 2, padded_image.shape[0] // 2) + rotation_matrix = cv2.getRotationMatrix2D(center, angle, 1.0) + rotated_image = cv2.warpAffine(padded_image, rotation_matrix, (padded_image.shape[1], padded_image.shape[0])) + + return rotated_image + +def get_object_images(objects): + ''' + Load object images from file paths. + ''' + + object_images = [] + + for obj in objects: + img = cv2.imread(obj['path']) + # img = cv2.resize(img, obj['size']) + img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + object_images.append(img) + + return object_images + +def generate_dataset(num_samples, objects, canvas_size, max_object_counts=None, noise_intensity=None, max_rotation_angle=None): + ''' + Generate a synthetic dataset with images and masks. + + Args: + num_samples (int): The number of samples to generate. + objects (list): List of object descriptions. + canvas_size (int): Size of the canvas to place objects on. + max_object_counts (list, optional): Maximum object counts for each class. Default is None. + noise_intensity (float, optional): intensity of the additional noise to the image + + ''' + + dataset_images = [] + dataset_masks = [] + + object_images = get_object_images(objects) + class_intensities = [ (obj['intensity'][0], obj['intensity'][1]) for obj in objects] + + + if len(object_images[0].shape) == 3: + num_of_img_channels = object_images[0].shape[-1] + else: + num_of_img_channels = 1 + + if max_object_counts is None: + max_object_counts = [10] * len(object_images) + + for _ in range(num_samples): + canvas = np.zeros((canvas_size, canvas_size, num_of_img_channels), dtype=np.uint8) + mask = np.zeros((canvas_size, canvas_size, len(object_images)), dtype=np.uint8) + + for object_index, object_img in enumerate(object_images): + + max_count = max_object_counts[object_index] + object_count = random.randint(1, max_count) + + for _ in range(object_count): + + object_size = random.randint(canvas_size//20, canvas_size//5) + + object_img_resized = cv2.resize(object_img, (object_size, object_size)) + # object_img_resized = (object_img_resized>0).astype(np.uint8)*(255 - object_size) + intensity_mean = (class_intensities[object_index][1] - class_intensities[object_index][0])/2 + intensity_scale = (class_intensities[object_index][1] - intensity_mean)/3 + class_intensity = np.random.normal(loc=intensity_mean, scale=intensity_scale) + class_intensity = np.clip(class_intensity, class_intensities[object_index][0], class_intensities[object_index][1]) + # class_intensity = random.randint(int(class_intensities[object_index][0]), int(class_intensities[object_index][1])) + object_img_resized = (object_img_resized>0).astype(np.uint8)*(class_intensity)*255 + + if num_of_img_channels == 1: + + if max_rotation_angle is not None: + # Randomly rotate the object image + rotation_angle = random.uniform(-max_rotation_angle, max_rotation_angle) + object_img_transformed = add_padding_for_rotation(object_img_resized, rotation_angle) + else: + object_img_transformed = object_img_resized + + object_size_x, object_size_y = object_img_transformed.shape + + + + object_mask = np.zeros((object_size_x, object_size_y), dtype=np.uint8) + + if num_of_img_channels == 1: # Grayscale image + object_mask[object_img_transformed > 0] = object_index + 1 + # object_img_resized = np.expand_dims(object_img_resized, axis=-1) + object_img_transformed = np.expand_dims(object_img_transformed, axis=-1) + else: # Color image with alpha channel + object_mask[object_img_resized[:, :, -1] > 0] = object_index + 1 + + + x = random.randint(0, canvas_size - object_size_x) + y = random.randint(0, canvas_size - object_size_y) + + intersecting_mask = mask[y:y + object_size_y, x:x + object_size_x].max(axis=-1) + if (intersecting_mask > 0).any(): + continue # Skip if there is an intersection with objects from other classes + + assert mask[y:y + object_size_y, x:x + object_size_x, object_index].shape == object_mask.shape + + canvas[y:y + object_size_y, x:x + object_size_x] = object_img_transformed + mask[y:y + object_size_y, x:x + object_size_x, object_index] = np.maximum( + mask[y:y + object_size_y, x:x + object_size_x, object_index], object_mask + ) + + + # Add noise to the canvas + if noise_intensity is not None: + + if num_of_img_channels == 1: + noise = np.random.normal(scale=noise_intensity, size=(canvas_size, canvas_size, 1)) + # noise = random_noise(canvas, mode='speckle', mean=noise_intensity) + + else: + noise = np.random.normal(scale=noise_intensity, size=(canvas_size, canvas_size, num_of_img_channels)) + noisy_canvas = canvas + noise.astype(np.uint8) + + dataset_images.append(noisy_canvas.squeeze(2)) + + else: + + dataset_images.append(canvas.squeeze(2)) + + mask = mask.max(axis=-1) + if len(mask.shape) == 2: + mask = custom_label2rgb(mask, colors=["red", "green", "blue"]) + mask = ndi.label(mask)[0] + else: + for j in range(mask.shape[-1]): + mask[..., j] = ndi.label(mask[..., j])[0] + mask = mask.transpose(2, 0, 1) + + dataset_masks.append(mask) + + return dataset_images, dataset_masks + +def get_synthetic_dataset(num_samples, canvas_size=512, max_object_counts=[15, 15, 15]): + objects = [ + { + 'name': 'triangle', + 'path': 'shapes/triangle.png', + 'intensity' : [0, 0.33] + }, + { + 'name': 'circle', + 'path': 'shapes/circle.png', + 'intensity' : [0.34, 0.66] + }, + { + 'name': 'square', + 'path': 'shapes/square.png', + 'intensity' : [0.67, 1.0] + }, + ] + images, masks = generate_dataset(num_samples, objects, canvas_size=canvas_size, max_object_counts=max_object_counts, noise_intensity=5, max_rotation_angle=30) + return images, masks \ No newline at end of file diff --git a/src/server/test/test.py b/src/server/test/test.py deleted file mode 100644 index f675149..0000000 --- a/src/server/test/test.py +++ /dev/null @@ -1,82 +0,0 @@ -import numpy as np -import torch - -from tqdm import tqdm - -import sys -import cv2 -import os - -from copy import deepcopy - -import sys -sys.path.append("../") - -from models import CellposePatchCNN -from dcp_server.utils import read_config -from skimage.color import label2rgb - -def get_dataset(dataset_path): - - images_path = os.path.join(dataset_path, "images") - masks_path = os.path.join(dataset_path, "masks") - - - images_files = [img for img in os.listdir(images_path)] - masks_files = [mask for mask in os.listdir(masks_path)] - - images, masks = [], [] - for img_file, mask_file in zip(images_files, masks_files): - - img_path = os.path.join(images_path, img_file) - mask_path = os.path.join(masks_path, mask_file) - - img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE) - msk = np.load(mask_path) - - images.append(img) - masks.append(msk) - - return images, masks - - - -if __name__=='__main__': - - img = cv2.imread("/home/ubuntu/data-centric-platform/src/server/dcp_server/data/img.jpg", cv2.IMREAD_GRAYSCALE) - msk = np.load("/home/ubuntu/data-centric-platform/src/server/dcp_server/data/mask.npy") - classifier_model_config, classifier_train_config, classifier_eval_config = {}, {}, {} - - segmentor_model_config = read_config("model", config_path = "config.cfg") - segmentor_train_config = read_config("train", config_path = "config.cfg") - segmentor_eval_config = read_config("eval", config_path = "config.cfg") - - patch_model = CellposePatchCNN( - segmentor_model_config, segmentor_train_config, segmentor_eval_config, - classifier_model_config, classifier_train_config, classifier_eval_config) - - images, masks = get_dataset("/home/ubuntu/data-centric-platform/src/server/dcp_server/data") - - # for i in tqdm(range(1)): - # loss_train = patch_model.train(deepcopy(images), deepcopy(masks)) - # assert(loss_train>1e-2) - - # instance segmentation mask (C, W, H) --> semantic multiclass segmentation mask (W, H) - # for i in range(msk.shape[0]): - # msk[i, ...][msk[i, ...] > 0] = i + 1 - - # msk = msk.sum(0) - - # img = img.mean(axis=1, keepdims=True) - - final_mask, jaccard_index = patch_model.eval(img, instance_mask=msk) - final_mask = final_mask.numpy() - - cv2.imwrite("/home/ubuntu/data-centric-platform/src/server/dcp_server/data/final_mask.jpg", 255*label2rgb(final_mask)) - print(jaccard_index) - - - - - - \ No newline at end of file diff --git a/src/server/test/test_integration.py b/src/server/test/test_integration.py index 5bca348..cec7e06 100644 --- a/src/server/test/test_integration.py +++ b/src/server/test/test_integration.py @@ -2,6 +2,8 @@ import cv2 import sys import torch +import random +import math from torchmetrics import JaccardIndex @@ -13,60 +15,38 @@ sys.path.append("../") from dcp_server.models import CellposePatchCNN from dcp_server.utils import read_config +from synthetic_dataset import get_synthetic_dataset import pytest - -def get_dataset(dataset_path): - - images_path = os.path.join(dataset_path, "images") - masks_path = os.path.join(dataset_path, "masks") - - - images_files = [img for img in os.listdir(images_path)] - masks_files = [mask for mask in os.listdir(masks_path)] - - images, masks = [], [] - for img_file, mask_file in zip(images_files, masks_files): - - img_path = os.path.join(images_path, img_file) - mask_path = os.path.join(masks_path, mask_file) - - img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE) - msk = np.load(mask_path) - - images.append(img) - masks.append(msk) - - return images, masks - - @pytest.fixture def patch_model(): + - model_config = read_config('model', config_path='config.cfg') - train_config = read_config('train', config_path='config.cfg') - eval_config = read_config('eval', config_path='config.cfg') + model_config = read_config('model', config_path='../dcp_server/config.cfg') + train_config = read_config('train', config_path='../dcp_server/config.cfg') + eval_config = read_config('eval', config_path='../dcp_server/config.cfg') patch_model = CellposePatchCNN(model_config, train_config, eval_config) return patch_model @pytest.fixture def data_train(): - images, masks = get_dataset("/home/ubuntu/data-centric-platform/src/server/dcp_server/data") + images, masks = get_synthetic_dataset(num_samples=2) return images, masks @pytest.fixture -def data_eval(): - img = cv2.imread("/home/ubuntu/data-centric-platform/src/server/dcp_server/data/img.jpg", cv2.IMREAD_GRAYSCALE) - msk = np.load("/home/ubuntu/data-centric-platform/src/server/dcp_server/data/mask.npy") +def data_eval(): + img, msk = get_synthetic_dataset(num_samples=1) return img, msk def test_train_run(data_train, patch_model): images, masks = data_train + patch_model.train(images, masks) - assert(patch_model.segmentor.loss>1e-2) #TODO figure out appropriate value + # CellposeModel eval doesn't work when images is a list (only on a single image) + # assert(patch_model.segmentor.loss>1e-2) #TODO figure out appropriate value assert(patch_model.classifier.loss>1e-2) def test_eval_run(data_eval, patch_model): @@ -74,10 +54,29 @@ def test_eval_run(data_eval, patch_model): imgs, masks = data_eval jaccard_index_instances = 0 jaccard_index_classes = 0 + + jaccard_metric_binary = JaccardIndex(task="multiclass", num_classes=2, average="macro", ignore_index=0) + jaccard_metric_multi = JaccardIndex(task="multiclass", num_classes=4, average="macro", ignore_index=0) + for img, mask in zip(imgs, masks): - pred_mask = patch_model.eval(img) - jaccard_index_instances += JaccardIndex(pred_mask[0], mask[0]) - jaccard_index_classes += JaccardIndex(pred_mask[1], mask[1]) + + #mask - instance multiclass segmentation (512, 512, 3) + #pred_mask - tuple of cellpose (512, 512), patch net multiclass segmentation (512, 512, 3) + + pred_mask = patch_model.eval(img) #, channels=[0,0]) + mask = (mask > 0) * np.arange(1, 4) + + pred_mask_bin = torch.tensor(pred_mask[0].astype(bool).astype(int)) + bin_mask = torch.tensor(mask.sum(-1).astype(bool).astype(int)) + + jaccard_index_instances += jaccard_metric_binary( + pred_mask_bin, + bin_mask + ) + jaccard_index_classes += jaccard_metric_multi( + torch.tensor(pred_mask[1].astype(int)), + torch.tensor(mask.sum(-1).astype(int)) + ) jaccard_index_instances /= len(imgs) assert(jaccard_index_instances<0.6) From 5ce4845c5e281879f28e56b0eacdb6293a77c0b1 Mon Sep 17 00:00:00 2001 From: Koren_Mariia <71977543+KorenMary@users.noreply.github.com> Date: Sun, 29 Oct 2023 13:03:56 +0100 Subject: [PATCH 16/47] Update requirements.txt --- src/server/requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/src/server/requirements.txt b/src/server/requirements.txt index 352bcf5..3879e66 100644 --- a/src/server/requirements.txt +++ b/src/server/requirements.txt @@ -1,3 +1,4 @@ cellpose>=2.2 bentoml>=1.0.13 scikit-image>=0.19.3 +torchmetrics>=0.11.4 From 9de2a22e52bea0d0a6cda33eefc289a7914bc759 Mon Sep 17 00:00:00 2001 From: Mariia Date: Sun, 29 Oct 2023 13:12:37 +0100 Subject: [PATCH 17/47] requirements modified --- src/server/dcp_server/config.cfg | 10 +- src/server/dcp_server/models.py | 139 ++++++--------------------- src/server/requirements.txt | 1 + src/server/test/synthetic_dataset.py | 6 -- src/server/test/test_integration.py | 9 -- 5 files changed, 36 insertions(+), 129 deletions(-) diff --git a/src/server/dcp_server/config.cfg b/src/server/dcp_server/config.cfg index 8af1781..1f35f84 100644 --- a/src/server/dcp_server/config.cfg +++ b/src/server/dcp_server/config.cfg @@ -18,7 +18,7 @@ "model_type": "cyto" }, "classifier":{ - "in_channels": 3, + "in_channels": 1, "num_classes": 3, "black_bg": "False", "include_mask": "False" @@ -32,7 +32,7 @@ "train":{ "segmentor":{ "n_epochs": 1, - "channels":[0] + "channels":[0,0] }, "classifier":{ "train_data":{ @@ -48,7 +48,11 @@ }, "eval":{ - "segmentor": "None", + "segmentor": { + "channels": [0,0], + "rescale": 1, + "batch_size": 1 + }, "classifier": { "data":{ "patch_size": 64 diff --git a/src/server/dcp_server/models.py b/src/server/dcp_server/models.py index a177af4..07b63ad 100644 --- a/src/server/dcp_server/models.py +++ b/src/server/dcp_server/models.py @@ -30,6 +30,7 @@ def __init__(self, model_config, train_config, eval_config): """ # Initialize the cellpose model + super().__init__(**model_config) self.train_config = train_config self.eval_config = eval_config @@ -55,9 +56,15 @@ def train(self, imgs, masks): :type imgs: List[np.ndarray] :param masks: masks of the given images (training labels) :type masks: List[np.ndarray] - """ - super().train(train_data=deepcopy(imgs), train_labels=masks, **self.train_config) - self.loss = self.loss_fn(masks, super().eval(imgs, self.eval_config)[0]) + """ + super().train( + train_data=deepcopy(imgs), + train_labels=masks, + min_train_masks=0, + **self.train_config + ) + + #self.loss = self.loss_fn(masks[0], super().eval(imgs[0], **self.eval_config)[0]) def masks_to_outlines(self, mask): """ get outlines of masks as a 0-1 array @@ -121,6 +128,7 @@ def forward(self, x): x = self.final_conv(x) x = self.pooling(x) x = x.view(x.size(0), -1) + return x def train (self, imgs, labels): @@ -159,6 +167,7 @@ def train (self, imgs, labels): imgs, labels = data optimizer.zero_grad() preds = self.forward(imgs) + l = loss_fn(preds, labels) l.backward() optimizer.step() @@ -194,23 +203,17 @@ def __init__(self, model_config, train_config, eval_config): self.eval_config = eval_config # Initialize the cellpose model and the classifier - self.segmentor = CustomCellposeModel(self.model_config["segmentor"], - self.train_config["segmentor"], - self.eval_config["segmentor"]) - self.classifier = CellClassifierFCNN(self.model_config["classifier"], - self.train_config["classifier"], - self.eval_config["classifier"]) - - def init_from_checkpoints(self, chpt_classifier=None, chpt_segmentor=None): - """ - Initialize the model from pre-trained checkpoints. - """ self.segmentor = CustomCellposeModel( - model_config={"gpu":torch.cuda.is_available(), "pretrained_model":chpt_segmentor} - ) - self.classifier.load_state_dict(torch.load(chpt_classifier)["model"]) - + self.model_config.get("segmentor", {}), + self.train_config.get("segmentor", {}), + self.eval_config.get("segmentor", {}) + ) + self.classifier = CellClassifierFCNN( + self.model_config.get("classifier", {}), + self.train_config.get("classifier", {}), + self.eval_config.get("classifier", {}) + ) def train(self, imgs, masks): # masks should have first channel as a cellpose mask and all other layers @@ -226,10 +229,11 @@ def train(self, imgs, masks): # train cellpose masks = np.array(masks) - masks_instances = list(masks[:,0, ...]) + masks_instances = [mask.sum(-1) for mask in masks] if masks[0].ndim == 3 else masks + masks_classes = [((mask > 0) * np.arange(1, 4)).sum(-1) for mask in masks] + self.segmentor.train(imgs, masks_instances) # create patch dataset to train classifier - masks_classes = list(masks[:,1, ...]) patches, labels = self.create_patch_dataset(imgs, masks_classes, masks_instances) # train classifier self.classifier.train(patches, labels) @@ -267,89 +271,6 @@ def eval(self, img, **eval_config): return final_mask - # REMOVE? replaced by code in eval - ''' - def get_prediction(self, input_image, cellpose_mask): - """ - Performs object segmentation and classification on an input image using the Cellpose model and a classifier model. - - Args: - image_path (str): The file path of the input image. - model (CellposeModel): The Cellpose model used for object segmentation, instance segmenation mask. - - Returns: - tuple: A tuple containing the cellpose_mask and final_mask, representing the segmentation masks obtained from - the Cellpose model and the combined segmentation and classification mask, respectively. - """ - - # Obtain segmentation mask using Cellpose model - - # Find objects in the cellpose_mask - locs = find_objects(cellpose_mask) - - # Get patches and labels based on object centroids - patches, labels = self.get_centered_patches(input_image, cellpose_mask, int(1.5 * input_image.shape[0] // 5), noise_intensity=5) - - labels = torch.tensor(labels) - labels_fit = [] - - final_mask = torch.zeros(cellpose_mask.shape) - - with torch.no_grad(): - for i, patch in enumerate(patches): - loc = locs[i] - - # Prepare image patch for classification - img = torch.tensor(patch.astype(np.float32)).unsqueeze(0).unsqueeze(0) / 255 - # img = img.mean(dim=1, keepdim=True) - - # Perform inference using model_classifier - logits = self.classifier(img) - - _, predicted = torch.max(logits, 1) - labels_fit.append(predicted) - - # Assign predicted class to corresponding location in final_mask - final_mask[loc] = predicted + 1 - - # Apply mask to final_mask, retaining only regions where cellpose_mask is greater than 0 - final_mask = final_mask * ((cellpose_mask > 0).long()) - - return final_mask - ''' - def find_max_patch_size(self, mask): - - # Find objects in the mask - objects = find_objects(mask) - - # Initialize variables to store the maximum patch size - max_patch_size = 0 - max_patch_indices = None - - # Iterate over the found objects - for obj in objects: - # Extract start and stop values from the slice object - slices = [s for s in obj] - start = [s.start for s in slices] - stop = [s.stop for s in slices] - - # Calculate the size of the patch along each axis - patch_size = tuple(stop[i] - start[i] for i in range(len(start))) - - # Calculate the total size (area) of the patch - total_size = 1 - for size in patch_size: - total_size *= size - - # Check if the current patch size is larger than the maximum - if total_size > max_patch_size: - max_patch_size = total_size - max_patch_indices = obj - - max_patch_size_edge = np.ceil(np.sqrt(max_patch_size)) - - return max_patch_size_edge - def crop_centered_padded_patch(self, x: np.ndarray, c, @@ -394,6 +315,7 @@ def crop_centered_padded_patch(self, patch = x[max(top, 0):min(bottom, x.shape[0]), max(left, 0):min(right, x.shape[1]), :] + # used during the training phase only if len(c) == 3: patch = patch[...,c[2]] @@ -483,21 +405,16 @@ def create_patch_dataset(self, imgs, masks_classes, masks_instances): the max cell size to define the patch size. All patches and masks should then be returned in the same format as imgs and masks (same type, i.e. check if tensor or np.array and same convention of dims, e.g. CxHxW) - Args: - imgs (): - masks (): - black_bg (bool): Flag indicating whether to use a black background for patches. - include_mask (bool): Flag indicating whether to include the mask along with patches. ''' noise_intensity = self.train_config["classifier"]["train_data"]["noise_intensity"] - max_patch_size = self.train_config["classifier"]["train_data"]["patch_size"] - num_classes = self.train_config["classifier"]["train_data"]["num_classes"] patches, labels = [], [] for img, mask_class, mask_instance in zip(imgs, masks_classes, masks_instances): # Convert to one-hot encoding - # mask has dimension WxHxNum_of_channels + # mask_instance has dimension WxH + # mask_class has dimension WxH + patch, label = self.get_centered_patches(img, mask_instance, self.train_config["classifier"]["train_data"]["patch_size"], diff --git a/src/server/requirements.txt b/src/server/requirements.txt index 3879e66..8fab307 100644 --- a/src/server/requirements.txt +++ b/src/server/requirements.txt @@ -2,3 +2,4 @@ cellpose>=2.2 bentoml>=1.0.13 scikit-image>=0.19.3 torchmetrics>=0.11.4 +torch>=2.1.0 diff --git a/src/server/test/synthetic_dataset.py b/src/server/test/synthetic_dataset.py index 3712935..ccf3fee 100644 --- a/src/server/test/synthetic_dataset.py +++ b/src/server/test/synthetic_dataset.py @@ -1,14 +1,8 @@ import numpy as np import cv2 -import matplotlib.pyplot as plt import random -import os -import math -from skimage.transform import rotate import skimage.color as color -from skimage.util import random_noise -from skimage import io import scipy.ndimage as ndi # set seed for reproducibility diff --git a/src/server/test/test_integration.py b/src/server/test/test_integration.py index cec7e06..2cb1323 100644 --- a/src/server/test/test_integration.py +++ b/src/server/test/test_integration.py @@ -1,16 +1,7 @@ -import os -import cv2 import sys import torch -import random -import math from torchmetrics import JaccardIndex - - import numpy as np -from tqdm import tqdm -from copy import deepcopy -from skimage.color import label2rgb sys.path.append("../") from dcp_server.models import CellposePatchCNN From c5de9ba590aba6361e5ca29f9914b0263114ed16 Mon Sep 17 00:00:00 2001 From: Mariia Date: Sun, 29 Oct 2023 13:30:06 +0100 Subject: [PATCH 18/47] fixed test paths --- src/server/test/synthetic_dataset.py | 6 +++--- src/server/test/test_integration.py | 9 +++++---- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/src/server/test/synthetic_dataset.py b/src/server/test/synthetic_dataset.py index ccf3fee..e577ec8 100644 --- a/src/server/test/synthetic_dataset.py +++ b/src/server/test/synthetic_dataset.py @@ -220,17 +220,17 @@ def get_synthetic_dataset(num_samples, canvas_size=512, max_object_counts=[15, 1 objects = [ { 'name': 'triangle', - 'path': 'shapes/triangle.png', + 'path': 'test/shapes/triangle.png', 'intensity' : [0, 0.33] }, { 'name': 'circle', - 'path': 'shapes/circle.png', + 'path': 'test/shapes/circle.png', 'intensity' : [0.34, 0.66] }, { 'name': 'square', - 'path': 'shapes/square.png', + 'path': 'test/shapes/square.png', 'intensity' : [0.67, 1.0] }, ] diff --git a/src/server/test/test_integration.py b/src/server/test/test_integration.py index 2cb1323..7f6521a 100644 --- a/src/server/test/test_integration.py +++ b/src/server/test/test_integration.py @@ -3,7 +3,8 @@ from torchmetrics import JaccardIndex import numpy as np -sys.path.append("../") +sys.path.append(".") + from dcp_server.models import CellposePatchCNN from dcp_server.utils import read_config from synthetic_dataset import get_synthetic_dataset @@ -14,9 +15,9 @@ def patch_model(): - model_config = read_config('model', config_path='../dcp_server/config.cfg') - train_config = read_config('train', config_path='../dcp_server/config.cfg') - eval_config = read_config('eval', config_path='../dcp_server/config.cfg') + model_config = read_config('model', config_path='dcp_server/config.cfg') + train_config = read_config('train', config_path='dcp_server/config.cfg') + eval_config = read_config('eval', config_path='dcp_server/config.cfg') patch_model = CellposePatchCNN(model_config, train_config, eval_config) return patch_model From eb5c1f4eae353a37aba838945b2ceff145d3e427 Mon Sep 17 00:00:00 2001 From: Mariia Date: Mon, 30 Oct 2023 00:10:28 +0100 Subject: [PATCH 19/47] Updates according to code review --- src/server/dcp_server/config.cfg | 5 ++- src/server/dcp_server/models.py | 65 +++++++++++++++++++++++++--- src/server/dcp_server/service.py | 3 ++ src/server/test/synthetic_dataset.py | 2 +- 4 files changed, 65 insertions(+), 10 deletions(-) diff --git a/src/server/dcp_server/config.cfg b/src/server/dcp_server/config.cfg index 1f35f84..9565027 100644 --- a/src/server/dcp_server/config.cfg +++ b/src/server/dcp_server/config.cfg @@ -26,7 +26,7 @@ }, "data": { - "data_root": "/home/ubuntu/dcp-data" + "data_root": "D:/Helmholtz/dcp/data-centric-platform/data" }, "train":{ @@ -55,7 +55,8 @@ }, "classifier": { "data":{ - "patch_size": 64 + "patch_size": 64, + "noise_intensity": 5 } } } diff --git a/src/server/dcp_server/models.py b/src/server/dcp_server/models.py index 07b63ad..233b736 100644 --- a/src/server/dcp_server/models.py +++ b/src/server/dcp_server/models.py @@ -7,6 +7,7 @@ from tqdm import tqdm import numpy as np from scipy.ndimage import find_objects, center_of_mass +import cv2 #from segment_anything import SamPredictor, sam_model_registry #from segment_anything.automatic_mask_generator import SamAutomaticMaskGenerator @@ -83,8 +84,10 @@ class CellClassifierFCNN(nn.Module): Fully convolutional classifier for cell images. Args: - in_channels (int): Number of input channels. - num_classes (int): Number of output classes. + model_config (dict): Model configuration. + train_config (dict): Training configuration. + eval_config (dict): Evaluation configuration. + ''' def __init__(self, model_config, train_config, eval_config): @@ -228,6 +231,10 @@ def train(self, imgs, masks): # mask.transpose(1, 2, 0) for mask in masks] # train cellpose + + if imgs[0].ndim == 3: + imgs = [cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) for img in imgs] + masks = np.array(masks) masks_instances = [mask.sum(-1) for mask in masks] if masks[0].ndim == 3 else masks masks_classes = [((mask > 0) * np.arange(1, 4)).sum(-1) for mask in masks] @@ -237,7 +244,6 @@ def train(self, imgs, masks): patches, labels = self.create_patch_dataset(imgs, masks_classes, masks_instances) # train classifier self.classifier.train(patches, labels) - #return # TODO - define if we need to return something def eval(self, img, **eval_config): @@ -252,11 +258,19 @@ def eval(self, img, **eval_config): # find coordinates of detected objects locs = find_objects(instance_mask) class_mask = np.zeros(instance_mask.shape) + + + max_patch_size = self.eval_config["classifier"]["data"]["patch_size"] + noise_intensity = self.eval_config["classifier"]["data"]["noise_intensity"] + + if max_patch_size is None: + max_patch_size = self.find_max_patch_size(instance_mask) + # get patches centered around detected objects patches, _ = self.get_centered_patches(img, instance_mask, - self.eval_config["classifier"]["data"]["patch_size"], - noise_intensity=5) + max_patch_size, + noise_intensity=noise_intensity) # loop over patches and create classification mask for idx, patch in enumerate(patches): patch_class = self.classifier.eval(patch) @@ -316,6 +330,7 @@ def crop_centered_padded_patch(self, patch = x[max(top, 0):min(bottom, x.shape[0]), max(left, 0):min(right, x.shape[1]), :] # used during the training phase only + # c is (cx, cy, celltype = {0, 1, 2}) during training or (cx, cy) during inference if len(c) == 3: patch = patch[...,c[2]] @@ -397,17 +412,53 @@ def get_centered_patches(self, if mask_class is not None: labels.append(mask_class[c[0]][c[1]]) return patches, labels + + def find_max_patch_size(self, mask): + + # Find objects in the mask + objects = find_objects(mask) + + # Initialize variables to store the maximum patch size + max_patch_size = 0 + + # Iterate over the found objects + for obj in objects: + # Extract start and stop values from the slice object + slices = [s for s in obj] + start = [s.start for s in slices] + stop = [s.stop for s in slices] + + # Calculate the size of the patch along each axis + patch_size = tuple(stop[i] - start[i] for i in range(len(start))) + + # Calculate the total size (area) of the patch + total_size = 1 + for size in patch_size: + total_size *= size + + # Check if the current patch size is larger than the maximum + if total_size > max_patch_size: + max_patch_size = total_size + + max_patch_size_edge = np.ceil(np.sqrt(max_patch_size)) + + return max_patch_size_edge def create_patch_dataset(self, imgs, masks_classes, masks_instances): ''' Splits img and masks into patches of equal size which are centered around the cells. - The algorithm should first run through all images to find the max cell size, and use + If patch_size is not given, the algorithm should first run through all images to find the max cell size, and use the max cell size to define the patch size. All patches and masks should then be returned in the same format as imgs and masks (same type, i.e. check if tensor or np.array and same convention of dims, e.g. CxHxW) ''' noise_intensity = self.train_config["classifier"]["train_data"]["noise_intensity"] + max_patch_size = self.train_config["classifier"]["train_data"]["patch_size"] + + if max_patch_size is None: + max_patch_size = np.max([self.find_max_patch_size(mask) for mask in masks_instances]) + patches, labels = [], [] for img, mask_class, mask_instance in zip(imgs, masks_classes, masks_instances): @@ -417,7 +468,7 @@ def create_patch_dataset(self, imgs, masks_classes, masks_instances): patch, label = self.get_centered_patches(img, mask_instance, - self.train_config["classifier"]["train_data"]["patch_size"], + max_patch_size, noise_intensity=noise_intensity, mask_class=mask_class) patches.extend(patch) diff --git a/src/server/dcp_server/service.py b/src/server/dcp_server/service.py index 308fec3..8eae7ef 100644 --- a/src/server/dcp_server/service.py +++ b/src/server/dcp_server/service.py @@ -5,6 +5,8 @@ from dcp_server.serviceclasses import CustomBentoService, CustomRunnable from dcp_server.utils import read_config +import sys, inspect + models_module = __import__("models") segmentation_module = __import__("segmentationclasses") @@ -17,6 +19,7 @@ setup_config = read_config('setup', config_path = 'config.cfg') # instantiate the model + model_class = getattr(models_module, service_config['model_to_use']) model = model_class(model_config = model_config, train_config = train_config, eval_config = eval_config) custom_model_runner = t.cast( diff --git a/src/server/test/synthetic_dataset.py b/src/server/test/synthetic_dataset.py index e577ec8..4adabb5 100644 --- a/src/server/test/synthetic_dataset.py +++ b/src/server/test/synthetic_dataset.py @@ -235,4 +235,4 @@ def get_synthetic_dataset(num_samples, canvas_size=512, max_object_counts=[15, 1 }, ] images, masks = generate_dataset(num_samples, objects, canvas_size=canvas_size, max_object_counts=max_object_counts, noise_intensity=5, max_rotation_angle=30) - return images, masks \ No newline at end of file + return images, masks From 1c9734422abe0fe5ca468790cc31531f3efcd3ff Mon Sep 17 00:00:00 2001 From: Christina Bukas Date: Mon, 30 Oct 2023 11:45:31 +0100 Subject: [PATCH 20/47] updated eval config to include channel and z axis --- src/server/dcp_server/config.cfg | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/server/dcp_server/config.cfg b/src/server/dcp_server/config.cfg index 8af1781..8b0fffc 100644 --- a/src/server/dcp_server/config.cfg +++ b/src/server/dcp_server/config.cfg @@ -31,7 +31,7 @@ "train":{ "segmentor":{ - "n_epochs": 1, + "n_epochs": 7, "channels":[0] }, "classifier":{ @@ -40,7 +40,7 @@ "noise_intensity": 5, "num_classes": 3 }, - "n_epochs": 1, + "n_epochs": 8, "lr": 0.001, "batch_size": 1, "optimizer": "Adam" @@ -48,7 +48,10 @@ }, "eval":{ - "segmentor": "None", + "segmentor": { + "z_axis": null, + "channel_axis": null + }, "classifier": { "data":{ "patch_size": 64 From a3f9ab5ed44e678dc49a614ebe4ba2a4b348a580 Mon Sep 17 00:00:00 2001 From: Christina Bukas Date: Mon, 30 Oct 2023 11:46:02 +0100 Subject: [PATCH 21/47] updated get_image_size_properties to return z and channel axis --- src/server/dcp_server/fsimagestorage.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/src/server/dcp_server/fsimagestorage.py b/src/server/dcp_server/fsimagestorage.py index 98ad112..55c8b18 100644 --- a/src/server/dcp_server/fsimagestorage.py +++ b/src/server/dcp_server/fsimagestorage.py @@ -105,21 +105,30 @@ def get_image_size_properties(self, img, file_extension): """ orig_size = img.shape - # png and jpeg will be RGB by default and 2D + # png and jpeg will be RGB by default and 2D # tif can be grayscale 2D or 2D RGB and RGBA - if file_extension in (".jpg", ".jpeg", ".png") or (file_extension in (".tiff", ".tif") and len(orig_size)==2 or (len(orig_size)==3 and (orig_size[-1]==3 or orig_size[-1]==4))): + # RGB can be [C, H, W] or [H, W, C] + if file_extension in (".jpg", ".jpeg", ".png"): height, width = orig_size[0], orig_size[1] channel_ax = 2 + z_axis = None + elif file_extension in (".tiff", ".tif") and len(orig_size)==2: + channel_ax = None + z_axis = None + elif (len(orig_size)==3 and (orig_size[-1]==3 or orig_size[-1]==4)): + channel_ax = 2 + z_axis = None # or 3D tiff grayscale elif file_extension in (".tiff", ".tif") and len(orig_size)==3: print('Warning: 3D image stack found. We are assuming your first dimension is your stack dimension. Please cross check this.') height, width = orig_size[1], orig_size[2] - channel_ax = 0 + channel_ax = None + z_axis = 0 else: pass - return height, width, channel_ax + return height, width, channel_ax, z_axis def rescale_image(self, img, height, width, channel_ax, order): """rescale image From 7d128a2815ec991d43edeae8909d5939e94e03eb Mon Sep 17 00:00:00 2001 From: Christina Bukas Date: Mon, 30 Oct 2023 11:46:50 +0100 Subject: [PATCH 22/47] updated the way conifgs are passed to models --- src/server/dcp_server/models.py | 38 +++++++++++++++++---------------- 1 file changed, 20 insertions(+), 18 deletions(-) diff --git a/src/server/dcp_server/models.py b/src/server/dcp_server/models.py index a177af4..5e1e027 100644 --- a/src/server/dcp_server/models.py +++ b/src/server/dcp_server/models.py @@ -30,11 +30,11 @@ def __init__(self, model_config, train_config, eval_config): """ # Initialize the cellpose model - super().__init__(**model_config) - self.train_config = train_config - self.eval_config = eval_config + super().__init__(**model_config["segmentor"]) + self.train_config = train_config["segmentor"] + self.eval_config = eval_config["segmentor"] - def eval(self, img, **eval_config): + def eval(self, img): """Evaluate the model - find mask of the given image Calls the original eval function. @@ -44,9 +44,9 @@ def eval(self, img, **eval_config): :type z_axis: int :return: mask of the image, list of 2D arrays, or single 3D array (if do_3D=True) labelled image. :rtype: np.ndarray - """ - return super().eval(x=img, **eval_config)[0] # 0 to take only mask - + """ + return super().eval(x=img, **self.eval_config)[0] # 0 to take only mask + def train(self, imgs, masks): """Trains the given model Calls the original train function. @@ -57,7 +57,9 @@ def train(self, imgs, masks): :type masks: List[np.ndarray] """ super().train(train_data=deepcopy(imgs), train_labels=masks, **self.train_config) - self.loss = self.loss_fn(masks, super().eval(imgs, self.eval_config)[0]) + + #pred_masks = [self.eval(img, **self.eval_config) for img in masks] + #self.loss = self.loss_fn(masks, pred_masks) def masks_to_outlines(self, mask): """ get outlines of masks as a 0-1 array @@ -83,11 +85,11 @@ class CellClassifierFCNN(nn.Module): def __init__(self, model_config, train_config, eval_config): super().__init__() - self.in_channels = model_config["in_channels"] - self.num_classes = model_config["num_classes"] + 1 + self.in_channels = model_config["classifier"]["in_channels"] + self.num_classes = model_config["classifier"]["num_classes"] + 1 - self.train_config = train_config - self.eval_config = eval_config + self.train_config = train_config["classifier"] + self.eval_config = eval_config["classifier"] self.layer1 = nn.Sequential( nn.Conv2d(self.in_channels, 16, 3, 2, 5), @@ -194,12 +196,12 @@ def __init__(self, model_config, train_config, eval_config): self.eval_config = eval_config # Initialize the cellpose model and the classifier - self.segmentor = CustomCellposeModel(self.model_config["segmentor"], - self.train_config["segmentor"], - self.eval_config["segmentor"]) - self.classifier = CellClassifierFCNN(self.model_config["classifier"], - self.train_config["classifier"], - self.eval_config["classifier"]) + self.segmentor = CustomCellposeModel(self.model_config, + self.train_config, + self.eval_config) + self.classifier = CellClassifierFCNN(self.model_config, + self.train_config, + self.eval_config) def init_from_checkpoints(self, chpt_classifier=None, chpt_segmentor=None): """ From d05c84f2189c1cb2329668b94c1e0d20d64304d5 Mon Sep 17 00:00:00 2001 From: Christina Bukas Date: Mon, 30 Oct 2023 11:47:25 +0100 Subject: [PATCH 23/47] updated with z and channel axis args and removed eval_config from eval since it is in init --- src/server/dcp_server/segmentationclasses.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/server/dcp_server/segmentationclasses.py b/src/server/dcp_server/segmentationclasses.py index ce5a713..506e5df 100644 --- a/src/server/dcp_server/segmentationclasses.py +++ b/src/server/dcp_server/segmentationclasses.py @@ -34,14 +34,15 @@ async def segment_image(self, input_path, list_of_images): # Load the image img = self.imagestorage.load_image(img_filepath) # Get size properties - height, width, channel_ax = self.imagestorage.get_image_size_properties(img, utils.get_file_extension(img_filepath)) + height, width, channel_ax, z_axis = self.imagestorage.get_image_size_properties(img, utils.get_file_extension(img_filepath)) img = self.imagestorage.rescale_image(img, height, width, channel_ax, order=None) # Add channel ax into the model's evaluation parameters dictionary - self.model.eval_config['z_axis'] = channel_ax + self.model.eval_config['segmentor']['z_axis'] = z_axis + self.model.eval_config['segmentor']['channel_axis'] = channel_ax # Evaluate the model - mask = await self.runner.evaluate.async_run(img = img, **self.model.eval_config) + mask = await self.runner.evaluate.async_run(img = img) # Resize the mask - channel_ax = self.model.eval_config['z_axis'] + channel_ax = self.model.eval_config['segmentor']['channel_axis'] mask = self.imagestorage.rescale_image(mask, height, width, channel_ax, order=0) # Save segmentation seg_name = utils.get_path_stem(img_filepath) + setup_config['seg_name_string'] + '.tiff' From a06f0571498eda846c8c70fb48ae69244bc2bfee Mon Sep 17 00:00:00 2001 From: Christina Bukas Date: Mon, 30 Oct 2023 11:47:49 +0100 Subject: [PATCH 24/47] removed eval config from eval since it is in model init --- src/server/dcp_server/serviceclasses.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/src/server/dcp_server/serviceclasses.py b/src/server/dcp_server/serviceclasses.py index 8f62c0f..37779a2 100644 --- a/src/server/dcp_server/serviceclasses.py +++ b/src/server/dcp_server/serviceclasses.py @@ -25,7 +25,7 @@ def __init__(self, model, save_model_path): self.save_model_path = save_model_path @bentoml.Runnable.method(batchable=False) - def evaluate(self, img: np.ndarray, **eval_config) -> np.ndarray: + def evaluate(self, img: np.ndarray) -> np.ndarray: """Evaluate the model - find mask of the given image :param img: image to evaluate on @@ -36,7 +36,7 @@ def evaluate(self, img: np.ndarray, **eval_config) -> np.ndarray: :rtype: np.ndarray """ - mask = self.model.eval(img=img, **eval_config) + mask = self.model.eval(img=img) return mask @@ -51,9 +51,17 @@ def train(self, imgs: List[np.ndarray], masks: List[np.ndarray]) -> str: :return: path of the saved model :rtype: str """ - + s1 = self.model.segmentor.net.state_dict() + c1 = self.model.classifier.parameters() self.model.train(imgs, masks) - + s2 = self.model.segmentor.net.state_dict() + c2 = self.model.classifier.parameters() + if s1 == s2: print('S1 and S2 COMP: THEY ARE THE SAME!!!!!') + else: print('S1 and S2 COMP: THEY ARE NOOOT THE SAME!!!!!') + for p1, p2 in zip(c1, c2): + if p1.data.ne(p2.data).sum() > 0: + print("C1 and C2 NOT THE SAME") + break # Save the bentoml model bentoml.picklable_model.save_model(self.save_model_path, self.model) From ffe5adfef9c1fcb3d469d07e440f5164c62967bc Mon Sep 17 00:00:00 2001 From: hpelin Date: Tue, 31 Oct 2023 10:51:48 +0100 Subject: [PATCH 25/47] search seg bug fix --- src/client/dcp_client/app.py | 2 +- src/server/dcp_server/fsimagestorage.py | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/client/dcp_client/app.py b/src/client/dcp_client/app.py index 3a7f927..b5e89c2 100644 --- a/src/client/dcp_client/app.py +++ b/src/client/dcp_client/app.py @@ -36,7 +36,7 @@ def search_segs(self, img_directory, cur_selected_img): """Returns a list of full paths of segmentations for an image""" # Take all segmentations of the image from the current directory: search_string = utils.get_path_stem(cur_selected_img) + '_seg' - seg_files = [file_name for file_name in os.listdir(img_directory) if search_string in file_name] + seg_files = [file_name for file_name in os.listdir(img_directory) if (search_string == utils.get_path_stem(file_name) or str(file_name).startswith(search_string))] return seg_files diff --git a/src/server/dcp_server/fsimagestorage.py b/src/server/dcp_server/fsimagestorage.py index 98ad112..fc6d2ee 100644 --- a/src/server/dcp_server/fsimagestorage.py +++ b/src/server/dcp_server/fsimagestorage.py @@ -60,7 +60,11 @@ def search_segs(self, cur_selected_img): img_directory = utils.get_path_parent(os.path.join(self.root_dir, cur_selected_img)) # Take all segmentations of the image from the current directory: search_string = utils.get_path_stem(cur_selected_img) + setup_config['seg_name_string'] - seg_files = [os.path.join(img_directory, file_name) for file_name in os.listdir(img_directory) if search_string in file_name] + #seg_files = [os.path.join(img_directory, file_name) for file_name in os.listdir(img_directory) if search_string in file_name] + # TODO: check where this is used - copied the command from app's search_segs function (to fix the 1_seg and 11_seg bug) + seg_files = [file_name for file_name in os.listdir(img_directory) if (search_string == utils.get_path_stem(file_name) or str(file_name).startswith(search_string))] + + return seg_files def get_image_seg_pairs(self, directory): From eb0d9c240cf2474e53808822301b284e5b603c1a Mon Sep 17 00:00:00 2001 From: hpelin Date: Tue, 31 Oct 2023 11:10:16 +0100 Subject: [PATCH 26/47] passed 'segmentor' into initi of CustomCellP. --- src/server/dcp_server/models.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/server/dcp_server/models.py b/src/server/dcp_server/models.py index 233b736..bde0a2d 100644 --- a/src/server/dcp_server/models.py +++ b/src/server/dcp_server/models.py @@ -31,8 +31,7 @@ def __init__(self, model_config, train_config, eval_config): """ # Initialize the cellpose model - - super().__init__(**model_config) + super().__init__(**model_config['segmentor']) self.train_config = train_config self.eval_config = eval_config From 6f6d2fc5d01ed58f62b4f22330a4ac4ff5f3430e Mon Sep 17 00:00:00 2001 From: Christina Bukas Date: Fri, 3 Nov 2023 16:31:42 +0100 Subject: [PATCH 27/47] moved image processing and dataset creation functions to utils --- src/server/dcp_server/utils.py | 182 ++++++++++++++++++++++++++++++++- 1 file changed, 181 insertions(+), 1 deletion(-) diff --git a/src/server/dcp_server/utils.py b/src/server/dcp_server/utils.py index 866b1b1..aa987c6 100644 --- a/src/server/dcp_server/utils.py +++ b/src/server/dcp_server/utils.py @@ -1,6 +1,7 @@ from pathlib import Path import json - +import numpy as np +from scipy.ndimage import find_objects, center_of_mass def read_config(name, config_path = 'config.cfg') -> dict: """Reads the configuration file @@ -31,3 +32,182 @@ def join_path(root_dir, filepath): return str(Path(root_dir, filepath)) def get_file_extension(file): return str(Path(file).suffix) + + +def crop_centered_padded_patch(x: np.ndarray, + c, + p, + mask: np.ndarray=None, + noise_intensity=None) -> np.ndarray: + """ + Crop a patch from an array `x` centered at coordinates `c` with size `p`, and apply padding if necessary. + + Args: + x (np.ndarray): The input array from which the patch will be cropped. + c (tuple): The coordinates (row, column, channel) at the center of the patch. + p (tuple): The size of the patch to be cropped (height, width). + + Returns: + np.ndarray: The cropped patch with applied padding. + """ + + height, width = p # Size of the patch + + # Calculate the boundaries of the patch + top = c[0] - height // 2 + bottom = top + height + + left = c[1] - width // 2 + right = left + width + + # Crop the patch from the input array + if mask is not None: + mask_ = mask.max(-1) if len(mask.shape) >= 3 else mask + # central_label = mask_[c[0], c[1]] + central_label = mask_[c[0]][c[1]] + # Zero out values in the patch where the mask is not equal to the central label + # m = (mask_ != central_label) & (mask_ > 0) + m = (mask_ != central_label) & (mask_ > 0) + x[m] = 0 + if noise_intensity is not None: + x[m] = np.random.normal(scale=noise_intensity, size=x[m].shape) + + patch = x[max(top, 0):min(bottom, x.shape[0]), max(left, 0):min(right, x.shape[1]), :] + + # Calculate the required padding amounts + size_x, size_y = x.shape[1], x.shape[0] + + # Apply padding if necessary + if left < 0: + patch = np.hstack(( + np.random.normal(scale=noise_intensity, size=(patch.shape[0], abs(left), patch.shape[2])).astype(np.uint8), + patch)) + # Apply padding on the right side if necessary + if right > size_x: + patch = np.hstack(( + patch, + np.random.normal(scale=noise_intensity, size=(patch.shape[0], (right - size_x), patch.shape[2])).astype(np.uint8))) + # Apply padding on the top side if necessary + if top < 0: + patch = np.vstack(( + np.random.normal(scale=noise_intensity, size=(abs(top), patch.shape[1], patch.shape[2])).astype(np.uint8), + patch)) + # Apply padding on the bottom side if necessary + if bottom > size_y: + patch = np.vstack(( + patch, + np.random.normal(scale=noise_intensity, size=(bottom - size_y, patch.shape[1], patch.shape[2])).astype(np.uint8))) + + return patch + + +def get_center_of_mass(mask: np.ndarray) -> np.ndarray: + """ + Compute the centers of mass for each object in a mask. + + Args: + mask (np.ndarray): The input mask containing labeled objects. + + Returns: + np.ndarray: An array of coordinates (row, column, channel) representing the centers of mass for each object. + """ + + # Compute the centers of mass for each labeled object in the mask + return [(int(x[0]), int(x[1])) + for x in center_of_mass(mask, mask, np.arange(1, mask.max() + 1))] + +def get_centered_patches(img, + mask, + p_size: int, + noise_intensity=5, + mask_class=None): + + ''' + Extracts centered patches from the input image based on the centers of objects identified in the mask. + + Args: + img: The input image. + mask: The mask representing the objects in the image. + p_size (int): The size of the patches to extract. + noise_intensity: The intensity of noise to add to the patches. + + ''' + + patches, labels = [], [] + # if image is 2D add an additional dim for channels + if img.ndim<3: img = img[:, :, np.newaxis] + if mask.ndim<3: mask = mask[:, :, np.newaxis] + # compute center of mass of objects + centers_of_mass = get_center_of_mass(mask) + # Crop patches around each center of mass + for c in centers_of_mass: + c_x, c_y = c + patch = crop_centered_padded_patch(img.copy(), + (c_x, c_y), + (p_size, p_size), + mask=mask, + noise_intensity=noise_intensity) + patches.append(patch) + if mask_class is not None: labels.append(mask_class[c[0]][c[1]]) + + return patches, labels + +def get_objects(mask): + return find_objects(mask) + +def find_max_patch_size(mask): + + # Find objects in the mask + objects = get_objects(mask) + + # Initialize variables to store the maximum patch size + max_patch_size = 0 + + # Iterate over the found objects + for obj in objects: + # Extract start and stop values from the slice object + slices = [s for s in obj] + start = [s.start for s in slices] + stop = [s.stop for s in slices] + + # Calculate the size of the patch along each axis + patch_size = tuple(stop[i] - start[i] for i in range(len(start))) + + # Calculate the total size (area) of the patch + total_size = 1 + for size in patch_size: + total_size *= size + + # Check if the current patch size is larger than the maximum + if total_size > max_patch_size: + max_patch_size = total_size + + max_patch_size_edge = np.ceil(np.sqrt(max_patch_size)) + + return max_patch_size_edge + +def create_patch_dataset(imgs, masks_classes, masks_instances, noise_intensity, max_patch_size): + ''' + Splits img and masks into patches of equal size which are centered around the cells. + If patch_size is not given, the algorithm should first run through all images to find the max cell size, and use + the max cell size to define the patch size. All patches and masks should then be returned + in the same format as imgs and masks (same type, i.e. check if tensor or np.array and same + convention of dims, e.g. CxHxW) + ''' + + if max_patch_size is None: + max_patch_size = np.max([find_max_patch_size(mask) for mask in masks_instances]) + + + patches, labels = [], [] + for img, mask_class, mask_instance in zip(imgs, masks_classes, masks_instances): + # mask_instance has dimension WxH + # mask_class has dimension WxH + patch, label = get_centered_patches(img, + mask_instance, + max_patch_size, + noise_intensity=noise_intensity, + mask_class=mask_class) + patches.extend(patch) + labels.extend(label) + return patches, labels \ No newline at end of file From 460cdaf4ba699ad935618075793a8ba03452eba1 Mon Sep 17 00:00:00 2001 From: Christina Bukas Date: Fri, 3 Nov 2023 16:32:00 +0100 Subject: [PATCH 28/47] commenting lines which print model weights --- src/server/dcp_server/serviceclasses.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/server/dcp_server/serviceclasses.py b/src/server/dcp_server/serviceclasses.py index 37779a2..1515eac 100644 --- a/src/server/dcp_server/serviceclasses.py +++ b/src/server/dcp_server/serviceclasses.py @@ -51,9 +51,10 @@ def train(self, imgs: List[np.ndarray], masks: List[np.ndarray]) -> str: :return: path of the saved model :rtype: str """ - s1 = self.model.segmentor.net.state_dict() - c1 = self.model.classifier.parameters() + #s1 = self.model.segmentor.net.state_dict() + #c1 = self.model.classifier.parameters() self.model.train(imgs, masks) + ''' s2 = self.model.segmentor.net.state_dict() c2 = self.model.classifier.parameters() if s1 == s2: print('S1 and S2 COMP: THEY ARE THE SAME!!!!!') @@ -62,6 +63,7 @@ def train(self, imgs: List[np.ndarray], masks: List[np.ndarray]) -> str: if p1.data.ne(p2.data).sum() > 0: print("C1 and C2 NOT THE SAME") break + ''' # Save the bentoml model bentoml.picklable_model.save_model(self.save_model_path, self.model) From 266bb725f33bf59036565fa3ad0466dc372cb536 Mon Sep 17 00:00:00 2001 From: Christina Bukas Date: Fri, 3 Nov 2023 16:32:13 +0100 Subject: [PATCH 29/47] added comment on expected formats --- src/server/dcp_server/fsimagestorage.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/server/dcp_server/fsimagestorage.py b/src/server/dcp_server/fsimagestorage.py index 55c8b18..3259084 100644 --- a/src/server/dcp_server/fsimagestorage.py +++ b/src/server/dcp_server/fsimagestorage.py @@ -115,6 +115,7 @@ def get_image_size_properties(self, img, file_extension): elif file_extension in (".tiff", ".tif") and len(orig_size)==2: channel_ax = None z_axis = None + # if we have 3 dimensions and the third is size 3 or 4, then we assume it is the channel axis elif (len(orig_size)==3 and (orig_size[-1]==3 or orig_size[-1]==4)): channel_ax = 2 z_axis = None From a36ec9e299c5cbff759cc633c21ba5d02edc74d1 Mon Sep 17 00:00:00 2001 From: Christina Bukas Date: Fri, 3 Nov 2023 16:32:49 +0100 Subject: [PATCH 30/47] added new mask_channel_axis arg in config which stores channel axis of mask --- src/server/dcp_server/segmentationclasses.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/src/server/dcp_server/segmentationclasses.py b/src/server/dcp_server/segmentationclasses.py index 506e5df..e739cc5 100644 --- a/src/server/dcp_server/segmentationclasses.py +++ b/src/server/dcp_server/segmentationclasses.py @@ -35,15 +35,22 @@ async def segment_image(self, input_path, list_of_images): img = self.imagestorage.load_image(img_filepath) # Get size properties height, width, channel_ax, z_axis = self.imagestorage.get_image_size_properties(img, utils.get_file_extension(img_filepath)) - img = self.imagestorage.rescale_image(img, height, width, channel_ax, order=None) + img = self.imagestorage.rescale_image(img, + height, + width, + channel_ax, + order=None) # Add channel ax into the model's evaluation parameters dictionary self.model.eval_config['segmentor']['z_axis'] = z_axis self.model.eval_config['segmentor']['channel_axis'] = channel_ax # Evaluate the model mask = await self.runner.evaluate.async_run(img = img) # Resize the mask - channel_ax = self.model.eval_config['segmentor']['channel_axis'] - mask = self.imagestorage.rescale_image(mask, height, width, channel_ax, order=0) + mask = self.imagestorage.rescale_image(mask, + height, + width, + self.model.eval_config['mask_channel_axis'], + order=0) # Save segmentation seg_name = utils.get_path_stem(img_filepath) + setup_config['seg_name_string'] + '.tiff' self.imagestorage.save_image(os.path.join(input_path, seg_name), mask) From 3c86217caa9b98872c525f22a2e1897de5d42ae4 Mon Sep 17 00:00:00 2001 From: Mariia Date: Sun, 5 Nov 2023 19:34:46 +0100 Subject: [PATCH 31/47] tested on 3 channel images --- src/server/dcp_server/config.cfg | 2 +- src/server/dcp_server/models.py | 15 +++++++++------ src/server/test/synthetic_dataset.py | 1 - 3 files changed, 10 insertions(+), 8 deletions(-) diff --git a/src/server/dcp_server/config.cfg b/src/server/dcp_server/config.cfg index 9565027..538df5a 100644 --- a/src/server/dcp_server/config.cfg +++ b/src/server/dcp_server/config.cfg @@ -18,7 +18,7 @@ "model_type": "cyto" }, "classifier":{ - "in_channels": 1, + "in_channels": 3, "num_classes": 3, "black_bg": "False", "include_mask": "False" diff --git a/src/server/dcp_server/models.py b/src/server/dcp_server/models.py index 233b736..4278532 100644 --- a/src/server/dcp_server/models.py +++ b/src/server/dcp_server/models.py @@ -147,6 +147,8 @@ def train (self, imgs, labels): # optimizer_class = self.train_config['optimizer'] # Convert input images and labels to tensors + + # normalize images imgs = [(img-np.min(img))/(np.max(img)-np.min(img)) for img in imgs] # convert to tensor imgs = torch.stack([torch.from_numpy(img.astype(np.float32)) for img in imgs]) @@ -168,6 +170,7 @@ def train (self, imgs, labels): self.loss = 0 for data in train_dataloader: imgs, labels = data + optimizer.zero_grad() preds = self.forward(imgs) @@ -232,10 +235,10 @@ def train(self, imgs, masks): # train cellpose - if imgs[0].ndim == 3: - imgs = [cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) for img in imgs] + imgs = [cv2.merge((img, img, img)) for img in imgs] masks = np.array(masks) + masks_instances = [mask.sum(-1) for mask in masks] if masks[0].ndim == 3 else masks masks_classes = [((mask > 0) * np.arange(1, 4)).sum(-1) for mask in masks] @@ -259,7 +262,6 @@ def eval(self, img, **eval_config): locs = find_objects(instance_mask) class_mask = np.zeros(instance_mask.shape) - max_patch_size = self.eval_config["classifier"]["data"]["patch_size"] noise_intensity = self.eval_config["classifier"]["data"]["noise_intensity"] @@ -331,8 +333,8 @@ def crop_centered_padded_patch(self, # used during the training phase only # c is (cx, cy, celltype = {0, 1, 2}) during training or (cx, cy) during inference - if len(c) == 3: - patch = patch[...,c[2]] + # if len(c) == 3: + # patch = patch[...,c[2]] # Calculate the required padding amounts size_x, size_y = x.shape[1], x.shape[0] @@ -409,6 +411,7 @@ def get_centered_patches(self, mask=mask, noise_intensity=noise_intensity) patches.append(patch) + # during the training step we store the true labels of the patch if mask_class is not None: labels.append(mask_class[c[0]][c[1]]) return patches, labels @@ -458,9 +461,9 @@ def create_patch_dataset(self, imgs, masks_classes, masks_instances): if max_patch_size is None: max_patch_size = np.max([self.find_max_patch_size(mask) for mask in masks_instances]) - patches, labels = [], [] + for img, mask_class, mask_instance in zip(imgs, masks_classes, masks_instances): # Convert to one-hot encoding # mask_instance has dimension WxH diff --git a/src/server/test/synthetic_dataset.py b/src/server/test/synthetic_dataset.py index 4adabb5..5156be5 100644 --- a/src/server/test/synthetic_dataset.py +++ b/src/server/test/synthetic_dataset.py @@ -117,7 +117,6 @@ def generate_dataset(num_samples, objects, canvas_size, max_object_counts=None, object_images = get_object_images(objects) class_intensities = [ (obj['intensity'][0], obj['intensity'][1]) for obj in objects] - if len(object_images[0].shape) == 3: num_of_img_channels = object_images[0].shape[-1] else: From 389dab0606585d412f2fc9a01ea167e8fa659d1c Mon Sep 17 00:00:00 2001 From: Mariia Date: Mon, 6 Nov 2023 00:15:24 +0100 Subject: [PATCH 32/47] mask paths fixed --- src/server/dcp_server/fsimagestorage.py | 2 +- src/server/dcp_server/models.py | 4 ++-- src/server/dcp_server/segmentationclasses.py | 5 +++-- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/src/server/dcp_server/fsimagestorage.py b/src/server/dcp_server/fsimagestorage.py index 42325e2..9281fe5 100644 --- a/src/server/dcp_server/fsimagestorage.py +++ b/src/server/dcp_server/fsimagestorage.py @@ -62,8 +62,8 @@ def search_segs(self, cur_selected_img): search_string = utils.get_path_stem(cur_selected_img) + setup_config['seg_name_string'] #seg_files = [os.path.join(img_directory, file_name) for file_name in os.listdir(img_directory) if search_string in file_name] # TODO: check where this is used - copied the command from app's search_segs function (to fix the 1_seg and 11_seg bug) - seg_files = [file_name for file_name in os.listdir(img_directory) if (search_string == utils.get_path_stem(file_name) or str(file_name).startswith(search_string))] + seg_files = [os.path.join(img_directory, file_name) for file_name in os.listdir(img_directory) if (search_string == utils.get_path_stem(file_name) or str(file_name).startswith(search_string))] return seg_files diff --git a/src/server/dcp_server/models.py b/src/server/dcp_server/models.py index 1304949..38faa72 100644 --- a/src/server/dcp_server/models.py +++ b/src/server/dcp_server/models.py @@ -241,10 +241,10 @@ def train(self, imgs, masks): masks = np.array(masks) - masks_instances = [mask.sum(-1) for mask in masks] if masks[0].ndim == 3 else masks + masks_instances = list(masks[:,0, ...]) #[mask.sum(-1) for mask in masks] if masks[0].ndim == 3 else masks self.segmentor.train(imgs, masks_instances) # create patch dataset to train classifier - masks_classes = [((mask > 0) * np.arange(1, 4)).sum(-1) for mask in masks] + masks_classes = list(masks[:,1, ...]) #[((mask > 0) * np.arange(1, 4)).sum(-1) for mask in masks] patches, labels = create_patch_dataset(imgs, masks_classes, masks_instances, diff --git a/src/server/dcp_server/segmentationclasses.py b/src/server/dcp_server/segmentationclasses.py index e739cc5..d0e6384 100644 --- a/src/server/dcp_server/segmentationclasses.py +++ b/src/server/dcp_server/segmentationclasses.py @@ -63,7 +63,8 @@ async def train(self, input_path): :type input_path: str :return: runner's train function output - path of the saved model :rtype: str - """ + """ + train_img_mask_pairs = self.imagestorage.get_image_seg_pairs(input_path) if not train_img_mask_pairs: @@ -71,7 +72,7 @@ async def train(self, input_path): imgs, masks = self.imagestorage.prepare_images_and_masks_for_training(train_img_mask_pairs) model_save_path = await self.runner.train.async_run(imgs, masks) - + return model_save_path From b189b4099f47434221ed7ed46fe35854148428ad Mon Sep 17 00:00:00 2001 From: Mariia Date: Mon, 6 Nov 2023 11:53:17 +0100 Subject: [PATCH 33/47] changed the synthetic mask format to (2,512,512) in the pytest and tested --- src/server/test/synthetic_dataset.py | 6 +++--- src/server/test/test_integration.py | 32 ++++++++++++++++------------ 2 files changed, 21 insertions(+), 17 deletions(-) diff --git a/src/server/test/synthetic_dataset.py b/src/server/test/synthetic_dataset.py index 5156be5..256bdec 100644 --- a/src/server/test/synthetic_dataset.py +++ b/src/server/test/synthetic_dataset.py @@ -219,17 +219,17 @@ def get_synthetic_dataset(num_samples, canvas_size=512, max_object_counts=[15, 1 objects = [ { 'name': 'triangle', - 'path': 'test/shapes/triangle.png', + 'path': 'shapes/triangle.png', 'intensity' : [0, 0.33] }, { 'name': 'circle', - 'path': 'test/shapes/circle.png', + 'path': 'shapes/circle.png', 'intensity' : [0.34, 0.66] }, { 'name': 'square', - 'path': 'test/shapes/square.png', + 'path': 'shapes/square.png', 'intensity' : [0.67, 1.0] }, ] diff --git a/src/server/test/test_integration.py b/src/server/test/test_integration.py index 7f6521a..9bc14a0 100644 --- a/src/server/test/test_integration.py +++ b/src/server/test/test_integration.py @@ -1,3 +1,4 @@ +import os import sys import torch from torchmetrics import JaccardIndex @@ -14,10 +15,10 @@ @pytest.fixture def patch_model(): - - model_config = read_config('model', config_path='dcp_server/config.cfg') - train_config = read_config('train', config_path='dcp_server/config.cfg') - eval_config = read_config('eval', config_path='dcp_server/config.cfg') + print(os.getcwd()) + model_config = read_config('model', config_path='../dcp_server/config.cfg') + train_config = read_config('train', config_path='../dcp_server/config.cfg') + eval_config = read_config('eval', config_path='../dcp_server/config.cfg') patch_model = CellposePatchCNN(model_config, train_config, eval_config) return patch_model @@ -25,19 +26,23 @@ def patch_model(): @pytest.fixture def data_train(): images, masks = get_synthetic_dataset(num_samples=2) - return images, masks + masks_instances = [mask.sum(-1) for mask in masks] + masks_classes = [((mask > 0) * np.arange(1, 4)).sum(-1) for mask in masks] + masks_ = [np.stack((instances, classes)) for instances, classes in zip(masks_instances, masks_classes)] + return images, masks_ @pytest.fixture def data_eval(): img, msk = get_synthetic_dataset(num_samples=1) - return img, msk + msk = np.array(msk) + msk_ = np.stack((msk.sum(-1), ((msk > 0) * np.arange(1, 4)).sum(-1)), axis=0).transpose(1,0,2,3) + return img, msk_ def test_train_run(data_train, patch_model): images, masks = data_train patch_model.train(images, masks) - # CellposeModel eval doesn't work when images is a list (only on a single image) # assert(patch_model.segmentor.loss>1e-2) #TODO figure out appropriate value assert(patch_model.classifier.loss>1e-2) @@ -52,14 +57,13 @@ def test_eval_run(data_eval, patch_model): for img, mask in zip(imgs, masks): - #mask - instance multiclass segmentation (512, 512, 3) - #pred_mask - tuple of cellpose (512, 512), patch net multiclass segmentation (512, 512, 3) + #mask - instance segmentation mask + classes (2, 512, 512) + #pred_mask - tuple of cellpose (512, 512), patch net multiclass segmentation (512, 512, 2) pred_mask = patch_model.eval(img) #, channels=[0,0]) - mask = (mask > 0) * np.arange(1, 4) - - pred_mask_bin = torch.tensor(pred_mask[0].astype(bool).astype(int)) - bin_mask = torch.tensor(mask.sum(-1).astype(bool).astype(int)) + + pred_mask_bin = torch.tensor((pred_mask[0]>0).astype(bool).astype(int)) + bin_mask = torch.tensor((mask[0]>0).astype(bool).astype(int)) jaccard_index_instances += jaccard_metric_binary( pred_mask_bin, @@ -67,7 +71,7 @@ def test_eval_run(data_eval, patch_model): ) jaccard_index_classes += jaccard_metric_multi( torch.tensor(pred_mask[1].astype(int)), - torch.tensor(mask.sum(-1).astype(int)) + torch.tensor(mask[1].astype(int)) ) jaccard_index_instances /= len(imgs) From 11d02edab25c6876826476b7699f432300b873c9 Mon Sep 17 00:00:00 2001 From: Mariia Date: Mon, 6 Nov 2023 13:59:33 +0100 Subject: [PATCH 34/47] tested train function --- src/client/dcp_client/config.cfg | 4 ++-- src/server/dcp_server/config.cfg | 2 +- src/server/dcp_server/models.py | 7 +++---- src/server/test/synthetic_dataset.py | 2 ++ src/server/test/test_integration.py | 5 ++--- 5 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/client/dcp_client/config.cfg b/src/client/dcp_client/config.cfg index 53b5db3..d7eb494 100644 --- a/src/client/dcp_client/config.cfg +++ b/src/client/dcp_client/config.cfg @@ -1,9 +1,9 @@ { "server":{ "user": "ubuntu", - "host": "local", + "host": "jusuf-vm2", "data-path": "/home/ubuntu/dcp-data", - "ip": "0.0.0.0", + "ip": "134.94.88.74", "port": 7010 } } \ No newline at end of file diff --git a/src/server/dcp_server/config.cfg b/src/server/dcp_server/config.cfg index 6eced06..e000e71 100644 --- a/src/server/dcp_server/config.cfg +++ b/src/server/dcp_server/config.cfg @@ -26,7 +26,7 @@ }, "data": { - "data_root": "D:/Helmholtz/dcp/data-centric-platform/data" + "data_root": "data" }, "train":{ diff --git a/src/server/dcp_server/models.py b/src/server/dcp_server/models.py index 38faa72..e44901e 100644 --- a/src/server/dcp_server/models.py +++ b/src/server/dcp_server/models.py @@ -239,12 +239,11 @@ def train(self, imgs, masks): imgs = [cv2.merge((img, img, img)) for img in imgs] - masks = np.array(masks) - - masks_instances = list(masks[:,0, ...]) #[mask.sum(-1) for mask in masks] if masks[0].ndim == 3 else masks + masks = np.array(masks) + masks_instances = list(masks[:,0,...]) #[mask.sum(-1) for mask in masks] if masks[0].ndim == 3 else masks self.segmentor.train(imgs, masks_instances) # create patch dataset to train classifier - masks_classes = list(masks[:,1, ...]) #[((mask > 0) * np.arange(1, 4)).sum(-1) for mask in masks] + masks_classes = list(masks[:,1,...]) #[((mask > 0) * np.arange(1, 4)).sum(-1) for mask in masks] patches, labels = create_patch_dataset(imgs, masks_classes, masks_instances, diff --git a/src/server/test/synthetic_dataset.py b/src/server/test/synthetic_dataset.py index 256bdec..e95dd2d 100644 --- a/src/server/test/synthetic_dataset.py +++ b/src/server/test/synthetic_dataset.py @@ -216,8 +216,10 @@ def generate_dataset(num_samples, objects, canvas_size, max_object_counts=None, return dataset_images, dataset_masks def get_synthetic_dataset(num_samples, canvas_size=512, max_object_counts=[15, 15, 15]): + objects = [ { + 'name': 'triangle', 'path': 'shapes/triangle.png', 'intensity' : [0, 0.33] diff --git a/src/server/test/test_integration.py b/src/server/test/test_integration.py index 9bc14a0..0be025f 100644 --- a/src/server/test/test_integration.py +++ b/src/server/test/test_integration.py @@ -15,7 +15,6 @@ @pytest.fixture def patch_model(): - print(os.getcwd()) model_config = read_config('model', config_path='../dcp_server/config.cfg') train_config = read_config('train', config_path='../dcp_server/config.cfg') eval_config = read_config('eval', config_path='../dcp_server/config.cfg') @@ -25,7 +24,8 @@ def patch_model(): @pytest.fixture def data_train(): - images, masks = get_synthetic_dataset(num_samples=2) + images, masks = get_synthetic_dataset(num_samples=3) + masks = [np.array(mask) for mask in masks] masks_instances = [mask.sum(-1) for mask in masks] masks_classes = [((mask > 0) * np.arange(1, 4)).sum(-1) for mask in masks] masks_ = [np.stack((instances, classes)) for instances, classes in zip(masks_instances, masks_classes)] @@ -41,7 +41,6 @@ def data_eval(): def test_train_run(data_train, patch_model): images, masks = data_train - patch_model.train(images, masks) # assert(patch_model.segmentor.loss>1e-2) #TODO figure out appropriate value assert(patch_model.classifier.loss>1e-2) From fe52557a2d9cc5c8d79ec8652f0201fff02d8310 Mon Sep 17 00:00:00 2001 From: Mariia Date: Mon, 6 Nov 2023 14:12:39 +0100 Subject: [PATCH 35/47] changed the paths to the files due to test error --- src/server/test/synthetic_dataset.py | 10 +++++++--- src/server/test/test_integration.py | 6 +++--- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/src/server/test/synthetic_dataset.py b/src/server/test/synthetic_dataset.py index e95dd2d..a3ede96 100644 --- a/src/server/test/synthetic_dataset.py +++ b/src/server/test/synthetic_dataset.py @@ -1,6 +1,8 @@ import numpy as np import cv2 import random +import os +import sys import skimage.color as color import scipy.ndimage as ndi @@ -217,21 +219,23 @@ def generate_dataset(num_samples, objects, canvas_size, max_object_counts=None, def get_synthetic_dataset(num_samples, canvas_size=512, max_object_counts=[15, 15, 15]): + print(os.getcwd()) + assert 0 ==1 objects = [ { 'name': 'triangle', - 'path': 'shapes/triangle.png', + 'path': 'dcp_server/shapes/triangle.png', 'intensity' : [0, 0.33] }, { 'name': 'circle', - 'path': 'shapes/circle.png', + 'path': 'dcp_server/shapes/circle.png', 'intensity' : [0.34, 0.66] }, { 'name': 'square', - 'path': 'shapes/square.png', + 'path': 'dcp_server/shapes/square.png', 'intensity' : [0.67, 1.0] }, ] diff --git a/src/server/test/test_integration.py b/src/server/test/test_integration.py index 0be025f..8e6dda9 100644 --- a/src/server/test/test_integration.py +++ b/src/server/test/test_integration.py @@ -15,9 +15,9 @@ @pytest.fixture def patch_model(): - model_config = read_config('model', config_path='../dcp_server/config.cfg') - train_config = read_config('train', config_path='../dcp_server/config.cfg') - eval_config = read_config('eval', config_path='../dcp_server/config.cfg') + model_config = read_config('model', config_path='dcp_server/config.cfg') + train_config = read_config('train', config_path='dcp_server/config.cfg') + eval_config = read_config('eval', config_path='dcp_server/config.cfg') patch_model = CellposePatchCNN(model_config, train_config, eval_config) return patch_model From c60715471ca26b730ad5b8ca6ea39987329b470c Mon Sep 17 00:00:00 2001 From: Mariia Date: Mon, 6 Nov 2023 14:26:57 +0100 Subject: [PATCH 36/47] fixed the bug --- src/server/test/synthetic_dataset.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/server/test/synthetic_dataset.py b/src/server/test/synthetic_dataset.py index a3ede96..ac7dced 100644 --- a/src/server/test/synthetic_dataset.py +++ b/src/server/test/synthetic_dataset.py @@ -219,8 +219,6 @@ def generate_dataset(num_samples, objects, canvas_size, max_object_counts=None, def get_synthetic_dataset(num_samples, canvas_size=512, max_object_counts=[15, 15, 15]): - print(os.getcwd()) - assert 0 ==1 objects = [ { From 58bebd316172fdb0beb7fa3b0e8e93b09704b653 Mon Sep 17 00:00:00 2001 From: Mariia Date: Mon, 6 Nov 2023 14:36:28 +0100 Subject: [PATCH 37/47] changed the paths for shapes --- src/server/test/synthetic_dataset.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/server/test/synthetic_dataset.py b/src/server/test/synthetic_dataset.py index ac7dced..27fbc5b 100644 --- a/src/server/test/synthetic_dataset.py +++ b/src/server/test/synthetic_dataset.py @@ -223,17 +223,17 @@ def get_synthetic_dataset(num_samples, canvas_size=512, max_object_counts=[15, 1 { 'name': 'triangle', - 'path': 'dcp_server/shapes/triangle.png', + 'path': 'test/shapes/triangle.png', 'intensity' : [0, 0.33] }, { 'name': 'circle', - 'path': 'dcp_server/shapes/circle.png', + 'path': 'test/shapes/circle.png', 'intensity' : [0.34, 0.66] }, { 'name': 'square', - 'path': 'dcp_server/shapes/square.png', + 'path': 'test/shapes/square.png', 'intensity' : [0.67, 1.0] }, ] From a74daf66a3e6bbff32dfd4cdb54b76739d7f05f1 Mon Sep 17 00:00:00 2001 From: Christina Bukas Date: Wed, 8 Nov 2023 21:47:07 +0100 Subject: [PATCH 38/47] added pytest to requirements --- src/client/requirements.txt | 3 ++- src/server/requirements.txt | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/client/requirements.txt b/src/client/requirements.txt index 4f420df..e92f6e4 100644 --- a/src/client/requirements.txt +++ b/src/client/requirements.txt @@ -1,2 +1,3 @@ napari[pyqt5]>=0.4.17 -bentoml[grpc]>=1.0.13 \ No newline at end of file +bentoml[grpc]>=1.0.13 +pytest>=7.4.3 \ No newline at end of file diff --git a/src/server/requirements.txt b/src/server/requirements.txt index 8fab307..57c2128 100644 --- a/src/server/requirements.txt +++ b/src/server/requirements.txt @@ -3,3 +3,4 @@ bentoml>=1.0.13 scikit-image>=0.19.3 torchmetrics>=0.11.4 torch>=2.1.0 +pytest>=7.4.3 From a0709125b2e16571a3203b9d7190d706af96abf5 Mon Sep 17 00:00:00 2001 From: Christina Bukas Date: Wed, 8 Nov 2023 21:48:58 +0100 Subject: [PATCH 39/47] changed convention to always convert image to grayscale on read, input channels to net adapted accordingly --- src/server/dcp_server/config.cfg | 4 +-- src/server/dcp_server/fsimagestorage.py | 33 ++++++++++--------------- 2 files changed, 15 insertions(+), 22 deletions(-) diff --git a/src/server/dcp_server/config.cfg b/src/server/dcp_server/config.cfg index e000e71..9a67ff1 100644 --- a/src/server/dcp_server/config.cfg +++ b/src/server/dcp_server/config.cfg @@ -6,7 +6,7 @@ }, "service": { - "model_to_use": "CellposePatchCNN", + "model_to_use": "CustomCellposeModel", "save_model_path": "mytrainedmodel", "runner_name": "cellpose_runner", "service_name": "data-centric-platform", @@ -18,7 +18,7 @@ "model_type": "cyto" }, "classifier":{ - "in_channels": 3, + "in_channels": 1, "num_classes": 3, "black_bg": "False", "include_mask": "False" diff --git a/src/server/dcp_server/fsimagestorage.py b/src/server/dcp_server/fsimagestorage.py index 9281fe5..84d6cf7 100644 --- a/src/server/dcp_server/fsimagestorage.py +++ b/src/server/dcp_server/fsimagestorage.py @@ -13,15 +13,17 @@ class FilesystemImageStorage(): def __init__(self, data_root): self.root_dir = data_root - def load_image(self, cur_selected_img): + def load_image(self, cur_selected_img, is_gray=True): """Load the image (using skiimage) :param cur_selected_img: full path of the image that needs to be loaded :type cur_selected_img: str :return: loaded image :rtype: ndarray - """ - return imread(os.path.join(self.root_dir , cur_selected_img)) + """ + try: + return imread(os.path.join(self.root_dir , cur_selected_img), as_gray=is_gray) + except ValueError: return None def save_image(self, to_save_path, img): """Save given image (using skiimage) @@ -110,32 +112,23 @@ def get_image_size_properties(self, img, file_extension): orig_size = img.shape # png and jpeg will be RGB by default and 2D - # tif can be grayscale 2D or 2D RGB and RGBA - # RGB can be [C, H, W] or [H, W, C] + # tif can be grayscale 2D or 3D [Z, H, W] + # image channels have already been removed in imread with is_gray=True if file_extension in (".jpg", ".jpeg", ".png"): height, width = orig_size[0], orig_size[1] - channel_ax = 2 z_axis = None elif file_extension in (".tiff", ".tif") and len(orig_size)==2: - channel_ax = None - z_axis = None - # if we have 3 dimensions and the third is size 3 or 4, then we assume it is the channel axis - elif (len(orig_size)==3 and (orig_size[-1]==3 or orig_size[-1]==4)): - channel_ax = 2 z_axis = None - # or 3D tiff grayscale + # if we have 3 dimensions the [Z, H, W] elif file_extension in (".tiff", ".tif") and len(orig_size)==3: print('Warning: 3D image stack found. We are assuming your first dimension is your stack dimension. Please cross check this.') - height, width = orig_size[1], orig_size[2] - channel_ax = None - z_axis = 0 - + z_axis = 0 else: - pass + print('File not currently supported. See documentation for accepted types') - return height, width, channel_ax, z_axis + return height, width, z_axis - def rescale_image(self, img, height, width, channel_ax, order): + def rescale_image(self, img, height, width, channel_ax=None, order=2): """rescale image :param img: image @@ -180,6 +173,6 @@ def prepare_images_and_masks_for_training(self, train_img_mask_pairs): imgs=[] masks=[] for img_file, mask_file in train_img_mask_pairs: - imgs.append(imread(img_file)) + imgs.append(self.load_image(img_file)) masks.append(imread(mask_file)) return imgs, masks \ No newline at end of file From 6e270dccd8a7890bc7df92650025f4121d6273d7 Mon Sep 17 00:00:00 2001 From: Christina Bukas Date: Wed, 8 Nov 2023 21:51:27 +0100 Subject: [PATCH 40/47] changed documentation for channel_axis --- src/server/dcp_server/fsimagestorage.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/server/dcp_server/fsimagestorage.py b/src/server/dcp_server/fsimagestorage.py index 84d6cf7..4c2c79a 100644 --- a/src/server/dcp_server/fsimagestorage.py +++ b/src/server/dcp_server/fsimagestorage.py @@ -106,7 +106,7 @@ def get_image_size_properties(self, img, file_extension): :return: size properties: - height - width - - channel_ax + - z_axis """ From bf62ccf3269cec22ab00402ad5e3f52086c70cd8 Mon Sep 17 00:00:00 2001 From: Christina Bukas Date: Wed, 8 Nov 2023 21:52:06 +0100 Subject: [PATCH 41/47] since grayscale convention changed arguments of channel_axis --- src/server/dcp_server/segmentationclasses.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/server/dcp_server/segmentationclasses.py b/src/server/dcp_server/segmentationclasses.py index d0e6384..e1213d5 100644 --- a/src/server/dcp_server/segmentationclasses.py +++ b/src/server/dcp_server/segmentationclasses.py @@ -34,15 +34,13 @@ async def segment_image(self, input_path, list_of_images): # Load the image img = self.imagestorage.load_image(img_filepath) # Get size properties - height, width, channel_ax, z_axis = self.imagestorage.get_image_size_properties(img, utils.get_file_extension(img_filepath)) + height, width, z_axis = self.imagestorage.get_image_size_properties(img, utils.get_file_extension(img_filepath)) img = self.imagestorage.rescale_image(img, height, width, - channel_ax, order=None) # Add channel ax into the model's evaluation parameters dictionary self.model.eval_config['segmentor']['z_axis'] = z_axis - self.model.eval_config['segmentor']['channel_axis'] = channel_ax # Evaluate the model mask = await self.runner.evaluate.async_run(img = img) # Resize the mask From 5b60b1f8b70b595eab1aa2dc52c9bc66ac53b906 Mon Sep 17 00:00:00 2001 From: Christina Bukas Date: Wed, 8 Nov 2023 21:53:12 +0100 Subject: [PATCH 42/47] changed method to compute center of mass and return instance label as well --- src/server/dcp_server/utils.py | 47 ++++++++++++++++++++++++---------- 1 file changed, 33 insertions(+), 14 deletions(-) diff --git a/src/server/dcp_server/utils.py b/src/server/dcp_server/utils.py index aa987c6..86f5466 100644 --- a/src/server/dcp_server/utils.py +++ b/src/server/dcp_server/utils.py @@ -2,6 +2,7 @@ import json import numpy as np from scipy.ndimage import find_objects, center_of_mass +from skimage import measure def read_config(name, config_path = 'config.cfg') -> dict: """Reads the configuration file @@ -37,6 +38,7 @@ def get_file_extension(file): return str(Path(file).suffix) def crop_centered_padded_patch(x: np.ndarray, c, p, + l, mask: np.ndarray=None, noise_intensity=None) -> np.ndarray: """ @@ -44,8 +46,9 @@ def crop_centered_padded_patch(x: np.ndarray, Args: x (np.ndarray): The input array from which the patch will be cropped. - c (tuple): The coordinates (row, column, channel) at the center of the patch. + c (tuple): The coordinates (row, column) at the center of the patch. p (tuple): The size of the patch to be cropped (height, width). + l (int): The instance label of the mask at the patch Returns: np.ndarray: The cropped patch with applied padding. @@ -63,11 +66,9 @@ def crop_centered_padded_patch(x: np.ndarray, # Crop the patch from the input array if mask is not None: mask_ = mask.max(-1) if len(mask.shape) >= 3 else mask - # central_label = mask_[c[0], c[1]] - central_label = mask_[c[0]][c[1]] # Zero out values in the patch where the mask is not equal to the central label # m = (mask_ != central_label) & (mask_ > 0) - m = (mask_ != central_label) & (mask_ > 0) + m = (mask_ != l) & (mask_ > 0) x[m] = 0 if noise_intensity is not None: x[m] = np.random.normal(scale=noise_intensity, size=x[m].shape) @@ -101,7 +102,7 @@ def crop_centered_padded_patch(x: np.ndarray, return patch -def get_center_of_mass(mask: np.ndarray) -> np.ndarray: +def get_center_of_mass_and_label(mask: np.ndarray) -> np.ndarray: """ Compute the centers of mass for each object in a mask. @@ -109,12 +110,24 @@ def get_center_of_mass(mask: np.ndarray) -> np.ndarray: mask (np.ndarray): The input mask containing labeled objects. Returns: - np.ndarray: An array of coordinates (row, column, channel) representing the centers of mass for each object. + list of tuples: A list of coordinates (row, column) representing the centers of mass for each object. + list of ints: Holds the label for each object in the mask """ # Compute the centers of mass for each labeled object in the mask + ''' return [(int(x[0]), int(x[1])) for x in center_of_mass(mask, mask, np.arange(1, mask.max() + 1))] + ''' + centers = [] + labels = [] + for region in measure.regionprops(mask): + center = region.centroid + centers.append((int(center[0]), int(center[1]))) + labels.append(region.label) + return centers, labels + + def get_centered_patches(img, mask, @@ -133,24 +146,30 @@ def get_centered_patches(img, ''' - patches, labels = [], [] + patches, instance_labels, class_labels = [], [], [] # if image is 2D add an additional dim for channels if img.ndim<3: img = img[:, :, np.newaxis] if mask.ndim<3: mask = mask[:, :, np.newaxis] # compute center of mass of objects - centers_of_mass = get_center_of_mass(mask) + centers_of_mass, instance_labels = get_center_of_mass_and_label(mask) # Crop patches around each center of mass - for c in centers_of_mass: + for c, l in zip(centers_of_mass, instance_labels): c_x, c_y = c patch = crop_centered_padded_patch(img.copy(), (c_x, c_y), (p_size, p_size), + l, mask=mask, noise_intensity=noise_intensity) patches.append(patch) - if mask_class is not None: labels.append(mask_class[c[0]][c[1]]) - - return patches, labels + if mask_class is not None: + # get the class instance for the specific object + instance_labels.append(l) + class_l = int(np.unique(mask_class[mask[:,:,0]==l])) + #-1 because labels from mask start from 1, we want classes to start from 0 + class_labels.append(class_l-1) + + return patches, instance_labels, class_labels def get_objects(mask): return find_objects(mask) @@ -203,11 +222,11 @@ def create_patch_dataset(imgs, masks_classes, masks_instances, noise_intensity, for img, mask_class, mask_instance in zip(imgs, masks_classes, masks_instances): # mask_instance has dimension WxH # mask_class has dimension WxH - patch, label = get_centered_patches(img, + patch, _, label = get_centered_patches(img, mask_instance, max_patch_size, noise_intensity=noise_intensity, mask_class=mask_class) patches.extend(patch) - labels.extend(label) + labels.extend(label) return patches, labels \ No newline at end of file From f7c795db638d87738d3d66d999db807b5764a72c Mon Sep 17 00:00:00 2001 From: Christina Bukas Date: Wed, 8 Nov 2023 21:55:20 +0100 Subject: [PATCH 43/47] removed merging of image, adapted class mask to ensure label corresponds to right object --- src/server/dcp_server/models.py | 31 ++++++++++--------------------- 1 file changed, 10 insertions(+), 21 deletions(-) diff --git a/src/server/dcp_server/models.py b/src/server/dcp_server/models.py index e44901e..3b184fc 100644 --- a/src/server/dcp_server/models.py +++ b/src/server/dcp_server/models.py @@ -6,7 +6,6 @@ from copy import deepcopy from tqdm import tqdm import numpy as np -import cv2 #from segment_anything import SamPredictor, sam_model_registry #from segment_anything.automatic_mask_generator import SamAutomaticMaskGenerator @@ -88,7 +87,7 @@ def __init__(self, model_config, train_config, eval_config): super().__init__() self.in_channels = model_config["classifier"]["in_channels"] - self.num_classes = model_config["classifier"]["num_classes"] + 1 + self.num_classes = model_config["classifier"]["num_classes"] self.train_config = train_config["classifier"] self.eval_config = eval_config["classifier"] @@ -125,7 +124,6 @@ def forward(self, x): x = self.final_conv(x) x = self.pooling(x) x = x.view(x.size(0), -1) - return x def train (self, imgs, labels): @@ -165,7 +163,7 @@ def train (self, imgs, labels): optimizer.zero_grad() preds = self.forward(imgs) - + l = loss_fn(preds, labels) l.backward() optimizer.step() @@ -234,11 +232,7 @@ def train(self, imgs, masks): :type masks: List[np.ndarray] of same shape as output of eval, i.e. one channel instances, second channel classes, so [2, H, W] or [2, 3, H, W] for 3D """ - # train cellpose - - imgs = [cv2.merge((img, img, img)) for img in imgs] - masks = np.array(masks) masks_instances = list(masks[:,0,...]) #[mask.sum(-1) for mask in masks] if masks[0].ndim == 3 else masks self.segmentor.train(imgs, masks_instances) @@ -253,36 +247,31 @@ def train(self, imgs, masks): self.classifier.train(patches, labels) def eval(self, img): - # TBD we assume image is either 2D [H, W] or [H, W, C] (see fsimage storage) + # TBD we assume image is either 2D [H, W] (see fsimage storage) # The final mask which is returned should have # first channel the output of cellpose and the rest are the class channels - # TODO test case produces img with size HxW for eval and HxWx3 for train with torch.no_grad(): # get instance mask from segmentor instance_mask = self.segmentor.eval(img) # find coordinates of detected objects - locs = get_objects(instance_mask) class_mask = np.zeros(instance_mask.shape) max_patch_size = self.eval_config["classifier"]["data"]["patch_size"] + if max_patch_size is None: max_patch_size = find_max_patch_size(instance_mask) noise_intensity = self.eval_config["classifier"]["data"]["noise_intensity"] - if max_patch_size is None: - max_patch_size = find_max_patch_size(instance_mask) - # get patches centered around detected objects - patches, _ = get_centered_patches(img, - instance_mask, - max_patch_size, - noise_intensity=noise_intensity) + patches, instance_labels, _ = get_centered_patches(img, + instance_mask, + max_patch_size, + noise_intensity=noise_intensity) # loop over patches and create classification mask for idx, patch in enumerate(patches): patch_class = self.classifier.eval(patch) # patch size should be HxWxC, e.g. 64,64,3 - loc = locs[idx] # Assign predicted class to corresponding location in final_mask - class_mask[loc] = patch_class.item() + 1 + class_mask[instance_mask==instance_labels[idx]] = patch_class.item() + 1 # Apply mask to final_mask, retaining only regions where cellpose_mask is greater than 0 - class_mask = class_mask * (instance_mask > 0)#.long()) + #class_mask = class_mask * (instance_mask > 0)#.long()) final_mask = np.stack((instance_mask, class_mask), axis=self.eval_config['mask_channel_axis']).astype(np.uint16) # size 2xHxW return final_mask From f4f75f0381bc5a07d7f9ebb28dd2c26d65fc3b9c Mon Sep 17 00:00:00 2001 From: Christina Bukas Date: Wed, 8 Nov 2023 22:29:20 +0100 Subject: [PATCH 44/47] added min_train_masks arg --- src/server/dcp_server/config.cfg | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/server/dcp_server/config.cfg b/src/server/dcp_server/config.cfg index 9a67ff1..36a2306 100644 --- a/src/server/dcp_server/config.cfg +++ b/src/server/dcp_server/config.cfg @@ -6,7 +6,7 @@ }, "service": { - "model_to_use": "CustomCellposeModel", + "model_to_use": "CellposePatchCNN", "save_model_path": "mytrainedmodel", "runner_name": "cellpose_runner", "service_name": "data-centric-platform", @@ -32,7 +32,8 @@ "train":{ "segmentor":{ "n_epochs": 7, - "channels": [0,0] + "channels": [0,0], + "min_train_masks": 1 }, "classifier":{ "train_data":{ @@ -52,8 +53,7 @@ "z_axis": null, "channel_axis": null, "rescale": 1, - "batch_size": 1, - "channels": [0,0] + "batch_size": 1 }, "classifier": { "data":{ From e5f9afc4f0c43bb81feeb7113a82b99384c9bbdb Mon Sep 17 00:00:00 2001 From: Christina Bukas Date: Wed, 8 Nov 2023 22:29:49 +0100 Subject: [PATCH 45/47] included height and width computations for elif statements --- src/server/dcp_server/fsimagestorage.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/server/dcp_server/fsimagestorage.py b/src/server/dcp_server/fsimagestorage.py index 4c2c79a..bec1b56 100644 --- a/src/server/dcp_server/fsimagestorage.py +++ b/src/server/dcp_server/fsimagestorage.py @@ -118,11 +118,13 @@ def get_image_size_properties(self, img, file_extension): height, width = orig_size[0], orig_size[1] z_axis = None elif file_extension in (".tiff", ".tif") and len(orig_size)==2: + height, width = orig_size[0], orig_size[1] z_axis = None # if we have 3 dimensions the [Z, H, W] elif file_extension in (".tiff", ".tif") and len(orig_size)==3: - print('Warning: 3D image stack found. We are assuming your first dimension is your stack dimension. Please cross check this.') - z_axis = 0 + print('Warning: 3D image stack found. We are assuming your first dimension is your stack dimension. Please cross check this.') + height, width = orig_size[1], orig_size[2] + z_axis = 0 else: print('File not currently supported. See documentation for accepted types') From 10d4ed41d488049dff68dd82e99f5bb1e8d2650c Mon Sep 17 00:00:00 2001 From: Mariia Date: Sun, 12 Nov 2023 21:38:51 +0100 Subject: [PATCH 46/47] Add tests for multiple models including CustomCellposeModel --- src/server/dcp_server/models.py | 16 ++++++- src/server/test/test_integration.py | 74 +++++++++++++++++++++-------- 2 files changed, 68 insertions(+), 22 deletions(-) diff --git a/src/server/dcp_server/models.py b/src/server/dcp_server/models.py index 3b184fc..1b0d160 100644 --- a/src/server/dcp_server/models.py +++ b/src/server/dcp_server/models.py @@ -7,6 +7,8 @@ from tqdm import tqdm import numpy as np +from cellpose.metrics import aggregated_jaccard_index + #from segment_anything import SamPredictor, sam_model_registry #from segment_anything.automatic_mask_generator import SamAutomaticMaskGenerator @@ -54,10 +56,20 @@ def train(self, imgs, masks): :param masks: masks of the given images (training labels) :type masks: List[np.ndarray] """ + + if not isinstance(masks, np.ndarray): + masks = np.array(masks) + + if masks[0].shape[0] == 2: + masks = list(masks[:,0,...]) + super().train(train_data=deepcopy(imgs), train_labels=masks, **self.train_config["segmentor"]) + + pred_masks = [self.eval(img) for img in masks] + self.metric = aggregated_jaccard_index(masks, pred_masks) + # pred_masks = [self.eval(img) for img in masks] - #pred_masks = [self.eval(img, **self.eval_config) for img in masks] - #self.loss = self.loss_fn(masks, pred_masks) + # self.loss = self.loss_fn(masks, pred_masks) def masks_to_outlines(self, mask): """ get outlines of masks as a 0-1 array diff --git a/src/server/test/test_integration.py b/src/server/test/test_integration.py index 8e6dda9..eaad4c1 100644 --- a/src/server/test/test_integration.py +++ b/src/server/test/test_integration.py @@ -4,23 +4,39 @@ from torchmetrics import JaccardIndex import numpy as np +import inspect +# from importlib.machinery import SourceFileLoader + sys.path.append(".") -from dcp_server.models import CellposePatchCNN +import dcp_server.models as models from dcp_server.utils import read_config from synthetic_dataset import get_synthetic_dataset import pytest -@pytest.fixture -def patch_model(): - +# retrieve models names +model_classes = [ + cls_obj for cls_name, cls_obj in inspect.getmembers(models) \ + if inspect.isclass(cls_obj) \ + and cls_obj.__module__ == models.__name__ \ + and not cls_name.startswith("CellClassifier") + ] + +@pytest.fixture(params=model_classes) +def model_class(request): + return request.param + +@pytest.fixture() +def model(model_class): + model_config = read_config('model', config_path='dcp_server/config.cfg') train_config = read_config('train', config_path='dcp_server/config.cfg') eval_config = read_config('eval', config_path='dcp_server/config.cfg') - patch_model = CellposePatchCNN(model_config, train_config, eval_config) - return patch_model + model = model_class(model_config, train_config, eval_config) + + return model @pytest.fixture def data_train(): @@ -38,16 +54,26 @@ def data_eval(): msk_ = np.stack((msk.sum(-1), ((msk > 0) * np.arange(1, 4)).sum(-1)), axis=0).transpose(1,0,2,3) return img, msk_ -def test_train_run(data_train, patch_model): +def test_train_run(data_train, model): images, masks = data_train - patch_model.train(images, masks) + model.train(images, masks) # assert(patch_model.segmentor.loss>1e-2) #TODO figure out appropriate value - assert(patch_model.classifier.loss>1e-2) + + # retrieve the attribute names of the class of the current model + attrs = model.__class__.__dict__.keys() + + if "classifier" in attrs: + assert(model.classifier.loss>1e-2) + if "segmentor" in attrs: + assert(model.segmentor.loss>1e-2) + if "metric" in attrs: + assert(model.metric>1e-2) -def test_eval_run(data_eval, patch_model): +def test_eval_run(data_eval, model): imgs, masks = data_eval + jaccard_index_instances = 0 jaccard_index_classes = 0 @@ -59,28 +85,36 @@ def test_eval_run(data_eval, patch_model): #mask - instance segmentation mask + classes (2, 512, 512) #pred_mask - tuple of cellpose (512, 512), patch net multiclass segmentation (512, 512, 2) - pred_mask = patch_model.eval(img) #, channels=[0,0]) + pred_mask = model.eval(img) #, channels=[0,0]) - pred_mask_bin = torch.tensor((pred_mask[0]>0).astype(bool).astype(int)) + if pred_mask.ndim > 2: + pred_mask_bin = torch.tensor((pred_mask[0]>0).astype(bool).astype(int)) + else: + pred_mask_bin = torch.tensor((pred_mask > 0).astype(bool).astype(int)) + bin_mask = torch.tensor((mask[0]>0).astype(bool).astype(int)) jaccard_index_instances += jaccard_metric_binary( pred_mask_bin, bin_mask ) - jaccard_index_classes += jaccard_metric_multi( - torch.tensor(pred_mask[1].astype(int)), - torch.tensor(mask[1].astype(int)) - ) + + if pred_mask.ndim > 2: + + jaccard_index_classes += jaccard_metric_multi( + torch.tensor(pred_mask[1].astype(int)), + torch.tensor(mask[1].astype(int)) + ) jaccard_index_instances /= len(imgs) assert(jaccard_index_instances<0.6) - jaccard_index_classes /= len(imgs) - assert(jaccard_index_instances<0.6) - - + # for PatchCNN model + if pred_mask.ndim > 2: + jaccard_index_classes /= len(imgs) + assert(jaccard_index_classes<0.6) + From 90e7b2573845081fd7e30ca0a7df2f654a27b357 Mon Sep 17 00:00:00 2001 From: Mariia Date: Mon, 13 Nov 2023 17:35:07 +0100 Subject: [PATCH 47/47] fix the test in train part --- src/server/dcp_server/models.py | 3 ++- src/server/test/test_integration.py | 6 ++---- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/src/server/dcp_server/models.py b/src/server/dcp_server/models.py index 1b0d160..caf5226 100644 --- a/src/server/dcp_server/models.py +++ b/src/server/dcp_server/models.py @@ -66,7 +66,8 @@ def train(self, imgs, masks): super().train(train_data=deepcopy(imgs), train_labels=masks, **self.train_config["segmentor"]) pred_masks = [self.eval(img) for img in masks] - self.metric = aggregated_jaccard_index(masks, pred_masks) + print(len(pred_masks)) + self.metric = np.mean(aggregated_jaccard_index(masks, pred_masks)) # pred_masks = [self.eval(img) for img in masks] # self.loss = self.loss_fn(masks, pred_masks) diff --git a/src/server/test/test_integration.py b/src/server/test/test_integration.py index eaad4c1..be5d51f 100644 --- a/src/server/test/test_integration.py +++ b/src/server/test/test_integration.py @@ -40,7 +40,7 @@ def model(model_class): @pytest.fixture def data_train(): - images, masks = get_synthetic_dataset(num_samples=3) + images, masks = get_synthetic_dataset(num_samples=4) masks = [np.array(mask) for mask in masks] masks_instances = [mask.sum(-1) for mask in masks] masks_classes = [((mask > 0) * np.arange(1, 4)).sum(-1) for mask in masks] @@ -61,12 +61,10 @@ def test_train_run(data_train, model): # assert(patch_model.segmentor.loss>1e-2) #TODO figure out appropriate value # retrieve the attribute names of the class of the current model - attrs = model.__class__.__dict__.keys() + attrs = model.__dict__.keys() if "classifier" in attrs: assert(model.classifier.loss>1e-2) - if "segmentor" in attrs: - assert(model.segmentor.loss>1e-2) if "metric" in attrs: assert(model.metric>1e-2)