diff --git a/src/server/test/configs/test_config_MultiCellpose.yaml b/src/server/test/configs/test_config_MultiCellpose.yaml index b74476f..46b913d 100644 --- a/src/server/test/configs/test_config_MultiCellpose.yaml +++ b/src/server/test/configs/test_config_MultiCellpose.yaml @@ -31,7 +31,7 @@ "train":{ "segmentor":{ - "n_epochs": 20, + "n_epochs": 30, "channels": [0,0], "min_train_masks": 1, "learning_rate":0.01 diff --git a/src/server/test/configs/test_config_UNet.yaml b/src/server/test/configs/test_config_UNet.yaml index f6ee29b..f4eba07 100644 --- a/src/server/test/configs/test_config_UNet.yaml +++ b/src/server/test/configs/test_config_UNet.yaml @@ -29,7 +29,7 @@ "train":{ "classifier":{ - "n_epochs": 20, + "n_epochs": 30, "lr": 0.005, "batch_size": 5, "optimizer": "Adam" diff --git a/src/server/test/test_integration.py b/src/server/test/test_integration.py index 8637377..6e37ea2 100644 --- a/src/server/test/test_integration.py +++ b/src/server/test/test_integration.py @@ -82,12 +82,21 @@ def test_train_eval_run(data_train, data_eval, model): """ Performs testing, training, and evaluation with the provided data and model. """ - + # train images, masks = data_train if model.model_name == "CustomCellpose": masks = [mask[0] for mask in masks] model.train(images, masks) + # retrieve the attribute names of the class of the current model + attrs = model.__dict__.keys() + + if "metric" in attrs: + assert model.metric > 0.1 + if "loss" in attrs: + assert model.loss < 0.83 + + # validate imgs_test, masks_test = data_eval if model.model_name == "CustomCellpose": masks = [mask[0] for mask in masks_test] @@ -128,15 +137,6 @@ def test_train_eval_run(data_train, data_eval, model): jaccard_index_instances /= len(imgs_test) assert jaccard_index_instances > 0.2 - # retrieve the attribute names of the class of the current model - attrs = model.__dict__.keys() - - if "metric" in attrs: - assert model.metric > 0.1 - if "loss" in attrs: - assert model.loss < 0.83 - - # for PatchCNN model if pred_mask.ndim > 2: jaccard_index_classes /= len(imgs_test) diff --git a/src/server/test/test_models.py b/src/server/test/test_models.py index 86529e1..eddf8f9 100644 --- a/src/server/test/test_models.py +++ b/src/server/test/test_models.py @@ -32,33 +32,3 @@ def test_eval_rf_not_fitted(): # if we don't fit the model then the model returns zeros assert np.all(model_rf.eval(X_test) == np.zeros(X_test.shape)) - -def test_update_configs(): - """ - Tests the update of model training and evaluation configurations. - """ - - model_config = read_config( - "model", config_path="test/configs/test_config_Inst2MultiSeg_RF.yaml" - ) - data_config = read_config( - "data", config_path="test/configs/test_config_Inst2MultiSeg_RF.yaml" - ) - train_config = read_config( - "train", config_path="test/configs/test_config_Inst2MultiSeg_RF.yaml" - ) - eval_config = read_config( - "eval", config_path="test/configs/test_config_Inst2MultiSeg_RF.yaml" - ) - - model = models.CustomCellpose( - "Cellpose", model_config, data_config, train_config, eval_config - ) - - new_train_config = {"param1": "value1"} - new_eval_config = {"param2": "value2"} - - model.update_configs(new_train_config, new_eval_config) - - assert model.train_config == new_train_config - assert model.eval_config == new_eval_config