From 7aa18eed35dcdbb05c0356739abea333cf56e7ff Mon Sep 17 00:00:00 2001 From: Christina Bukas Date: Fri, 8 Dec 2023 16:22:13 +0100 Subject: [PATCH] add assertion to make sure model class is the same after load --- src/server/dcp_server/serviceclasses.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/src/server/dcp_server/serviceclasses.py b/src/server/dcp_server/serviceclasses.py index 53f30cc..bc8375d 100644 --- a/src/server/dcp_server/serviceclasses.py +++ b/src/server/dcp_server/serviceclasses.py @@ -25,9 +25,8 @@ def __init__(self, model, save_model_path): self.model = model self.save_model_path = save_model_path - # load model if it already exists to continue training from there? - if self.save_model_path in [model.tag.name for model in bentoml.models.list()]: - self.model = bentoml.pytorch.load_model(self.save_model_path+":latest") + # update with the latest model if it already exists to continue training from there? + self.check_and_load_model() @bentoml.Runnable.method(batchable=False) def evaluate(self, img: np.ndarray) -> np.ndarray: @@ -40,12 +39,18 @@ def evaluate(self, img: np.ndarray) -> np.ndarray: :return: mask of the image, list of 2D arrays, or single 3D array (if do_3D=True) labelled image. :rtype: np.ndarray """ - # load the latest model if it is available (in case train has already occured) - if self.save_model_path in [model.tag.name for model in bentoml.models.list()]: - self.model = bentoml.pytorch.load_model(self.save_model_path+":latest") + # update with the latest model if it is available (in case train has already occured) + self.check_and_load_model() mask = self.model.eval(img=img) return mask + + def check_and_load_model(self): + bento_model_list = [model.tag.name for model in bentoml.models.list()] + if self.save_model_path in bento_model_list: + loaded_model = bentoml.pytorch.load_model(self.save_model_path+":latest") + assert loaded_model.__class__.__name__ == self.model.__class__.__name__, 'Check your config, loaded model and model to use not the same!' + self.model = loaded_model @bentoml.Runnable.method(batchable=False) def train(self, imgs: List[np.ndarray], masks: List[np.ndarray]) -> str: