Skip to content

Commit

Permalink
Merge pull request #66 from HelmholtzAI-Consultants-Munich/include-mask
Browse files Browse the repository at this point in the history
Add include mask option to the classifier and tested it
  • Loading branch information
christinab12 authored Jan 31, 2024
2 parents e904d6b + e37b65c commit 4d14a2b
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 9 deletions.
13 changes: 10 additions & 3 deletions src/server/dcp_server/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,9 @@ def __init__(self, model_config, train_config, eval_config):

self.train_config = train_config["classifier"]
self.eval_config = eval_config["classifier"]

self.include_mask = model_config["classifier"]["include_mask"]
self.in_channels = self.in_channels + 1 if self.include_mask else self.in_channels

self.layer1 = nn.Sequential(
nn.Conv2d(self.in_channels, 16, 3, 2, 5),
Expand Down Expand Up @@ -248,8 +251,8 @@ def __init__(self, model_config, train_config, eval_config, model_name):
self.model_config = model_config
self.train_config = train_config
self.eval_config = eval_config
self.include_mask = self.model_config["classifier"]["include_mask"]
self.model_name = model_name

self.classifier_class = self.model_config.get("classifier").get("model_class", "CellClassifierFCNN")

# Initialize the cellpose model and the classifier
Expand All @@ -267,6 +270,8 @@ def __init__(self, model_config, train_config, eval_config, model_name):
self.classifier = CellClassifierShallowModel(self.model_config,
self.train_config,
self.eval_config)
# make sure include mask is set to False if we are using the random forest model
self.include_mask = False

def update_configs(self, train_config, eval_config):
self.train_config = train_config
Expand All @@ -291,7 +296,8 @@ def train(self, imgs, masks):
masks_classes,
masks_instances,
noise_intensity = self.train_config["classifier"]["train_data"]["noise_intensity"],
max_patch_size = self.train_config["classifier"]["train_data"]["patch_size"])
max_patch_size = self.train_config["classifier"]["train_data"]["patch_size"],
include_mask = self.include_mask)
x = patches
if self.classifier_class == "RandomForest":
x = create_dataset_for_rf(patches, patch_masks)
Expand Down Expand Up @@ -319,7 +325,8 @@ def eval(self, img):
patches, patch_masks, instance_labels, _ = get_centered_patches(img,
instance_mask,
max_patch_size,
noise_intensity=noise_intensity)
noise_intensity=noise_intensity,
include_mask=self.include_mask)
x = patches
if self.classifier_class == "RandomForest":
x = create_dataset_for_rf(patches, patch_masks)
Expand Down
21 changes: 15 additions & 6 deletions src/server/dcp_server/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np
from scipy.ndimage import find_objects
from skimage import measure
from copy import deepcopy
import SimpleITK as sitk
from radiomics import shape2D

Expand Down Expand Up @@ -154,16 +155,19 @@ def get_centered_patches(img,
mask,
p_size: int,
noise_intensity=5,
mask_class=None):
mask_class=None,
include_mask=False):

'''
Extracts centered patches from the input image based on the centers of objects identified in the mask.
Args:
img: The input image.
mask: The mask representing the objects in the image.
img (np.array): The input image.
mask (np.array): The mask representing the objects in the image.
p_size (int): The size of the patches to extract.
noise_intensity: The intensity of noise to add to the patches.
noise_intensity (float): The intensity of noise to add to the patches.
mask_class (int): The class represented in the patch
include_mask (bool): Whether or not to include mask as input argument to model
'''

Expand All @@ -182,6 +186,10 @@ def get_centered_patches(img,
obj_label,
mask=deepcopy(mask),
noise_intensity=noise_intensity)
if include_mask:
patch_mask = 255 * (patch_mask > 0).astype(np.uint8)
patch = np.concatenate((patch, patch_mask), axis=-1)

patches.append(patch)
patch_masks.append(patch_mask)
if mask_class is not None:
Expand Down Expand Up @@ -227,13 +235,14 @@ def find_max_patch_size(mask):

return max_patch_size_edge

def create_patch_dataset(imgs, masks_classes, masks_instances, noise_intensity, max_patch_size):
def create_patch_dataset(imgs, masks_classes, masks_instances, noise_intensity, max_patch_size, include_mask):
'''
Splits img and masks into patches of equal size which are centered around the cells.
If patch_size is not given, the algorithm should first run through all images to find the max cell size, and use
the max cell size to define the patch size. All patches and masks should then be returned
in the same format as imgs and masks (same type, i.e. check if tensor or np.array and same
convention of dims, e.g. CxHxW)
include_mask(bool) : Flag indicating whether to include the mask along with patches.
'''

if max_patch_size is None:
Expand All @@ -249,7 +258,7 @@ def create_patch_dataset(imgs, masks_classes, masks_instances, noise_intensity,
max_patch_size,
noise_intensity=noise_intensity,
mask_class=mask_class,
)
include_mask = include_mask)
patches.extend(patch)
patch_masks.extend(patch_mask)
labels.extend(label)
Expand Down
20 changes: 20 additions & 0 deletions src/server/test/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import numpy as np
import pytest
from dcp_server.utils import find_max_patch_size

@pytest.fixture
def sample_mask():
mask = np.zeros((10, 10), dtype=np.uint8)
mask[2:6, 3:7] = 1
mask[7:9, 2:5] = 1
return mask

def test_find_max_patch_size(sample_mask):
# Test when the function is called with a sample mask
result = find_max_patch_size(sample_mask)
assert isinstance(result, float)
assert result > 0




0 comments on commit 4d14a2b

Please sign in to comment.