diff --git a/src/server/dcp_server/models.py b/src/server/dcp_server/models.py index d4ee2be..71aefc1 100644 --- a/src/server/dcp_server/models.py +++ b/src/server/dcp_server/models.py @@ -68,14 +68,10 @@ def train(self, imgs, masks): if masks[0].shape[0] == 2: masks = list(masks[:,0,...]) - print(masks[0].shape) - super().train(train_data=deepcopy(imgs), train_labels=masks, **self.train_config["segmentor"]) pred_masks = [self.eval(img) for img in imgs] - print(np.unique(pred_masks[0])) self.metric = np.mean(aggregated_jaccard_index(masks, pred_masks)) - print(self.metric) # self.loss = self.loss_fn(masks, pred_masks) def masks_to_outlines(self, mask): @@ -347,8 +343,6 @@ def __init__(self, model_config, train_config, eval_config): self.train_config = train_config["unet"] self.eval_config = eval_config["unet"] - print(self.model_config) - self.in_channels = self.model_config["in_channels"] self.out_channels = self.model_config["out_channels"] self.features = self.model_config["features"] @@ -411,8 +405,6 @@ def train(self, imgs, masks): # convert to tensor imgs = torch.stack([torch.from_numpy(img.astype(np.float32)) for img in imgs]) imgs = imgs.unsqueeze(1) if imgs.ndim == 3 else imgs - # print(f"Imgs shapes: {imgs.shape}") - # imgs = torch.permute(imgs, (0, 3, 1, 2)) # Your classification label mask masks = np.array(masks) masks = torch.stack([torch.from_numpy(mask[1]) for mask in masks]) @@ -434,7 +426,6 @@ def train(self, imgs, masks): #forward path preds = self.forward(imgs) - print(preds.shape, masks.shape) loss = loss_fn(preds, masks) #backward path diff --git a/src/server/dcp_server/test_config.cfg b/src/server/dcp_server/test_config.cfg index 7034662..aed0007 100644 --- a/src/server/dcp_server/test_config.cfg +++ b/src/server/dcp_server/test_config.cfg @@ -36,7 +36,7 @@ "train":{ "segmentor":{ - "n_epochs": 15, + "n_epochs": 10, "channels": [0,0], "min_train_masks": 1, "batch_size": 2, @@ -66,6 +66,7 @@ "segmentor": { "z_axis": null, "channel_axis": null, + "rescale": 1, "batch_size": 1 }, "classifier": { diff --git a/src/server/test/test_integration.py b/src/server/test/test_integration.py index 080cce2..dc01944 100644 --- a/src/server/test/test_integration.py +++ b/src/server/test/test_integration.py @@ -43,7 +43,7 @@ def model(model_class): @pytest.fixture def data_train(): - print(model_classes) + images, masks = get_synthetic_dataset(num_samples=4) masks = [np.array(mask) for mask in masks] masks_instances = [mask.sum(-1) for mask in masks]