Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reload saved model #38

Merged
merged 12 commits into from
Dec 6, 2023
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ curated/
# model dir
*mytrainedmodel/

#configs
src/client/dcp_client/config.cfg

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down Expand Up @@ -152,4 +155,4 @@ docs/
test-napari.pub
data/
BentoML/
models/
models/
6 changes: 3 additions & 3 deletions src/server/dcp_server/config.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
},

"service": {
"model_to_use": "CellposePatchCNN",
"save_model_path": "mytrainedmodel",
"model_to_use": "CustomCellposeModel",
"save_model_path": "mito",
"runner_name": "cellpose_runner",
"service_name": "data-centric-platform",
"port": 7010
Expand All @@ -31,7 +31,7 @@

"train":{
"segmentor":{
"n_epochs": 7,
"n_epochs": 10,
"channels": [0,0],
"min_train_masks": 1
},
Expand Down
47 changes: 22 additions & 25 deletions src/server/dcp_server/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
#from segment_anything import SamPredictor, sam_model_registry
#from segment_anything.automatic_mask_generator import SamAutomaticMaskGenerator

from dcp_server.utils import get_centered_patches, find_max_patch_size, create_patch_dataset, get_objects
from dcp_server.utils import get_centered_patches, find_max_patch_size, create_patch_dataset

class CustomCellposeModel(models.CellposeModel):
class CustomCellposeModel(models.CellposeModel, nn.Module):
"""Custom cellpose model inheriting the attributes and functions from the original CellposeModel and implementing
additional attributes and methods needed for this project.
"""
Expand All @@ -32,7 +32,14 @@ def __init__(self, model_config, train_config, eval_config):
"""

# Initialize the cellpose model
super().__init__(**model_config["segmentor"])
#super().__init__(**model_config["segmentor"])
nn.Module.__init__(self)
models.CellposeModel.__init__(self, **model_config["segmentor"])
self.mkldnn = False # otherwise we get error with saving model
self.train_config = train_config
self.eval_config = eval_config

def update_configs(self, train_config, eval_config):
self.train_config = train_config
self.eval_config = eval_config

Expand Down Expand Up @@ -66,10 +73,7 @@ def train(self, imgs, masks):
super().train(train_data=deepcopy(imgs), train_labels=masks, **self.train_config["segmentor"])

pred_masks = [self.eval(img) for img in masks]
print(len(pred_masks))
self.metric = np.mean(aggregated_jaccard_index(masks, pred_masks))
# pred_masks = [self.eval(img) for img in masks]

# self.loss = self.loss_fn(masks, pred_masks)

def masks_to_outlines(self, mask):
Expand All @@ -87,7 +91,7 @@ def masks_to_outlines(self, mask):
class CellClassifierFCNN(nn.Module):

'''
Fully convolutional classifier for cell images.
Fully convolutional classifier for cell images. NOTE -> This model cannot be used as a standalone model in DCP

Args:
model_config (dict): Model configuration.
Expand Down Expand Up @@ -128,6 +132,10 @@ def __init__(self, model_config, train_config, eval_config):
self.final_conv = nn.Conv2d(128, self.num_classes, 1)
self.pooling = nn.AdaptiveMaxPool2d(1)

def update_configs(self, train_config, eval_config):
self.train_config = train_config
self.eval_config = eval_config

def forward(self, x):

x = self.layer1(x)
Expand Down Expand Up @@ -200,14 +208,15 @@ def eval(self, img):
return y_hat


class CellposePatchCNN():
class CellposePatchCNN(nn.Module):

"""
Cellpose & patches of cells and then cnn to classify each patch
"""

def __init__(self, model_config, train_config, eval_config):

super().__init__()

self.model_config = model_config
self.train_config = train_config
self.eval_config = eval_config
Expand All @@ -220,22 +229,10 @@ def __init__(self, model_config, train_config, eval_config):
self.train_config,
self.eval_config)

def init_from_checkpoints(self, chpt_classifier=None, chpt_segmentor=None):
"""
Initialize the model from pre-trained checkpoints.
"""

self.segmentor = CustomCellposeModel(
self.model_config.get("segmentor", {}),
self.train_config.get("segmentor", {}),
self.eval_config.get("segmentor", {})
)
self.classifier = CellClassifierFCNN(
self.model_config.get("classifier", {}),
self.train_config.get("classifier", {}),
self.eval_config.get("classifier", {})
)

def update_configs(self, train_config, eval_config):
self.train_config = train_config
self.eval_config = eval_config

def train(self, imgs, masks):
"""Trains the given model. First trains the segmentor and then the clasiffier.

Expand Down
35 changes: 15 additions & 20 deletions src/server/dcp_server/serviceclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from bentoml.io import Text, NumpyNdarray
from typing import List

from dcp_server import models as DCPModels


class CustomRunnable(bentoml.Runnable):
'''
Expand All @@ -16,13 +18,16 @@ class CustomRunnable(bentoml.Runnable):
def __init__(self, model, save_model_path):
"""Constructs all the necessary attributes for the CustomRunnable.

:param model: model to be trained or evaluated
:param model: model to be trained or evaluated - will be one of classes in models.py
:param save_model_path: full path of the model object that it will be saved into
:type save_model_path: str
"""

self.model = model
self.save_model_path = save_model_path
# load model if it already exists to continue training from there?
if self.save_model_path in [model.tag.name for model in bentoml.models.list()]:
self.model = bentoml.pytorch.load_model(self.save_model_path+":latest")

@bentoml.Runnable.method(batchable=False)
def evaluate(self, img: np.ndarray) -> np.ndarray:
Expand All @@ -34,8 +39,10 @@ def evaluate(self, img: np.ndarray) -> np.ndarray:
:type z_axis: int
:return: mask of the image, list of 2D arrays, or single 3D array (if do_3D=True) labelled image.
:rtype: np.ndarray
"""

"""
# load the latest model if it is available (in case train has already occured)
if self.save_model_path in [model.tag.name for model in bentoml.models.list()]:
self.model = bentoml.pytorch.load_model(self.save_model_path+":latest")
mask = self.model.eval(img=img)

return mask
Expand All @@ -51,24 +58,15 @@ def train(self, imgs: List[np.ndarray], masks: List[np.ndarray]) -> str:
:return: path of the saved model
:rtype: str
"""
#s1 = self.model.segmentor.net.state_dict()
#c1 = self.model.classifier.parameters()
self.model.train(imgs, masks)
'''
s2 = self.model.segmentor.net.state_dict()
c2 = self.model.classifier.parameters()
if s1 == s2: print('S1 and S2 COMP: THEY ARE THE SAME!!!!!')
else: print('S1 and S2 COMP: THEY ARE NOOOT THE SAME!!!!!')
for p1, p2 in zip(c1, c2):
if p1.data.ne(p2.data).sum() > 0:
print("C1 and C2 NOT THE SAME")
break
'''
# Save the bentoml model
bentoml.picklable_model.save_model(self.save_model_path, self.model)
#bentoml.picklable_model.save_model(self.save_model_path, self.model)
bentoml.pytorch.save_model(self.save_model_path, # Model name in the local Model Store
self.model, # Model instance being saved
external_modules=[DCPModels]
)

return self.save_model_path


class CustomBentoService():
"""BentoML Service class. Contains all the functions necessary to serve the service with BentoML
Expand Down Expand Up @@ -123,12 +121,9 @@ async def train(input_path):
:rtype: str
"""
print("Calling retrain from server.")

# Train the model
model_path = await self.segmentation.train(input_path)

msg = "Success! Trained model saved in: " + model_path

return msg

return svc
Expand Down
67 changes: 67 additions & 0 deletions src/server/test/test_config.cfg
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
{
"setup": {
"segmentation": "GeneralSegmentation",
"accepted_types": [".jpg", ".jpeg", ".png", ".tiff", ".tif"],
"seg_name_string": "_seg"
},

"service": {
"model_to_use": "CustomCellposeModel",
"save_model_path": "mito",
"runner_name": "cellpose_runner",
"service_name": "data-centric-platform",
"port": 7010
},

"model": {
"segmentor": {
"model_type": "cyto"
},
"classifier":{
"in_channels": 1,
"num_classes": 3,
"black_bg": "False",
"include_mask": "False"
}
},

"data": {
"data_root": "data"
},

"train":{
"segmentor":{
"n_epochs": 20,
"channels": [0,0],
"min_train_masks": 1,
"learning_rate":0.01
},
"classifier":{
"train_data":{
"patch_size": 64,
"noise_intensity": 5,
"num_classes": 3
},
"n_epochs": 8,
"lr": 0.001,
"batch_size": 1,
"optimizer": "Adam"
}
},

"eval":{
"segmentor": {
"z_axis": null,
"channel_axis": null,
"rescale": 1,
"batch_size": 1
},
"classifier": {
"data":{
"patch_size": 64,
"noise_intensity": 5
}
},
"mask_channel_axis": 0
}
}
Loading