Skip to content

Commit

Permalink
Merge pull request #38 from HelmholtzAI-Consultants-Munich/reload-sav…
Browse files Browse the repository at this point in the history
…ed-model

Reload saved model
  • Loading branch information
christinab12 authored Dec 6, 2023
2 parents 979c175 + d61d6fc commit 0bd6c73
Show file tree
Hide file tree
Showing 6 changed files with 204 additions and 79 deletions.
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ curated/
# model dir
*mytrainedmodel/

#configs
src/client/dcp_client/config.cfg

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down Expand Up @@ -152,4 +155,4 @@ docs/
test-napari.pub
data/
BentoML/
models/
models/
6 changes: 3 additions & 3 deletions src/server/dcp_server/config.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
},

"service": {
"model_to_use": "CellposePatchCNN",
"save_model_path": "mytrainedmodel",
"model_to_use": "CustomCellposeModel",
"save_model_path": "mito",
"runner_name": "cellpose_runner",
"service_name": "data-centric-platform",
"port": 7010
Expand All @@ -31,7 +31,7 @@

"train":{
"segmentor":{
"n_epochs": 7,
"n_epochs": 10,
"channels": [0,0],
"min_train_masks": 1
},
Expand Down
47 changes: 22 additions & 25 deletions src/server/dcp_server/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
#from segment_anything import SamPredictor, sam_model_registry
#from segment_anything.automatic_mask_generator import SamAutomaticMaskGenerator

from dcp_server.utils import get_centered_patches, find_max_patch_size, create_patch_dataset, get_objects
from dcp_server.utils import get_centered_patches, find_max_patch_size, create_patch_dataset

class CustomCellposeModel(models.CellposeModel):
class CustomCellposeModel(models.CellposeModel, nn.Module):
"""Custom cellpose model inheriting the attributes and functions from the original CellposeModel and implementing
additional attributes and methods needed for this project.
"""
Expand All @@ -32,7 +32,14 @@ def __init__(self, model_config, train_config, eval_config):
"""

# Initialize the cellpose model
super().__init__(**model_config["segmentor"])
#super().__init__(**model_config["segmentor"])
nn.Module.__init__(self)
models.CellposeModel.__init__(self, **model_config["segmentor"])
self.mkldnn = False # otherwise we get error with saving model
self.train_config = train_config
self.eval_config = eval_config

def update_configs(self, train_config, eval_config):
self.train_config = train_config
self.eval_config = eval_config

Expand Down Expand Up @@ -66,10 +73,7 @@ def train(self, imgs, masks):
super().train(train_data=deepcopy(imgs), train_labels=masks, **self.train_config["segmentor"])

pred_masks = [self.eval(img) for img in masks]
print(len(pred_masks))
self.metric = np.mean(aggregated_jaccard_index(masks, pred_masks))
# pred_masks = [self.eval(img) for img in masks]

# self.loss = self.loss_fn(masks, pred_masks)

def masks_to_outlines(self, mask):
Expand All @@ -87,7 +91,7 @@ def masks_to_outlines(self, mask):
class CellClassifierFCNN(nn.Module):

'''
Fully convolutional classifier for cell images.
Fully convolutional classifier for cell images. NOTE -> This model cannot be used as a standalone model in DCP
Args:
model_config (dict): Model configuration.
Expand Down Expand Up @@ -128,6 +132,10 @@ def __init__(self, model_config, train_config, eval_config):
self.final_conv = nn.Conv2d(128, self.num_classes, 1)
self.pooling = nn.AdaptiveMaxPool2d(1)

def update_configs(self, train_config, eval_config):
self.train_config = train_config
self.eval_config = eval_config

def forward(self, x):

x = self.layer1(x)
Expand Down Expand Up @@ -200,14 +208,15 @@ def eval(self, img):
return y_hat


class CellposePatchCNN():
class CellposePatchCNN(nn.Module):

"""
Cellpose & patches of cells and then cnn to classify each patch
"""

def __init__(self, model_config, train_config, eval_config):

super().__init__()

self.model_config = model_config
self.train_config = train_config
self.eval_config = eval_config
Expand All @@ -220,22 +229,10 @@ def __init__(self, model_config, train_config, eval_config):
self.train_config,
self.eval_config)

def init_from_checkpoints(self, chpt_classifier=None, chpt_segmentor=None):
"""
Initialize the model from pre-trained checkpoints.
"""

self.segmentor = CustomCellposeModel(
self.model_config.get("segmentor", {}),
self.train_config.get("segmentor", {}),
self.eval_config.get("segmentor", {})
)
self.classifier = CellClassifierFCNN(
self.model_config.get("classifier", {}),
self.train_config.get("classifier", {}),
self.eval_config.get("classifier", {})
)

def update_configs(self, train_config, eval_config):
self.train_config = train_config
self.eval_config = eval_config

def train(self, imgs, masks):
"""Trains the given model. First trains the segmentor and then the clasiffier.
Expand Down
35 changes: 15 additions & 20 deletions src/server/dcp_server/serviceclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from bentoml.io import Text, NumpyNdarray
from typing import List

from dcp_server import models as DCPModels


class CustomRunnable(bentoml.Runnable):
'''
Expand All @@ -16,13 +18,16 @@ class CustomRunnable(bentoml.Runnable):
def __init__(self, model, save_model_path):
"""Constructs all the necessary attributes for the CustomRunnable.
:param model: model to be trained or evaluated
:param model: model to be trained or evaluated - will be one of classes in models.py
:param save_model_path: full path of the model object that it will be saved into
:type save_model_path: str
"""

self.model = model
self.save_model_path = save_model_path
# load model if it already exists to continue training from there?
if self.save_model_path in [model.tag.name for model in bentoml.models.list()]:
self.model = bentoml.pytorch.load_model(self.save_model_path+":latest")

@bentoml.Runnable.method(batchable=False)
def evaluate(self, img: np.ndarray) -> np.ndarray:
Expand All @@ -34,8 +39,10 @@ def evaluate(self, img: np.ndarray) -> np.ndarray:
:type z_axis: int
:return: mask of the image, list of 2D arrays, or single 3D array (if do_3D=True) labelled image.
:rtype: np.ndarray
"""

"""
# load the latest model if it is available (in case train has already occured)
if self.save_model_path in [model.tag.name for model in bentoml.models.list()]:
self.model = bentoml.pytorch.load_model(self.save_model_path+":latest")
mask = self.model.eval(img=img)

return mask
Expand All @@ -51,24 +58,15 @@ def train(self, imgs: List[np.ndarray], masks: List[np.ndarray]) -> str:
:return: path of the saved model
:rtype: str
"""
#s1 = self.model.segmentor.net.state_dict()
#c1 = self.model.classifier.parameters()
self.model.train(imgs, masks)
'''
s2 = self.model.segmentor.net.state_dict()
c2 = self.model.classifier.parameters()
if s1 == s2: print('S1 and S2 COMP: THEY ARE THE SAME!!!!!')
else: print('S1 and S2 COMP: THEY ARE NOOOT THE SAME!!!!!')
for p1, p2 in zip(c1, c2):
if p1.data.ne(p2.data).sum() > 0:
print("C1 and C2 NOT THE SAME")
break
'''
# Save the bentoml model
bentoml.picklable_model.save_model(self.save_model_path, self.model)
#bentoml.picklable_model.save_model(self.save_model_path, self.model)
bentoml.pytorch.save_model(self.save_model_path, # Model name in the local Model Store
self.model, # Model instance being saved
external_modules=[DCPModels]
)

return self.save_model_path


class CustomBentoService():
"""BentoML Service class. Contains all the functions necessary to serve the service with BentoML
Expand Down Expand Up @@ -123,12 +121,9 @@ async def train(input_path):
:rtype: str
"""
print("Calling retrain from server.")

# Train the model
model_path = await self.segmentation.train(input_path)

msg = "Success! Trained model saved in: " + model_path

return msg

return svc
Expand Down
67 changes: 67 additions & 0 deletions src/server/test/test_config.cfg
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
{
"setup": {
"segmentation": "GeneralSegmentation",
"accepted_types": [".jpg", ".jpeg", ".png", ".tiff", ".tif"],
"seg_name_string": "_seg"
},

"service": {
"model_to_use": "CustomCellposeModel",
"save_model_path": "mito",
"runner_name": "cellpose_runner",
"service_name": "data-centric-platform",
"port": 7010
},

"model": {
"segmentor": {
"model_type": "cyto"
},
"classifier":{
"in_channels": 1,
"num_classes": 3,
"black_bg": "False",
"include_mask": "False"
}
},

"data": {
"data_root": "data"
},

"train":{
"segmentor":{
"n_epochs": 20,
"channels": [0,0],
"min_train_masks": 1,
"learning_rate":0.01
},
"classifier":{
"train_data":{
"patch_size": 64,
"noise_intensity": 5,
"num_classes": 3
},
"n_epochs": 8,
"lr": 0.001,
"batch_size": 1,
"optimizer": "Adam"
}
},

"eval":{
"segmentor": {
"z_axis": null,
"channel_axis": null,
"rescale": 1,
"batch_size": 1
},
"classifier": {
"data":{
"patch_size": 64,
"noise_intensity": 5
}
},
"mask_channel_axis": 0
}
}
Loading

0 comments on commit 0bd6c73

Please sign in to comment.