diff --git a/src/client/dcp_client/app.py b/src/client/dcp_client/app.py index 3a7f927..b5e89c2 100644 --- a/src/client/dcp_client/app.py +++ b/src/client/dcp_client/app.py @@ -36,7 +36,7 @@ def search_segs(self, img_directory, cur_selected_img): """Returns a list of full paths of segmentations for an image""" # Take all segmentations of the image from the current directory: search_string = utils.get_path_stem(cur_selected_img) + '_seg' - seg_files = [file_name for file_name in os.listdir(img_directory) if search_string in file_name] + seg_files = [file_name for file_name in os.listdir(img_directory) if (search_string == utils.get_path_stem(file_name) or str(file_name).startswith(search_string))] return seg_files diff --git a/src/client/requirements.txt b/src/client/requirements.txt index 4f420df..e92f6e4 100644 --- a/src/client/requirements.txt +++ b/src/client/requirements.txt @@ -1,2 +1,3 @@ napari[pyqt5]>=0.4.17 -bentoml[grpc]>=1.0.13 \ No newline at end of file +bentoml[grpc]>=1.0.13 +pytest>=7.4.3 \ No newline at end of file diff --git a/src/client/test/test_app.py b/src/client/test/test_app.py index b9164c1..f2a29ca 100644 --- a/src/client/test/test_app.py +++ b/src/client/test/test_app.py @@ -1,62 +1,69 @@ import os +import sys from skimage import data from skimage.io import imsave -import unittest +import pytest + +sys.path.append("../") from dcp_client.app import Application from dcp_client.utils.bentoml_model import BentomlModel from dcp_client.utils.fsimagestorage import FilesystemImageStorage from dcp_client.utils.sync_src_dst import DataRSync -class TestApplication(unittest.TestCase): - - def test_run_train(self): - pass - - def test_run_inference(self): - pass - - def test_load_image(self): - - img = data.astronaut() - img2 = data.cat() - os.mkdir('in_prog') - imsave('in_prog/test_img.png', img) - imsave('in_prog/test_img2.png', img2) - rsyncer = DataRSync(user_name="local", - host_name="local", - server_repo_path='.') - self.app = Application(BentomlModel(), - rsyncer, - FilesystemImageStorage(), - "0.0.0.0", - 7010) - - self.app.cur_selected_img = 'test_img.png' - self.app.cur_selected_path = 'in_prog' - - img_test = self.app.load_image() # if image_name is None - self.assertEqual(img.all(), img_test.all()) - img_test2 = self.app.load_image('test_img2.png') # if a filename is given - self.assertEqual(img2.all(), img_test2.all()) - - # delete everyting we created - os.remove('in_prog/test_img.png') - os.remove('in_prog/test_img2.png') - os.rmdir('in_prog') - - def test_save_image(self): - pass - - def test_move_images(self): - pass - - def test_delete_images(self): - pass - - def test_search_segs(self): - pass + +@pytest.fixture +def app(): + img = data.astronaut() + img2 = data.cat() + os.mkdir('in_prog') + + imsave('in_prog/test_img.png', img) + imsave('in_prog/test_img2.png', img2) + + rsyncer = DataRSync(user_name="local", host_name="local", server_repo_path='.') + app = Application(BentomlModel(), rsyncer, FilesystemImageStorage(), "0.0.0.0", 7010) + + app.cur_selected_img = 'test_img.png' + app.cur_selected_path = 'in_prog' + + return app, img, img2 + +def test_load_image(app): + app, img, img2 = app # Unpack the app, img, and img2 from the fixture + + img_test = app.load_image() # if image_name is None + assert img.all() == img_test.all() + + img_test2 = app.load_image('test_img2.png') # if a filename is given + assert img2.all() == img_test2.all() + + # delete everything we created + os.remove('in_prog/test_img.png') + os.remove('in_prog/test_img2.png') + os.rmdir('in_prog') + +def test_run_train(): + pass + +def test_run_inference(): + pass + +def test_save_image(): + pass + +def test_move_images(): + pass + +def test_delete_images(): + pass + +def test_search_segs(): + pass + + + + + -if __name__=='__main__': - unittest.main() \ No newline at end of file diff --git a/src/client/test/test_fsimagestorage.py b/src/client/test/test_fsimagestorage.py index 8a6fb9f..275e5f0 100644 --- a/src/client/test/test_fsimagestorage.py +++ b/src/client/test/test_fsimagestorage.py @@ -1,51 +1,46 @@ import os +import pytest from skimage.io import imsave from skimage import data -import unittest from dcp_client.utils.fsimagestorage import FilesystemImageStorage - -class TestFilesystemImageStorage(unittest.TestCase): - - def test_load_image(self): - fis = FilesystemImageStorage() - img = data.astronaut() - fname = 'test_img.png' - imsave(fname, img) - img_test = fis.load_image('.', fname) - self.assertEqual(img.all(), img_test.all()) - os.remove(fname) - - def test_move_image(self): - fis = FilesystemImageStorage() - img = data.astronaut() - fname = 'test_img.png' - os.mkdir('temp') - imsave(fname, img) - fis.move_image('.', 'temp', fname) - self.assertTrue(os.path.exists('temp/test_img.png')) - os.remove('temp/test_img.png') - os.rmdir('temp') - - def test_save_image(self): - fis = FilesystemImageStorage() - img = data.astronaut() - fname = 'test_img.png' - fis.save_image('.', fname, img) - self.assertTrue(os.path.exists(fname)) - os.remove(fname) - - def test_delete_image(self): - fis = FilesystemImageStorage() - img = data.astronaut() - fname = 'test_img.png' - os.mkdir('temp') - imsave('temp/test_img.png', img) - fis.delete_image('temp', fname) - self.assertFalse(os.path.exists('temp/test_img.png')) - os.rmdir('temp') - - -if __name__=='__main__': - unittest.main() \ No newline at end of file +@pytest.fixture +def fis(): + return FilesystemImageStorage() + +@pytest.fixture +def sample_image(): + # Create a sample image + img = data.astronaut() + fname = 'test_img.png' + imsave(fname, img) + return fname + +def test_load_image(fis, sample_image): + img_test = fis.load_image('.', sample_image) + assert img_test.all() == data.astronaut().all() + os.remove(sample_image) + +def test_move_image(fis, sample_image): + temp_dir = 'temp' + os.mkdir(temp_dir) + fis.move_image('.', temp_dir, sample_image) + assert os.path.exists(os.path.join(temp_dir, 'test_img.png')) + os.remove(os.path.join(temp_dir, 'test_img.png')) + os.rmdir(temp_dir) + +def test_save_image(fis): + img = data.astronaut() + fname = 'output.png' + fis.save_image('.', fname, img) + assert os.path.exists(fname) + os.remove(fname) + +def test_delete_image(fis, sample_image): + temp_dir = 'temp' + os.mkdir(temp_dir) + fis.move_image('.', temp_dir, sample_image) + fis.delete_image(temp_dir, 'test_img.png') + assert not os.path.exists(os.path.join(temp_dir, 'test_img.png')) + os.rmdir(temp_dir) diff --git a/src/server/dcp_server/config.cfg b/src/server/dcp_server/config.cfg index 01efeb1..36a2306 100644 --- a/src/server/dcp_server/config.cfg +++ b/src/server/dcp_server/config.cfg @@ -1,26 +1,66 @@ { - "setup":{ - "segmentation": "GeneralSegmentation", - "accepted_types": [".jpg", ".jpeg", ".png", ".tiff", ".tif"], - "seg_name_string": "_seg" + "setup": { + "segmentation": "GeneralSegmentation", + "accepted_types": [".jpg", ".jpeg", ".png", ".tiff", ".tif"], + "seg_name_string": "_seg" }, - "service":{ - "model_to_use": "CustomCellposeModel", + + "service": { + "model_to_use": "CellposePatchCNN", "save_model_path": "mytrainedmodel", "runner_name": "cellpose_runner", "service_name": "data-centric-platform", "port": 7010 }, - "model": { - "model_type":"cyto" + + "model": { + "segmentor": { + "model_type": "cyto" + }, + "classifier":{ + "in_channels": 1, + "num_classes": 3, + "black_bg": "False", + "include_mask": "False" + } }, + "data": { - "data_root": "/home/ubuntu/dcp-data" + "data_root": "data" }, + "train":{ - "n_epochs": 2, - "channels":[0] + "segmentor":{ + "n_epochs": 7, + "channels": [0,0], + "min_train_masks": 1 + }, + "classifier":{ + "train_data":{ + "patch_size": 64, + "noise_intensity": 5, + "num_classes": 3 + }, + "n_epochs": 8, + "lr": 0.001, + "batch_size": 1, + "optimizer": "Adam" + } }, + "eval":{ + "segmentor": { + "z_axis": null, + "channel_axis": null, + "rescale": 1, + "batch_size": 1 + }, + "classifier": { + "data":{ + "patch_size": 64, + "noise_intensity": 5 + } + }, + "mask_channel_axis": 0 } } \ No newline at end of file diff --git a/src/server/dcp_server/fsimagestorage.py b/src/server/dcp_server/fsimagestorage.py index 222b964..bec1b56 100644 --- a/src/server/dcp_server/fsimagestorage.py +++ b/src/server/dcp_server/fsimagestorage.py @@ -13,15 +13,17 @@ class FilesystemImageStorage(): def __init__(self, data_root): self.root_dir = data_root - def load_image(self, cur_selected_img): + def load_image(self, cur_selected_img, is_gray=True): """Load the image (using skiimage) :param cur_selected_img: full path of the image that needs to be loaded :type cur_selected_img: str :return: loaded image :rtype: ndarray - """ - return imread(os.path.join(self.root_dir , cur_selected_img)) + """ + try: + return imread(os.path.join(self.root_dir , cur_selected_img), as_gray=is_gray) + except ValueError: return None def save_image(self, to_save_path, img): """Save given image (using skiimage) @@ -60,7 +62,11 @@ def search_segs(self, cur_selected_img): img_directory = utils.get_path_parent(os.path.join(self.root_dir, cur_selected_img)) # Take all segmentations of the image from the current directory: search_string = utils.get_path_stem(cur_selected_img) + setup_config['seg_name_string'] - seg_files = [os.path.join(img_directory, file_name) for file_name in os.listdir(img_directory) if search_string in file_name] + #seg_files = [os.path.join(img_directory, file_name) for file_name in os.listdir(img_directory) if search_string in file_name] + # TODO: check where this is used - copied the command from app's search_segs function (to fix the 1_seg and 11_seg bug) + + seg_files = [os.path.join(img_directory, file_name) for file_name in os.listdir(img_directory) if (search_string == utils.get_path_stem(file_name) or str(file_name).startswith(search_string))] + return seg_files def get_image_seg_pairs(self, directory): @@ -100,28 +106,31 @@ def get_image_size_properties(self, img, file_extension): :return: size properties: - height - width - - channel_ax + - z_axis """ orig_size = img.shape - # png and jpeg will be RGB by default and 2D - # tif can be grayscale 2D or 2D RGB and RGBA - if file_extension in (".jpg", ".jpeg", ".png") or (file_extension in (".tiff", ".tif") and len(orig_size)==2 or (len(orig_size)==3 and (orig_size[-1]==3 or orig_size[-1]==4))): + # png and jpeg will be RGB by default and 2D + # tif can be grayscale 2D or 3D [Z, H, W] + # image channels have already been removed in imread with is_gray=True + if file_extension in (".jpg", ".jpeg", ".png"): height, width = orig_size[0], orig_size[1] - channel_ax = None - # or 3D tiff grayscale + z_axis = None + elif file_extension in (".tiff", ".tif") and len(orig_size)==2: + height, width = orig_size[0], orig_size[1] + z_axis = None + # if we have 3 dimensions the [Z, H, W] elif file_extension in (".tiff", ".tif") and len(orig_size)==3: - print('Warning: 3D image stack found. We are assuming your first dimension is your stack dimension. Please cross check this.') - height, width = orig_size[1], orig_size[2] - channel_ax = 0 - + print('Warning: 3D image stack found. We are assuming your first dimension is your stack dimension. Please cross check this.') + height, width = orig_size[1], orig_size[2] + z_axis = 0 else: - pass + print('File not currently supported. See documentation for accepted types') - return height, width, channel_ax + return height, width, z_axis - def rescale_image(self, img, height, width, channel_ax): + def rescale_image(self, img, height, width, channel_ax=None, order=2): """rescale image :param img: image @@ -137,7 +146,7 @@ def rescale_image(self, img, height, width, channel_ax): """ max_dim = max(height, width) rescale_factor = max_dim/512 - return rescale(img, 1/rescale_factor, channel_axis=channel_ax) + return rescale(img, 1/rescale_factor, order=order, channel_axis=channel_ax) def resize_image(self, img, height, width, order): """resize image @@ -166,6 +175,6 @@ def prepare_images_and_masks_for_training(self, train_img_mask_pairs): imgs=[] masks=[] for img_file, mask_file in train_img_mask_pairs: - imgs.append(rgb2gray(imread(img_file))) + imgs.append(self.load_image(img_file)) masks.append(imread(mask_file)) return imgs, masks \ No newline at end of file diff --git a/src/server/dcp_server/models.py b/src/server/dcp_server/models.py index 8125df2..caf5226 100644 --- a/src/server/dcp_server/models.py +++ b/src/server/dcp_server/models.py @@ -1,7 +1,18 @@ from cellpose import models, utils +import torch +from torch import nn +from torch.optim import Adam +from torch.utils.data import TensorDataset, DataLoader +from copy import deepcopy +from tqdm import tqdm +import numpy as np + +from cellpose.metrics import aggregated_jaccard_index + #from segment_anything import SamPredictor, sam_model_registry #from segment_anything.automatic_mask_generator import SamAutomaticMaskGenerator +from dcp_server.utils import get_centered_patches, find_max_patch_size, create_patch_dataset, get_objects class CustomCellposeModel(models.CellposeModel): """Custom cellpose model inheriting the attributes and functions from the original CellposeModel and implementing @@ -21,23 +32,21 @@ def __init__(self, model_config, train_config, eval_config): """ # Initialize the cellpose model - super().__init__(**model_config) + super().__init__(**model_config["segmentor"]) self.train_config = train_config self.eval_config = eval_config - def eval(self, img, **eval_config): + def eval(self, img): """Evaluate the model - find mask of the given image Calls the original eval function. :param img: image to evaluate on :type img: np.ndarray - :param z_axis: z dimension (optional, default is None) - :type z_axis: int :return: mask of the image, list of 2D arrays, or single 3D array (if do_3D=True) labelled image. :rtype: np.ndarray - """ - return super().eval(x=img, **eval_config)[0] # 0 to take only mask - + """ + return super().eval(x=img, **self.eval_config["segmentor"])[0] # 0 to take only mask + def train(self, imgs, masks): """Trains the given model Calls the original train function. @@ -47,7 +56,21 @@ def train(self, imgs, masks): :param masks: masks of the given images (training labels) :type masks: List[np.ndarray] """ - super().train(train_data=imgs, train_labels=masks, **self.train_config) + + if not isinstance(masks, np.ndarray): + masks = np.array(masks) + + if masks[0].shape[0] == 2: + masks = list(masks[:,0,...]) + + super().train(train_data=deepcopy(imgs), train_labels=masks, **self.train_config["segmentor"]) + + pred_masks = [self.eval(img) for img in masks] + print(len(pred_masks)) + self.metric = np.mean(aggregated_jaccard_index(masks, pred_masks)) + # pred_masks = [self.eval(img) for img in masks] + + # self.loss = self.loss_fn(masks, pred_masks) def masks_to_outlines(self, mask): """ get outlines of masks as a 0-1 array @@ -61,11 +84,214 @@ def masks_to_outlines(self, mask): return utils.masks_to_outlines(mask) #[True, False] outputs +class CellClassifierFCNN(nn.Module): + + ''' + Fully convolutional classifier for cell images. + + Args: + model_config (dict): Model configuration. + train_config (dict): Training configuration. + eval_config (dict): Evaluation configuration. + + ''' + + def __init__(self, model_config, train_config, eval_config): + super().__init__() + + self.in_channels = model_config["classifier"]["in_channels"] + self.num_classes = model_config["classifier"]["num_classes"] + + self.train_config = train_config["classifier"] + self.eval_config = eval_config["classifier"] + + self.layer1 = nn.Sequential( + nn.Conv2d(self.in_channels, 16, 3, 2, 5), + nn.BatchNorm2d(16), + nn.ReLU(), + nn.Dropout2d(p=0.2), + ) + + self.layer2 = nn.Sequential( + nn.Conv2d(16, 64, 3, 1, 1), + nn.BatchNorm2d(64), + nn.ReLU(), + nn.Dropout2d(p=0.2), + ) + + self.layer3 = nn.Sequential( + nn.Conv2d(64, 128, 3, 2, 4), + nn.BatchNorm2d(128), + nn.ReLU(), + nn.Dropout2d(p=0.2), + ) + self.final_conv = nn.Conv2d(128, self.num_classes, 1) + self.pooling = nn.AdaptiveMaxPool2d(1) + + def forward(self, x): + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + + x = self.final_conv(x) + x = self.pooling(x) + x = x.view(x.size(0), -1) + return x + + def train (self, imgs, labels): + """ + input: + 1) imgs - List[np.ndarray[np.uint8]] with shape (3, dx, dy) + 2) labels - List[int] + """ + + lr = self.train_config['lr'] + epochs = self.train_config['n_epochs'] + batch_size = self.train_config['batch_size'] + # optimizer_class = self.train_config['optimizer'] + + # Convert input images and labels to tensors + + # normalize images + imgs = [(img-np.min(img))/(np.max(img)-np.min(img)) for img in imgs] + # convert to tensor + imgs = torch.stack([torch.from_numpy(img.astype(np.float32)) for img in imgs]) + imgs = torch.permute(imgs, (0, 3, 1, 2)) + # Your classification label mask + labels = torch.LongTensor([label for label in labels]) + + # Create a training dataset and dataloader + train_dataset = TensorDataset(imgs, labels) + train_dataloader = DataLoader(train_dataset, batch_size=batch_size) + + loss_fn = nn.CrossEntropyLoss() + optimizer = Adam(params=self.parameters(), lr=lr) #eval(f'{optimizer_class}(params={self.parameters()}, lr={lr})') + # TODO check if we should replace self.parameters with super.parameters() + + for _ in tqdm(range(epochs), desc="Running CellClassifierFCNN training"): + self.loss = 0 + for data in train_dataloader: + imgs, labels = data + + optimizer.zero_grad() + preds = self.forward(imgs) + + l = loss_fn(preds, labels) + l.backward() + optimizer.step() + self.loss += l.item() + + self.loss /= len(train_dataloader) + + def eval(self, img): + """ + Evaluate the model on the provided image and return the predicted label. + Input: + img: np.ndarray[np.uint8] + Output: y_hat - The predicted label + """ + # normalise + img = (img-np.min(img))/(np.max(img)-np.min(img)) + # convert to tensor + img = torch.permute(torch.tensor(img.astype(np.float32)), (2, 0, 1)).unsqueeze(0) + preds = self.forward(img) + y_hat = torch.argmax(preds, 1) + return y_hat + + +class CellposePatchCNN(): + + """ + Cellpose & patches of cells and then cnn to classify each patch + """ + + def __init__(self, model_config, train_config, eval_config): + + self.model_config = model_config + self.train_config = train_config + self.eval_config = eval_config + + # Initialize the cellpose model and the classifier + self.segmentor = CustomCellposeModel(self.model_config, + self.train_config, + self.eval_config) + self.classifier = CellClassifierFCNN(self.model_config, + self.train_config, + self.eval_config) + + def init_from_checkpoints(self, chpt_classifier=None, chpt_segmentor=None): + """ + Initialize the model from pre-trained checkpoints. + """ + self.segmentor = CustomCellposeModel( + self.model_config.get("segmentor", {}), + self.train_config.get("segmentor", {}), + self.eval_config.get("segmentor", {}) + ) + self.classifier = CellClassifierFCNN( + self.model_config.get("classifier", {}), + self.train_config.get("classifier", {}), + self.eval_config.get("classifier", {}) + ) + + def train(self, imgs, masks): + """Trains the given model. First trains the segmentor and then the clasiffier. + + :param imgs: images to train on (training data) + :type imgs: List[np.ndarray] + :param masks: masks of the given images (training labels) + :type masks: List[np.ndarray] of same shape as output of eval, i.e. one channel instances, + second channel classes, so [2, H, W] or [2, 3, H, W] for 3D + """ + # train cellpose + masks = np.array(masks) + masks_instances = list(masks[:,0,...]) #[mask.sum(-1) for mask in masks] if masks[0].ndim == 3 else masks + self.segmentor.train(imgs, masks_instances) + # create patch dataset to train classifier + masks_classes = list(masks[:,1,...]) #[((mask > 0) * np.arange(1, 4)).sum(-1) for mask in masks] + patches, labels = create_patch_dataset(imgs, + masks_classes, + masks_instances, + noise_intensity = self.train_config["classifier"]["train_data"]["noise_intensity"], + max_patch_size = self.train_config["classifier"]["train_data"]["patch_size"]) + # train classifier + self.classifier.train(patches, labels) + + def eval(self, img): + # TBD we assume image is either 2D [H, W] (see fsimage storage) + # The final mask which is returned should have + # first channel the output of cellpose and the rest are the class channels + with torch.no_grad(): + # get instance mask from segmentor + instance_mask = self.segmentor.eval(img) + # find coordinates of detected objects + class_mask = np.zeros(instance_mask.shape) + + max_patch_size = self.eval_config["classifier"]["data"]["patch_size"] + if max_patch_size is None: max_patch_size = find_max_patch_size(instance_mask) + noise_intensity = self.eval_config["classifier"]["data"]["noise_intensity"] + + # get patches centered around detected objects + patches, instance_labels, _ = get_centered_patches(img, + instance_mask, + max_patch_size, + noise_intensity=noise_intensity) + # loop over patches and create classification mask + for idx, patch in enumerate(patches): + patch_class = self.classifier.eval(patch) # patch size should be HxWxC, e.g. 64,64,3 + # Assign predicted class to corresponding location in final_mask + class_mask[instance_mask==instance_labels[idx]] = patch_class.item() + 1 + # Apply mask to final_mask, retaining only regions where cellpose_mask is greater than 0 + #class_mask = class_mask * (instance_mask > 0)#.long()) + final_mask = np.stack((instance_mask, class_mask), axis=self.eval_config['mask_channel_axis']).astype(np.uint16) # size 2xHxW + + return final_mask + + + # class CustomSAMModel(): # # https://github.com/facebookresearch/segment-anything/blob/main/notebooks/automatic_mask_generator_example.ipynb # def __init__(self): # pass - - - diff --git a/src/server/dcp_server/segmentationclasses.py b/src/server/dcp_server/segmentationclasses.py index e5aeb28..e1213d5 100644 --- a/src/server/dcp_server/segmentationclasses.py +++ b/src/server/dcp_server/segmentationclasses.py @@ -34,18 +34,21 @@ async def segment_image(self, input_path, list_of_images): # Load the image img = self.imagestorage.load_image(img_filepath) # Get size properties - height, width, channel_ax = self.imagestorage.get_image_size_properties(img, utils.get_file_extension(img_filepath)) - img = self.imagestorage.rescale_image(img, height, width, channel_ax) - + height, width, z_axis = self.imagestorage.get_image_size_properties(img, utils.get_file_extension(img_filepath)) + img = self.imagestorage.rescale_image(img, + height, + width, + order=None) # Add channel ax into the model's evaluation parameters dictionary - self.model.eval_config['z_axis'] = channel_ax - + self.model.eval_config['segmentor']['z_axis'] = z_axis # Evaluate the model - mask = await self.runner.evaluate.async_run(img = img, **self.model.eval_config) - + mask = await self.runner.evaluate.async_run(img = img) # Resize the mask - mask = self.imagestorage.resize_image(mask, height, width, order=0) - + mask = self.imagestorage.rescale_image(mask, + height, + width, + self.model.eval_config['mask_channel_axis'], + order=0) # Save segmentation seg_name = utils.get_path_stem(img_filepath) + setup_config['seg_name_string'] + '.tiff' self.imagestorage.save_image(os.path.join(input_path, seg_name), mask) @@ -58,16 +61,16 @@ async def train(self, input_path): :type input_path: str :return: runner's train function output - path of the saved model :rtype: str - """ + """ + train_img_mask_pairs = self.imagestorage.get_image_seg_pairs(input_path) if not train_img_mask_pairs: return "No images and segs found" imgs, masks = self.imagestorage.prepare_images_and_masks_for_training(train_img_mask_pairs) - model_save_path = await self.runner.train.async_run(imgs, masks) - + return model_save_path diff --git a/src/server/dcp_server/service.py b/src/server/dcp_server/service.py index 308fec3..8eae7ef 100644 --- a/src/server/dcp_server/service.py +++ b/src/server/dcp_server/service.py @@ -5,6 +5,8 @@ from dcp_server.serviceclasses import CustomBentoService, CustomRunnable from dcp_server.utils import read_config +import sys, inspect + models_module = __import__("models") segmentation_module = __import__("segmentationclasses") @@ -17,6 +19,7 @@ setup_config = read_config('setup', config_path = 'config.cfg') # instantiate the model + model_class = getattr(models_module, service_config['model_to_use']) model = model_class(model_config = model_config, train_config = train_config, eval_config = eval_config) custom_model_runner = t.cast( diff --git a/src/server/dcp_server/serviceclasses.py b/src/server/dcp_server/serviceclasses.py index 8f62c0f..1515eac 100644 --- a/src/server/dcp_server/serviceclasses.py +++ b/src/server/dcp_server/serviceclasses.py @@ -25,7 +25,7 @@ def __init__(self, model, save_model_path): self.save_model_path = save_model_path @bentoml.Runnable.method(batchable=False) - def evaluate(self, img: np.ndarray, **eval_config) -> np.ndarray: + def evaluate(self, img: np.ndarray) -> np.ndarray: """Evaluate the model - find mask of the given image :param img: image to evaluate on @@ -36,7 +36,7 @@ def evaluate(self, img: np.ndarray, **eval_config) -> np.ndarray: :rtype: np.ndarray """ - mask = self.model.eval(img=img, **eval_config) + mask = self.model.eval(img=img) return mask @@ -51,9 +51,19 @@ def train(self, imgs: List[np.ndarray], masks: List[np.ndarray]) -> str: :return: path of the saved model :rtype: str """ - + #s1 = self.model.segmentor.net.state_dict() + #c1 = self.model.classifier.parameters() self.model.train(imgs, masks) - + ''' + s2 = self.model.segmentor.net.state_dict() + c2 = self.model.classifier.parameters() + if s1 == s2: print('S1 and S2 COMP: THEY ARE THE SAME!!!!!') + else: print('S1 and S2 COMP: THEY ARE NOOOT THE SAME!!!!!') + for p1, p2 in zip(c1, c2): + if p1.data.ne(p2.data).sum() > 0: + print("C1 and C2 NOT THE SAME") + break + ''' # Save the bentoml model bentoml.picklable_model.save_model(self.save_model_path, self.model) diff --git a/src/server/dcp_server/utils.py b/src/server/dcp_server/utils.py index 866b1b1..86f5466 100644 --- a/src/server/dcp_server/utils.py +++ b/src/server/dcp_server/utils.py @@ -1,6 +1,8 @@ from pathlib import Path import json - +import numpy as np +from scipy.ndimage import find_objects, center_of_mass +from skimage import measure def read_config(name, config_path = 'config.cfg') -> dict: """Reads the configuration file @@ -31,3 +33,200 @@ def join_path(root_dir, filepath): return str(Path(root_dir, filepath)) def get_file_extension(file): return str(Path(file).suffix) + + +def crop_centered_padded_patch(x: np.ndarray, + c, + p, + l, + mask: np.ndarray=None, + noise_intensity=None) -> np.ndarray: + """ + Crop a patch from an array `x` centered at coordinates `c` with size `p`, and apply padding if necessary. + + Args: + x (np.ndarray): The input array from which the patch will be cropped. + c (tuple): The coordinates (row, column) at the center of the patch. + p (tuple): The size of the patch to be cropped (height, width). + l (int): The instance label of the mask at the patch + + Returns: + np.ndarray: The cropped patch with applied padding. + """ + + height, width = p # Size of the patch + + # Calculate the boundaries of the patch + top = c[0] - height // 2 + bottom = top + height + + left = c[1] - width // 2 + right = left + width + + # Crop the patch from the input array + if mask is not None: + mask_ = mask.max(-1) if len(mask.shape) >= 3 else mask + # Zero out values in the patch where the mask is not equal to the central label + # m = (mask_ != central_label) & (mask_ > 0) + m = (mask_ != l) & (mask_ > 0) + x[m] = 0 + if noise_intensity is not None: + x[m] = np.random.normal(scale=noise_intensity, size=x[m].shape) + + patch = x[max(top, 0):min(bottom, x.shape[0]), max(left, 0):min(right, x.shape[1]), :] + + # Calculate the required padding amounts + size_x, size_y = x.shape[1], x.shape[0] + + # Apply padding if necessary + if left < 0: + patch = np.hstack(( + np.random.normal(scale=noise_intensity, size=(patch.shape[0], abs(left), patch.shape[2])).astype(np.uint8), + patch)) + # Apply padding on the right side if necessary + if right > size_x: + patch = np.hstack(( + patch, + np.random.normal(scale=noise_intensity, size=(patch.shape[0], (right - size_x), patch.shape[2])).astype(np.uint8))) + # Apply padding on the top side if necessary + if top < 0: + patch = np.vstack(( + np.random.normal(scale=noise_intensity, size=(abs(top), patch.shape[1], patch.shape[2])).astype(np.uint8), + patch)) + # Apply padding on the bottom side if necessary + if bottom > size_y: + patch = np.vstack(( + patch, + np.random.normal(scale=noise_intensity, size=(bottom - size_y, patch.shape[1], patch.shape[2])).astype(np.uint8))) + + return patch + + +def get_center_of_mass_and_label(mask: np.ndarray) -> np.ndarray: + """ + Compute the centers of mass for each object in a mask. + + Args: + mask (np.ndarray): The input mask containing labeled objects. + + Returns: + list of tuples: A list of coordinates (row, column) representing the centers of mass for each object. + list of ints: Holds the label for each object in the mask + """ + + # Compute the centers of mass for each labeled object in the mask + ''' + return [(int(x[0]), int(x[1])) + for x in center_of_mass(mask, mask, np.arange(1, mask.max() + 1))] + ''' + centers = [] + labels = [] + for region in measure.regionprops(mask): + center = region.centroid + centers.append((int(center[0]), int(center[1]))) + labels.append(region.label) + return centers, labels + + + +def get_centered_patches(img, + mask, + p_size: int, + noise_intensity=5, + mask_class=None): + + ''' + Extracts centered patches from the input image based on the centers of objects identified in the mask. + + Args: + img: The input image. + mask: The mask representing the objects in the image. + p_size (int): The size of the patches to extract. + noise_intensity: The intensity of noise to add to the patches. + + ''' + + patches, instance_labels, class_labels = [], [], [] + # if image is 2D add an additional dim for channels + if img.ndim<3: img = img[:, :, np.newaxis] + if mask.ndim<3: mask = mask[:, :, np.newaxis] + # compute center of mass of objects + centers_of_mass, instance_labels = get_center_of_mass_and_label(mask) + # Crop patches around each center of mass + for c, l in zip(centers_of_mass, instance_labels): + c_x, c_y = c + patch = crop_centered_padded_patch(img.copy(), + (c_x, c_y), + (p_size, p_size), + l, + mask=mask, + noise_intensity=noise_intensity) + patches.append(patch) + if mask_class is not None: + # get the class instance for the specific object + instance_labels.append(l) + class_l = int(np.unique(mask_class[mask[:,:,0]==l])) + #-1 because labels from mask start from 1, we want classes to start from 0 + class_labels.append(class_l-1) + + return patches, instance_labels, class_labels + +def get_objects(mask): + return find_objects(mask) + +def find_max_patch_size(mask): + + # Find objects in the mask + objects = get_objects(mask) + + # Initialize variables to store the maximum patch size + max_patch_size = 0 + + # Iterate over the found objects + for obj in objects: + # Extract start and stop values from the slice object + slices = [s for s in obj] + start = [s.start for s in slices] + stop = [s.stop for s in slices] + + # Calculate the size of the patch along each axis + patch_size = tuple(stop[i] - start[i] for i in range(len(start))) + + # Calculate the total size (area) of the patch + total_size = 1 + for size in patch_size: + total_size *= size + + # Check if the current patch size is larger than the maximum + if total_size > max_patch_size: + max_patch_size = total_size + + max_patch_size_edge = np.ceil(np.sqrt(max_patch_size)) + + return max_patch_size_edge + +def create_patch_dataset(imgs, masks_classes, masks_instances, noise_intensity, max_patch_size): + ''' + Splits img and masks into patches of equal size which are centered around the cells. + If patch_size is not given, the algorithm should first run through all images to find the max cell size, and use + the max cell size to define the patch size. All patches and masks should then be returned + in the same format as imgs and masks (same type, i.e. check if tensor or np.array and same + convention of dims, e.g. CxHxW) + ''' + + if max_patch_size is None: + max_patch_size = np.max([find_max_patch_size(mask) for mask in masks_instances]) + + + patches, labels = [], [] + for img, mask_class, mask_instance in zip(imgs, masks_classes, masks_instances): + # mask_instance has dimension WxH + # mask_class has dimension WxH + patch, _, label = get_centered_patches(img, + mask_instance, + max_patch_size, + noise_intensity=noise_intensity, + mask_class=mask_class) + patches.extend(patch) + labels.extend(label) + return patches, labels \ No newline at end of file diff --git a/src/server/requirements.txt b/src/server/requirements.txt index 352bcf5..57c2128 100644 --- a/src/server/requirements.txt +++ b/src/server/requirements.txt @@ -1,3 +1,6 @@ cellpose>=2.2 bentoml>=1.0.13 scikit-image>=0.19.3 +torchmetrics>=0.11.4 +torch>=2.1.0 +pytest>=7.4.3 diff --git a/src/server/test/shapes/circle.png b/src/server/test/shapes/circle.png new file mode 100644 index 0000000..3d2fd3e Binary files /dev/null and b/src/server/test/shapes/circle.png differ diff --git a/src/server/test/shapes/square.png b/src/server/test/shapes/square.png new file mode 100644 index 0000000..8af926d Binary files /dev/null and b/src/server/test/shapes/square.png differ diff --git a/src/server/test/shapes/triangle.png b/src/server/test/shapes/triangle.png new file mode 100644 index 0000000..8ed8ba9 Binary files /dev/null and b/src/server/test/shapes/triangle.png differ diff --git a/src/server/test/synthetic_dataset.py b/src/server/test/synthetic_dataset.py new file mode 100644 index 0000000..27fbc5b --- /dev/null +++ b/src/server/test/synthetic_dataset.py @@ -0,0 +1,241 @@ +import numpy as np +import cv2 +import random +import os +import sys + +import skimage.color as color +import scipy.ndimage as ndi + +# set seed for reproducibility +seed_value = 2023 +random.seed(seed_value) +np.random.seed(seed_value) + + +def assign_unique_colors(labels, colors): + ''' + Assigns unique colors to each label in the given label array. + ''' + unique_labels = np.unique(labels) + # Create a dictionary to store the color assignment for each label + label_colors = {} + + # Iterate over the unique labels and assign colors + for label in unique_labels: + # Skip assigning colors if the label is 0 (background) + if label == 0: + continue + + # Check if the label is present in the labels + if label in labels: + # Assign the color to the label + color_index = label % len(colors) + label_colors[label] = colors[color_index] + + return label_colors + +def custom_label2rgb(labels, colors=['red', 'green', 'blue'], bg_label=0, alpha=0.5): + ''' + Converts a label array to an RGB image using assigned colors for each label. + ''' + + label_colors = assign_unique_colors(labels, colors) + + # Convert the labels to RGB using the assigned colors + rgb_image = np.zeros((labels.shape[0], labels.shape[1], 3), dtype=float) + for label in np.unique(labels): + mask = labels == label + if label in label_colors: + rgb = color.label2rgb(mask, colors=[label_colors[label]], bg_label=bg_label, alpha=alpha) + rgb_image += rgb + + return rgb_image + +def add_padding_for_rotation(image, angle): + ''' + Apply padding and rotation to an image. + The purpose of this function is to ensure that the rotated image fits within its original dimensions by adding padding, preventing any parts of the image from being cropped. + + Args: + image (numpy.ndarray): The input image. + angle (float): The rotation angle in degrees. + ''' + + # Calculate rotated bounding box + h, w = image.shape[:2] + center = (w // 2, h // 2) + rotation_matrix = cv2.getRotationMatrix2D(center, angle, 1.0) + cos_theta = abs(rotation_matrix[0, 0]) + sin_theta = abs(rotation_matrix[0, 1]) + new_w = int((h * sin_theta) + (w * cos_theta)) + new_h = int((h * cos_theta) + (w * sin_theta)) + + # Calculate padding amounts + pad_w = (new_w - w) // 2 + pad_h = (new_h - h) // 2 + + # Add padding to the image + padded_image = cv2.copyMakeBorder(image, pad_h, pad_h, pad_w, pad_w, cv2.BORDER_CONSTANT) + + # Rotate the padded image + center = (padded_image.shape[1] // 2, padded_image.shape[0] // 2) + rotation_matrix = cv2.getRotationMatrix2D(center, angle, 1.0) + rotated_image = cv2.warpAffine(padded_image, rotation_matrix, (padded_image.shape[1], padded_image.shape[0])) + + return rotated_image + +def get_object_images(objects): + ''' + Load object images from file paths. + ''' + + object_images = [] + + for obj in objects: + img = cv2.imread(obj['path']) + # img = cv2.resize(img, obj['size']) + img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + object_images.append(img) + + return object_images + +def generate_dataset(num_samples, objects, canvas_size, max_object_counts=None, noise_intensity=None, max_rotation_angle=None): + ''' + Generate a synthetic dataset with images and masks. + + Args: + num_samples (int): The number of samples to generate. + objects (list): List of object descriptions. + canvas_size (int): Size of the canvas to place objects on. + max_object_counts (list, optional): Maximum object counts for each class. Default is None. + noise_intensity (float, optional): intensity of the additional noise to the image + + ''' + + dataset_images = [] + dataset_masks = [] + + object_images = get_object_images(objects) + class_intensities = [ (obj['intensity'][0], obj['intensity'][1]) for obj in objects] + + if len(object_images[0].shape) == 3: + num_of_img_channels = object_images[0].shape[-1] + else: + num_of_img_channels = 1 + + if max_object_counts is None: + max_object_counts = [10] * len(object_images) + + for _ in range(num_samples): + canvas = np.zeros((canvas_size, canvas_size, num_of_img_channels), dtype=np.uint8) + mask = np.zeros((canvas_size, canvas_size, len(object_images)), dtype=np.uint8) + + for object_index, object_img in enumerate(object_images): + + max_count = max_object_counts[object_index] + object_count = random.randint(1, max_count) + + for _ in range(object_count): + + object_size = random.randint(canvas_size//20, canvas_size//5) + + object_img_resized = cv2.resize(object_img, (object_size, object_size)) + # object_img_resized = (object_img_resized>0).astype(np.uint8)*(255 - object_size) + intensity_mean = (class_intensities[object_index][1] - class_intensities[object_index][0])/2 + intensity_scale = (class_intensities[object_index][1] - intensity_mean)/3 + class_intensity = np.random.normal(loc=intensity_mean, scale=intensity_scale) + class_intensity = np.clip(class_intensity, class_intensities[object_index][0], class_intensities[object_index][1]) + # class_intensity = random.randint(int(class_intensities[object_index][0]), int(class_intensities[object_index][1])) + object_img_resized = (object_img_resized>0).astype(np.uint8)*(class_intensity)*255 + + if num_of_img_channels == 1: + + if max_rotation_angle is not None: + # Randomly rotate the object image + rotation_angle = random.uniform(-max_rotation_angle, max_rotation_angle) + object_img_transformed = add_padding_for_rotation(object_img_resized, rotation_angle) + else: + object_img_transformed = object_img_resized + + object_size_x, object_size_y = object_img_transformed.shape + + + + object_mask = np.zeros((object_size_x, object_size_y), dtype=np.uint8) + + if num_of_img_channels == 1: # Grayscale image + object_mask[object_img_transformed > 0] = object_index + 1 + # object_img_resized = np.expand_dims(object_img_resized, axis=-1) + object_img_transformed = np.expand_dims(object_img_transformed, axis=-1) + else: # Color image with alpha channel + object_mask[object_img_resized[:, :, -1] > 0] = object_index + 1 + + + x = random.randint(0, canvas_size - object_size_x) + y = random.randint(0, canvas_size - object_size_y) + + intersecting_mask = mask[y:y + object_size_y, x:x + object_size_x].max(axis=-1) + if (intersecting_mask > 0).any(): + continue # Skip if there is an intersection with objects from other classes + + assert mask[y:y + object_size_y, x:x + object_size_x, object_index].shape == object_mask.shape + + canvas[y:y + object_size_y, x:x + object_size_x] = object_img_transformed + mask[y:y + object_size_y, x:x + object_size_x, object_index] = np.maximum( + mask[y:y + object_size_y, x:x + object_size_x, object_index], object_mask + ) + + + # Add noise to the canvas + if noise_intensity is not None: + + if num_of_img_channels == 1: + noise = np.random.normal(scale=noise_intensity, size=(canvas_size, canvas_size, 1)) + # noise = random_noise(canvas, mode='speckle', mean=noise_intensity) + + else: + noise = np.random.normal(scale=noise_intensity, size=(canvas_size, canvas_size, num_of_img_channels)) + noisy_canvas = canvas + noise.astype(np.uint8) + + dataset_images.append(noisy_canvas.squeeze(2)) + + else: + + dataset_images.append(canvas.squeeze(2)) + + mask = mask.max(axis=-1) + if len(mask.shape) == 2: + mask = custom_label2rgb(mask, colors=["red", "green", "blue"]) + mask = ndi.label(mask)[0] + else: + for j in range(mask.shape[-1]): + mask[..., j] = ndi.label(mask[..., j])[0] + mask = mask.transpose(2, 0, 1) + + dataset_masks.append(mask) + + return dataset_images, dataset_masks + +def get_synthetic_dataset(num_samples, canvas_size=512, max_object_counts=[15, 15, 15]): + + objects = [ + { + + 'name': 'triangle', + 'path': 'test/shapes/triangle.png', + 'intensity' : [0, 0.33] + }, + { + 'name': 'circle', + 'path': 'test/shapes/circle.png', + 'intensity' : [0.34, 0.66] + }, + { + 'name': 'square', + 'path': 'test/shapes/square.png', + 'intensity' : [0.67, 1.0] + }, + ] + images, masks = generate_dataset(num_samples, objects, canvas_size=canvas_size, max_object_counts=max_object_counts, noise_intensity=5, max_rotation_angle=30) + return images, masks diff --git a/src/server/test/test_integration.py b/src/server/test/test_integration.py new file mode 100644 index 0000000..be5d51f --- /dev/null +++ b/src/server/test/test_integration.py @@ -0,0 +1,122 @@ +import os +import sys +import torch +from torchmetrics import JaccardIndex +import numpy as np + +import inspect +# from importlib.machinery import SourceFileLoader + +sys.path.append(".") + +import dcp_server.models as models +from dcp_server.utils import read_config +from synthetic_dataset import get_synthetic_dataset + +import pytest + +# retrieve models names +model_classes = [ + cls_obj for cls_name, cls_obj in inspect.getmembers(models) \ + if inspect.isclass(cls_obj) \ + and cls_obj.__module__ == models.__name__ \ + and not cls_name.startswith("CellClassifier") + ] + +@pytest.fixture(params=model_classes) +def model_class(request): + return request.param + +@pytest.fixture() +def model(model_class): + + model_config = read_config('model', config_path='dcp_server/config.cfg') + train_config = read_config('train', config_path='dcp_server/config.cfg') + eval_config = read_config('eval', config_path='dcp_server/config.cfg') + + model = model_class(model_config, train_config, eval_config) + + return model + +@pytest.fixture +def data_train(): + images, masks = get_synthetic_dataset(num_samples=4) + masks = [np.array(mask) for mask in masks] + masks_instances = [mask.sum(-1) for mask in masks] + masks_classes = [((mask > 0) * np.arange(1, 4)).sum(-1) for mask in masks] + masks_ = [np.stack((instances, classes)) for instances, classes in zip(masks_instances, masks_classes)] + return images, masks_ + +@pytest.fixture +def data_eval(): + img, msk = get_synthetic_dataset(num_samples=1) + msk = np.array(msk) + msk_ = np.stack((msk.sum(-1), ((msk > 0) * np.arange(1, 4)).sum(-1)), axis=0).transpose(1,0,2,3) + return img, msk_ + +def test_train_run(data_train, model): + + images, masks = data_train + model.train(images, masks) + # assert(patch_model.segmentor.loss>1e-2) #TODO figure out appropriate value + + # retrieve the attribute names of the class of the current model + attrs = model.__dict__.keys() + + if "classifier" in attrs: + assert(model.classifier.loss>1e-2) + if "metric" in attrs: + assert(model.metric>1e-2) + +def test_eval_run(data_eval, model): + + imgs, masks = data_eval + + jaccard_index_instances = 0 + jaccard_index_classes = 0 + + jaccard_metric_binary = JaccardIndex(task="multiclass", num_classes=2, average="macro", ignore_index=0) + jaccard_metric_multi = JaccardIndex(task="multiclass", num_classes=4, average="macro", ignore_index=0) + + for img, mask in zip(imgs, masks): + + #mask - instance segmentation mask + classes (2, 512, 512) + #pred_mask - tuple of cellpose (512, 512), patch net multiclass segmentation (512, 512, 2) + + pred_mask = model.eval(img) #, channels=[0,0]) + + if pred_mask.ndim > 2: + pred_mask_bin = torch.tensor((pred_mask[0]>0).astype(bool).astype(int)) + else: + pred_mask_bin = torch.tensor((pred_mask > 0).astype(bool).astype(int)) + + bin_mask = torch.tensor((mask[0]>0).astype(bool).astype(int)) + + jaccard_index_instances += jaccard_metric_binary( + pred_mask_bin, + bin_mask + ) + + if pred_mask.ndim > 2: + + jaccard_index_classes += jaccard_metric_multi( + torch.tensor(pred_mask[1].astype(int)), + torch.tensor(mask[1].astype(int)) + ) + + jaccard_index_instances /= len(imgs) + assert(jaccard_index_instances<0.6) + + # for PatchCNN model + if pred_mask.ndim > 2: + + jaccard_index_classes /= len(imgs) + assert(jaccard_index_classes<0.6) + + + + + + + + \ No newline at end of file