Skip to content

Commit

Permalink
update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Christina Bukas committed Mar 13, 2024
1 parent 091d6bb commit a34e334
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 42 deletions.
2 changes: 1 addition & 1 deletion src/server/test/configs/test_config_MultiCellpose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

"train":{
"segmentor":{
"n_epochs": 20,
"n_epochs": 30,
"channels": [0,0],
"min_train_masks": 1,
"learning_rate":0.01
Expand Down
2 changes: 1 addition & 1 deletion src/server/test/configs/test_config_UNet.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

"train":{
"classifier":{
"n_epochs": 20,
"n_epochs": 30,
"lr": 0.005,
"batch_size": 5,
"optimizer": "Adam"
Expand Down
20 changes: 10 additions & 10 deletions src/server/test/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,21 @@ def test_train_eval_run(data_train, data_eval, model):
"""
Performs testing, training, and evaluation with the provided data and model.
"""

# train
images, masks = data_train
if model.model_name == "CustomCellpose":
masks = [mask[0] for mask in masks]
model.train(images, masks)

# retrieve the attribute names of the class of the current model
attrs = model.__dict__.keys()

if "metric" in attrs:
assert model.metric > 0.1
if "loss" in attrs:
assert model.loss < 0.83

# validate
imgs_test, masks_test = data_eval
if model.model_name == "CustomCellpose":
masks = [mask[0] for mask in masks_test]
Expand Down Expand Up @@ -128,15 +137,6 @@ def test_train_eval_run(data_train, data_eval, model):
jaccard_index_instances /= len(imgs_test)
assert jaccard_index_instances > 0.2

# retrieve the attribute names of the class of the current model
attrs = model.__dict__.keys()

if "metric" in attrs:
assert model.metric > 0.1
if "loss" in attrs:
assert model.loss < 0.83

# for PatchCNN model
if pred_mask.ndim > 2:

jaccard_index_classes /= len(imgs_test)
Expand Down
30 changes: 0 additions & 30 deletions src/server/test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,33 +32,3 @@ def test_eval_rf_not_fitted():
# if we don't fit the model then the model returns zeros
assert np.all(model_rf.eval(X_test) == np.zeros(X_test.shape))


def test_update_configs():
"""
Tests the update of model training and evaluation configurations.
"""

model_config = read_config(
"model", config_path="test/configs/test_config_Inst2MultiSeg_RF.yaml"
)
data_config = read_config(
"data", config_path="test/configs/test_config_Inst2MultiSeg_RF.yaml"
)
train_config = read_config(
"train", config_path="test/configs/test_config_Inst2MultiSeg_RF.yaml"
)
eval_config = read_config(
"eval", config_path="test/configs/test_config_Inst2MultiSeg_RF.yaml"
)

model = models.CustomCellpose(
"Cellpose", model_config, data_config, train_config, eval_config
)

new_train_config = {"param1": "value1"}
new_eval_config = {"param2": "value2"}

model.update_configs(new_train_config, new_eval_config)

assert model.train_config == new_train_config
assert model.eval_config == new_eval_config

0 comments on commit a34e334

Please sign in to comment.