diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 07ea746..255b0ac 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -77,7 +77,7 @@ jobs: strategy: matrix: platform: [ubuntu-latest, windows-latest, macos-latest] - python-version: [3.8, 3.9, "3.10"] + python-version: [3.9, "3.10"] steps: - name: Checkout Repository @@ -91,8 +91,10 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - python -m pip install setuptools + python -m pip install --upgrade setuptools + pip install numpy pip install pytest + pip install wheel pip install coverage pip install -e ".[testing]" working-directory: src/server diff --git a/src/server/dcp_server/config.cfg b/src/server/dcp_server/config.cfg index aa52b21..74ddc83 100644 --- a/src/server/dcp_server/config.cfg +++ b/src/server/dcp_server/config.cfg @@ -18,6 +18,7 @@ "model_type": "cyto" }, "classifier":{ + "model_class": "RandomForest", "in_channels": 1, "num_classes": 3, "features":[64,128,256,512], diff --git a/src/server/dcp_server/models.py b/src/server/dcp_server/models.py index 5a0ef61..cf015ee 100644 --- a/src/server/dcp_server/models.py +++ b/src/server/dcp_server/models.py @@ -3,17 +3,22 @@ from torch import nn from torch.optim import Adam from torch.utils.data import TensorDataset, DataLoader +from torchmetrics import F1Score from copy import deepcopy from tqdm import tqdm import numpy as np from scipy.ndimage import label -from cellpose.metrics import aggregated_jaccard_index +from sklearn.ensemble import RandomForestClassifier +from sklearn.metrics import f1_score, log_loss +from sklearn.exceptions import NotFittedError +from cellpose.metrics import aggregated_jaccard_index +from cellpose.dynamics import labels_to_flows #from segment_anything import SamPredictor, sam_model_registry #from segment_anything.automatic_mask_generator import SamAutomaticMaskGenerator -from dcp_server.utils import get_centered_patches, find_max_patch_size, create_patch_dataset +from dcp_server.utils import get_centered_patches, find_max_patch_size, create_patch_dataset, create_dataset_for_rf class CustomCellposeModel(models.CellposeModel, nn.Module): """Custom cellpose model inheriting the attributes and functions from the original CellposeModel and implementing @@ -39,6 +44,7 @@ def __init__(self, model_config, train_config, eval_config, model_name): self.mkldnn = False # otherwise we get error with saving model self.train_config = train_config self.eval_config = eval_config + self.loss = 1e6 self.model_name = model_name def update_configs(self, train_config, eval_config): @@ -71,12 +77,27 @@ def train(self, imgs, masks): if masks[0].shape[0] == 2: masks = list(masks[:,0,...]) - super().train(train_data=deepcopy(imgs), train_labels=masks, **self.train_config["segmentor"]) + + # compute loss and metric + true_bin_masks = [mask>0 for mask in masks] # get binary masks + true_flows = labels_to_flows(masks) # get cellpose flows + # get predicted flows and cell probability + pred_masks = [] + pred_flows = [] + true_lbl = [] + for idx, img in enumerate(imgs): + mask, flows, _ = super().eval(x=img, **self.eval_config["segmentor"]) + pred_masks.append(mask) + pred_flows.append(np.stack([flows[1][0], flows[1][1], flows[2]])) # stack cell probability map, horizontal and vertical flow + true_lbl.append(np.stack([true_bin_masks[idx], true_flows[idx][2], true_flows[idx][3]])) - pred_masks = [self.eval(img) for img in masks] - self.metric = np.mean(aggregated_jaccard_index(masks, pred_masks)) # TODO move metric computation - # self.loss = self.loss_fn(masks, pred_masks) + true_lbl = np.stack(true_lbl) + pred_flows=np.stack(pred_flows) + pred_flows = torch.from_numpy(pred_flows).float().to('cpu') + # compute loss, combination of mse for flows and bce for cell probability + self.loss = self.loss_fn(true_lbl, pred_flows) + self.metric = np.mean(aggregated_jaccard_index(masks, pred_masks)) def masks_to_outlines(self, mask): """ get outlines of masks as a 0-1 array @@ -105,8 +126,8 @@ class CellClassifierFCNN(nn.Module): def __init__(self, model_config, train_config, eval_config): super().__init__() - self.in_channels = model_config["classifier"]["in_channels"] - self.num_classes = model_config["classifier"]["num_classes"] + self.in_channels = model_config["classifier"].get("in_channels",1) + self.num_classes = model_config["classifier"].get("num_classes",3) self.train_config = train_config["classifier"] self.eval_config = eval_config["classifier"] @@ -134,6 +155,8 @@ def __init__(self, model_config, train_config, eval_config): self.final_conv = nn.Conv2d(128, self.num_classes, 1) self.pooling = nn.AdaptiveMaxPool2d(1) + self.metric_fn = F1Score(num_classes=self.num_classes, task="multiclass") + def update_configs(self, train_config, eval_config): self.train_config = train_config self.eval_config = eval_config @@ -180,7 +203,7 @@ def train (self, imgs, labels): # TODO check if we should replace self.parameters with super.parameters() for _ in tqdm(range(epochs), desc="Running CellClassifierFCNN training"): - self.loss = 0 + self.loss, self.metric = 0, 0 for data in train_dataloader: imgs, labels = data @@ -192,7 +215,10 @@ def train (self, imgs, labels): optimizer.step() self.loss += l.item() + self.metric += self.metric_fn(preds, labels) + self.loss /= len(train_dataloader) + self.metric /= len(train_dataloader) def eval(self, img): """ @@ -224,15 +250,24 @@ def __init__(self, model_config, train_config, eval_config, model_name): self.eval_config = eval_config self.model_name = model_name + self.classifier_class = self.model_config.get("classifier").get("model_class", "CellClassifierFCNN") + # Initialize the cellpose model and the classifier self.segmentor = CustomCellposeModel(self.model_config, self.train_config, self.eval_config, "Cellpose") - self.classifier = CellClassifierFCNN(self.model_config, - self.train_config, - self.eval_config) - + + if self.classifier_class == "FCNN": + self.classifier = CellClassifierFCNN(self.model_config, + self.train_config, + self.eval_config) + + elif self.classifier_class == "RandomForest": + self.classifier = CellClassifierShallowModel(self.model_config, + self.train_config, + self.eval_config) + def update_configs(self, train_config, eval_config): self.train_config = train_config self.eval_config = eval_config @@ -249,19 +284,25 @@ def train(self, imgs, masks): # train cellpose masks = np.array(masks) masks_instances = list(masks[:,0,...]) #[mask.sum(-1) for mask in masks] if masks[0].ndim == 3 else masks - self.segmentor.train(imgs, masks_instances) + self.segmentor.train(deepcopy(imgs), masks_instances) # create patch dataset to train classifier masks_classes = list(masks[:,1,...]) #[((mask > 0) * np.arange(1, 4)).sum(-1) for mask in masks] - patches, labels = create_patch_dataset(imgs, - masks_classes, - masks_instances, - noise_intensity = self.train_config["classifier"]["train_data"]["noise_intensity"], - max_patch_size = self.train_config["classifier"]["train_data"]["patch_size"]) + patches, patch_masks, labels = create_patch_dataset(imgs, + masks_classes, + masks_instances, + noise_intensity = self.train_config["classifier"]["train_data"]["noise_intensity"], + max_patch_size = self.train_config["classifier"]["train_data"]["patch_size"]) + x = patches + if self.classifier_class == "RandomForest": + x = create_dataset_for_rf(patches, patch_masks) # train classifier - self.classifier.train(patches, labels) + self.classifier.train(x, labels) + # and compute metric and loss + self.metric = (self.segmentor.metric + self.classifier.metric) / 2 + self.loss = (self.segmentor.loss + self.classifier.loss)/2 def eval(self, img): - # TBD we assume image is either 2D [H, W] (see fsimage storage) + # TBD we assume image is 2D [H, W] (see fsimage storage) # The final mask which is returned should have # first channel the output of cellpose and the rest are the class channels with torch.no_grad(): @@ -275,22 +316,57 @@ def eval(self, img): noise_intensity = self.eval_config["classifier"]["data"]["noise_intensity"] # get patches centered around detected objects - patches, instance_labels, _ = get_centered_patches(img, + patches, patch_masks, instance_labels, _ = get_centered_patches(img, instance_mask, max_patch_size, noise_intensity=noise_intensity) + x = patches + if self.classifier_class == "RandomForest": + x = create_dataset_for_rf(patches, patch_masks) # loop over patches and create classification mask - for idx, patch in enumerate(patches): - patch_class = self.classifier.eval(patch) # patch size should be HxWxC, e.g. 64,64,3 + for idx in range(len(x)): + patch_class = self.classifier.eval(x[idx]) # Assign predicted class to corresponding location in final_mask - class_mask[instance_mask==instance_labels[idx]] = patch_class.item() + 1 + patch_class = patch_class.item() if isinstance(patch_class, torch.Tensor) else patch_class + class_mask[instance_mask==instance_labels[idx]] = patch_class + 1 # Apply mask to final_mask, retaining only regions where cellpose_mask is greater than 0 - #class_mask = class_mask * (instance_mask > 0)#.long()) final_mask = np.stack((instance_mask, class_mask), axis=self.eval_config['mask_channel_axis']).astype(np.uint16) # size 2xHxW return final_mask +class CellClassifierShallowModel: + + def __init__(self, model_config, train_config, eval_config): + + self.model_config = model_config + self.train_config = train_config + self.eval_config = eval_config + + self.model = RandomForestClassifier() # TODO chnage config so RandomForestClassifier accepts input params + + + def train(self, X_train, y_train): + self.model.fit(X_train,y_train) + + y_hat = self.model.predict(X_train) + y_hat_proba = self.model.predict_proba(X_train) + + self.metric = f1_score(y_train, y_hat, average='micro') + # Binary Cross Entrop Loss + self.loss = log_loss(y_train, y_hat_proba) + + + def eval(self, X_test): + + X_test = X_test.reshape(1,-1) + + try: + y_hat = self.model.predict(X_test) + except NotFittedError as e: + y_hat = np.zeros(X_test.shape[0]) + + return y_hat class UNet(nn.Module): diff --git a/src/server/dcp_server/serviceclasses.py b/src/server/dcp_server/serviceclasses.py index 410ee42..72a1e35 100644 --- a/src/server/dcp_server/serviceclasses.py +++ b/src/server/dcp_server/serviceclasses.py @@ -48,7 +48,7 @@ def evaluate(self, img: np.ndarray) -> np.ndarray: def check_and_load_model(self): bento_model_list = [model.tag.name for model in bentoml.models.list()] if self.save_model_path in bento_model_list: - loaded_model = bentoml.pytorch.load_model(self.save_model_path+":latest") + loaded_model = bentoml.picklable_model.load_model(self.save_model_path+":latest") assert loaded_model.__class__.__name__ == self.model.__class__.__name__, 'Check your config, loaded model and model to use not the same!' self.model = loaded_model @@ -65,11 +65,15 @@ def train(self, imgs: List[np.ndarray], masks: List[np.ndarray]) -> str: """ self.model.train(imgs, masks) # Save the bentoml model - #bentoml.picklable_model.save_model(self.save_model_path, self.model) - bentoml.pytorch.save_model(self.save_model_path, # Model name in the local Model Store - self.model, # Model instance being saved - external_modules=[DCPModels] - ) + bentoml.picklable_model.save_model( + self.save_model_path, + self.model, + external_modules=[DCPModels], + ) + # bentoml.pytorch.save_model(self.save_model_path, # Model name in the local Model Store + # self.model, # Model instance being saved + # external_modules=[DCPModels] + # ) return self.save_model_path diff --git a/src/server/dcp_server/utils.py b/src/server/dcp_server/utils.py index 86f5466..3e22670 100644 --- a/src/server/dcp_server/utils.py +++ b/src/server/dcp_server/utils.py @@ -1,8 +1,11 @@ from pathlib import Path import json +from copy import deepcopy import numpy as np -from scipy.ndimage import find_objects, center_of_mass +from scipy.ndimage import find_objects from skimage import measure +import SimpleITK as sitk +from radiomics import shape2D def read_config(name, config_path = 'config.cfg') -> dict: """Reads the configuration file @@ -35,71 +38,89 @@ def join_path(root_dir, filepath): return str(Path(root_dir, filepath)) def get_file_extension(file): return str(Path(file).suffix) -def crop_centered_padded_patch(x: np.ndarray, - c, - p, - l, +def crop_centered_padded_patch(img: np.ndarray, + patch_center_xy, + patch_size, + obj_label, mask: np.ndarray=None, noise_intensity=None) -> np.ndarray: """ Crop a patch from an array `x` centered at coordinates `c` with size `p`, and apply padding if necessary. Args: - x (np.ndarray): The input array from which the patch will be cropped. - c (tuple): The coordinates (row, column) at the center of the patch. - p (tuple): The size of the patch to be cropped (height, width). - l (int): The instance label of the mask at the patch + img (np.ndarray): The input array from which the patch will be cropped. + patch_center_xy (tuple): The coordinates (row, column) at the center of the patch. + patch_size (tuple): The size of the patch to be cropped (height, width). + obj_label (int): The instance label of the mask at the patch + mask (np.ndarray, optional): The mask array that asociated with the array x; + mask is used during training to mask out non-central elements; + for RandomForest, it is used to calculate pyradiomics features. + noise_intensity (float, optional): Intensity of noise to be added to the background. Returns: np.ndarray: The cropped patch with applied padding. """ - height, width = p # Size of the patch - + height, width = patch_size # Size of the patch + img_height, img_width = img.shape[0], img.shape[1] # Size of the input image + # Calculate the boundaries of the patch - top = c[0] - height // 2 + top = patch_center_xy[0] - height // 2 bottom = top + height - - left = c[1] - width // 2 + left = patch_center_xy[1] - width // 2 right = left + width # Crop the patch from the input array if mask is not None: mask_ = mask.max(-1) if len(mask.shape) >= 3 else mask # Zero out values in the patch where the mask is not equal to the central label - # m = (mask_ != central_label) & (mask_ > 0) - m = (mask_ != l) & (mask_ > 0) - x[m] = 0 - if noise_intensity is not None: - x[m] = np.random.normal(scale=noise_intensity, size=x[m].shape) - - patch = x[max(top, 0):min(bottom, x.shape[0]), max(left, 0):min(right, x.shape[1]), :] - - # Calculate the required padding amounts - size_x, size_y = x.shape[1], x.shape[0] - - # Apply padding if necessary + mask_other_objs = (mask_ != obj_label) & (mask_ > 0) + img[mask_other_objs] = 0 + # Add random noise at locations where other objects are present if noise_intensity is given + if noise_intensity is not None: img[mask_other_objs] = np.random.normal(scale=noise_intensity, size=img[mask_other_objs].shape) + mask[mask_other_objs] = 0 + # crop the mask + mask = mask[max(top, 0):min(bottom, img_height), max(left, 0):min(right, img_width), :] + + patch = img[max(top, 0):min(bottom, img_height), max(left, 0):min(right, img_width), :] + # Calculate the required padding amounts and apply padding if necessary if left < 0: patch = np.hstack(( np.random.normal(scale=noise_intensity, size=(patch.shape[0], abs(left), patch.shape[2])).astype(np.uint8), patch)) + if mask is not None: + mask = np.hstack(( + np.zeros((mask.shape[0], abs(left), mask.shape[2])).astype(np.uint8), + mask)) # Apply padding on the right side if necessary - if right > size_x: + if right > img_width: patch = np.hstack(( patch, - np.random.normal(scale=noise_intensity, size=(patch.shape[0], (right - size_x), patch.shape[2])).astype(np.uint8))) + np.random.normal(scale=noise_intensity, size=(patch.shape[0], (right - img_width), patch.shape[2])).astype(np.uint8))) + if mask is not None: + mask = np.hstack(( + mask, + np.zeros((mask.shape[0], (right - img_width), mask.shape[2])).astype(np.uint8))) # Apply padding on the top side if necessary if top < 0: patch = np.vstack(( np.random.normal(scale=noise_intensity, size=(abs(top), patch.shape[1], patch.shape[2])).astype(np.uint8), patch)) + if mask is not None: + mask = np.vstack(( + np.zeros((abs(top), mask.shape[1], mask.shape[2])).astype(np.uint8), + mask)) # Apply padding on the bottom side if necessary - if bottom > size_y: + if bottom > img_height: patch = np.vstack(( patch, - np.random.normal(scale=noise_intensity, size=(bottom - size_y, patch.shape[1], patch.shape[2])).astype(np.uint8))) - - return patch + np.random.normal(scale=noise_intensity, size=(bottom - img_height, patch.shape[1], patch.shape[2])).astype(np.uint8))) + if mask is not None: + mask = np.vstack(( + mask, + np.zeros((bottom - img_height, mask.shape[1], mask.shape[2])).astype(np.uint8))) + + return patch, mask def get_center_of_mass_and_label(mask: np.ndarray) -> np.ndarray: @@ -146,30 +167,31 @@ def get_centered_patches(img, ''' - patches, instance_labels, class_labels = [], [], [] + patches, patch_masks, instance_labels, class_labels = [], [], [], [] # if image is 2D add an additional dim for channels if img.ndim<3: img = img[:, :, np.newaxis] if mask.ndim<3: mask = mask[:, :, np.newaxis] # compute center of mass of objects centers_of_mass, instance_labels = get_center_of_mass_and_label(mask) # Crop patches around each center of mass - for c, l in zip(centers_of_mass, instance_labels): + for c, obj_label in zip(centers_of_mass, instance_labels): c_x, c_y = c - patch = crop_centered_padded_patch(img.copy(), + patch, patch_mask = crop_centered_padded_patch(img.copy(), (c_x, c_y), (p_size, p_size), - l, - mask=mask, + obj_label, + mask=deepcopy(mask), noise_intensity=noise_intensity) patches.append(patch) + patch_masks.append(patch_mask) if mask_class is not None: # get the class instance for the specific object - instance_labels.append(l) - class_l = int(np.unique(mask_class[mask[:,:,0]==l])) + instance_labels.append(obj_label) + class_l = int(np.unique(mask_class[mask[:,:,0]==obj_label])) #-1 because labels from mask start from 1, we want classes to start from 0 class_labels.append(class_l-1) - return patches, instance_labels, class_labels + return patches, patch_masks, instance_labels, class_labels def get_objects(mask): return find_objects(mask) @@ -218,15 +240,95 @@ def create_patch_dataset(imgs, masks_classes, masks_instances, noise_intensity, max_patch_size = np.max([find_max_patch_size(mask) for mask in masks_instances]) - patches, labels = [], [] + patches, patch_masks, labels = [], [], [] for img, mask_class, mask_instance in zip(imgs, masks_classes, masks_instances): # mask_instance has dimension WxH # mask_class has dimension WxH - patch, _, label = get_centered_patches(img, + patch, patch_mask, _, label = get_centered_patches(img, mask_instance, max_patch_size, noise_intensity=noise_intensity, - mask_class=mask_class) + mask_class=mask_class, + ) patches.extend(patch) + patch_masks.extend(patch_mask) labels.extend(label) - return patches, labels \ No newline at end of file + return patches, patch_masks, labels + + +def get_shape_features(img, mask): + """ + Calculate shape-based radiomic features from an image within the region defined by the mask. + + Args: + - img (np.ndarray): The input image. + - mask (np.ndarray): The mask corresponding to the image. + + Returns: + - np.ndarray: An array containing the calculated shape-based radiomic features, such as: + Elongation, Sphericity, Perimeter surface. + """ + + mask = 255 * ((mask) > 0).astype(np.uint8) + + image = sitk.GetImageFromArray(img.squeeze()) + roi_mask = sitk.GetImageFromArray(mask.squeeze()) + + shape_calculator = shape2D.RadiomicsShape2D(inputImage=image, inputMask=roi_mask, label=255) + # Calculate the shape-based radiomic features + shape_features = shape_calculator.execute() + + return np.array(list(shape_features.values())) + +def extract_intensity_features(image, mask): + """ + Extract intensity-based features from an image within the region defined by the mask. + + Args: + - image (np.ndarray): The input image. + - mask (np.ndarray): The mask defining the region of interest. + + Returns: + - np.ndarray: An array containing the extracted intensity-based features: + median intensity, mean intensity, 25th/75th percentile intensity within the masked region. + + """ + + features = {} + + # Ensure the image and mask have the same dimensions + + if image.shape != mask.shape: + raise ValueError("Image and mask must have the same dimensions") + + masked_image = image[(mask>0)] + # features["min_intensity"] = np.min(masked_image) + # features["max_intensity"] = np.max(masked_image) + features["median_intensity"] = np.median(masked_image) + features["mean_intensity"] = np.mean(masked_image) + features["25th_percentile_intensity"] = np.percentile(masked_image, 25) + features["75th_percentile_intensity"] = np.percentile(masked_image, 75) + + return np.array(list(features.values())) + +def create_dataset_for_rf(imgs, masks): + """ + Extract intensity-based features from an image within the region defined by the mask. + + Args: + - imgs (List): A list of all input images. + - mask (List): A list of all corresponding masks defining the region of interest. + + Returns: + - List: A list of arrays containing shape and intensity-based features + + """ + X = [] + for img, mask in zip(imgs, masks): + + shape_features = get_shape_features(img, mask) + intensity_features = extract_intensity_features(img, mask) + features_list = np.concatenate((shape_features, intensity_features), axis=0) + X.append(features_list) + + return X \ No newline at end of file diff --git a/src/server/requirements.txt b/src/server/requirements.txt index d954704..e3a9efb 100644 --- a/src/server/requirements.txt +++ b/src/server/requirements.txt @@ -1,6 +1,11 @@ +wheel==0.42.0 cellpose>=2.2 bentoml==1.0.16 scikit-image>=0.19.3 torchmetrics>=0.11.4 torch>=2.1.0 pytest>=7.4.3 +numpy +scikit-learn>=1.2.2 +SimpleITK>=2.2.1 +pyradiomics==3.0.1 \ No newline at end of file diff --git a/src/server/test/test_config.cfg b/src/server/test/configs/test_config_RF.cfg similarity index 97% rename from src/server/test/test_config.cfg rename to src/server/test/configs/test_config_RF.cfg index 2c50d33..c09c6af 100644 --- a/src/server/test/test_config.cfg +++ b/src/server/test/configs/test_config_RF.cfg @@ -18,6 +18,7 @@ "model_type": "cyto" }, "classifier":{ + "model_class": "RandomForest", "in_channels": 1, "num_classes": 3, "features":[64,128,256,512], diff --git a/src/server/test/configs/test_config_fcnn.cfg b/src/server/test/configs/test_config_fcnn.cfg new file mode 100644 index 0000000..02039f6 --- /dev/null +++ b/src/server/test/configs/test_config_fcnn.cfg @@ -0,0 +1,69 @@ +{ + "setup": { + "segmentation": "GeneralSegmentation", + "accepted_types": [".jpg", ".jpeg", ".png", ".tiff", ".tif"], + "seg_name_string": "_seg" + }, + + "service": { + "model_to_use": "CustomCellposeModel", + "save_model_path": "mito", + "runner_name": "cellpose_runner", + "service_name": "data-centric-platform", + "port": 7010 + }, + + "model": { + "segmentor": { + "model_type": "cyto" + }, + "classifier":{ + "model_class": "FCNN", + "in_channels": 1, + "num_classes": 3, + "features":[64,128,256,512], + "black_bg": "False", + "include_mask": "False" + } + }, + + "data": { + "data_root": "data" + }, + + "train":{ + "segmentor":{ + "n_epochs": 20, + "channels": [0,0], + "min_train_masks": 1, + "learning_rate":0.01 + }, + "classifier":{ + "train_data":{ + "patch_size": 64, + "noise_intensity": 5, + "num_classes": 3 + }, + "n_epochs": 20, + "lr": 0.005, + "batch_size": 5, + "optimizer": "Adam" + } + }, + + "eval":{ + "segmentor": { + "z_axis": null, + "channel_axis": null, + "rescale": 1, + "batch_size": 1 + }, + "classifier": { + "data":{ + "patch_size": 64, + "noise_intensity": 5 + } + }, + "mask_channel_axis": 0 + } +} \ No newline at end of file diff --git a/src/server/test/test_integration.py b/src/server/test/test_integration.py index e8e9d99..ced69cd 100644 --- a/src/server/test/test_integration.py +++ b/src/server/test/test_integration.py @@ -1,4 +1,5 @@ import sys +from glob import glob import inspect import random @@ -7,6 +8,8 @@ import torch from torchmetrics import JaccardIndex +# from importlib.machinery import SourceFileLoader + sys.path.append(".") import dcp_server.models as models @@ -28,17 +31,22 @@ and not cls_name.startswith("CellClassifier") ] +config_paths = glob("test/configs/*.cfg") @pytest.fixture(params=model_classes) def model_class(request): return request.param +@pytest.fixture(params=config_paths) +def config_path(request): + return request.param + @pytest.fixture() -def model(model_class): +def model(model_class, config_path): - model_config = read_config('model', config_path='test/test_config.cfg') - train_config = read_config('train', config_path='test/test_config.cfg') - eval_config = read_config('eval', config_path='test/test_config.cfg') + model_config = read_config('model', config_path=config_path) + train_config = read_config('train', config_path=config_path) + eval_config = read_config('eval', config_path=config_path) model = model_class(model_config, train_config, eval_config, str(model_class)) @@ -169,12 +177,10 @@ def test_train_eval_run(data_train, data_eval, model): # retrieve the attribute names of the class of the current model attrs = model.__dict__.keys() - if "classifier" in attrs: - assert(model.classifier.loss<0.4) if "metric" in attrs: assert(model.metric>0.1) if "loss" in attrs: - assert(model.loss<0.33) + assert(model.loss<0.75) # for PatchCNN model if pred_mask.ndim > 2: diff --git a/src/server/test/test_models.py b/src/server/test/test_models.py new file mode 100644 index 0000000..84b203c --- /dev/null +++ b/src/server/test/test_models.py @@ -0,0 +1,36 @@ +import pytest +import numpy as np + +import dcp_server.models as models +from dcp_server.utils import read_config + +def test_eval_rf_not_fitted(): + + model_config = read_config('model', config_path='test/configs/test_config_RF.cfg') + train_config = read_config('train', config_path='test/configs/test_config_RF.cfg') + eval_config = read_config('eval', config_path='test/configs/test_config_RF.cfg') + + model_rf = models.CellClassifierShallowModel(model_config,train_config,eval_config) + + X_test = np.array([[1, 2, 3]]) + # if we don't fit the model then the model returns zeros + assert np.all(model_rf.eval(X_test)== np.zeros(X_test.shape)) + +def test_update_configs(): + + model_config = read_config('model', config_path='test/configs/test_config_RF.cfg') + train_config = read_config('train', config_path='test/configs/test_config_RF.cfg') + eval_config = read_config('eval', config_path='test/configs/test_config_RF.cfg') + + model = models.CustomCellposeModel(model_config,train_config,eval_config, "Cellpose") + + new_train_config = {"param1": "value1"} + new_eval_config = {"param2": "value2"} + + model.update_configs(new_train_config, new_eval_config) + + assert model.train_config == new_train_config + assert model.eval_config == new_eval_config + + +