From 79c495248f81a13f512b0d92992823f856498e89 Mon Sep 17 00:00:00 2001 From: Christina Bukas Date: Wed, 22 Nov 2023 13:51:01 +0100 Subject: [PATCH 01/11] adapted classes to inherit from nn.module --- src/server/dcp_server/models.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/src/server/dcp_server/models.py b/src/server/dcp_server/models.py index caf5226..1bf9047 100644 --- a/src/server/dcp_server/models.py +++ b/src/server/dcp_server/models.py @@ -12,9 +12,9 @@ #from segment_anything import SamPredictor, sam_model_registry #from segment_anything.automatic_mask_generator import SamAutomaticMaskGenerator -from dcp_server.utils import get_centered_patches, find_max_patch_size, create_patch_dataset, get_objects +from dcp_server.utils import get_centered_patches, find_max_patch_size, create_patch_dataset -class CustomCellposeModel(models.CellposeModel): +class CustomCellposeModel(models.CellposeModel, nn.Module): """Custom cellpose model inheriting the attributes and functions from the original CellposeModel and implementing additional attributes and methods needed for this project. """ @@ -32,7 +32,9 @@ def __init__(self, model_config, train_config, eval_config): """ # Initialize the cellpose model - super().__init__(**model_config["segmentor"]) + #super().__init__(**model_config["segmentor"]) + nn.Module.__init__(self) + models.CellposeModel.__init__(self, **model_config["segmentor"]) self.train_config = train_config self.eval_config = eval_config @@ -200,14 +202,15 @@ def eval(self, img): return y_hat -class CellposePatchCNN(): +class CellposePatchCNN(nn.Module): """ Cellpose & patches of cells and then cnn to classify each patch """ def __init__(self, model_config, train_config, eval_config): - + super().__init__() + self.model_config = model_config self.train_config = train_config self.eval_config = eval_config From c4a903c58cd62eaf0007dedd8e85d7cb480404ec Mon Sep 17 00:00:00 2001 From: Christina Bukas Date: Wed, 22 Nov 2023 13:51:37 +0100 Subject: [PATCH 02/11] use bentoml pytorch module for saving model --- src/server/dcp_server/serviceclasses.py | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/src/server/dcp_server/serviceclasses.py b/src/server/dcp_server/serviceclasses.py index 1515eac..511999e 100644 --- a/src/server/dcp_server/serviceclasses.py +++ b/src/server/dcp_server/serviceclasses.py @@ -16,7 +16,7 @@ class CustomRunnable(bentoml.Runnable): def __init__(self, model, save_model_path): """Constructs all the necessary attributes for the CustomRunnable. - :param model: model to be trained or evaluated + :param model: model to be trained or evaluated - will be one of classes in models.py :param save_model_path: full path of the model object that it will be saved into :type save_model_path: str """ @@ -54,18 +54,15 @@ def train(self, imgs: List[np.ndarray], masks: List[np.ndarray]) -> str: #s1 = self.model.segmentor.net.state_dict() #c1 = self.model.classifier.parameters() self.model.train(imgs, masks) - ''' - s2 = self.model.segmentor.net.state_dict() - c2 = self.model.classifier.parameters() - if s1 == s2: print('S1 and S2 COMP: THEY ARE THE SAME!!!!!') - else: print('S1 and S2 COMP: THEY ARE NOOOT THE SAME!!!!!') - for p1, p2 in zip(c1, c2): - if p1.data.ne(p2.data).sum() > 0: - print("C1 and C2 NOT THE SAME") - break - ''' # Save the bentoml model - bentoml.picklable_model.save_model(self.save_model_path, self.model) + #bentoml.picklable_model.save_model(self.save_model_path, self.model) + bentoml.pytorch.save_model(self.save_model_path, # Model name in the local Model Store + self.model, # Model instance being saved + labels={ # User-defined labels for managing models in BentoCloud + "owner": "ai-consultants", + "stage": "dev", + }, + ) return self.save_model_path From 943a40fb95e169edb072cd115dc6ea73cda690dd Mon Sep 17 00:00:00 2001 From: Christina Bukas Date: Wed, 22 Nov 2023 15:29:09 +0100 Subject: [PATCH 03/11] adding client config to git ignore to ignore local runs --- .gitignore | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 2c8d21e..8d5b07d 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,9 @@ curated/ # model dir *mytrainedmodel/ +#configs +src/client/dcp_client/config.cfg + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] @@ -152,4 +155,4 @@ docs/ test-napari.pub data/ BentoML/ -models/ \ No newline at end of file +models/ From 69e64a8f839818e2eb00efd3dd36f514308413a1 Mon Sep 17 00:00:00 2001 From: Christina Bukas Date: Wed, 22 Nov 2023 15:30:20 +0100 Subject: [PATCH 04/11] adding argument to load bento model --- src/server/dcp_server/config.cfg | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/server/dcp_server/config.cfg b/src/server/dcp_server/config.cfg index 36a2306..6eb40c1 100644 --- a/src/server/dcp_server/config.cfg +++ b/src/server/dcp_server/config.cfg @@ -6,8 +6,9 @@ }, "service": { - "model_to_use": "CellposePatchCNN", - "save_model_path": "mytrainedmodel", + "model_to_use": "CustomCellposeModel", + "load_latest_model": null, + "save_model_path": "mito", "runner_name": "cellpose_runner", "service_name": "data-centric-platform", "port": 7010 @@ -31,7 +32,7 @@ "train":{ "segmentor":{ - "n_epochs": 7, + "n_epochs": 50, "channels": [0,0], "min_train_masks": 1 }, From 744e606821362ee082510413ca88f27b8b265578 Mon Sep 17 00:00:00 2001 From: Christina Bukas Date: Wed, 22 Nov 2023 15:30:51 +0100 Subject: [PATCH 05/11] adding option to load bento model and update configs with latest --- src/server/dcp_server/models.py | 30 +++++++++++++----------------- src/server/dcp_server/service.py | 8 +++++++- 2 files changed, 20 insertions(+), 18 deletions(-) diff --git a/src/server/dcp_server/models.py b/src/server/dcp_server/models.py index 1bf9047..248b70b 100644 --- a/src/server/dcp_server/models.py +++ b/src/server/dcp_server/models.py @@ -37,6 +37,10 @@ def __init__(self, model_config, train_config, eval_config): models.CellposeModel.__init__(self, **model_config["segmentor"]) self.train_config = train_config self.eval_config = eval_config + + def update_configs(self, train_config, eval_config): + self.train_config = train_config + self.eval_config = eval_config def eval(self, img): """Evaluate the model - find mask of the given image @@ -89,7 +93,7 @@ def masks_to_outlines(self, mask): class CellClassifierFCNN(nn.Module): ''' - Fully convolutional classifier for cell images. + Fully convolutional classifier for cell images. NOTE -> This model cannot be used as a standalone model in DCP Args: model_config (dict): Model configuration. @@ -130,6 +134,10 @@ def __init__(self, model_config, train_config, eval_config): self.final_conv = nn.Conv2d(128, self.num_classes, 1) self.pooling = nn.AdaptiveMaxPool2d(1) + def update_configs(self, train_config, eval_config): + self.train_config = train_config + self.eval_config = eval_config + def forward(self, x): x = self.layer1(x) @@ -223,22 +231,10 @@ def __init__(self, model_config, train_config, eval_config): self.train_config, self.eval_config) - def init_from_checkpoints(self, chpt_classifier=None, chpt_segmentor=None): - """ - Initialize the model from pre-trained checkpoints. - """ - - self.segmentor = CustomCellposeModel( - self.model_config.get("segmentor", {}), - self.train_config.get("segmentor", {}), - self.eval_config.get("segmentor", {}) - ) - self.classifier = CellClassifierFCNN( - self.model_config.get("classifier", {}), - self.train_config.get("classifier", {}), - self.eval_config.get("classifier", {}) - ) - + def update_configs(self, train_config, eval_config): + self.train_config = train_config + self.eval_config = eval_config + def train(self, imgs, masks): """Trains the given model. First trains the segmentor and then the clasiffier. diff --git a/src/server/dcp_server/service.py b/src/server/dcp_server/service.py index 8eae7ef..d6677de 100644 --- a/src/server/dcp_server/service.py +++ b/src/server/dcp_server/service.py @@ -21,7 +21,13 @@ # instantiate the model model_class = getattr(models_module, service_config['model_to_use']) -model = model_class(model_config = model_config, train_config = train_config, eval_config = eval_config) +# load latest model if specified in config +if service_config['load_latest_model']: + model = bentoml.pytorch.load_model(service_config['save_model_path']+':latest') + model.update_configs(train_config = train_config, eval_config = eval_config) +# else initialise with random weights +else: + model = model_class(model_config = model_config, train_config = train_config, eval_config = eval_config) custom_model_runner = t.cast( "CustomRunner", bentoml.Runner(CustomRunnable, name=service_config['runner_name'], runnable_init_params={"model": model, "save_model_path": service_config['save_model_path']}) From 007dfd5529f250f2ecd161d23876e8e28259a152 Mon Sep 17 00:00:00 2001 From: Christina Bukas Date: Wed, 22 Nov 2023 15:31:15 +0100 Subject: [PATCH 06/11] removed old commented out code for testing --- src/server/dcp_server/serviceclasses.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/server/dcp_server/serviceclasses.py b/src/server/dcp_server/serviceclasses.py index 511999e..a203c1f 100644 --- a/src/server/dcp_server/serviceclasses.py +++ b/src/server/dcp_server/serviceclasses.py @@ -51,8 +51,6 @@ def train(self, imgs: List[np.ndarray], masks: List[np.ndarray]) -> str: :return: path of the saved model :rtype: str """ - #s1 = self.model.segmentor.net.state_dict() - #c1 = self.model.classifier.parameters() self.model.train(imgs, masks) # Save the bentoml model #bentoml.picklable_model.save_model(self.save_model_path, self.model) @@ -120,12 +118,9 @@ async def train(input_path): :rtype: str """ print("Calling retrain from server.") - # Train the model model_path = await self.segmentation.train(input_path) - msg = "Success! Trained model saved in: " + model_path - return msg return svc From 6afed320d0dad25e63cc7a526f5d89d90d17ace6 Mon Sep 17 00:00:00 2001 From: Christina Bukas Date: Fri, 24 Nov 2023 14:22:23 +0100 Subject: [PATCH 07/11] setting cellpose mkldnn parameter to False as hinders saving --- src/server/dcp_server/models.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/server/dcp_server/models.py b/src/server/dcp_server/models.py index 248b70b..d4bba11 100644 --- a/src/server/dcp_server/models.py +++ b/src/server/dcp_server/models.py @@ -35,6 +35,7 @@ def __init__(self, model_config, train_config, eval_config): #super().__init__(**model_config["segmentor"]) nn.Module.__init__(self) models.CellposeModel.__init__(self, **model_config["segmentor"]) + self.mkldnn = False # otherwise we get error with saving model self.train_config = train_config self.eval_config = eval_config @@ -72,10 +73,7 @@ def train(self, imgs, masks): super().train(train_data=deepcopy(imgs), train_labels=masks, **self.train_config["segmentor"]) pred_masks = [self.eval(img) for img in masks] - print(len(pred_masks)) self.metric = np.mean(aggregated_jaccard_index(masks, pred_masks)) - # pred_masks = [self.eval(img) for img in masks] - # self.loss = self.loss_fn(masks, pred_masks) def masks_to_outlines(self, mask): From b44357281a484dac4db1da6d0862abb8b9964fd0 Mon Sep 17 00:00:00 2001 From: Christina Bukas Date: Fri, 24 Nov 2023 14:23:29 +0100 Subject: [PATCH 08/11] removing load latest save model, happends by default in eval if exists --- src/server/dcp_server/config.cfg | 3 +-- src/server/dcp_server/service.py | 8 +------- 2 files changed, 2 insertions(+), 9 deletions(-) diff --git a/src/server/dcp_server/config.cfg b/src/server/dcp_server/config.cfg index 6eb40c1..9517782 100644 --- a/src/server/dcp_server/config.cfg +++ b/src/server/dcp_server/config.cfg @@ -7,7 +7,6 @@ "service": { "model_to_use": "CustomCellposeModel", - "load_latest_model": null, "save_model_path": "mito", "runner_name": "cellpose_runner", "service_name": "data-centric-platform", @@ -62,6 +61,6 @@ "noise_intensity": 5 } }, - "mask_channel_axis": 0 + "mask_channel_axis": null } } \ No newline at end of file diff --git a/src/server/dcp_server/service.py b/src/server/dcp_server/service.py index d6677de..8eae7ef 100644 --- a/src/server/dcp_server/service.py +++ b/src/server/dcp_server/service.py @@ -21,13 +21,7 @@ # instantiate the model model_class = getattr(models_module, service_config['model_to_use']) -# load latest model if specified in config -if service_config['load_latest_model']: - model = bentoml.pytorch.load_model(service_config['save_model_path']+':latest') - model.update_configs(train_config = train_config, eval_config = eval_config) -# else initialise with random weights -else: - model = model_class(model_config = model_config, train_config = train_config, eval_config = eval_config) +model = model_class(model_config = model_config, train_config = train_config, eval_config = eval_config) custom_model_runner = t.cast( "CustomRunner", bentoml.Runner(CustomRunnable, name=service_config['runner_name'], runnable_init_params={"model": model, "save_model_path": service_config['save_model_path']}) From 91e4d6367ca060e1fd9fa7ad461f213a26fcae13 Mon Sep 17 00:00:00 2001 From: Christina Bukas Date: Fri, 24 Nov 2023 14:25:10 +0100 Subject: [PATCH 09/11] changed model saving and include reloading on eval and init --- src/server/dcp_server/serviceclasses.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/src/server/dcp_server/serviceclasses.py b/src/server/dcp_server/serviceclasses.py index a203c1f..53f30cc 100644 --- a/src/server/dcp_server/serviceclasses.py +++ b/src/server/dcp_server/serviceclasses.py @@ -4,6 +4,8 @@ from bentoml.io import Text, NumpyNdarray from typing import List +from dcp_server import models as DCPModels + class CustomRunnable(bentoml.Runnable): ''' @@ -23,6 +25,9 @@ def __init__(self, model, save_model_path): self.model = model self.save_model_path = save_model_path + # load model if it already exists to continue training from there? + if self.save_model_path in [model.tag.name for model in bentoml.models.list()]: + self.model = bentoml.pytorch.load_model(self.save_model_path+":latest") @bentoml.Runnable.method(batchable=False) def evaluate(self, img: np.ndarray) -> np.ndarray: @@ -34,8 +39,10 @@ def evaluate(self, img: np.ndarray) -> np.ndarray: :type z_axis: int :return: mask of the image, list of 2D arrays, or single 3D array (if do_3D=True) labelled image. :rtype: np.ndarray - """ - + """ + # load the latest model if it is available (in case train has already occured) + if self.save_model_path in [model.tag.name for model in bentoml.models.list()]: + self.model = bentoml.pytorch.load_model(self.save_model_path+":latest") mask = self.model.eval(img=img) return mask @@ -56,14 +63,10 @@ def train(self, imgs: List[np.ndarray], masks: List[np.ndarray]) -> str: #bentoml.picklable_model.save_model(self.save_model_path, self.model) bentoml.pytorch.save_model(self.save_model_path, # Model name in the local Model Store self.model, # Model instance being saved - labels={ # User-defined labels for managing models in BentoCloud - "owner": "ai-consultants", - "stage": "dev", - }, - ) + external_modules=[DCPModels] + ) return self.save_model_path - class CustomBentoService(): """BentoML Service class. Contains all the functions necessary to serve the service with BentoML From 86c143a5d7b75cd1d07a0f10e72dd6d745d765aa Mon Sep 17 00:00:00 2001 From: Christina Bukas Date: Fri, 24 Nov 2023 15:21:56 +0100 Subject: [PATCH 10/11] changing mask_channel_axis for tests to pass --- src/server/dcp_server/config.cfg | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/server/dcp_server/config.cfg b/src/server/dcp_server/config.cfg index 9517782..a3efe5d 100644 --- a/src/server/dcp_server/config.cfg +++ b/src/server/dcp_server/config.cfg @@ -31,7 +31,7 @@ "train":{ "segmentor":{ - "n_epochs": 50, + "n_epochs": 10, "channels": [0,0], "min_train_masks": 1 }, @@ -61,6 +61,6 @@ "noise_intensity": 5 } }, - "mask_channel_axis": null + "mask_channel_axis": 0 } } \ No newline at end of file From d61d6fccab05fb37d3ca8d0ef0bba607eaa2bcfe Mon Sep 17 00:00:00 2001 From: Mariia Date: Tue, 5 Dec 2023 23:14:59 +0100 Subject: [PATCH 11/11] Merge train and eval tests --- src/server/test/test_config.cfg | 67 +++++++++++++++ src/server/test/test_integration.py | 123 +++++++++++++++++++++------- 2 files changed, 160 insertions(+), 30 deletions(-) create mode 100644 src/server/test/test_config.cfg diff --git a/src/server/test/test_config.cfg b/src/server/test/test_config.cfg new file mode 100644 index 0000000..04073ed --- /dev/null +++ b/src/server/test/test_config.cfg @@ -0,0 +1,67 @@ +{ + "setup": { + "segmentation": "GeneralSegmentation", + "accepted_types": [".jpg", ".jpeg", ".png", ".tiff", ".tif"], + "seg_name_string": "_seg" + }, + + "service": { + "model_to_use": "CustomCellposeModel", + "save_model_path": "mito", + "runner_name": "cellpose_runner", + "service_name": "data-centric-platform", + "port": 7010 + }, + + "model": { + "segmentor": { + "model_type": "cyto" + }, + "classifier":{ + "in_channels": 1, + "num_classes": 3, + "black_bg": "False", + "include_mask": "False" + } + }, + + "data": { + "data_root": "data" + }, + + "train":{ + "segmentor":{ + "n_epochs": 20, + "channels": [0,0], + "min_train_masks": 1, + "learning_rate":0.01 + }, + "classifier":{ + "train_data":{ + "patch_size": 64, + "noise_intensity": 5, + "num_classes": 3 + }, + "n_epochs": 8, + "lr": 0.001, + "batch_size": 1, + "optimizer": "Adam" + } + }, + + "eval":{ + "segmentor": { + "z_axis": null, + "channel_axis": null, + "rescale": 1, + "batch_size": 1 + }, + "classifier": { + "data":{ + "patch_size": 64, + "noise_intensity": 5 + } + }, + "mask_channel_axis": 0 + } +} \ No newline at end of file diff --git a/src/server/test/test_integration.py b/src/server/test/test_integration.py index be5d51f..05cf39e 100644 --- a/src/server/test/test_integration.py +++ b/src/server/test/test_integration.py @@ -3,6 +3,7 @@ import torch from torchmetrics import JaccardIndex import numpy as np +import random import inspect # from importlib.machinery import SourceFileLoader @@ -15,13 +16,19 @@ import pytest +seed_value = 2023 +random.seed(seed_value) +torch.manual_seed(seed_value) +np.random.seed(seed_value) + # retrieve models names model_classes = [ cls_obj for cls_name, cls_obj in inspect.getmembers(models) \ if inspect.isclass(cls_obj) \ and cls_obj.__module__ == models.__name__ \ and not cls_name.startswith("CellClassifier") - ] + ] + @pytest.fixture(params=model_classes) def model_class(request): @@ -30,9 +37,9 @@ def model_class(request): @pytest.fixture() def model(model_class): - model_config = read_config('model', config_path='dcp_server/config.cfg') - train_config = read_config('train', config_path='dcp_server/config.cfg') - eval_config = read_config('eval', config_path='dcp_server/config.cfg') + model_config = read_config('model', config_path='test/test_config.cfg') + train_config = read_config('train', config_path='test/test_config.cfg') + eval_config = read_config('eval', config_path='test/test_config.cfg') model = model_class(model_config, train_config, eval_config) @@ -40,6 +47,7 @@ def model(model_class): @pytest.fixture def data_train(): + images, masks = get_synthetic_dataset(num_samples=4) masks = [np.array(mask) for mask in masks] masks_instances = [mask.sum(-1) for mask in masks] @@ -54,23 +62,76 @@ def data_eval(): msk_ = np.stack((msk.sum(-1), ((msk > 0) * np.arange(1, 4)).sum(-1)), axis=0).transpose(1,0,2,3) return img, msk_ -def test_train_run(data_train, model): +# def test_train_run(data_train, model): - images, masks = data_train - model.train(images, masks) - # assert(patch_model.segmentor.loss>1e-2) #TODO figure out appropriate value +# images, masks = data_train +# model.train(images, masks) +# # assert(patch_model.segmentor.loss>1e-2) #TODO figure out appropriate value - # retrieve the attribute names of the class of the current model - attrs = model.__dict__.keys() +# # retrieve the attribute names of the class of the current model +# attrs = model.__dict__.keys() - if "classifier" in attrs: - assert(model.classifier.loss>1e-2) - if "metric" in attrs: - assert(model.metric>1e-2) +# if "classifier" in attrs: +# assert(model.classifier.loss<0.4) +# if "metric" in attrs: +# assert(model.metric>0.1) +# if "loss" in attrs: +# assert(model.loss<0.3) + +# def test_eval_run(data_train, data_eval, model): + +# images, masks = data_train +# model.train(images, masks) + +# imgs_test, masks_test = data_eval + +# jaccard_index_instances = 0 +# jaccard_index_classes = 0 + +# jaccard_metric_binary = JaccardIndex(task="multiclass", num_classes=2, average="macro", ignore_index=0) +# jaccard_metric_multi = JaccardIndex(task="multiclass", num_classes=4, average="macro", ignore_index=0) + +# for img, mask in zip(imgs_test, masks_test): + +# #mask - instance segmentation mask + classes (2, 512, 512) +# #pred_mask - tuple of cellpose (512, 512), patch net multiclass segmentation (512, 512, 2) + +# pred_mask = model.eval(img) #, channels=[0,0]) + +# if pred_mask.ndim > 2: +# pred_mask_bin = torch.tensor((pred_mask[0]>0).astype(bool).astype(int)) +# else: +# pred_mask_bin = torch.tensor((pred_mask > 0).astype(bool).astype(int)) + +# bin_mask = torch.tensor((mask[0]>0).astype(bool).astype(int)) + +# jaccard_index_instances += jaccard_metric_binary( +# pred_mask_bin, +# bin_mask +# ) + +# if pred_mask.ndim > 2: + +# jaccard_index_classes += jaccard_metric_multi( +# torch.tensor(pred_mask[1].astype(int)), +# torch.tensor(mask[1].astype(int)) +# ) -def test_eval_run(data_eval, model): +# jaccard_index_instances /= len(imgs_test) +# assert(jaccard_index_instances>0.2) + +# # for PatchCNN model +# if pred_mask.ndim > 2: + +# jaccard_index_classes /= len(imgs_test) +# assert(jaccard_index_classes>0.1) - imgs, masks = data_eval +def test_train_eval_run(data_train, data_eval, model): + + images, masks = data_train + model.train(images, masks) + + imgs_test, masks_test = data_eval jaccard_index_instances = 0 jaccard_index_classes = 0 @@ -78,12 +139,12 @@ def test_eval_run(data_eval, model): jaccard_metric_binary = JaccardIndex(task="multiclass", num_classes=2, average="macro", ignore_index=0) jaccard_metric_multi = JaccardIndex(task="multiclass", num_classes=4, average="macro", ignore_index=0) - for img, mask in zip(imgs, masks): + for img, mask in zip(imgs_test, masks_test): #mask - instance segmentation mask + classes (2, 512, 512) #pred_mask - tuple of cellpose (512, 512), patch net multiclass segmentation (512, 512, 2) - pred_mask = model.eval(img) #, channels=[0,0]) + pred_mask = model.eval(img) if pred_mask.ndim > 2: pred_mask_bin = torch.tensor((pred_mask[0]>0).astype(bool).astype(int)) @@ -104,19 +165,21 @@ def test_eval_run(data_eval, model): torch.tensor(mask[1].astype(int)) ) - jaccard_index_instances /= len(imgs) - assert(jaccard_index_instances<0.6) - - # for PatchCNN model - if pred_mask.ndim > 2: - - jaccard_index_classes /= len(imgs) - assert(jaccard_index_classes<0.6) - - + jaccard_index_instances /= len(imgs_test) + assert(jaccard_index_instances>0.2) + # retrieve the attribute names of the class of the current model + attrs = model.__dict__.keys() + if "classifier" in attrs: + assert(model.classifier.loss<0.4) + if "metric" in attrs: + assert(model.metric>0.1) + if "loss" in attrs: + assert(model.loss<0.3) - + # for PatchCNN model + if pred_mask.ndim > 2: - \ No newline at end of file + jaccard_index_classes /= len(imgs_test) + assert(jaccard_index_classes>0.1) \ No newline at end of file