Skip to content

Commit

Permalink
Merge pull request #36 from HelmholtzAI-Consultants-Munich/fix-rescaling
Browse files Browse the repository at this point in the history
fixed mask resizing to match original image size
  • Loading branch information
christinab12 authored Nov 22, 2023
2 parents 6ccb734 + 731a916 commit df2d0c0
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 10 deletions.
11 changes: 8 additions & 3 deletions src/server/dcp_server/fsimagestorage.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def rescale_image(self, img, height, width, channel_ax=None, order=2):
rescale_factor = max_dim/512
return rescale(img, 1/rescale_factor, order=order, channel_axis=channel_ax)

def resize_image(self, img, height, width, order):
def resize_image(self, img, height, width, channel_ax=None, order=2):
"""resize image
:param img: image
Expand All @@ -161,8 +161,13 @@ def resize_image(self, img, height, width, order):
:type order: int
:return: resized image
:rtype: ndarray
"""
return resize(img, (height, width), order=order)
"""
if channel_ax is not None:
n_channel_dim = img.shape[channel_ax]
output_size = [height, width]
output_size.insert(channel_ax, n_channel_dim)
else: output_size = [height, width]
return resize(img, output_size, order=order)

def prepare_images_and_masks_for_training(self, train_img_mask_pairs):
"""Image and mask processing for training.
Expand Down
13 changes: 6 additions & 7 deletions src/server/dcp_server/segmentationclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,18 +37,17 @@ async def segment_image(self, input_path, list_of_images):
height, width, z_axis = self.imagestorage.get_image_size_properties(img, utils.get_file_extension(img_filepath))
img = self.imagestorage.rescale_image(img,
height,
width,
order=None)
width)
# Add channel ax into the model's evaluation parameters dictionary
self.model.eval_config['segmentor']['z_axis'] = z_axis
# Evaluate the model
mask = await self.runner.evaluate.async_run(img = img)
# Resize the mask
mask = self.imagestorage.rescale_image(mask,
height,
width,
self.model.eval_config['mask_channel_axis'],
order=0)
mask = self.imagestorage.resize_image(mask,
height,
width,
self.model.eval_config['mask_channel_axis'],
order=0)
# Save segmentation
seg_name = utils.get_path_stem(img_filepath) + setup_config['seg_name_string'] + '.tiff'
self.imagestorage.save_image(os.path.join(input_path, seg_name), mask)
Expand Down

0 comments on commit df2d0c0

Please sign in to comment.