From 731a9166d615c6ae86d33460d0f061bcbd17794f Mon Sep 17 00:00:00 2001 From: Christina Bukas Date: Wed, 22 Nov 2023 15:59:12 +0100 Subject: [PATCH] fixed mask resizing to match original image size --- src/server/dcp_server/fsimagestorage.py | 11 ++++++++--- src/server/dcp_server/segmentationclasses.py | 13 ++++++------- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/src/server/dcp_server/fsimagestorage.py b/src/server/dcp_server/fsimagestorage.py index bec1b56..0d6a729 100644 --- a/src/server/dcp_server/fsimagestorage.py +++ b/src/server/dcp_server/fsimagestorage.py @@ -148,7 +148,7 @@ def rescale_image(self, img, height, width, channel_ax=None, order=2): rescale_factor = max_dim/512 return rescale(img, 1/rescale_factor, order=order, channel_axis=channel_ax) - def resize_image(self, img, height, width, order): + def resize_image(self, img, height, width, channel_ax=None, order=2): """resize image :param img: image @@ -161,8 +161,13 @@ def resize_image(self, img, height, width, order): :type order: int :return: resized image :rtype: ndarray - """ - return resize(img, (height, width), order=order) + """ + if channel_ax is not None: + n_channel_dim = img.shape[channel_ax] + output_size = [height, width] + output_size.insert(channel_ax, n_channel_dim) + else: output_size = [height, width] + return resize(img, output_size, order=order) def prepare_images_and_masks_for_training(self, train_img_mask_pairs): """Image and mask processing for training. diff --git a/src/server/dcp_server/segmentationclasses.py b/src/server/dcp_server/segmentationclasses.py index e1213d5..01b2c1d 100644 --- a/src/server/dcp_server/segmentationclasses.py +++ b/src/server/dcp_server/segmentationclasses.py @@ -37,18 +37,17 @@ async def segment_image(self, input_path, list_of_images): height, width, z_axis = self.imagestorage.get_image_size_properties(img, utils.get_file_extension(img_filepath)) img = self.imagestorage.rescale_image(img, height, - width, - order=None) + width) # Add channel ax into the model's evaluation parameters dictionary self.model.eval_config['segmentor']['z_axis'] = z_axis # Evaluate the model mask = await self.runner.evaluate.async_run(img = img) # Resize the mask - mask = self.imagestorage.rescale_image(mask, - height, - width, - self.model.eval_config['mask_channel_axis'], - order=0) + mask = self.imagestorage.resize_image(mask, + height, + width, + self.model.eval_config['mask_channel_axis'], + order=0) # Save segmentation seg_name = utils.get_path_stem(img_filepath) + setup_config['seg_name_string'] + '.tiff' self.imagestorage.save_image(os.path.join(input_path, seg_name), mask)