diff --git a/src/server/dcp_server/models.py b/src/server/dcp_server/models.py index cf015ee..b7e41e4 100644 --- a/src/server/dcp_server/models.py +++ b/src/server/dcp_server/models.py @@ -131,6 +131,9 @@ def __init__(self, model_config, train_config, eval_config): self.train_config = train_config["classifier"] self.eval_config = eval_config["classifier"] + + self.include_mask = model_config["classifier"]["include_mask"] + self.in_channels = self.in_channels + 1 if self.include_mask else self.in_channels self.layer1 = nn.Sequential( nn.Conv2d(self.in_channels, 16, 3, 2, 5), @@ -248,8 +251,8 @@ def __init__(self, model_config, train_config, eval_config, model_name): self.model_config = model_config self.train_config = train_config self.eval_config = eval_config + self.include_mask = self.model_config["classifier"]["include_mask"] self.model_name = model_name - self.classifier_class = self.model_config.get("classifier").get("model_class", "CellClassifierFCNN") # Initialize the cellpose model and the classifier @@ -267,6 +270,8 @@ def __init__(self, model_config, train_config, eval_config, model_name): self.classifier = CellClassifierShallowModel(self.model_config, self.train_config, self.eval_config) + # make sure include mask is set to False if we are using the random forest model + self.include_mask = False def update_configs(self, train_config, eval_config): self.train_config = train_config @@ -291,7 +296,8 @@ def train(self, imgs, masks): masks_classes, masks_instances, noise_intensity = self.train_config["classifier"]["train_data"]["noise_intensity"], - max_patch_size = self.train_config["classifier"]["train_data"]["patch_size"]) + max_patch_size = self.train_config["classifier"]["train_data"]["patch_size"], + include_mask = self.include_mask) x = patches if self.classifier_class == "RandomForest": x = create_dataset_for_rf(patches, patch_masks) @@ -319,7 +325,8 @@ def eval(self, img): patches, patch_masks, instance_labels, _ = get_centered_patches(img, instance_mask, max_patch_size, - noise_intensity=noise_intensity) + noise_intensity=noise_intensity, + include_mask=self.include_mask) x = patches if self.classifier_class == "RandomForest": x = create_dataset_for_rf(patches, patch_masks) diff --git a/src/server/dcp_server/utils.py b/src/server/dcp_server/utils.py index 3e22670..ff18b63 100644 --- a/src/server/dcp_server/utils.py +++ b/src/server/dcp_server/utils.py @@ -4,6 +4,7 @@ import numpy as np from scipy.ndimage import find_objects from skimage import measure +from copy import deepcopy import SimpleITK as sitk from radiomics import shape2D @@ -154,16 +155,19 @@ def get_centered_patches(img, mask, p_size: int, noise_intensity=5, - mask_class=None): + mask_class=None, + include_mask=False): ''' Extracts centered patches from the input image based on the centers of objects identified in the mask. Args: - img: The input image. - mask: The mask representing the objects in the image. + img (np.array): The input image. + mask (np.array): The mask representing the objects in the image. p_size (int): The size of the patches to extract. - noise_intensity: The intensity of noise to add to the patches. + noise_intensity (float): The intensity of noise to add to the patches. + mask_class (int): The class represented in the patch + include_mask (bool): Whether or not to include mask as input argument to model ''' @@ -182,6 +186,10 @@ def get_centered_patches(img, obj_label, mask=deepcopy(mask), noise_intensity=noise_intensity) + if include_mask: + patch_mask = 255 * (patch_mask > 0).astype(np.uint8) + patch = np.concatenate((patch, patch_mask), axis=-1) + patches.append(patch) patch_masks.append(patch_mask) if mask_class is not None: @@ -227,13 +235,14 @@ def find_max_patch_size(mask): return max_patch_size_edge -def create_patch_dataset(imgs, masks_classes, masks_instances, noise_intensity, max_patch_size): +def create_patch_dataset(imgs, masks_classes, masks_instances, noise_intensity, max_patch_size, include_mask): ''' Splits img and masks into patches of equal size which are centered around the cells. If patch_size is not given, the algorithm should first run through all images to find the max cell size, and use the max cell size to define the patch size. All patches and masks should then be returned in the same format as imgs and masks (same type, i.e. check if tensor or np.array and same convention of dims, e.g. CxHxW) + include_mask(bool) : Flag indicating whether to include the mask along with patches. ''' if max_patch_size is None: @@ -249,7 +258,7 @@ def create_patch_dataset(imgs, masks_classes, masks_instances, noise_intensity, max_patch_size, noise_intensity=noise_intensity, mask_class=mask_class, - ) + include_mask = include_mask) patches.extend(patch) patch_masks.extend(patch_mask) labels.extend(label) diff --git a/src/server/test/test_utils.py b/src/server/test/test_utils.py new file mode 100644 index 0000000..35678a2 --- /dev/null +++ b/src/server/test/test_utils.py @@ -0,0 +1,20 @@ +import numpy as np +import pytest +from dcp_server.utils import find_max_patch_size + +@pytest.fixture +def sample_mask(): + mask = np.zeros((10, 10), dtype=np.uint8) + mask[2:6, 3:7] = 1 + mask[7:9, 2:5] = 1 + return mask + +def test_find_max_patch_size(sample_mask): + # Test when the function is called with a sample mask + result = find_max_patch_size(sample_mask) + assert isinstance(result, float) + assert result > 0 + + + +