Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add include mask option to the classifier and tested it #66

Merged
merged 6 commits into from
Jan 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions src/server/dcp_server/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,9 @@ def __init__(self, model_config, train_config, eval_config):

self.train_config = train_config["classifier"]
self.eval_config = eval_config["classifier"]

self.include_mask = model_config["classifier"]["include_mask"]
self.in_channels = self.in_channels + 1 if self.include_mask else self.in_channels

self.layer1 = nn.Sequential(
nn.Conv2d(self.in_channels, 16, 3, 2, 5),
Expand Down Expand Up @@ -248,8 +251,8 @@ def __init__(self, model_config, train_config, eval_config, model_name):
self.model_config = model_config
self.train_config = train_config
self.eval_config = eval_config
self.include_mask = self.model_config["classifier"]["include_mask"]
self.model_name = model_name

self.classifier_class = self.model_config.get("classifier").get("model_class", "CellClassifierFCNN")

# Initialize the cellpose model and the classifier
Expand All @@ -267,6 +270,8 @@ def __init__(self, model_config, train_config, eval_config, model_name):
self.classifier = CellClassifierShallowModel(self.model_config,
self.train_config,
self.eval_config)
# make sure include mask is set to False if we are using the random forest model
self.include_mask = False

def update_configs(self, train_config, eval_config):
self.train_config = train_config
Expand All @@ -291,7 +296,8 @@ def train(self, imgs, masks):
masks_classes,
masks_instances,
noise_intensity = self.train_config["classifier"]["train_data"]["noise_intensity"],
max_patch_size = self.train_config["classifier"]["train_data"]["patch_size"])
max_patch_size = self.train_config["classifier"]["train_data"]["patch_size"],
include_mask = self.include_mask)
x = patches
if self.classifier_class == "RandomForest":
x = create_dataset_for_rf(patches, patch_masks)
Expand Down Expand Up @@ -319,7 +325,8 @@ def eval(self, img):
patches, patch_masks, instance_labels, _ = get_centered_patches(img,
instance_mask,
max_patch_size,
noise_intensity=noise_intensity)
noise_intensity=noise_intensity,
include_mask=self.include_mask)
x = patches
if self.classifier_class == "RandomForest":
x = create_dataset_for_rf(patches, patch_masks)
Expand Down
21 changes: 15 additions & 6 deletions src/server/dcp_server/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np
from scipy.ndimage import find_objects
from skimage import measure
from copy import deepcopy
import SimpleITK as sitk
from radiomics import shape2D

Expand Down Expand Up @@ -154,16 +155,19 @@ def get_centered_patches(img,
mask,
p_size: int,
noise_intensity=5,
mask_class=None):
mask_class=None,
include_mask=False):

'''
Extracts centered patches from the input image based on the centers of objects identified in the mask.

Args:
img: The input image.
mask: The mask representing the objects in the image.
img (np.array): The input image.
mask (np.array): The mask representing the objects in the image.
p_size (int): The size of the patches to extract.
noise_intensity: The intensity of noise to add to the patches.
noise_intensity (float): The intensity of noise to add to the patches.
mask_class (int): The class represented in the patch
include_mask (bool): Whether or not to include mask as input argument to model

'''

Expand All @@ -182,6 +186,10 @@ def get_centered_patches(img,
obj_label,
mask=deepcopy(mask),
noise_intensity=noise_intensity)
if include_mask:
patch_mask = 255 * (patch_mask > 0).astype(np.uint8)
patch = np.concatenate((patch, patch_mask), axis=-1)

patches.append(patch)
patch_masks.append(patch_mask)
if mask_class is not None:
Expand Down Expand Up @@ -227,13 +235,14 @@ def find_max_patch_size(mask):

return max_patch_size_edge

def create_patch_dataset(imgs, masks_classes, masks_instances, noise_intensity, max_patch_size):
def create_patch_dataset(imgs, masks_classes, masks_instances, noise_intensity, max_patch_size, include_mask):
'''
Splits img and masks into patches of equal size which are centered around the cells.
If patch_size is not given, the algorithm should first run through all images to find the max cell size, and use
the max cell size to define the patch size. All patches and masks should then be returned
in the same format as imgs and masks (same type, i.e. check if tensor or np.array and same
convention of dims, e.g. CxHxW)
include_mask(bool) : Flag indicating whether to include the mask along with patches.
'''

if max_patch_size is None:
Expand All @@ -249,7 +258,7 @@ def create_patch_dataset(imgs, masks_classes, masks_instances, noise_intensity,
max_patch_size,
noise_intensity=noise_intensity,
mask_class=mask_class,
)
include_mask = include_mask)
patches.extend(patch)
patch_masks.extend(patch_mask)
labels.extend(label)
Expand Down
20 changes: 20 additions & 0 deletions src/server/test/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import numpy as np
import pytest
from dcp_server.utils import find_max_patch_size

@pytest.fixture
def sample_mask():
mask = np.zeros((10, 10), dtype=np.uint8)
mask[2:6, 3:7] = 1
mask[7:9, 2:5] = 1
return mask

def test_find_max_patch_size(sample_mask):
# Test when the function is called with a sample mask
result = find_max_patch_size(sample_mask)
assert isinstance(result, float)
assert result > 0




Loading