From 39d9e7dfac19f6738164a0ff3d7af16497498d92 Mon Sep 17 00:00:00 2001 From: Mariia Date: Sat, 6 Jan 2024 18:20:12 +0100 Subject: [PATCH 1/5] Add include mask option to the classifier and tested it --- src/server/dcp_server/config.cfg | 2 +- src/server/dcp_server/models.py | 11 +++++++-- src/server/dcp_server/utils.py | 38 ++++++++++++++++++++++++++------ 3 files changed, 41 insertions(+), 10 deletions(-) diff --git a/src/server/dcp_server/config.cfg b/src/server/dcp_server/config.cfg index 1771417..1af83c1 100644 --- a/src/server/dcp_server/config.cfg +++ b/src/server/dcp_server/config.cfg @@ -21,7 +21,7 @@ "in_channels": 1, "num_classes": 3, "black_bg": "False", - "include_mask": "False" + "include_mask": "True" } }, diff --git a/src/server/dcp_server/models.py b/src/server/dcp_server/models.py index d4bba11..3c8ba1e 100644 --- a/src/server/dcp_server/models.py +++ b/src/server/dcp_server/models.py @@ -108,6 +108,9 @@ def __init__(self, model_config, train_config, eval_config): self.train_config = train_config["classifier"] self.eval_config = eval_config["classifier"] + + self.include_mask = model_config["classifier"]["include_mask"] + self.in_channels = self.in_channels + 1 if self.include_mask else self.in_channels self.layer1 = nn.Sequential( nn.Conv2d(self.in_channels, 16, 3, 2, 5), @@ -220,6 +223,7 @@ def __init__(self, model_config, train_config, eval_config): self.model_config = model_config self.train_config = train_config self.eval_config = eval_config + self.include_mask = self.model_config["classifier"]["include_mask"] # Initialize the cellpose model and the classifier self.segmentor = CustomCellposeModel(self.model_config, @@ -252,7 +256,9 @@ def train(self, imgs, masks): masks_classes, masks_instances, noise_intensity = self.train_config["classifier"]["train_data"]["noise_intensity"], - max_patch_size = self.train_config["classifier"]["train_data"]["patch_size"]) + max_patch_size = self.train_config["classifier"]["train_data"]["patch_size"], + include_mask = self.include_mask + ) # train classifier self.classifier.train(patches, labels) @@ -274,7 +280,8 @@ def eval(self, img): patches, instance_labels, _ = get_centered_patches(img, instance_mask, max_patch_size, - noise_intensity=noise_intensity) + noise_intensity=noise_intensity, + include_mask=self.include_mask) # loop over patches and create classification mask for idx, patch in enumerate(patches): patch_class = self.classifier.eval(patch) # patch size should be HxWxC, e.g. 64,64,3 diff --git a/src/server/dcp_server/utils.py b/src/server/dcp_server/utils.py index 86f5466..335517c 100644 --- a/src/server/dcp_server/utils.py +++ b/src/server/dcp_server/utils.py @@ -74,7 +74,8 @@ def crop_centered_padded_patch(x: np.ndarray, x[m] = np.random.normal(scale=noise_intensity, size=x[m].shape) patch = x[max(top, 0):min(bottom, x.shape[0]), max(left, 0):min(right, x.shape[1]), :] - + if mask is not None: + mask = mask[max(top, 0):min(bottom, x.shape[0]), max(left, 0):min(right, x.shape[1]), :] # Calculate the required padding amounts size_x, size_y = x.shape[1], x.shape[0] @@ -83,23 +84,39 @@ def crop_centered_padded_patch(x: np.ndarray, patch = np.hstack(( np.random.normal(scale=noise_intensity, size=(patch.shape[0], abs(left), patch.shape[2])).astype(np.uint8), patch)) + if mask is not None: + mask = np.hstack(( + np.zeros((mask.shape[0], abs(left), mask.shape[2])).astype(np.uint8), + mask)) # Apply padding on the right side if necessary if right > size_x: patch = np.hstack(( patch, np.random.normal(scale=noise_intensity, size=(patch.shape[0], (right - size_x), patch.shape[2])).astype(np.uint8))) + if mask is not None: + mask = np.hstack(( + mask, + np.zeros((mask.shape[0], (right - size_x), mask.shape[2])).astype(np.uint8))) # Apply padding on the top side if necessary if top < 0: patch = np.vstack(( np.random.normal(scale=noise_intensity, size=(abs(top), patch.shape[1], patch.shape[2])).astype(np.uint8), patch)) + if mask is not None: + mask = np.vstack(( + np.zeros((abs(top), mask.shape[1], mask.shape[2])).astype(np.uint8), + mask)) # Apply padding on the bottom side if necessary if bottom > size_y: patch = np.vstack(( patch, np.random.normal(scale=noise_intensity, size=(bottom - size_y, patch.shape[1], patch.shape[2])).astype(np.uint8))) - - return patch + if mask is not None: + mask = np.vstack(( + mask, + np.zeros((bottom - size_y, mask.shape[1], mask.shape[2])).astype(np.uint8))) + + return patch, mask def get_center_of_mass_and_label(mask: np.ndarray) -> np.ndarray: @@ -133,7 +150,8 @@ def get_centered_patches(img, mask, p_size: int, noise_intensity=5, - mask_class=None): + mask_class=None, + include_mask=False): ''' Extracts centered patches from the input image based on the centers of objects identified in the mask. @@ -155,12 +173,16 @@ def get_centered_patches(img, # Crop patches around each center of mass for c, l in zip(centers_of_mass, instance_labels): c_x, c_y = c - patch = crop_centered_padded_patch(img.copy(), + patch, patch_mask = crop_centered_padded_patch(img.copy(), (c_x, c_y), (p_size, p_size), l, mask=mask, noise_intensity=noise_intensity) + if include_mask: + patch = np.concatenate((patch, patch_mask), axis=-1) + + patches.append(patch) if mask_class is not None: # get the class instance for the specific object @@ -205,13 +227,14 @@ def find_max_patch_size(mask): return max_patch_size_edge -def create_patch_dataset(imgs, masks_classes, masks_instances, noise_intensity, max_patch_size): +def create_patch_dataset(imgs, masks_classes, masks_instances, noise_intensity, max_patch_size, include_mask): ''' Splits img and masks into patches of equal size which are centered around the cells. If patch_size is not given, the algorithm should first run through all images to find the max cell size, and use the max cell size to define the patch size. All patches and masks should then be returned in the same format as imgs and masks (same type, i.e. check if tensor or np.array and same convention of dims, e.g. CxHxW) + include_mask(bool) : Flag indicating whether to include the mask along with patches. ''' if max_patch_size is None: @@ -226,7 +249,8 @@ def create_patch_dataset(imgs, masks_classes, masks_instances, noise_intensity, mask_instance, max_patch_size, noise_intensity=noise_intensity, - mask_class=mask_class) + mask_class=mask_class, + include_mask = include_mask) patches.extend(patch) labels.extend(label) return patches, labels \ No newline at end of file From 0488e8decb1d08de54fbfa56e19e5c29407b14c8 Mon Sep 17 00:00:00 2001 From: Mariia Date: Sun, 7 Jan 2024 21:13:10 +0100 Subject: [PATCH 2/5] Mask scale fix --- src/server/dcp_server/utils.py | 6 +++++- src/server/test/test_config.cfg | 2 +- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/server/dcp_server/utils.py b/src/server/dcp_server/utils.py index 335517c..8b6d07f 100644 --- a/src/server/dcp_server/utils.py +++ b/src/server/dcp_server/utils.py @@ -3,6 +3,7 @@ import numpy as np from scipy.ndimage import find_objects, center_of_mass from skimage import measure +from copy import deepcopy def read_config(name, config_path = 'config.cfg') -> dict: """Reads the configuration file @@ -72,6 +73,8 @@ def crop_centered_padded_patch(x: np.ndarray, x[m] = 0 if noise_intensity is not None: x[m] = np.random.normal(scale=noise_intensity, size=x[m].shape) + if mask is not None: + mask[m] = 0 patch = x[max(top, 0):min(bottom, x.shape[0]), max(left, 0):min(right, x.shape[1]), :] if mask is not None: @@ -177,9 +180,10 @@ def get_centered_patches(img, (c_x, c_y), (p_size, p_size), l, - mask=mask, + mask=deepcopy(mask), noise_intensity=noise_intensity) if include_mask: + patch_mask = 255 * (patch_mask > 0).astype(np.uint8) patch = np.concatenate((patch, patch_mask), axis=-1) diff --git a/src/server/test/test_config.cfg b/src/server/test/test_config.cfg index 04073ed..4373d04 100644 --- a/src/server/test/test_config.cfg +++ b/src/server/test/test_config.cfg @@ -21,7 +21,7 @@ "in_channels": 1, "num_classes": 3, "black_bg": "False", - "include_mask": "False" + "include_mask": "True" } }, From def908be78a4819051c64eb134dbd62fdcf271a7 Mon Sep 17 00:00:00 2001 From: Mariia Date: Wed, 17 Jan 2024 13:22:08 +0100 Subject: [PATCH 3/5] Add test for max patch size --- src/server/test/test_utils.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) create mode 100644 src/server/test/test_utils.py diff --git a/src/server/test/test_utils.py b/src/server/test/test_utils.py new file mode 100644 index 0000000..35678a2 --- /dev/null +++ b/src/server/test/test_utils.py @@ -0,0 +1,20 @@ +import numpy as np +import pytest +from dcp_server.utils import find_max_patch_size + +@pytest.fixture +def sample_mask(): + mask = np.zeros((10, 10), dtype=np.uint8) + mask[2:6, 3:7] = 1 + mask[7:9, 2:5] = 1 + return mask + +def test_find_max_patch_size(sample_mask): + # Test when the function is called with a sample mask + result = find_max_patch_size(sample_mask) + assert isinstance(result, float) + assert result > 0 + + + + From e7566692ec45bda92f9b2c5729c5182eb10bed56 Mon Sep 17 00:00:00 2001 From: Mariia Date: Sun, 21 Jan 2024 21:56:55 +0100 Subject: [PATCH 4/5] Change include_mask parameter to False --- src/server/dcp_server/config.cfg | 2 +- src/server/test/test_config.cfg | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/server/dcp_server/config.cfg b/src/server/dcp_server/config.cfg index 1af83c1..1771417 100644 --- a/src/server/dcp_server/config.cfg +++ b/src/server/dcp_server/config.cfg @@ -21,7 +21,7 @@ "in_channels": 1, "num_classes": 3, "black_bg": "False", - "include_mask": "True" + "include_mask": "False" } }, diff --git a/src/server/test/test_config.cfg b/src/server/test/test_config.cfg index 4373d04..04073ed 100644 --- a/src/server/test/test_config.cfg +++ b/src/server/test/test_config.cfg @@ -21,7 +21,7 @@ "in_channels": 1, "num_classes": 3, "black_bg": "False", - "include_mask": "True" + "include_mask": "False" } }, From e37b65c578235f1cccda1657d2d1356987630ac7 Mon Sep 17 00:00:00 2001 From: Christina Bukas Date: Wed, 31 Jan 2024 10:55:15 +0100 Subject: [PATCH 5/5] disable inlcude mask for RF model --- src/server/dcp_server/models.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/server/dcp_server/models.py b/src/server/dcp_server/models.py index 1f37256..b7e41e4 100644 --- a/src/server/dcp_server/models.py +++ b/src/server/dcp_server/models.py @@ -270,6 +270,8 @@ def __init__(self, model_config, train_config, eval_config, model_name): self.classifier = CellClassifierShallowModel(self.model_config, self.train_config, self.eval_config) + # make sure include mask is set to False if we are using the random forest model + self.include_mask = False def update_configs(self, train_config, eval_config): self.train_config = train_config