diff --git a/.gitignore b/.gitignore index 2c8d21e..8d5b07d 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,9 @@ curated/ # model dir *mytrainedmodel/ +#configs +src/client/dcp_client/config.cfg + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] @@ -152,4 +155,4 @@ docs/ test-napari.pub data/ BentoML/ -models/ \ No newline at end of file +models/ diff --git a/src/server/dcp_server/config.cfg b/src/server/dcp_server/config.cfg index 36a2306..a3efe5d 100644 --- a/src/server/dcp_server/config.cfg +++ b/src/server/dcp_server/config.cfg @@ -6,8 +6,8 @@ }, "service": { - "model_to_use": "CellposePatchCNN", - "save_model_path": "mytrainedmodel", + "model_to_use": "CustomCellposeModel", + "save_model_path": "mito", "runner_name": "cellpose_runner", "service_name": "data-centric-platform", "port": 7010 @@ -31,7 +31,7 @@ "train":{ "segmentor":{ - "n_epochs": 7, + "n_epochs": 10, "channels": [0,0], "min_train_masks": 1 }, diff --git a/src/server/dcp_server/models.py b/src/server/dcp_server/models.py index caf5226..d4bba11 100644 --- a/src/server/dcp_server/models.py +++ b/src/server/dcp_server/models.py @@ -12,9 +12,9 @@ #from segment_anything import SamPredictor, sam_model_registry #from segment_anything.automatic_mask_generator import SamAutomaticMaskGenerator -from dcp_server.utils import get_centered_patches, find_max_patch_size, create_patch_dataset, get_objects +from dcp_server.utils import get_centered_patches, find_max_patch_size, create_patch_dataset -class CustomCellposeModel(models.CellposeModel): +class CustomCellposeModel(models.CellposeModel, nn.Module): """Custom cellpose model inheriting the attributes and functions from the original CellposeModel and implementing additional attributes and methods needed for this project. """ @@ -32,7 +32,14 @@ def __init__(self, model_config, train_config, eval_config): """ # Initialize the cellpose model - super().__init__(**model_config["segmentor"]) + #super().__init__(**model_config["segmentor"]) + nn.Module.__init__(self) + models.CellposeModel.__init__(self, **model_config["segmentor"]) + self.mkldnn = False # otherwise we get error with saving model + self.train_config = train_config + self.eval_config = eval_config + + def update_configs(self, train_config, eval_config): self.train_config = train_config self.eval_config = eval_config @@ -66,10 +73,7 @@ def train(self, imgs, masks): super().train(train_data=deepcopy(imgs), train_labels=masks, **self.train_config["segmentor"]) pred_masks = [self.eval(img) for img in masks] - print(len(pred_masks)) self.metric = np.mean(aggregated_jaccard_index(masks, pred_masks)) - # pred_masks = [self.eval(img) for img in masks] - # self.loss = self.loss_fn(masks, pred_masks) def masks_to_outlines(self, mask): @@ -87,7 +91,7 @@ def masks_to_outlines(self, mask): class CellClassifierFCNN(nn.Module): ''' - Fully convolutional classifier for cell images. + Fully convolutional classifier for cell images. NOTE -> This model cannot be used as a standalone model in DCP Args: model_config (dict): Model configuration. @@ -128,6 +132,10 @@ def __init__(self, model_config, train_config, eval_config): self.final_conv = nn.Conv2d(128, self.num_classes, 1) self.pooling = nn.AdaptiveMaxPool2d(1) + def update_configs(self, train_config, eval_config): + self.train_config = train_config + self.eval_config = eval_config + def forward(self, x): x = self.layer1(x) @@ -200,14 +208,15 @@ def eval(self, img): return y_hat -class CellposePatchCNN(): +class CellposePatchCNN(nn.Module): """ Cellpose & patches of cells and then cnn to classify each patch """ def __init__(self, model_config, train_config, eval_config): - + super().__init__() + self.model_config = model_config self.train_config = train_config self.eval_config = eval_config @@ -220,22 +229,10 @@ def __init__(self, model_config, train_config, eval_config): self.train_config, self.eval_config) - def init_from_checkpoints(self, chpt_classifier=None, chpt_segmentor=None): - """ - Initialize the model from pre-trained checkpoints. - """ - - self.segmentor = CustomCellposeModel( - self.model_config.get("segmentor", {}), - self.train_config.get("segmentor", {}), - self.eval_config.get("segmentor", {}) - ) - self.classifier = CellClassifierFCNN( - self.model_config.get("classifier", {}), - self.train_config.get("classifier", {}), - self.eval_config.get("classifier", {}) - ) - + def update_configs(self, train_config, eval_config): + self.train_config = train_config + self.eval_config = eval_config + def train(self, imgs, masks): """Trains the given model. First trains the segmentor and then the clasiffier. diff --git a/src/server/dcp_server/serviceclasses.py b/src/server/dcp_server/serviceclasses.py index 1515eac..53f30cc 100644 --- a/src/server/dcp_server/serviceclasses.py +++ b/src/server/dcp_server/serviceclasses.py @@ -4,6 +4,8 @@ from bentoml.io import Text, NumpyNdarray from typing import List +from dcp_server import models as DCPModels + class CustomRunnable(bentoml.Runnable): ''' @@ -16,13 +18,16 @@ class CustomRunnable(bentoml.Runnable): def __init__(self, model, save_model_path): """Constructs all the necessary attributes for the CustomRunnable. - :param model: model to be trained or evaluated + :param model: model to be trained or evaluated - will be one of classes in models.py :param save_model_path: full path of the model object that it will be saved into :type save_model_path: str """ self.model = model self.save_model_path = save_model_path + # load model if it already exists to continue training from there? + if self.save_model_path in [model.tag.name for model in bentoml.models.list()]: + self.model = bentoml.pytorch.load_model(self.save_model_path+":latest") @bentoml.Runnable.method(batchable=False) def evaluate(self, img: np.ndarray) -> np.ndarray: @@ -34,8 +39,10 @@ def evaluate(self, img: np.ndarray) -> np.ndarray: :type z_axis: int :return: mask of the image, list of 2D arrays, or single 3D array (if do_3D=True) labelled image. :rtype: np.ndarray - """ - + """ + # load the latest model if it is available (in case train has already occured) + if self.save_model_path in [model.tag.name for model in bentoml.models.list()]: + self.model = bentoml.pytorch.load_model(self.save_model_path+":latest") mask = self.model.eval(img=img) return mask @@ -51,24 +58,15 @@ def train(self, imgs: List[np.ndarray], masks: List[np.ndarray]) -> str: :return: path of the saved model :rtype: str """ - #s1 = self.model.segmentor.net.state_dict() - #c1 = self.model.classifier.parameters() self.model.train(imgs, masks) - ''' - s2 = self.model.segmentor.net.state_dict() - c2 = self.model.classifier.parameters() - if s1 == s2: print('S1 and S2 COMP: THEY ARE THE SAME!!!!!') - else: print('S1 and S2 COMP: THEY ARE NOOOT THE SAME!!!!!') - for p1, p2 in zip(c1, c2): - if p1.data.ne(p2.data).sum() > 0: - print("C1 and C2 NOT THE SAME") - break - ''' # Save the bentoml model - bentoml.picklable_model.save_model(self.save_model_path, self.model) + #bentoml.picklable_model.save_model(self.save_model_path, self.model) + bentoml.pytorch.save_model(self.save_model_path, # Model name in the local Model Store + self.model, # Model instance being saved + external_modules=[DCPModels] + ) return self.save_model_path - class CustomBentoService(): """BentoML Service class. Contains all the functions necessary to serve the service with BentoML @@ -123,12 +121,9 @@ async def train(input_path): :rtype: str """ print("Calling retrain from server.") - # Train the model model_path = await self.segmentation.train(input_path) - msg = "Success! Trained model saved in: " + model_path - return msg return svc diff --git a/src/server/test/test_config.cfg b/src/server/test/test_config.cfg new file mode 100644 index 0000000..04073ed --- /dev/null +++ b/src/server/test/test_config.cfg @@ -0,0 +1,67 @@ +{ + "setup": { + "segmentation": "GeneralSegmentation", + "accepted_types": [".jpg", ".jpeg", ".png", ".tiff", ".tif"], + "seg_name_string": "_seg" + }, + + "service": { + "model_to_use": "CustomCellposeModel", + "save_model_path": "mito", + "runner_name": "cellpose_runner", + "service_name": "data-centric-platform", + "port": 7010 + }, + + "model": { + "segmentor": { + "model_type": "cyto" + }, + "classifier":{ + "in_channels": 1, + "num_classes": 3, + "black_bg": "False", + "include_mask": "False" + } + }, + + "data": { + "data_root": "data" + }, + + "train":{ + "segmentor":{ + "n_epochs": 20, + "channels": [0,0], + "min_train_masks": 1, + "learning_rate":0.01 + }, + "classifier":{ + "train_data":{ + "patch_size": 64, + "noise_intensity": 5, + "num_classes": 3 + }, + "n_epochs": 8, + "lr": 0.001, + "batch_size": 1, + "optimizer": "Adam" + } + }, + + "eval":{ + "segmentor": { + "z_axis": null, + "channel_axis": null, + "rescale": 1, + "batch_size": 1 + }, + "classifier": { + "data":{ + "patch_size": 64, + "noise_intensity": 5 + } + }, + "mask_channel_axis": 0 + } +} \ No newline at end of file diff --git a/src/server/test/test_integration.py b/src/server/test/test_integration.py index be5d51f..05cf39e 100644 --- a/src/server/test/test_integration.py +++ b/src/server/test/test_integration.py @@ -3,6 +3,7 @@ import torch from torchmetrics import JaccardIndex import numpy as np +import random import inspect # from importlib.machinery import SourceFileLoader @@ -15,13 +16,19 @@ import pytest +seed_value = 2023 +random.seed(seed_value) +torch.manual_seed(seed_value) +np.random.seed(seed_value) + # retrieve models names model_classes = [ cls_obj for cls_name, cls_obj in inspect.getmembers(models) \ if inspect.isclass(cls_obj) \ and cls_obj.__module__ == models.__name__ \ and not cls_name.startswith("CellClassifier") - ] + ] + @pytest.fixture(params=model_classes) def model_class(request): @@ -30,9 +37,9 @@ def model_class(request): @pytest.fixture() def model(model_class): - model_config = read_config('model', config_path='dcp_server/config.cfg') - train_config = read_config('train', config_path='dcp_server/config.cfg') - eval_config = read_config('eval', config_path='dcp_server/config.cfg') + model_config = read_config('model', config_path='test/test_config.cfg') + train_config = read_config('train', config_path='test/test_config.cfg') + eval_config = read_config('eval', config_path='test/test_config.cfg') model = model_class(model_config, train_config, eval_config) @@ -40,6 +47,7 @@ def model(model_class): @pytest.fixture def data_train(): + images, masks = get_synthetic_dataset(num_samples=4) masks = [np.array(mask) for mask in masks] masks_instances = [mask.sum(-1) for mask in masks] @@ -54,23 +62,76 @@ def data_eval(): msk_ = np.stack((msk.sum(-1), ((msk > 0) * np.arange(1, 4)).sum(-1)), axis=0).transpose(1,0,2,3) return img, msk_ -def test_train_run(data_train, model): +# def test_train_run(data_train, model): - images, masks = data_train - model.train(images, masks) - # assert(patch_model.segmentor.loss>1e-2) #TODO figure out appropriate value +# images, masks = data_train +# model.train(images, masks) +# # assert(patch_model.segmentor.loss>1e-2) #TODO figure out appropriate value - # retrieve the attribute names of the class of the current model - attrs = model.__dict__.keys() +# # retrieve the attribute names of the class of the current model +# attrs = model.__dict__.keys() - if "classifier" in attrs: - assert(model.classifier.loss>1e-2) - if "metric" in attrs: - assert(model.metric>1e-2) +# if "classifier" in attrs: +# assert(model.classifier.loss<0.4) +# if "metric" in attrs: +# assert(model.metric>0.1) +# if "loss" in attrs: +# assert(model.loss<0.3) + +# def test_eval_run(data_train, data_eval, model): + +# images, masks = data_train +# model.train(images, masks) + +# imgs_test, masks_test = data_eval + +# jaccard_index_instances = 0 +# jaccard_index_classes = 0 + +# jaccard_metric_binary = JaccardIndex(task="multiclass", num_classes=2, average="macro", ignore_index=0) +# jaccard_metric_multi = JaccardIndex(task="multiclass", num_classes=4, average="macro", ignore_index=0) + +# for img, mask in zip(imgs_test, masks_test): + +# #mask - instance segmentation mask + classes (2, 512, 512) +# #pred_mask - tuple of cellpose (512, 512), patch net multiclass segmentation (512, 512, 2) + +# pred_mask = model.eval(img) #, channels=[0,0]) + +# if pred_mask.ndim > 2: +# pred_mask_bin = torch.tensor((pred_mask[0]>0).astype(bool).astype(int)) +# else: +# pred_mask_bin = torch.tensor((pred_mask > 0).astype(bool).astype(int)) + +# bin_mask = torch.tensor((mask[0]>0).astype(bool).astype(int)) + +# jaccard_index_instances += jaccard_metric_binary( +# pred_mask_bin, +# bin_mask +# ) + +# if pred_mask.ndim > 2: + +# jaccard_index_classes += jaccard_metric_multi( +# torch.tensor(pred_mask[1].astype(int)), +# torch.tensor(mask[1].astype(int)) +# ) -def test_eval_run(data_eval, model): +# jaccard_index_instances /= len(imgs_test) +# assert(jaccard_index_instances>0.2) + +# # for PatchCNN model +# if pred_mask.ndim > 2: + +# jaccard_index_classes /= len(imgs_test) +# assert(jaccard_index_classes>0.1) - imgs, masks = data_eval +def test_train_eval_run(data_train, data_eval, model): + + images, masks = data_train + model.train(images, masks) + + imgs_test, masks_test = data_eval jaccard_index_instances = 0 jaccard_index_classes = 0 @@ -78,12 +139,12 @@ def test_eval_run(data_eval, model): jaccard_metric_binary = JaccardIndex(task="multiclass", num_classes=2, average="macro", ignore_index=0) jaccard_metric_multi = JaccardIndex(task="multiclass", num_classes=4, average="macro", ignore_index=0) - for img, mask in zip(imgs, masks): + for img, mask in zip(imgs_test, masks_test): #mask - instance segmentation mask + classes (2, 512, 512) #pred_mask - tuple of cellpose (512, 512), patch net multiclass segmentation (512, 512, 2) - pred_mask = model.eval(img) #, channels=[0,0]) + pred_mask = model.eval(img) if pred_mask.ndim > 2: pred_mask_bin = torch.tensor((pred_mask[0]>0).astype(bool).astype(int)) @@ -104,19 +165,21 @@ def test_eval_run(data_eval, model): torch.tensor(mask[1].astype(int)) ) - jaccard_index_instances /= len(imgs) - assert(jaccard_index_instances<0.6) - - # for PatchCNN model - if pred_mask.ndim > 2: - - jaccard_index_classes /= len(imgs) - assert(jaccard_index_classes<0.6) - - + jaccard_index_instances /= len(imgs_test) + assert(jaccard_index_instances>0.2) + # retrieve the attribute names of the class of the current model + attrs = model.__dict__.keys() + if "classifier" in attrs: + assert(model.classifier.loss<0.4) + if "metric" in attrs: + assert(model.metric>0.1) + if "loss" in attrs: + assert(model.loss<0.3) - + # for PatchCNN model + if pred_mask.ndim > 2: - \ No newline at end of file + jaccard_index_classes /= len(imgs_test) + assert(jaccard_index_classes>0.1) \ No newline at end of file